question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

guide on how to use LSTM version of DDPG on gym environments

See original GitHub issue

I am trying to run DDPG with the gym Pendulum-v0 environment. However I am getting this error:

TypeError: The batch size of x must be equal to or less thanthe size of the previous state h.

This is my code:

env = gym.make('Pendulum-v0')
obs_size = env.observation_space.shape[0]
n_actions = env.action_space.shape[0]

q_func = q_func_.FCLSTMSAQFunction(obs_size, n_actions, n_hidden_channels=50, n_hidden_layers=2)
pi = policy.FCLSTMDeterministicPolicy(n_input_channels=obs_size, n_hidden_channels=50, n_hidden_layers=2, 
                                      action_size=n_actions, 
                                      min_action=env.action_space.low, 
                                      max_action=env.action_space.high, 
                                      bound_action=True
                                     )
model = DDPGModel(policy=pi, q_func=q_func)
opt_a = optimizers.Adam(alpha=1e-4)
opt_c = optimizers.Adam(alpha=1e-3)
opt_a.setup(model['policy'])
opt_c.setup(model['q_function'])
opt_a.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_a')
opt_c.add_hook(chainer.optimizer.GradientClipping(1.0), 'hook_c')

ou_sigma = (env.action_space.high - env.action_space.low) * 0.2
explorer = explorers.AdditiveOU(sigma=ou_sigma)

replay_buffer = chainerrl.replay_buffer.ReplayBuffer(capacity=5 * 10 ** 5)

phi = lambda x: x.astype(np.float32, copy=False)

agent = DDPG(model, opt_a, opt_c, replay_buffer, gamma=0.995, explorer=explorer, 
             replay_start_size=5000, target_update_method='soft', 
             target_update_interval=1, update_interval=1,
             soft_update_tau=1e-2, n_times_update=1, 
             gpu=0, minibatch_size=200, phi=phi)

n_episodes = 200
max_episode_len = 200
for i in range(1, n_episodes + 1):
    obs = env.reset()
    reward = 0
    done = False
    R = 0  # return (sum of rewards)
    t = 0  # time step
    while not done and t < max_episode_len:
        # Uncomment to watch the behaviour
#         env.render()
        action = agent.act_and_train(obs, reward)
        obs, reward, done, _ = env.step(action)
        R += reward
        t += 1
    if i % 10 == 0:
        print('episode:', i,
              '\nR:', R,
              '\nstatistics:', agent.get_statistics())
    agent.stop_episode_and_train(obs, reward, done)
print('Finished.')

Here is the full initial running and error:

episode: 10 R: -1069.3354146961874 statistics: [(‘average_q’, -0.1465160510604003), (‘average_actor_loss’, 0.0), (‘average_critic_loss’, 0.0)] episode: 20 R: -1583.6140918088897 statistics: [(‘average_q’, -0.16802258113631832), (‘average_actor_loss’, 0.0), (‘average_critic_loss’, 0.0)]

TypeError Traceback (most recent call last) <ipython-input-11-222c13d7cf2a> in <module> 10 # Uncomment to watch the behaviour 11 # env.render() —> 12 action = agent.act_and_train(obs, reward) 13 obs, reward, done, _ = env.step(action) 14 R += reward

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\agents\ddpg.py in act_and_train(self, obs, reward) 335 self.last_action = action 336 –> 337 self.replay_updater.update_if_necessary(self.t) 338 339 return self.last_action

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\replay_buffer.py in update_if_necessary(self, iteration) 543 else: 544 transitions = self.replay_buffer.sample(self.batchsize) –> 545 self.update_func(transitions)

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\agents\ddpg.py in update(self, experiences, errors_out) 263 264 batch = batch_experiences(experiences, self.xp, self.phi, self.gamma) –> 265 self.critic_optimizer.update(lambda: self.compute_critic_loss(batch)) 266 self.actor_optimizer.update(lambda: self.compute_actor_loss(batch)) 267

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\optimizer.py in update(self, lossfun, *args, **kwds) 862 if lossfun is not None: 863 use_cleargrads = getattr(self, ‘_use_cleargrads’, True) –> 864 loss = lossfun(*args, **kwds) 865 if use_cleargrads: 866 self.target.cleargrads()

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\agents\ddpg.py in <lambda>() 263 264 batch = batch_experiences(experiences, self.xp, self.phi, self.gamma) –> 265 self.critic_optimizer.update(lambda: self.compute_critic_loss(batch)) 266 self.actor_optimizer.update(lambda: self.compute_actor_loss(batch)) 267

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\agents\ddpg.py in compute_critic_loss(self, batch) 208 # Estimated Q-function observes s_t and a_t 209 predict_q = F.reshape( –> 210 self.q_function(batch_state, batch_actions), 211 (batchsize,)) 212

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainerrl\q_functions\state_action_q_functions.py in call(self, x, a) 105 h = F.concat((x, a), axis=1) 106 h = self.nonlinearity(self.fc(h)) –> 107 h = self.lstm(h) 108 return self.out(h) 109

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\link.py in call(self, *args, **kwargs) 292 # forward is implemented in the child classes 293 forward = self.forward # type: ignore –> 294 out = forward(*args, **kwargs) 295 296 # Call forward_postprocess hook

~\AppData\Local\Continuum\anaconda3\envs\chainer\lib\site-packages\chainer\links\connection\lstm.py in forward(self, x) 296 msg = (‘The batch size of x must be equal to or less than’ 297 ‘the size of the previous state h.’) –> 298 raise TypeError(msg) 299 elif h_size > batch: 300 h_update, h_rest = split_axis.split_axis(

TypeError: The batch size of x must be equal to or less thanthe size of the previous state h.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:10 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
muupancommented, Sep 27, 2019

Correct, it is always in order.

1reaction
muupancommented, Sep 26, 2019

When you use a recurrent model with DDPG, you need to

  • pass episodic_update=True and
  • use chainerrl.replay_buffers.EpisodicReplayBuffer instead of ReplayBuffer

so that it uses a batch of sequences, not a batch of transitions, for updates. You can also specify the maximum length of the sequences by episodic_update_len.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Solving Reinforcement Learning Classic Control Problems
This is a very helpful blog on DDPG. Read this doc to know how to use Gym environments. Check out other cool environments...
Read more >
Examples — Stable Baselines 2.10.3a0 documentation
This package is in maintenance mode, please use Stable-Baselines3 (SB3) for an up-to-date version. You can find a migration guide in SB3 documentation....
Read more >
LSTM-DDPG for Trading with Variable Positions
The LSTM is used to extract environmental state features from environmental observations and the DDPG is used to make trading decisions.
Read more >
LSTM to the rescue | Deep Reinforcement Learning with ...
While backpropagating an RNN, we learned about a problem called vanishing ... Access the full title and Packt library for free now with...
Read more >
arXiv:2102.12344v5 [cs.LG] 13 Sep 2021
presence of enemies, learning in 3D environments in first- person shooter games. ... pared on the same version of a task, LSTM-DDPG always....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found