[Feature] [RLlib] Pytorch JIT as framework
See original GitHub issueSearch before asking
- I had searched in the issues and found no similar feature requirement.
Description
According to Sven, the torch framework is roughly 50% slower than the tf framework. We should look into torch.jit
to see if we can get any speedups.
We have a couple different options:
torch.jit.trace
: This traces a model using dummy tensors, but it can miss unused portions of the code path (i.e. conditionals). I believe the surface for errors here is large, and believe we should stay away from it.
torch.jit.script
: This does static analysis on the python source and compiles it into TorchScript
. This is likely the way forward, but to do it correctly, it requires all class members to be TorchScript
compatible which is not feasible given the current TorchModelV2
(e.g. any code utilizing numpy
arrays is not compilable).
torch.jit.script
submodules: Rather than JIT the entire TorchModelV2
, we can jit all torch.nn.Module
children of TorchModelV2
. This allows us to use the JIT without too many changes. The downside is we lose a few optimizations between distinct networks (e.g. with MLP -> Value net, we cannot fuse the last MLP operation with the first value netw operation). I think this is the best option, because the number of optimizations we lose is very small.
torch.jit.script
submodules with TensorRT
: Like above, but utilize TensorRT
from NVidia to further lower TorchScript into native TensorRT CUDA code. This would presumably only work on NVidia GPUs with Tensor cores, and would require using reduced-precision (fp16
) floats, but could provide significant speedups for the env sampling workers: https://developer.nvidia.com/blog/accelerating-inference-up-to-6x-faster-in-pytorch-with-torch-tensorrt/
torch.jit.script
submodules with CudaGraphs
: Like above, but utilize torch
to construct the graph
on a CUDA device by capturing the CUDA stream. Rather than call into the model normally, we can load new arguments into GPU buffers and replay the stream with the new arguments: https://pytorch.org/docs/master/notes/cuda.html#cuda-graphs
torch.jit.script
submodules with torch.AMP
: Like above, but automatically cast all tensors to fp16
. This utilizes CUDA tensor cores for a 4x speedup (and some CPU equivalent for a less-modest speedup), but requires scaling gradients using GradScaler
https://pytorch.org/docs/stable/amp.html.
Use case
torch.jit
the default models.
Related issues
No response
Are you willing to submit a PR?
- Yes I am willing to submit a PR!
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (11 by maintainers)
Top GitHub Comments
ah right. looking at everything under /sampler_perf/, mean_inference_ms is still the dominating component, although it only takes about 1.7ms, and no big difference between jit and non-jit version. try a really big network? the first step probably should be to observe some performance gain for the inference step alone.
With the new
torch.compile
present in the upcoming torch 2.0, I think we can close this.