[Bug] PPO fails with `batch_size=1`
See original GitHub issue🐛 Bug
After tracking nans for a while, I found that it originated in line 196 of ppo/ppo.py, when doing advantages.std()
This is because rollout_data
has length 1 and the computation of the std fails.
To Reproduce
after relevant imports
if __name__ == "__main__":
env = gym.make('CartPole-v1')
env.seed(0)
stable_baselines3.common.utils.set_random_seed(0)
from stable_baselines3 import PPO
model = PPO("MlpPolicy", env, verbose=1, n_steps=20, batch_size=1, n_epochs=1)
model.learn(total_timesteps=1) # 000)
obs = env.reset()
for _ in range(1):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
Traceback (most recent call last):
File "d:/Onebox/as/code/exploration/distort-and-recover/environment.py", line 194, in <module>
model.learn(total_timesteps=1) # 000)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\ppo\ppo.py", line 281, in learn
return super(PPO, self).learn(
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\on_policy_algorithm.py", line 249, in learn
self.train()
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\ppo\ppo.py", line 192, in train
values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\policies.py", line 610, in evaluate_actions
distribution = self._get_action_dist_from_latent(latent_pi, latent_sde)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\policies.py", line 575, in _get_action_dist_from_latent
return self.action_dist.proba_distribution(action_logits=mean_actions)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\stable_baselines3\common\distributions.py", line 275, in proba_distribution
self.distribution = Categorical(logits=action_logits)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\torch\distributions\categorical.py", line 64, in __init__
super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
File "C:\Users\myuser\Miniconda3.9\envs\sandbox\lib\site-packages\torch\distributions\distribution.py", line 53, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter logits has invalid values
Expected behavior
Break an assertion or implement a workaround
### System Info
conda list
# packages in environment at C:\Users\myuser\Miniconda3.9\envs\sandbox:
#
# Name Version Build Channel
appdirs 1.4.4 py_0
astroid 2.5 py38haa95532_1
attrs 20.3.0 pyhd3eb1b0_0
backcall 0.2.0 pyhd3eb1b0_0
black 19.10b0 py_0
blas 1.0 mkl
blosc 1.21.0 h19a0ad4_0
boost-cpp 1.74.0 h54f0996_1 conda-forge
brotli 1.0.9 ha925a31_2
bzip2 1.0.8 he774522_0
ca-certificates 2020.12.5 h5b45459_0 conda-forge
certifi 2020.12.5 py38haa244fe_1 conda-forge
charls 2.2.0 h6c2663c_0
click 7.1.2 pyhd3eb1b0_0
cloudpickle 1.6.0 py_0
colorama 0.4.4 pyhd3eb1b0_0
colorio 0.7.3 pypi_0 pypi
colour-science 0.3.16 pyh44b312d_1 conda-forge
cpuonly 1.0 0 pytorch
curl 7.68.0 h4496350_0 conda-forge
cycler 0.10.0 py38_0
cytoolz 0.11.0 py38he774522_0
dask-core 2021.4.0 pyhd3eb1b0_0
decorator 5.0.6 pyhd3eb1b0_0
exifread 2.1.2 py_1 conda-forge
expat 2.3.0 h39d44d4_0 conda-forge
ffmpeg 4.3.1 ha925a31_0 conda-forge
freetype 2.10.4 hd328e21_0
freexl 1.0.6 ha8e266a_0 conda-forge
fsspec 0.9.0 pyhd3eb1b0_0
future 0.18.2 py38haa244fe_3 conda-forge
gdal 2.3.3 py38hdf43c64_0
geos 3.7.1 h33f27b4_0
giflib 5.2.1 h62dcd97_0
gym 0.18.0 py38h43734a8_1 conda-forge
hdf4 4.2.13 h712560f_2
hdf5 1.10.4 h7ebc959_0
icc_rt 2019.0.0 h0cc432a_1
icu 58.2 ha925a31_3
imagecodecs 2021.3.31 py38h5da4933_0
imageio 2.9.0 pyhd3eb1b0_0
intel-openmp 2021.2.0 haa95532_616
ipykernel 5.3.4 py38h5ca1d4c_0
ipython 7.22.0 py38hd4e2768_0
ipython_genutils 0.2.0 pyhd3eb1b0_1
isort 5.8.0 pyhd3eb1b0_0
jedi 0.17.0 py38_0
joblib 1.0.1 pyhd3eb1b0_0
jpeg 9b hb83a4c4_2
jupyter_client 6.1.12 pyhd3eb1b0_0
jupyter_core 4.7.1 py38haa95532_0
kealib 1.4.7 h07cbb95_6
kiwisolver 1.3.1 py38hd77b12b_0
krb5 1.16.4 hdd46e55_0 conda-forge
lazy-object-proxy 1.6.0 py38h2bbff1b_0
lcms2 2.12 h83e58a3_0
lerc 2.2.1 hd77b12b_0
libaec 1.0.4 h33f27b4_1
libcurl 7.68.0 h4496350_0 conda-forge
libdeflate 1.7 h2bbff1b_5
libgdal 2.3.3 h10f50ba_0
libiconv 1.16 he774522_0 conda-forge
libkml 1.3.0 h9859afa_1013 conda-forge
libnetcdf 4.6.1 h411e497_2
libopencv 4.0.1 hbb9e17c_0
libpng 1.6.37 h2a8f88b_0
libpq 11.5 hb0bdaea_2 conda-forge
libsodium 1.0.18 h62dcd97_0
libspatialite 4.3.0a h6a0152f_1026 conda-forge
libssh2 1.9.0 h680486a_6 conda-forge
libtiff 4.2.0 hd0e1b90_0
libuv 1.40.0 he774522_0
libxml2 2.9.10 hf5bbc77_4 conda-forge
libzopfli 1.0.3 ha925a31_0
llvmlite 0.36.0 py38h34b8924_4
locket 0.2.1 py38haa95532_1
lz4-c 1.9.3 h2bbff1b_0
matplotlib 3.3.4 py38haa95532_0
matplotlib-base 3.3.4 py38h49ac443_0
mccabe 0.6.1 py38_1
meshio 4.4.3 pypi_0 pypi
meshzoo 0.7.3 pypi_0 pypi
mkl 2021.2.0 haa95532_296
mkl-service 2.3.0 py38h2bbff1b_1
mkl_fft 1.3.0 py38h277e83a_2
mkl_random 1.2.1 py38hf11a4ad_2
mypy_extensions 0.4.3 py38_0
networkx 2.5 py_0
ninja 1.10.2 h6d14046_1
numba 0.53.1 py38hf11a4ad_0
numpy 1.20.1 py38h34a8a5c_0
numpy-base 1.20.1 py38haf7ebc8_0
olefile 0.46 py_0
opencv 4.0.1 py38h2a7c758_0
openjpeg 2.3.0 h5ec785f_1
openssl 1.1.1k h8ffe710_0 conda-forge
pandas 1.2.4 pypi_0 pypi
parso 0.8.2 pyhd3eb1b0_0
partd 1.2.0 pyhd3eb1b0_0
pathspec 0.7.0 py_0
pcre 8.44 ha925a31_0 conda-forge
pickleshare 0.7.5 pyhd3eb1b0_1003
piexif 1.1.3 py_2 conda-forge
pillow 8.2.0 py38h4fa10fc_0
pip 21.0.1 py38haa95532_0
proj4 5.2.0 h6538335_1006 conda-forge
prompt-toolkit 3.0.17 pyh06a4308_0
py-opencv 4.0.1 py38he44ac1e_0
pyglet 1.5.16 py38haa244fe_0 conda-forge
pygments 2.8.1 pyhd3eb1b0_0
pygmsh 7.1.9 pypi_0 pypi
pylibtiff 0.4.2 py38hdbb2d2f_3 conda-forge
pylint 2.7.4 py38haa95532_1
pyparsing 2.4.7 pyhd3eb1b0_0
pyqt 5.9.2 py38ha925a31_4
python 3.8.8 hdbf39b2_5
python-dateutil 2.8.1 pyhd3eb1b0_0
python_abi 3.8 1_cp38 conda-forge
pytorch 1.8.1 py3.8_cpu_0 [cpuonly] pytorch
pytz 2021.1 pypi_0 pypi
pyvista 0.30.0 pypi_0 pypi
pywavelets 1.1.1 py38he774522_2
pywin32 227 py38he774522_1
pyyaml 5.4.1 py38h2bbff1b_1
pyzmq 20.0.0 py38hd77b12b_1
qt 5.9.7 vc14h73c81de_0
rawpy 0.16.0 pypi_0 pypi
regex 2021.4.4 py38h2bbff1b_0
scikit-image 0.18.1 py38hf11a4ad_0
scikit-learn 0.24.1 py38hf11a4ad_0
scipy 1.6.2 py38h66253e8_1
scooby 0.5.7 pypi_0 pypi
setuptools 52.0.0 py38haa95532_0
sip 4.19.13 py38ha925a31_0
six 1.15.0 py38haa95532_0
snappy 1.1.8 h33f27b4_0
sqlite 3.35.4 h2bbff1b_0
stable-baselines3 1.0 pypi_0 pypi
tbb 2020.3 h74a9793_0
threadpoolctl 2.1.0 pyh5ca1d4c_0
tifffile 2021.4.8 pyhd3eb1b0_2
tk 8.6.10 he774522_0
toml 0.10.2 pyhd3eb1b0_0
toolz 0.11.1 pyhd3eb1b0_0
torchaudio 0.8.1 py38 pytorch
torchvision 0.9.1 py38_cpu [cpuonly] pytorch
tornado 6.1 py38h2bbff1b_0
tqdm 4.59.0 pyhd3eb1b0_1
traitlets 5.0.5 pyhd3eb1b0_0
transforms3d 0.3.1 pypi_0 pypi
typed-ast 1.4.2 py38h2bbff1b_1
typing_extensions 3.7.4.3 pyha847dfd_0
vc 14.2 h21ff451_1
vs2015_runtime 14.27.29016 h5e58377_2
vtk 9.0.1 pypi_0 pypi
wcwidth 0.2.5 py_0
wheel 0.36.2 pyhd3eb1b0_0
wincertstore 0.2 py38_0
wrapt 1.12.1 py38he774522_1
xerces-c 3.2.3 h0e60522_2 conda-forge
xz 5.2.5 h62dcd97_0
yaml 0.2.5 he774522_0
zeromq 4.3.3 ha925a31_3
zfp 0.5.5 hd77b12b_6
zlib 1.2.11 h62dcd97_4
zstd 1.4.5 h04227a9_0
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top Results From Across the Web
[Bug] Silent NaNs in PPO Loss Calculation if n_steps=1 and ...
@Dylan-Kerler There should be an error if n_envs * n_steps = 1 , and a warning if total number of samples is not...
Read more >Agent trains great with PPO but terrible with SAC - Reddit
Using PPO works fantastically, but my SAC agent doesnt really improve after 2M steps. Does anyone see an obvious error in my hyperparameters ......
Read more >tianshou.policy — Tianshou 0.4.10 documentation
Categorical to calculate the log_prob, please be careful about the shape: Categorical distribution gives “[batch_size]” shape while Normal distribution gives “[ ...
Read more >Episode Length and train_batch_size compatibility with RLLib ...
I have used the default PPO parameters from RLLib. ... ERROR trial_runner.py:958 -- Trial trial-db0c7_00000: Error processing event.
Read more >Playing CartPole with the Actor-Critic method | TensorFlow Core
The problem is considered "solved" when the average total reward for the episode reaches 195 over ... Convert state into a batched tensor...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
For sake of ease of use and bit of hand-holding, I’d say we throw an error with batch size == 1 (advantage can not be normalized, very noisy gradient). I do not see when this would be an useful or non-harmful option.
from @externalsupplierstaff : “I see Why don’t thing.std(unbiased=len(thing)>1), or a warning, or something? This is a silent and hard to track bug, it’s not nice people keep stumbling upon it.”
any thoughts on that @Miffyli ? I’m hesitating between the three options: