question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] PPO fails with `batch_size=1`

See original GitHub issue

🐛 Bug

After tracking nans for a while, I found that it originated in line 196 of ppo/ppo.py, when doing advantages.std() This is because rollout_data has length 1 and the computation of the std fails.

To Reproduce

after relevant imports

if __name__ == "__main__":
    env = gym.make('CartPole-v1')
    env.seed(0)
    stable_baselines3.common.utils.set_random_seed(0)

    from stable_baselines3 import PPO
    model = PPO("MlpPolicy", env, verbose=1, n_steps=20, batch_size=1, n_epochs=1)
    model.learn(total_timesteps=1)  # 000)
    obs = env.reset()
    for _ in range(1):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        env.render()
Traceback (most recent call last):
  File "d:/Onebox/as/code/exploration/distort-and-recover/environment.py", line 194, in <module>
    model.learn(total_timesteps=1)  # 000)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\ppo\ppo.py", line 281, in learn
    return super(PPO, self).learn(
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 249, in learn
    self.train()
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\ppo\ppo.py", line 192, in train
    values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\policies.py", line 610, in evaluate_actions
    distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\policies.py", line 575, in _get_action_dist_from_latent
    return self.action_dist.proba_distribution(action_logits=mean_actions)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\distributions.py", line 275, in proba_distribution
    self.distribution = Categorical(logits=action_logits)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\torch\distributions\categorical.py", line 64, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\torch\distributions\distribution.py", line 53, in __init__
    raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter logits has invalid values

Expected behavior

Break an assertion or implement a workaround

### System Info

conda list
# packages in environment at C:\Users\myuser\Miniconda3.9\envs\sandbox:
#
# Name                    Version                   Build  Channel
appdirs                   1.4.4                      py_0
astroid                   2.5              py38haa95532_1
attrs                     20.3.0             pyhd3eb1b0_0
backcall                  0.2.0              pyhd3eb1b0_0
black                     19.10b0                    py_0
blas                      1.0                         mkl
blosc                     1.21.0               h19a0ad4_0
boost-cpp                 1.74.0               h54f0996_1    conda-forge
brotli                    1.0.9                ha925a31_2
bzip2                     1.0.8                he774522_0
ca-certificates           2020.12.5            h5b45459_0    conda-forge
certifi                   2020.12.5        py38haa244fe_1    conda-forge
charls                    2.2.0                h6c2663c_0
click                     7.1.2              pyhd3eb1b0_0
cloudpickle               1.6.0                      py_0
colorama                  0.4.4              pyhd3eb1b0_0
colorio                   0.7.3                    pypi_0    pypi
colour-science            0.3.16             pyh44b312d_1    conda-forge
cpuonly                   1.0                           0    pytorch
curl                      7.68.0               h4496350_0    conda-forge
cycler                    0.10.0                   py38_0
cytoolz                   0.11.0           py38he774522_0
dask-core                 2021.4.0           pyhd3eb1b0_0
decorator                 5.0.6              pyhd3eb1b0_0
exifread                  2.1.2                      py_1    conda-forge
expat                     2.3.0                h39d44d4_0    conda-forge
ffmpeg                    4.3.1                ha925a31_0    conda-forge
freetype                  2.10.4               hd328e21_0
freexl                    1.0.6                ha8e266a_0    conda-forge
fsspec                    0.9.0              pyhd3eb1b0_0
future                    0.18.2           py38haa244fe_3    conda-forge
gdal                      2.3.3            py38hdf43c64_0
geos                      3.7.1                h33f27b4_0
giflib                    5.2.1                h62dcd97_0
gym                       0.18.0           py38h43734a8_1    conda-forge
hdf4                      4.2.13               h712560f_2
hdf5                      1.10.4               h7ebc959_0
icc_rt                    2019.0.0             h0cc432a_1
icu                       58.2                 ha925a31_3
imagecodecs               2021.3.31        py38h5da4933_0
imageio                   2.9.0              pyhd3eb1b0_0
intel-openmp              2021.2.0           haa95532_616
ipykernel                 5.3.4            py38h5ca1d4c_0
ipython                   7.22.0           py38hd4e2768_0
ipython_genutils          0.2.0              pyhd3eb1b0_1
isort                     5.8.0              pyhd3eb1b0_0
jedi                      0.17.0                   py38_0
joblib                    1.0.1              pyhd3eb1b0_0
jpeg                      9b                   hb83a4c4_2
jupyter_client            6.1.12             pyhd3eb1b0_0
jupyter_core              4.7.1            py38haa95532_0
kealib                    1.4.7                h07cbb95_6
kiwisolver                1.3.1            py38hd77b12b_0
krb5                      1.16.4               hdd46e55_0    conda-forge
lazy-object-proxy         1.6.0            py38h2bbff1b_0
lcms2                     2.12                 h83e58a3_0
lerc                      2.2.1                hd77b12b_0
libaec                    1.0.4                h33f27b4_1
libcurl                   7.68.0               h4496350_0    conda-forge
libdeflate                1.7                  h2bbff1b_5
libgdal                   2.3.3                h10f50ba_0
libiconv                  1.16                 he774522_0    conda-forge
libkml                    1.3.0             h9859afa_1013    conda-forge
libnetcdf                 4.6.1                h411e497_2
libopencv                 4.0.1                hbb9e17c_0
libpng                    1.6.37               h2a8f88b_0
libpq                     11.5                 hb0bdaea_2    conda-forge
libsodium                 1.0.18               h62dcd97_0
libspatialite             4.3.0a            h6a0152f_1026    conda-forge
libssh2                   1.9.0                h680486a_6    conda-forge
libtiff                   4.2.0                hd0e1b90_0
libuv                     1.40.0               he774522_0
libxml2                   2.9.10               hf5bbc77_4    conda-forge
libzopfli                 1.0.3                ha925a31_0
llvmlite                  0.36.0           py38h34b8924_4
locket                    0.2.1            py38haa95532_1
lz4-c                     1.9.3                h2bbff1b_0
matplotlib                3.3.4            py38haa95532_0
matplotlib-base           3.3.4            py38h49ac443_0
mccabe                    0.6.1                    py38_1
meshio                    4.4.3                    pypi_0    pypi
meshzoo                   0.7.3                    pypi_0    pypi
mkl                       2021.2.0           haa95532_296
mkl-service               2.3.0            py38h2bbff1b_1
mkl_fft                   1.3.0            py38h277e83a_2
mkl_random                1.2.1            py38hf11a4ad_2
mypy_extensions           0.4.3                    py38_0
networkx                  2.5                        py_0
ninja                     1.10.2               h6d14046_1
numba                     0.53.1           py38hf11a4ad_0
numpy                     1.20.1           py38h34a8a5c_0
numpy-base                1.20.1           py38haf7ebc8_0
olefile                   0.46                       py_0
opencv                    4.0.1            py38h2a7c758_0
openjpeg                  2.3.0                h5ec785f_1
openssl                   1.1.1k               h8ffe710_0    conda-forge
pandas                    1.2.4                    pypi_0    pypi
parso                     0.8.2              pyhd3eb1b0_0
partd                     1.2.0              pyhd3eb1b0_0
pathspec                  0.7.0                      py_0
pcre                      8.44                 ha925a31_0    conda-forge
pickleshare               0.7.5           pyhd3eb1b0_1003
piexif                    1.1.3                      py_2    conda-forge
pillow                    8.2.0            py38h4fa10fc_0
pip                       21.0.1           py38haa95532_0
proj4                     5.2.0             h6538335_1006    conda-forge
prompt-toolkit            3.0.17             pyh06a4308_0
py-opencv                 4.0.1            py38he44ac1e_0
pyglet                    1.5.16           py38haa244fe_0    conda-forge
pygments                  2.8.1              pyhd3eb1b0_0
pygmsh                    7.1.9                    pypi_0    pypi
pylibtiff                 0.4.2            py38hdbb2d2f_3    conda-forge
pylint                    2.7.4            py38haa95532_1
pyparsing                 2.4.7              pyhd3eb1b0_0
pyqt                      5.9.2            py38ha925a31_4
python                    3.8.8                hdbf39b2_5
python-dateutil           2.8.1              pyhd3eb1b0_0
python_abi                3.8                      1_cp38    conda-forge
pytorch                   1.8.1               py3.8_cpu_0  [cpuonly]  pytorch
pytz                      2021.1                   pypi_0    pypi
pyvista                   0.30.0                   pypi_0    pypi
pywavelets                1.1.1            py38he774522_2
pywin32                   227              py38he774522_1
pyyaml                    5.4.1            py38h2bbff1b_1
pyzmq                     20.0.0           py38hd77b12b_1
qt                        5.9.7            vc14h73c81de_0
rawpy                     0.16.0                   pypi_0    pypi
regex                     2021.4.4         py38h2bbff1b_0
scikit-image              0.18.1           py38hf11a4ad_0
scikit-learn              0.24.1           py38hf11a4ad_0
scipy                     1.6.2            py38h66253e8_1
scooby                    0.5.7                    pypi_0    pypi
setuptools                52.0.0           py38haa95532_0
sip                       4.19.13          py38ha925a31_0
six                       1.15.0           py38haa95532_0
snappy                    1.1.8                h33f27b4_0
sqlite                    3.35.4               h2bbff1b_0
stable-baselines3         1.0                      pypi_0    pypi
tbb                       2020.3               h74a9793_0
threadpoolctl             2.1.0              pyh5ca1d4c_0
tifffile                  2021.4.8           pyhd3eb1b0_2
tk                        8.6.10               he774522_0
toml                      0.10.2             pyhd3eb1b0_0
toolz                     0.11.1             pyhd3eb1b0_0
torchaudio                0.8.1                      py38    pytorch
torchvision               0.9.1                  py38_cpu  [cpuonly]  pytorch
tornado                   6.1              py38h2bbff1b_0
tqdm                      4.59.0             pyhd3eb1b0_1
traitlets                 5.0.5              pyhd3eb1b0_0
transforms3d              0.3.1                    pypi_0    pypi
typed-ast                 1.4.2            py38h2bbff1b_1
typing_extensions         3.7.4.3            pyha847dfd_0
vc                        14.2                 h21ff451_1
vs2015_runtime            14.27.29016          h5e58377_2
vtk                       9.0.1                    pypi_0    pypi
wcwidth                   0.2.5                      py_0
wheel                     0.36.2             pyhd3eb1b0_0
wincertstore              0.2                      py38_0
wrapt                     1.12.1           py38he774522_1
xerces-c                  3.2.3                h0e60522_2    conda-forge
xz                        5.2.5                h62dcd97_0
yaml                      0.2.5                he774522_0
zeromq                    4.3.3                ha925a31_3
zfp                       0.5.5                hd77b12b_6
zlib                      1.2.11               h62dcd97_4
zstd                      1.4.5                h04227a9_0

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
Miffylicommented, May 18, 2021

For sake of ease of use and bit of hand-holding, I’d say we throw an error with batch size == 1 (advantage can not be normalized, very noisy gradient). I do not see when this would be an useful or non-harmful option.

1reaction
araffincommented, May 18, 2021

from @externalsupplierstaff : “I see Why don’t thing.std(unbiased=len(thing)>1), or a warning, or something? This is a silent and hard to track bug, it’s not nice people keep stumbling upon it.”

any thoughts on that @Miffyli ? I’m hesitating between the three options:

  • throw an error when batch_size=1
  • allow to deactivate advantage normalization
  • use unbiased estimate when batch_size == 1
Read more comments on GitHub >

github_iconTop Results From Across the Web

[Bug] Silent NaNs in PPO Loss Calculation if n_steps=1 and ...
@Dylan-Kerler There should be an error if n_envs * n_steps = 1 , and a warning if total number of samples is not...
Read more >
Agent trains great with PPO but terrible with SAC - Reddit
Using PPO works fantastically, but my SAC agent doesnt really improve after 2M steps. Does anyone see an obvious error in my hyperparameters ......
Read more >
tianshou.policy — Tianshou 0.4.10 documentation
Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[ ...
Read more >
Episode Length and train_batch_size compatibility with RLLib ...
I have used the default PPO parameters from RLLib. ... ERROR trial_runner.py:958 -- Trial trial-db0c7_00000: Error processing event.
Read more >
Playing CartPole with the Actor-Critic method | TensorFlow Core
The problem is considered "solved" when the average total reward for the episode reaches 195 over ... Convert state into a batched tensor...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found