question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Implement sampling and training asynchronously using the SAC algorithm

See original GitHub issue

Question

I’m trying to implement sampling and training asynchronously using the SAC algorithm. I made the attempt shown in the code below. But I always get an error because there seems to be a confusion between training and evaluation modes. The training mode (False or True) is configured in the policy. And this is shared between the train and collect_rollouts methods. Is it possible to do collect_rollouts asynchronously?

Additional context

Reference code:

_rollouts_ = queue.Queue()

def train_async(sac, max_steps):
	global _rollouts_
	iteraction = 0
	while True:
		if not _rollouts_.empty():
			rollout = _rollouts_.get()
			if rollout is not None:
				iteraction += 1
				gradient_steps = sac.gradient_steps if sac.gradient_steps >= 0 else rollout.episode_timesteps
				if gradient_steps > 0:
					sac.train(gradient_steps, sac.batch_size)
			else:
				print("Training ending")
				break
		else:
			print("waiting for rollouts....")
			time.sleep(1)

def rollouts_async(sac, max_steps,  callback, log_interval=None):
	steps  = 0
	global _rollouts_
	while True:
		rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
		if rollout.continue_training is False:
			_rollouts_.put(None)
			break
		else:
			_rollouts_.put( rollout )
			steps += 1
			if steps >= max_steps:
				_rollouts_.put(None)
				callback.on_training_end()
				break

def learn_async(sac, total_timesteps = 1000000, callback=None, log_interval=None, tb_log_name="run", reset_num_timesteps=True):
	total_timesteps, callback = sac._setup_learn(total_timesteps, None, callback, 0, 0, None, reset_num_timesteps, tb_log_name)
	callback.on_training_start(locals(), globals())
	t1 = threading.Thread(target=rollouts_async, args=(sac, total_timesteps, callback, log_interval))
	t1.start()
	t2 = threading.Thread(target=train_async, args=(sac, total_timesteps))
	t2.start()
	t1.join()
	t2.join()

Error:

Traceback (most recent call last):
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 973, in _bootstrap_inner
    self.run()
  File "C:\Users\gilza\anaconda3\lib\threading.py", line 910, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\gilza\doc\lab\nav\NavProAI4U\scripts\sb3sacutils.py", line 134, in rollouts_async
    rollout = sac.collect_rollouts(sac.env, callback, sac.train_freq, sac.replay_buffer, sac.action_noise, sac.learning_starts, log_interval)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 589, in collect_rollouts
    self._store_transition(replay_buffer, buffer_action, new_obs, reward, done, infos)
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\off_policy_algorithm.py", line 498, in _store_transition
    replay_buffer.add(
  File "C:\Users\gilza\anaconda3\lib\site-packages\stable_baselines3\common\buffers.py", line 562, in add
    self.actions[self.pos] = np.array(action).copy()
ValueError: could not broadcast input array from shape (256,4) into shape (4,)

Checklist

  • I have read the documentation (required) OK
  • I have checked that there is no similar issue in the repo (required) OK

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:6 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
araffincommented, Jan 5, 2022

Hello, I’m actively using the callback, the important thing to check is comparing training with the same number of gradient updates. And also comparing how long does it take to do the gradient update vs how long does it take to collect data.

I just used the callback yesterday and it worked fine:

Please note that we do not do technical support (so unless you provide a minimal example to reproduce the issue without a custom env, we won’t give further answers).

1reaction
araffincommented, Jan 2, 2022

Hello, you can find a working proof of concept here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/87001ed8a40f817d46c950e283d1ca29e405ad71/utils/callbacks.py#L95

(it is not polished but it works)

Read more comments on GitHub >

github_iconTop Results From Across the Web

Parallelizing Training: Async SAC on Humanoid - SLM Lab
SAC (Soft Actor Critic) is a sample efficient and off-policy algorithm. However it is very slow to train, since it consists of a...
Read more >
Soft Actor Critic—Deep Reinforcement Learning with Real ...
In this post, we will benchmark SAC against state-of-the-art model-free RL algorithms and showcase a spectrum of real-world robot examples, ...
Read more >
Asynchronous Reinforcement Learning for Real-Time Control ...
Abstract—An oft-ignored challenge of real-world reinforce- ment learning is that the real world does not pause when agents.
Read more >
Soft Actor-Critic Demystified - Towards Data Science
Soft Actor-Critic, the new Reinforcement Learning Algorithm from the folks at ... SAC makes use of three networks: a state value function V ......
Read more >
SAC-ABR: Soft Actor-Critic based deep reinforcement learning ...
We present the overall design together with the training and testing results of SAC-ABR, and evaluate its performance as compared to other state-of-the-art...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found