Implement sampling and training asynchronously using the SAC algorithm
See original GitHub issueQuestion
I’m trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?
Additional context
Reference code:
_rollouts_ = queue.Queue()
def train_async(sac, max_steps):
global _rollouts_
iteraction = 0
while True:
if not _rollouts_.empty():
rollout = _rollouts_.get()
if rollout is not None:
iteraction += 1
gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
if gradient_steps > 0:
sac.train(gradient_steps, sac.batch_size)
else:
print("Training ending")
break
else:
print("waiting for rollouts....")
time.sleep(1)
def rollouts_async(sac, max_steps, callback, log_interval=None):
steps = 0
global _rollouts_
while True:
rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
if rollout.continue_training is False:
_rollouts_.put(None)
break
else:
_rollouts_.put( rollout )
steps += 1
if steps >= max_steps:
_rollouts_.put(None)
callback.on_training_end()
break
def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
callback.on_training_start(locals(), globals())
t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
t1.start()
t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
t2.start()
t1.join()
t2.join()
Error:
Traceback (most recent call last):
File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
self.run()
File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
self._target(*self._args, **self._kwargs)
File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
replay_buffer.add(
File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)
Checklist
- I have read the documentation (required) OK
- I have checked that there is no similar issue in the repo (required) OK
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (4 by maintainers)
Top Results From Across the Web
Parallelizing Training: Async SAC on Humanoid - SLM Lab
SAC (Soft Actor Critic) is a sample efficient and off-policy algorithm. However it is very slow to train, since it consists of a...
Read more >Soft Actor Critic—Deep Reinforcement Learning with Real ...
In this post, we will benchmark SAC against state-of-the-art model-free RL algorithms and showcase a spectrum of real-world robot examples, ...
Read more >Asynchronous Reinforcement Learning for Real-Time Control ...
Abstract—An oft-ignored challenge of real-world reinforce- ment learning is that the real world does not pause when agents.
Read more >Soft Actor-Critic Demystified - Towards Data Science
Soft Actor-Critic, the new Reinforcement Learning Algorithm from the folks at ... SAC makes use of three networks: a state value function V ......
Read more >SAC-ABR: Soft Actor-Critic based deep reinforcement learning ...
We present the overall design together with the training and testing results of SAC-ABR, and evaluate its performance as compared to other state-of-the-art...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hello, I’m actively using the callback, the important thing to check is comparing training with the same number of gradient updates. And also comparing how long does it take to do the gradient update vs how long does it take to collect data.
I just used the callback yesterday and it worked fine:
Please note that we do not do technical support (so unless you provide a minimal example to reproduce the issue without a custom env, we won’t give further answers).
Hello, you can find a working proof of concept here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/87001ed8a40f817d46c950e283d1ca29e405ad71/utils/callbacks.py#L95
(it is not polished but it works)