question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feature] [RLlib] Pytorch JIT as framework

See original GitHub issue

Search before asking

  • I had searched in the issues and found no similar feature requirement.

Description

According to Sven, the torch framework is roughly 50% slower than the tf framework. We should look into torch.jit to see if we can get any speedups.

We have a couple different options: torch.jit.trace: This traces a model using dummy tensors, but it can miss unused portions of the code path (i.e. conditionals). I believe the surface for errors here is large, and believe we should stay away from it.

torch.jit.script: This does static analysis on the python source and compiles it into TorchScript. This is likely the way forward, but to do it correctly, it requires all class members to be TorchScript compatible which is not feasible given the current TorchModelV2 (e.g. any code utilizing numpy arrays is not compilable).

torch.jit.script submodules: Rather than JIT the entire TorchModelV2, we can jit all torch.nn.Module children of TorchModelV2. This allows us to use the JIT without too many changes. The downside is we lose a few optimizations between distinct networks (e.g. with MLP -> Value net, we cannot fuse the last MLP operation with the first value netw operation). I think this is the best option, because the number of optimizations we lose is very small.

torch.jit.script submodules with TensorRT: Like above, but utilize TensorRT from NVidia to further lower TorchScript into native TensorRT CUDA code. This would presumably only work on NVidia GPUs with Tensor cores, and would require using reduced-precision (fp16) floats, but could provide significant speedups for the env sampling workers: https://developer.nvidia.com/blog/accelerating-inference-up-to-6x-faster-in-pytorch-with-torch-tensorrt/

torch.jit.script submodules with CudaGraphs: Like above, but utilize torch to construct the graph on a CUDA device by capturing the CUDA stream. Rather than call into the model normally, we can load new arguments into GPU buffers and replay the stream with the new arguments: https://pytorch.org/docs/master/notes/cuda.html#cuda-graphs

torch.jit.script submodules with torch.AMP: Like above, but automatically cast all tensors to fp16. This utilizes CUDA tensor cores for a 4x speedup (and some CPU equivalent for a less-modest speedup), but requires scaling gradients using GradScaler https://pytorch.org/docs/stable/amp.html.

Use case

torch.jit the default models.

Related issues

No response

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:13 (11 by maintainers)

github_iconTop GitHub Comments

1reaction
gjolivercommented, Mar 28, 2022

ah right. looking at everything under /sampler_perf/, mean_inference_ms is still the dominating component, although it only takes about 1.7ms, and no big difference between jit and non-jit version. try a really big network? the first step probably should be to observe some performance gain for the inference step alone.

0reactions
smoradcommented, Dec 5, 2022

With the new torch.compile present in the upcoming torch 2.0, I think we can close this.

Read more comments on GitHub >

github_iconTop Results From Across the Web

ray.rllib.policy.torch_policy — Ray 2.2.0 - the Ray documentation
[docs]@DeveloperAPI class TorchPolicy(Policy): """PyTorch specific Policy class to use with RLlib ... self.config["model"], framework=self.framework ) model ...
Read more >
Deep Learning Framework (tf vs torch) Utilities — Ray 2.2.0
Tuple consisting of the torch- AND torch.nn modules. Raises. ImportError – If error=True and PyTorch is not installed. ray.rllib ...
Read more >
ray.rllib.policy.policy — Ray 2.2.0 - the Ray documentation
Policy is the abstract superclass for all DL-framework specific ... Signal Policy that currently we do not like to eager/jit trace # any...
Read more >
Models, Preprocessors, and Action Distributions — Ray 2.2.0
Example: # Use None for making RLlib try to find a default filter setup given the # observation space. "conv_filters": None, # Activation...
Read more >
TorchScript — PyTorch 1.13 documentation
A wrapper around C++ torch::jit::Module . ScriptFunction. Functionally equivalent to a ScriptModule , but represents a single function and does ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found