BasePolicy does not place all sub-modules on the requested device
See original GitHub issueDescribe the bug
When BasePolicy.__init__()
is supplied with a device=
keyword argument, it correctly places its mlp_extractor
on the given device, but does not place other modules—like action_net
, value_net
, etc.—on the same device. This leads to exceptions when trying to use the network later on.
Code example
Minimal example, which should work with a GPU-equipped machine:
from stable_baselines3.common import policies, utils
import gym
env = gym.make('CartPole-v1')
policy = policies.ActorCriticPolicy(
env.observation_space, env.action_space, lr_schedule=utils.get_schedule_fn(1e-3), device='cuda:0')
policy.predict(env.reset())
Output:
RuntimeError Traceback (most recent call last)
<ipython-input-1-6d099de2364b> in <module>
4 policy = policies.ActorCriticPolicy(
5 env.observation_space, env.action_space, lr_schedule=utils.get_schedule_fn(1e-3), device='cuda:0')
----> 6 policy.predict(env.reset())
~/repos/stable-baselines3/stable_baselines3/common/policies.py in predict(self, observation, state, mask, deterministic)
237 observation = th.as_tensor(observation).to(self.device)
238 with th.no_grad():
--> 239 actions = self._predict(observation, deterministic=deterministic)
240 # Convert to numpy
241 actions = actions.cpu().numpy()
~/repos/stable-baselines3/stable_baselines3/common/policies.py in _predict(self, observation, deterministic)
549 """
550 latent_pi, _, latent_sde = self._get_latent(observation)
--> 551 distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
552 return distribution.get_actions(deterministic=deterministic)
553
~/repos/stable-baselines3/stable_baselines3/common/policies.py in _get_action_dist_from_latent(self, latent_pi, latent_sde)
522 :return: (Distribution) Action distribution
523 """
--> 524 mean_actions = self.action_net(latent_pi)
525
526 if isinstance(self.action_dist, DiagGaussianDistribution):
~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
85
86 def forward(self, input):
---> 87 return F.linear(input, self.weight, self.bias)
88
89 def extra_repr(self):
~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1608 if input.dim() == 2 and bias is not None:
1609 # fused op is marginally faster
-> 1610 ret = torch.addmm(bias, input, weight.t())
1611 else:
1612 output = input.matmul(weight.t())
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_addmm
System Info
Relevant details:
- SB3 version: v0.8.0a3 via pip
- Torch version: 1.5.1 via pip
Additional context
The lowest-breakage way to fix this would be to simply add .to(self.device)
calls after constructing each module within BasePolicy
. However, it may be more elegant to remove the device
keyword argument from BasePolicy
entirely. Instead, the user could be responsible for device placement by calling policy.to(device)
after policy construction. This is how most modules in torch.nn
work. If the device identity is needed in predict()
, then it could be inferred from policy.parameters()
.
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (7 by maintainers)
Top GitHub Comments
To clarify, I’m suggesting that policies should be instantiated like this:
By “user”, I meant “whatever code is responsible for instantiating the policy”. I don’t have any problem with algorithms performing this conversion for actual end-users of Stable Baselines algorithms! (e.g. by inferring device from
get_device()
, or however it works at the moment)I have two concrete objections to the current strategy: (1) the current strategy is prone to errors if someone forgets to place some sub-modules on the right device (as is currently the case), and (2) the current strategy doesn’t keep
.device
in sync with.to(device)
calls. It’s possible to fix both of those problems with something like this:This is better, but still not perfect, since it might unexpectedly place a tensor on the
auto
device when the caller was expecting it to stay where it is. For example, if we added this to Torch’sReLU
module (which has no parameters), then it would start silently moving all of its arguments to the GPU on GPU-equipped machines.I think the most idiomatic approach (which doesn’t have any of these problems) would be to make the calling code responsible for placing both models and tensors on the same device before doing forward-propagation. The downside is that it would require moving methods like
.predict()
, which do numpy-to-Torch conversion, out of thePolicy
class. For example, I think rlpyt solves this by having separate classes for policies, which are just Torch modules, and “agents”, which are responsible for evaluating policies. IMO that would create a better separation of concerns, but I would understand if you wanted to keep those things together.