question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unexpected behaviour when predicting with batch norm and dropout

See original GitHub issue

Important Note: We do not do technical support, nor consulting and don’t answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

A neural network with dropout or batch normalisation layers should have different behaviour when training and predicting. When using PyTorch, we can switch been train and evaluation mode by using the methods model.train() and model.eval(), where model is an instance of torch.nn.Module.

Stable-Baselines3 does not appear to switch to evaluation model when predicting. For example, if you call the predict method of the DQN class and follow it through, there is nowhere in the code that switches the policy network to evaluation mode.

This causes problems if you are using a custom policy network with either dropout or batch normalisation. If you are using batch normalisation and you try to run predict on a single observation, it will throw an error because the model is in train mode and in this case batch normalisation requires at least two data points to update the layer’s statistics. If you are using dropout or batch normalisation and you run predict on multiple observations, it won’t give an error but it will result in unexpected behaviour. For dropout, the network activations will be dropped randomly instead of being scaled by the dropout rate. For batch normalisation, the observations will be used to update the layer’s statistics instead of simply being scaled by the existing statistics.

To Reproduce

Steps to reproduce the behavior when predicting with a single observation (note that when predicting with multiple observations there is no error, but the behaviour will not be correct):

from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torch.nn import BatchNorm1d


class BatchNorm(BaseFeaturesExtractor):
    
    def __init__(self, observation_space):
        features_dim = observation_space.shape[0]
        super(BatchNorm, self).__init__(observation_space, features_dim)
        self.batch_norm = BatchNorm1d(num_features=features_dim)

    def forward(self, observations):
        return self.batch_norm(observations)


dqn = DQN(
    policy='MlpPolicy',
    env='LunarLander-v2',
    policy_kwargs=dict(features_extractor_class=BatchNorm),
)


obs = dqn.env.reset()
dqn.predict(obs)

This results in the follow error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_26633/1972183594.py in <module>
     23 
     24 obs = dqn.env.reset()
---> 25 dqn.predict(obs)

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/dqn.py in predict(self, observation, state, mask, deterministic)
    209                 action = np.array(self.action_space.sample())
    210         else:
--> 211             action, state = self.policy.predict(observation, state, mask, deterministic)
    212         return action, state
    213 

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/common/policies.py in predict(self, observation, state, mask, deterministic)
    274         observation = th.as_tensor(observation).to(self.device)
    275         with th.no_grad():
--> 276             actions = self._predict(observation, deterministic=deterministic)
    277         # Convert to numpy
    278         actions = actions.cpu().numpy()

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in _predict(self, obs, deterministic)
    167 
    168     def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
--> 169         return self.q_net._predict(obs, deterministic=deterministic)
    170 
    171     def _get_constructor_parameters(self) -> Dict[str, Any]:

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in _predict(self, observation, deterministic)
     61 
     62     def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
---> 63         q_values = self.forward(observation)
     64         # Greedy action
     65         action = q_values.argmax(dim=1).reshape(-1)

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in forward(self, obs)
     58         :return: The estimated Q-Value for each action.
     59         """
---> 60         return self.q_net(self.extract_features(obs))
     61 
     62     def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:

~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/common/policies.py in extract_features(self, obs)
    117         assert self.features_extractor is not None, "No features extractor was set"
    118         preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
--> 119         return self.features_extractor(preprocessed_obs)
    120 
    121     def _get_constructor_parameters(self) -> Dict[str, Any]:

~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

/tmp/ipykernel_26633/1972183594.py in forward(self, observations)
     12 
     13     def forward(self, observations):
---> 14         return self.batch_norm(observations)
     15 
     16 

~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1050                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051             return forward_call(*input, **kwargs)
   1052         # Do not call functions when jit is used
   1053         full_backward_hooks, non_full_backward_hooks = [], []

~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
    165         used for normalization (i.e. in eval mode when buffers are not None).
    166         """
--> 167         return F.batch_norm(
    168             input,
    169             # If buffers are not to be tracked, ensure that they won't be updated

~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
   2277         )
   2278     if training:
-> 2279         _verify_batch_size(input.size())
   2280 
   2281     return torch.batch_norm(

~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/functional.py in _verify_batch_size(size)
   2245         size_prods *= size[i + 2]
   2246     if size_prods == 1:
-> 2247         raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
   2248 
   2249 

ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 8])

Expected behavior

I would expect the policy network to be put into evaluation mode when the predict method is called. In this case, the observation would simply be scaled using the batch normalisation statistics computed on the training data and there would be no error.

System Info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, docker, source, …): pip
  • GPU models and configuration: none
  • Python version: 3.8.5
  • PyTorch version: 1.9.0
  • Gym version: 0.17.3
  • Stable-Baselines3 version: 1.0

Additional context

Add any other context about the problem here.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
araffincommented, Aug 12, 2021

I will try to write a longer answer later, but in short, yes I knew about that (but most RL agent do not use BN or dropout), and that would make sense to call .train() in the train and .eval() in the predict + collect rollout (as BN would crash there anyway with only one env).

1reaction
ayerightcommented, Aug 12, 2021

Good point, I am surprised this has not come up earlier.

Am I to understand right all this would need is calling .eval() for PyTorch modules during predict, and then .train() during learn? Off top of my head this should not interfere with anything (might also speed up things, if in eval() mode gradients are not computed), but I wonder if there would be downsides to this. @araffin comments?

There is also the fact that SB3 does not by default come with batchnorm or dropouts (they are rarely used in RL algorithms, outside some generalization experiments like with ProcGen), so adding this support is somewhat so-so (it does not support anything builtin to SB3, only supports further modifications).

Adding this support is important, in my opinion. I believe that at the minute this will only effect batch norm and dropout, but there is the possibility that in future other modules will be added to the PyTorch library whose behaviour also depends on whether the model is in train or evaluation mode, and these modules may be important for RL.

There may be people out there who are using dropout in their RL models and do not know that they are getting the wrong behaviour, because in the case of dropout you don’t get an error. I just assumed that the model would automatically be put into evaluation mode. Making predictions while in train mode doesn’t seem right.

I don’t think this would be very difficult to implement. It’s just a case of calling eval() and train() in the right places.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pitfalls with Dropout and BatchNorm in regression problems
Usually, when I see BatchNorm and Dropout layers in a neural network, I don't pay them much attention. I tend to think of...
Read more >
How does batch normalization behave differently at training ...
Dropout works because the process creates multiple implicit ensembles that share weights. The idea is that for each training set, you randomly remove...
Read more >
Understanding the Disharmony Between ... - CVF Open Access
This paper first answers the question “why do the two most powerful techniques Dropout and Batch Normaliza- tion (BN) often lead to a...
Read more >
Dropout vs. batch normalization: an empirical ... - NSF PAR
In this paper we conduct an empirical study to inves- tigate the effect of dropout and batch normalization on training deep learning models....
Read more >
keras batchnorm has awful test performance - Stack Overflow
This is a little surprising, and I'm wondering if I'm implementing the test predictions incorrectly. Generalization w/o the batchnorm layer ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found