[Feature request] Adding multiprocessing support for off policy algorithms
See original GitHub issueI am in the process of adding multiprocessing(vectorized envs) support for off-policy algorithms (TD3, SAC, DDPG etc), I’ve added support for sampling multiple actions and updates timesteps appropriately to the number of vectorized environments. The modified code can run without throwing an error, but the algorithms don’t really converge anymore.
I tried on OpenAI Gym’s Pendulum-v0, where single instance envs made from make_vec_env('Pendulum-v0', n_envs = 1, vec_env_cls=DummyVecEnv)
trains fine. If I specify multiple instances such as make_vec_env('Pendulum-v0', n_envs = 2, vec_env_cls=DummyVecEnv)
or make_vec_env('Pendulum-v0', n_envs = 2, vec_env_cls=SubprocVecEnv)
, then the algorithms don’t converge at all.
Here’s a warning message that I get, which I suspect is closely related to the non-convergence.
/home/me/code/stable-baselines3/stable_baselines3/sac/sac.py:237: UserWarning: Using a target size (torch.Size([256, 2])) that is different
to the input size (torch.Size([256, 1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size. critic_loss = 0.5 * sum([F.mse_loss(current_q, q_backup) for current_q in current_q_estimates])
It appears to me that the replay buffer wasn’t not retrieving n_envs
samples thus the loss target had to rely on broadcasting
Some pointers on modifying the replay buffer so it would support multiprocessing would be much appreciated! If the authors would like, I can create a PR
https://github.com/yonkshi/stable-baselines3/commit/157971357a28be8435d09cfccec1d4258b220a6a
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (7 by maintainers)
Top GitHub Comments
I’ll be working on that in the coming weeks (I need to implement it for a personal project)
Please read the documentation, you are using
train_freq=(1, "episode")
(episodic training), to use mutliple env, you must use “step” as the unit (orŧrain_freq=1
for short). We recommend you to use TD3/SAC anyway (improved versions of DDPG).You need to install master version (cf. doc) as it is not yet released.