question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

BasePolicy does not place all sub-modules on the requested device

See original GitHub issue

Describe the bug

When BasePolicy.__init__() is supplied with a device= keyword argument, it correctly places its mlp_extractor on the given device, but does not place other modules—like action_net, value_net, etc.—on the same device. This leads to exceptions when trying to use the network later on.

Code example

Minimal example, which should work with a GPU-equipped machine:

from stable_baselines3.common import policies, utils
import gym
env = gym.make('CartPole-v1')
policy = policies.ActorCriticPolicy(
    env.observation_space, env.action_space, lr_schedule=utils.get_schedule_fn(1e-3), device='cuda:0')
policy.predict(env.reset())

Output:

RuntimeError                              Traceback (most recent call last)
<ipython-input-1-6d099de2364b> in <module>
      4 policy = policies.ActorCriticPolicy(
      5     env.observation_space, env.action_space, lr_schedule=utils.get_schedule_fn(1e-3), device='cuda:0')
----> 6 policy.predict(env.reset())

~/repos/stable-baselines3/stable_baselines3/common/policies.py in predict(self, observation, state, mask, deterministic)
    237         observation = th.as_tensor(observation).to(self.device)
    238         with th.no_grad():
--> 239             actions = self._predict(observation, deterministic=deterministic)
    240         # Convert to numpy
    241         actions = actions.cpu().numpy()

~/repos/stable-baselines3/stable_baselines3/common/policies.py in _predict(self, observation, deterministic)
    549         """
    550         latent_pi, _, latent_sde = self._get_latent(observation)
--> 551         distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
    552         return distribution.get_actions(deterministic=deterministic)
    553

~/repos/stable-baselines3/stable_baselines3/common/policies.py in _get_action_dist_from_latent(self, latent_pi, latent_sde)
    522         :return: (Distribution) Action distribution
    523         """                                              
--> 524         mean_actions = self.action_net(latent_pi)
    525                                                      
    526         if isinstance(self.action_dist, DiagGaussianDistribution):

~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:                
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():                                                     
    552             hook_result = hook(self, input, result)
                                                                           
~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     85                                   
     86     def forward(self, input):       
---> 87         return F.linear(input, self.weight, self.bias)                                                
     88                            
     89     def extra_repr(self):
                                                                                                                        
~/anaconda3/envs/imitation-sb3/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1608     if input.dim() == 2 and bias is not None:
   1609         # fused op is marginally faster                                      
-> 1610         ret = torch.addmm(bias, input, weight.t())
   1611     else:                              
   1612         output = input.matmul(weight.t())
                                                                                                            
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_addmm

System Info

Relevant details:

  • SB3 version: v0.8.0a3 via pip
  • Torch version: 1.5.1 via pip

Additional context

The lowest-breakage way to fix this would be to simply add .to(self.device) calls after constructing each module within BasePolicy. However, it may be more elegant to remove the device keyword argument from BasePolicy entirely. Instead, the user could be responsible for device placement by calling policy.to(device) after policy construction. This is how most modules in torch.nn work. If the device identity is needed in predict(), then it could be inferred from policy.parameters().

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
qxcvcommented, Jul 16, 2020

This cannot be done automatically? (we need to enable GPU automatically when the user is using algorithms and not policy alone)

To clarify, I’m suggesting that policies should be instantiated like this:

policy = SomePolicy(…).to(desired_device)

By “user”, I meant “whatever code is responsible for instantiating the policy”. I don’t have any problem with algorithms performing this conversion for actual end-users of Stable Baselines algorithms! (e.g. by inferring device from get_device(), or however it works at the moment)

0reactions
qxcvcommented, Aug 7, 2020

I have two concrete objections to the current strategy: (1) the current strategy is prone to errors if someone forgets to place some sub-modules on the right device (as is currently the case), and (2) the current strategy doesn’t keep .device in sync with .to(device) calls. It’s possible to fix both of those problems with something like this:

@property
def device(self):
    try:
        return next(self.parameters()).device
    except StopIteration:
        return get_device('auto')  # if we have no parameters

This is better, but still not perfect, since it might unexpectedly place a tensor on the auto device when the caller was expecting it to stay where it is. For example, if we added this to Torch’s ReLU module (which has no parameters), then it would start silently moving all of its arguments to the GPU on GPU-equipped machines.

I think the most idiomatic approach (which doesn’t have any of these problems) would be to make the calling code responsible for placing both models and tensors on the same device before doing forward-propagation. The downside is that it would require moving methods like .predict(), which do numpy-to-Torch conversion, out of the Policy class. For example, I think rlpyt solves this by having separate classes for policies, which are just Torch modules, and “agents”, which are responsible for evaluating policies. IMO that would create a better separation of concerns, but I would understand if you wanted to keep those things together.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Why does git fail to fetch specific valid submodule for a given ...
Running this command after cloning (and receiving the error) solved my problem: git submodule update --force --recursive --init --remote.
Read more >
Consider checking if git submodules have been initialized
Running this, I'm noticing that git submodule status takes 2-3 seconds on my machine, which is a bit steep to add to every...
Read more >
Git - Submodules - Git SCM
The issue with copying the code into your own project is that any custom changes you make are difficult to merge when upstream...
Read more >
Git Submodules: Adding, Using, Removing, Updating
Since I'm using submodules, the code can be pulled directly from the relevant submodule repositories rather than requiring me to manually update each...
Read more >
Using Git submodules with GitLab CI/CD
Then you can clone with HTTPS in all your CI/CD jobs. ... If there is no .gitmodules file, it's possible the submodule settings...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found