[Bug] Training a loaded model fails : If capturable=False, state_steps should not be CUDA tensors.
See original GitHub issue🐛 Bug
When (and only) using cuda, training a loaded model fails.
To Reproduce
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import DQN
import gym
device = "cuda" # no bug if cpu
env = DummyVecEnv([lambda: gym.make("MountainCar-v0")])
model = DQN("MlpPolicy", env, learning_starts=50, device=device)
model.learn(total_timesteps=64)
model.save("model.zip")
del model
model = DQN.load("model.zip", env=env, device=device)
model.learn(total_timesteps=64)
Traceback (most recent call last):
File "../error.py", line 15, in <module>
model.learn(total_timesteps=64)
File "/home/qgallouedec/stable-baselines3/stable_baselines3/dqn/dqn.py", line 264, in learn
return super().learn(
File "/home/qgallouedec/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 346, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/qgallouedec/stable-baselines3/stable_baselines3/dqn/dqn.py", line 213, in train
self.policy.optimizer.step()
File "/home/qgallouedec/stable-baselines3/env/lib/python3.8/site-packages/torch/optim/optimizer.py", line 109, in wrapper
return func(*args, **kwargs)
File "/home/qgallouedec/stable-baselines3/env/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/qgallouedec/stable-baselines3/env/lib/python3.8/site-packages/torch/optim/adam.py", line 157, in step
adam(params_with_grad,
File "/home/qgallouedec/stable-baselines3/env/lib/python3.8/site-packages/torch/optim/adam.py", line 213, in adam
func(params,
File "/home/qgallouedec/stable-baselines3/env/lib/python3.8/site-packages/torch/optim/adam.py", line 255, in _single_tensor_adam
assert not step_t.is_cuda, "If capturable=False, state_steps should not be CUDA tensors."
AssertionError: If capturable=False, state_steps should not be CUDA tensors.
### System Info
Describe the characteristic of your environment:
- OS: Linux-5.15.0-41-generic-x86_64-with-glibc2.29 # 44~20.04.1-Ubuntu SMP Fri Jun 24 13:27:29 UTC 2022
- Python: 3.8.10
- Stable-Baselines3: 1.5.1a9
- PyTorch: 1.12.0+cu102
- GPU Enabled: True
- Numpy: 1.22.2
- Gym: 0.21.0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created a year ago
- Comments:8 (2 by maintainers)
Top Results From Across the Web
assert not step_t.is_cuda, "If capturable=False, state_steps ...
[Bug] Training a loaded model fails : If capturable=False, state_steps should not be CUDA tensors. DLR-RM/stable-baselines3#967.
Read more >If capturable=False, state_steps should not be CUDA tensors ...
It seems related to a newly introduced parameter (capturable) for the Adam and AdamW optimizers. Currently two workarounds:.
Read more >CUDA semantics — PyTorch 1.13 documentation
CUDA semantics. torch.cuda is used to set up and run CUDA operations. It keeps track of the currently selected GPU, and all CUDA...
Read more >If capturable=False, state_steps should not be CUDA tensors.
If capturable=False, state_steps should not be CUDA tensors. If capturable=False, state_steps should not be CUDA tensors. 复制链接. 扫一扫.
Read more >Train and serve a TensorFlow model with TensorFlow Serving
This guide trains a neural network model to classify images of clothing, like sneakers and shirts, saves the trained model, and then serves ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Seems to be a torch specific issue that should be solved in torch 1.12.1: https://github.com/pytorch/pytorch/issues/80809#issuecomment-1175211598
I suggest we keep this issue open in the meantime.
Hmm peculiar; this does not happen on my Windows 10, but I am running PyTorch 1.11. I wonder if it is 1.12 breaking things. Currently I do not have bandwidth to download torches ^^'.