question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Provide tiny wrapper over pytorch ThroughputBenchmark

See original GitHub issue

🚀 Feature

PyTorch utils module provides ThroughputBenchmark since 1.2.0

 >>> from torch.utils import ThroughputBenchmark
>>> bench = ThroughputBenchmark(my_module)
>>> # Pre-populate benchmark's data set with the inputs
>>> for input in inputs:
    # Both args and kwargs work, same as any PyTorch Module / ScriptModule
    bench.add_input(input[0], x2=input[1])
>>> # Inputs supplied above are randomly used during the execution
>>> stats = bench.benchmark(
                num_calling_threads=4,
                num_warmup_iters = 100,
                num_iters = 1000,
            )
>>> print("Avg latency (ms): {}".format(stats.latency_avg_ms))
>>> print("Number of iterations: {}".format(stats.num_iters))

It would be interesting to provide a tiny wrapper over this to simplify usage with ignite.

Issue Analytics

  • State:open
  • Created 4 years ago
  • Comments:11 (9 by maintainers)

github_iconTop GitHub Comments

2reactions
kai-tubcommented, Mar 24, 2020

I’ve used a custom event_filter (mostly because I just want to get more familiar with the code)

import torch
from torch.utils import ThroughputBenchmark
from ignite.engine import Events, Engine
from typing import Iterable, Union, Callable
import contextlib


class ThroughputBenchmarkWrapper:
    def __init__(
        self, 
        num_calling_threads: int = 1,
        num_warmup_iters: int = 10,
        num_iters: int = 100,
    ):
        self._bench = ThroughputBenchmark(model)
        self._num_calling_threads = num_calling_threads
        self._num_warmup_iters = num_warmup_iters
        self._num_iters = num_iters
        self._stats = None

    def _batch_logger(self, engine: Engine, input_transform: Callable):
        input_data = input_transform(engine.state.batch)
        self._bench.add_input(input_data)

    def _run(self, engine: Engine):
        self._stats = self._bench.benchmark(
            num_calling_threads=self._num_calling_threads, 
            num_warmup_iters=self._num_warmup_iters, 
            num_iters=self._num_iters,
        )

    def _detach(self, engine: Engine):
        if engine.has_event_handler(self._batch_logger, Events.ITERATION_STARTED):
            engine.remove_event_handler(self._batch_logger, Events.ITERATION_STARTED)
        if engine.has_event_handler(self._run, Events.COMPLETED):
            engine.remove_event_handler(self._run, Events.COMPLETED)

    @contextlib.contextmanager
    def attach(self, engine: Engine, max_batches: int = 10, input_transform: Callable = lambda input_batch: input_batch[0]):
        def under_max_batches(engine: Engine, event: Events):
            # Events start with 1
            if event <= max_batches:
                return True
            return False
    
        if not engine.has_event_handler(self._run):
            engine.add_event_handler(
                Events.ITERATION_STARTED(event_filter=under_max_batches), 
                self._batch_logger, 
                input_transform, 
            )
            engine.add_event_handler(
                Events.COMPLETED, self._run,
            )

        yield engine
        self._detach(engine)

    @property
    def stats(self):
        if self._stats is None:
            raise RuntimeError(
                "Benchmark wrapper hasn't run yet so results can't be retrieved."
            )
        return self._state

I hope it is fine, that I am not including the docstrings and the input tests here, to keep the output short and because we are still discussing the design. (These are my first contributions to OS projects, sorry for asking trivial questions sometimes)

1reaction
kai-tubcommented, Mar 22, 2020

Hi, I would like to contribute. 😃 But the documentation of throughput_benchmark is rather short. For example, I can’t figure out, what the point of x2=input[1] is. Looking at the documentation it doesn’t seem like anything is done with this value.

Looking at the C/C++ binding (which I have very little experience with) https://github.com/pytorch/pytorch/blob/master/torch/csrc/utils/throughput_benchmark.cpp#L108

I don’t quite follow if x2 is used or not. In would assume that the label of the data point is just discarded in the benchmark?

But aside from the internal working, how would you like to structure the wrapper? What should the ignite wrapper provide for value? Should it automatically move the model to a given device similar to create_supervised_*? Should it optionally create a JIT trace? Or should it just attach to an engine with mostly the same code?

Thanks, Kai

Read more comments on GitHub >

github_iconTop Results From Across the Web

fidelity/stoke: A lightweight wrapper for PyTorch that ... - GitHub
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, ...
Read more >
PyTorch Benchmark
This recipe provides a quick-start guide to using PyTorch benchmark module to measure and compare code performance. Introduction. Benchmarking is an important ...
Read more >
Lightweight PyTorch Wrapper For ML Researchers - YouTube
PyTorch Lightning is a lightweight PyTorch wrapper that helps you scale your models and write less boilerplate code. In this Tutorial we ...
Read more >
PyTorch vs TensorFlow — spotting the difference
However, PyTorch is not a simple set of wrappers to support popular ... that your model is behind a brick wall with several...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found