question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

PyTorch RuntimeError: expected scalar type Float but found Double

See original GitHub issue

Description

I am using the causal world env to train a policy. In their paper they have evaluated PPO, SAC and TD3 using tf stablebaselines, suggesting that the env works with sb api.

As far as I understand the env uses a Box action space which is supported by SAC and PPO in sb3

A PPO policy with sb3 starts training without any errors.

But when I tried to train a SAC policy (with sb3 pytorch), I ran into the following error: RuntimeError: expected scalar type Float but found Double

Code example

import os
import argparse
import gym
import numpy as np
from causal_world.task_generators import generate_task
from causal_world.envs import CausalWorld
from stable_baselines3 import PPO, SAC
from stable_baselines3.common.env_checker import check_env

def train(args):
	
	#env = gym.make("BipedalWalker-v3")
	
	task = generate_task(task_generator_id='picking')
	env = CausalWorld(task=task)
	
	###### sb3 check env
	check_env(env) 
	
	if args.rl_algo == 'ppo':
		model = PPO("MlpPolicy", env, verbose=1)    
	if args.rl_algo == 'sac':	
		model = SAC("MlpPolicy", env, verbose=1)

	model.learn(total_timesteps=10000, log_interval=4)


if __name__ == "__main__":
	parser = argparse.ArgumentParser()
	parser.add_argument('--rl_algo',
		type=str,
		default='sac',
		help='rl algorihtm sac | ppo',
		required=False)							
	args = parser.parse_args()				
	train(args)

stack trace:

pybullet build time: Aug  3 2020 20:48:16
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Traceback (most recent call last):
  File "scripts/test_sb.py", line 47, in <module>
    train(args)
  File "scripts/test_sb.py", line 32, in train
    model.learn(total_timesteps=10000, log_interval=4)
  File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 300, in learn
    reset_num_timesteps=reset_num_timesteps,
  File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/off_policy_algorithm.py", line 371, in learn
    self.train(batch_size=self.batch_size, gradient_steps=gradient_steps)
  File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/sac/sac.py", line 241, in train
    current_q_values = self.critic(replay_data.observations, replay_data.actions)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/policies.py", line 884, in forward
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/home/cw_data_venv/lib/python3.7/site-packages/stable_baselines3/common/policies.py", line 884, in <genexpr>
    return tuple(q_net(qvalue_input) for q_net in self.q_networks)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 94, in forward
    return F.linear(input, self.weight, self.bias)
  File "/cvmfs/ai.mila.quebec/apps/x86_64/debian/pytorch/python3.7-cuda11.1-cudnn8.0-v1.8.1/lib/python3.7/site-packages/torch/nn/functional.py", line 1753, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: expected scalar type Float but found Double


System Info

Describe the characteristic of your environment:

  • all libraries installed with pip in a virtual environment (cw_data_venv)
  • OS: Linux-4.15.0-65-generic-x86_64-with-debian-buster-sid
  • Python: 3.7.6
  • Stable-Baselines3: 1.3.0
  • PyTorch: 1.8.1+cu111
  • GPU Enabled: False
  • Numpy: 1.21.4
  • Gym: 0.19.0

Additional context

The error persists on a gpu machine

It seems error is originating from the SAC policy when the current_q_values are calculated with a forward pass through the critic. It could potentially be solved by specifying dtype in the to_torch function in replay buffer

Checklist

  • I have read the documentation (required)
  • I have checked that there is no similar issue in the repo (required)
  • I have checked my env using the env checker (required)
  • I have provided a minimal working example to reproduce the bug (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
Miffylicommented, Nov 25, 2021

After further discussion with @araffin, it might be better to delegate this issue to the users by including a specific check in env_checker where it checks that all float types are float32. Reasoning for this is that transforming float64 -> float32 would have to be done all around the code which gets pretty messy.

1reaction
nikhilbarhate99commented, Nov 25, 2021

Okay In that case, include the fp32 precision specification in the documentation for custom_gym_env

Read more comments on GitHub >

github_iconTop Results From Across the Web

RuntimeError: Expected object of scalar type Double but got ...
How to solve "RuntimeError: expected scalar type Double but found Float" ... A fix would be to call .double() for convert to 64bit...
Read more >
RuntimeError: Expected object of scalar type Float but got ...
When the error is RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1' , you would...
Read more >
RuntimeError: expected scalar type Double but found Float #998
My function has two inputs and two outputs. import os import torch import numpy as np tkwargs = { "dtype": torch.double, "device": torch.device( ......
Read more >
expected scalar type double but found float - You.com
RuntimeError : expected scalar type Float but found Double ... In your script you are explicitly casting the input data to .double ()...
Read more >
expected scalar type Half but found Float" when using fp16
How to fix "RuntimeError: expected scalar type Half but found Float" when using fp16 #10 ... Hey @TessaCoil,. Thanks for the fix here!...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found