question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] VecNormalize fails on SAC/TD3

See original GitHub issue

🐛 Bug

I find VecNormalize wrapper is unusable when training off-policy algos like SAC and TD3.
I think the issue is located in def _store_transition() in off_policy_algorithm.py:

...
        if self._vec_normalize_env is not None:
            new_obs_ = self._vec_normalize_env.get_original_obs()
            reward_ = self._vec_normalize_env.get_original_reward()
        else:
            # Avoid changing the original ones
            self._last_original_obs, new_obs_, reward_ = self._last_obs, new_obs, reward

        # Avoid modification by reference
        next_obs = deepcopy(new_obs_)
...

e.g. get_original_obs() returns the unnormalized obs whose shape is not transposed: (1,96,96,3)

I might not understand exactly what’s the purpose of storing unnormalized obs and reward into the replay_buffer, when VecNormalize wrapper is used on purpose.

To Reproduce

env = make_vec_env('CarRacing-v0', 1)
env = VecNormalize(env, norm_obs=False)   # same for norm_obs=True, eventhough image input should be scaled in defualt.
model = SAC('CnnPolicy', env, verbose=1)    # or TD3 
model.learn(total_timesteps=int(2e5))

And I get error below, same error can also be reproduced with the newest version of SB3 on master branch. The error is gone when using PPO or A2C.

Traceback (most recent call last):

  File "/workspace/repos_dev/stable-baselines3/stable_baselines3/td3/td3.py", line 214, in learn
    reset_num_timesteps=reset_num_timesteps,
  File "/workspace/repos_dev/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 366, in learn
    log_interval=log_interval,
  File "/workspace/repos_dev/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 616, in collect_rollouts
    self._store_transition(replay_buffer, buffer_actions, new_obs, rewards, dones, infos)
  File "/workspace/repos_dev/stable-baselines3/stable_baselines3/common/off_policy_algorithm.py", line 534, in _store_transition
    infos,
  File "/workspace/repos_dev/stable-baselines3/stable_baselines3/common/buffers.py", line 246, in add
    self.observations[self.pos] = np.array(obs).copy()
ValueError: could not broadcast input array from shape (1,96,96,3) into shape (1,3,96,96)

Expected behavior

I expect the VecNormalize wrapper should work on all algotrithms in environment ‘CarRacing’, in document I don’t see any constrain regarding the usage of VecNormalize wrapper.

### System Info I’m using SB3 1.4.0 , gym 0.21.0 and python 3.7.11.

Checklist

I find a related issue but unfortunately it doesn’t solve my issue.

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
araffincommented, Mar 15, 2022

Hello,

To Reproduce

If you use a VecTransposeImage wrapper before the VecNormalize env, this solves your issue:

from stable_baselines3 import SAC
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecNormalize, VecTransposeImage

env = make_vec_env("CarRacing-v0", 1)
# Make it channel first to comply with PyTorch default
env = VecTransposeImage(env)
# Do not normalize image (done in SB3 automatically)
env = VecNormalize(env, norm_obs=False)
model = SAC("CnnPolicy", env, buffer_size=100, verbose=1)
model.learn(total_timesteps=int(2e5))

The error is gone when using PPO or A2C.

yes, the issue is similar to #693 .

1reaction
Miffylicommented, Mar 15, 2022

Hmm normalization of image pixels individually might hinder the performance (or not, not tested 😄), but it definitely is not something people do. For images, you should not use VecNormalize wrapper. Images (of type uint8) are automatically normalized to [0, 1] by dividing with 255.

Which part of the docs misled you to use VecNormalize with an image environment? The doc could be updated with a note/warning that one should only use VecNormalize with non-image envs 😃

I might not understand exactly what’s the purpose of storing unnormalized obs and reward into the replay_buffer, when VecNormalize wrapper is used on purpose.

Answering to sate your curiosity. Replay buffer stores the original samples so that when VecNormalize statistics change (which they do, constantly), you can re-normalize the replay buffer samples and use them with the new normalization parameters.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Changelog — Stable Baselines3 1.7.0a5 documentation
Fixed a bug in VecNormalize where error occurs when norm_obs is set to False for environment with dictionary observation (@buoyancy99).
Read more >
Releases · DLR-RM/stable-baselines3 - GitHub
Fixed a bug where set_env() with VecNormalize would result in an error with off-policy algorithms (thanks @cleversonahum); FPS calculation is now performed ...
Read more >
Stable Baselines Documentation - Read the Docs
We provide a helper to check that your environment runs without error: from stable_baselines.common.env_checker import check_env.
Read more >
stable-baselines Changelog - pyup.io
Fixed a bug in ``GAIL`` where the dataloader was not available after saving, causing an error when using ``CheckpointCallback`` - Fixed a bug...
Read more >
Stable-Baselines3: Reliable Reinforcement ... - ELIB-DLR
algorithms (SAC, TD3, QR-DQN, TQC8) and many additional features (e.g. ... improved9 (additional features, bug fixes, comments, documentation and more ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found