PyTorch RuntimeError: expected scalar type Float but found Double
See original GitHub issueDescription
I am using the causal world env to train a policy. In their paper they have evaluated PPO, SAC and TD3 using tf stablebaselines, suggesting that the env works with sb api.
As far as I understand the env uses a Box action space which is supported by SAC and PPO in sb3
A PPO policy with sb3 starts training without any errors.
But when I tried to train a SAC policy (with sb3 pytorch), I ran into the following error:
RuntimeError: expected scalar type Float but found Double
Code example
import os
import argparse
import gym
import numpy as np
from causal_world.task_generators import generate_task
from causal_world.envs import CausalWorld
from stable_baselines3 import PPO, SAC
from stable_baselines3.common.env_checker import check_env
def train(args):
#env = gym.make("BipedalWalker-v3")
task = generate_task(task_generator_id='picking')
env = CausalWorld(task=task)
###### sb3 check env
check_env(env)
if args.rl_algo == 'ppo':
model = PPO("MlpPolicy", env, verbose=1)
if args.rl_algo == 'sac':
model = SAC("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--rl_algo',
type=str,
default='sac',
help='rl algorihtm sac | ppo',
required=False)
args = parser.parse_args()
train(args)
stack trace:
pybullet build time: Aug 3 2020 20:48:16
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Traceback (most recent call last):
File "scripts/test_sb.py", line 47, in <module>
train(args)
File "scripts/test_sb.py", line 32, in train
model.learn(total_timesteps=10000, log_interval=4)
File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 300, in learn
reset_num_timesteps=reset_num_timesteps,
File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 371, in learn
self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 241, in train
current_q_values = self.critic(replay_data.observations, replay_data.actions)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/policies.py", line 884, in forward
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/policies.py", line 884, in <genexpr>
return tuple(q_net(qvalue_input) for q_net in self.q_networks)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 94, in forward
return F.linear(input, self.weight, self.bias)
File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Double
System Info
Describe the characteristic of your environment:
- all libraries installed with pip in a virtual environment (cw_data_venv)
- OS: Linux-4.15.0-65-generic-x86_64-with-debian-buster-sid
- Python: 3.7.6
- Stable-Baselines3: 1.3.0
- PyTorch: 1.8.1+cu111
- GPU Enabled: False
- Numpy: 1.21.4
- Gym: 0.19.0
Additional context
The error persists on a gpu machine
It seems error is originating from the SAC policy when the current_q_values
are calculated with a forward pass through the critic.
It could potentially be solved by specifying dtype
in the to_torch function in replay buffer
Checklist
- I have read the documentation (required)
- I have checked that there is no similar issue in the repo (required)
- I have checked my env using the env checker (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (1 by maintainers)
Top GitHub Comments
After further discussion with @araffin, it might be better to delegate this issue to the users by including a specific check in
env_checker
where it checks that all float types arefloat32
. Reasoning for this is that transforming float64 -> float32 would have to be done all around the code which gets pretty messy.Okay In that case, include the fp32 precision specification in the documentation for
custom_gym_env