Unexpected behaviour when predicting with batch norm and dropout
See original GitHub issueImportant Note: We do not do technical support, nor consulting and don’t answer personal questions per email. Please post your question on the RL Discord, Reddit or Stack Overflow in that case.
If your issue is related to a custom gym environment, please use the custom gym env template.
🐛 Bug
A neural network with dropout or batch normalisation layers should have different behaviour when training and predicting. When using PyTorch, we can switch been train and evaluation mode by using the methods model.train()
and model.eval()
, where model
is an instance of torch.nn.Module
.
Stable-Baselines3 does not appear to switch to evaluation model when predicting. For example, if you call the predict method of the DQN
class and follow it through, there is nowhere in the code that switches the policy network to evaluation mode.
This causes problems if you are using a custom policy network with either dropout or batch normalisation. If you are using batch normalisation and you try to run predict on a single observation, it will throw an error because the model is in train mode and in this case batch normalisation requires at least two data points to update the layer’s statistics. If you are using dropout or batch normalisation and you run predict on multiple observations, it won’t give an error but it will result in unexpected behaviour. For dropout, the network activations will be dropped randomly instead of being scaled by the dropout rate. For batch normalisation, the observations will be used to update the layer’s statistics instead of simply being scaled by the existing statistics.
To Reproduce
Steps to reproduce the behavior when predicting with a single observation (note that when predicting with multiple observations there is no error, but the behaviour will not be correct):
from stable_baselines3 import DQN
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from torch.nn import BatchNorm1d
class BatchNorm(BaseFeaturesExtractor):
def __init__(self, observation_space):
features_dim = observation_space.shape[0]
super(BatchNorm, self).__init__(observation_space, features_dim)
self.batch_norm = BatchNorm1d(num_features=features_dim)
def forward(self, observations):
return self.batch_norm(observations)
dqn = DQN(
policy='MlpPolicy',
env='LunarLander-v2',
policy_kwargs=dict(features_extractor_class=BatchNorm),
)
obs = dqn.env.reset()
dqn.predict(obs)
This results in the follow error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_26633/1972183594.py in <module>
23
24 obs = dqn.env.reset()
---> 25 dqn.predict(obs)
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/dqn.py in predict(self, observation, state, mask, deterministic)
209 action = np.array(self.action_space.sample())
210 else:
--> 211 action, state = self.policy.predict(observation, state, mask, deterministic)
212 return action, state
213
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/common/policies.py in predict(self, observation, state, mask, deterministic)
274 observation = th.as_tensor(observation).to(self.device)
275 with th.no_grad():
--> 276 actions = self._predict(observation, deterministic=deterministic)
277 # Convert to numpy
278 actions = actions.cpu().numpy()
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in _predict(self, obs, deterministic)
167
168 def _predict(self, obs: th.Tensor, deterministic: bool = True) -> th.Tensor:
--> 169 return self.q_net._predict(obs, deterministic=deterministic)
170
171 def _get_constructor_parameters(self) -> Dict[str, Any]:
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in _predict(self, observation, deterministic)
61
62 def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
---> 63 q_values = self.forward(observation)
64 # Greedy action
65 action = q_values.argmax(dim=1).reshape(-1)
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/dqn/policies.py in forward(self, obs)
58 :return: The estimated Q-Value for each action.
59 """
---> 60 return self.q_net(self.extract_features(obs))
61
62 def _predict(self, observation: th.Tensor, deterministic: bool = True) -> th.Tensor:
~/python-environments/crmrlplayground/lib/python3.8/site-packages/stable_baselines3/common/policies.py in extract_features(self, obs)
117 assert self.features_extractor is not None, "No features extractor was set"
118 preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images)
--> 119 return self.features_extractor(preprocessed_obs)
120
121 def _get_constructor_parameters(self) -> Dict[str, Any]:
~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_26633/1972183594.py in forward(self, observations)
12
13 def forward(self, observations):
---> 14 return self.batch_norm(observations)
15
16
~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py in forward(self, input)
165 used for normalization (i.e. in eval mode when buffers are not None).
166 """
--> 167 return F.batch_norm(
168 input,
169 # If buffers are not to be tracked, ensure that they won't be updated
~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/functional.py in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
2277 )
2278 if training:
-> 2279 _verify_batch_size(input.size())
2280
2281 return torch.batch_norm(
~/python-environments/crmrlplayground/lib/python3.8/site-packages/torch/nn/functional.py in _verify_batch_size(size)
2245 size_prods *= size[i + 2]
2246 if size_prods == 1:
-> 2247 raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
2248
2249
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 8])
Expected behavior
I would expect the policy network to be put into evaluation mode when the predict
method is called. In this case, the observation would simply be scaled using the batch normalisation statistics computed on the training data and there would be no error.
System Info
Describe the characteristic of your environment:
- Describe how the library was installed (pip, docker, source, …): pip
- GPU models and configuration: none
- Python version: 3.8.5
- PyTorch version: 1.9.0
- Gym version: 0.17.3
- Stable-Baselines3 version: 1.0
Additional context
Add any other context about the problem here.
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
Top GitHub Comments
I will try to write a longer answer later, but in short, yes I knew about that (but most RL agent do not use BN or dropout), and that would make sense to call
.train()
in the train and.eval()
in the predict + collect rollout (as BN would crash there anyway with only one env).Adding this support is important, in my opinion. I believe that at the minute this will only effect batch norm and dropout, but there is the possibility that in future other modules will be added to the PyTorch library whose behaviour also depends on whether the model is in train or evaluation mode, and these modules may be important for RL.
There may be people out there who are using dropout in their RL models and do not know that they are getting the wrong behaviour, because in the case of dropout you don’t get an error. I just assumed that the model would automatically be put into evaluation mode. Making predictions while in train mode doesn’t seem right.
I don’t think this would be very difficult to implement. It’s just a case of calling
eval()
andtrain()
in the right places.