question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Assertion assert_equal_shape failed for MultiDiscrete action space

See original GitHub issue

First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

So my questions are:

  1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
  2. Did I declare my policy function properly?
  3. Any general ideas on how to debug JAX functions?

I would appreciate any feedback. Thank you!

Error message

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Input In [25], in <cell line: 5>()
     13     transition_batch = tracer.pop()
     14     Gn = transition_batch.Rn
---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
     16     env.record_metrics(metrics)
     17 if done:

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
    127 def update(self, transition_batch, Adv):
    128     r"""
    129 
    130     Update the model parameters (weights) of the underlying function approximator.
   (...)
    147 
    148     """
--> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
    150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
    151         raise RuntimeError(f"found nan's in grads: {grads}")

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
    212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
    213     warnings.warn(
    214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
    215         "should be non-zero. Please sample actions with their propensities: "
    216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
    217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
--> 218 return self._grad_and_metrics_func(
    219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
    220     transition_batch, Adv)

File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
     58 def __call__(self, *args, **kwargs):
---> 59     return self._jitted_func(*args, **kwargs)

    [... skipping hidden 14 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
     77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
     78     grads_func = jax.grad(loss_func, has_aux=True)
     79     grads, (metrics, state_new) = \
---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
     82     # add some diagnostics of the gradients
     83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))

    [... skipping hidden 10 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
     45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
     46     objective, (dist_params, log_pi, state_new) = \
---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
     49     # flip sign to turn objective into loss
     50     loss = -objective

File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
     49 W = jnp.clip(transition_batch.W, 0.1, 10.)
     51 # some consistency checks
---> 52 chex.assert_equal_shape([W, Adv, log_pi])
     53 chex.assert_rank([W, Adv, log_pi], 1)
     54 objective = W * Adv * log_pi

File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
    195 else:
    196   try:
--> 197     host_assertion(*args, **kwargs)
    198   except jax.errors.ConcretizationTypeError as exc:
    199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
    200            "likely that it tried to access tensors' values during tracing. "
    201            "Make sure that you defined a jittable version of this Chex "
    202            "assertion.")

File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
    154     custom_message = custom_message.format(*custom_message_format_vars)
    155   error_msg = f"{error_msg} [{custom_message}]"
--> 157 raise exception_type(error_msg)

AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].

Example data

ExampleData(
  inputs=Inputs(
    args=ArgsType2(
      S={
        'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
      is_training=True)
    static_argnums=(
      1))
  output=(
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
    {
      'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))

Policy function

def pi(S, is_training):
    module = CustomizedModule()
    res = tuple([{"logits": item} for item in module(S["features"])])
    return res

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
KristianHolsheimercommented, Jul 21, 2022

Thanks for reporting this!

I can confirm that this is a proper bug. It has to do with the way variates are pre/post-processed.

I’ll have a closer look at it as soon as I have time, which is either tonight or tomorrow.

1reaction
KristianHolsheimercommented, Jul 22, 2022

Hi @xiangyuy, thanks for your patience. PR #22 is merged. I bumped the version number, because I also fixed a few other little things in the same PR.

Read more comments on GitHub >

github_iconTop Results From Across the Web

AssertionError on load() when using MultiDiscrete #25 - GitHub
"Error: the environment passed must have at least the same action space as the model was trained on." When using gym.spaces.MultiDiscrete as an ......
Read more >
Training DQN Agent with Multidiscrete action space in gym
I would like to train a DQN Agent with Keras-rl. My environment has both multi-discrete action and observation spaces.
Read more >
Source code for tianshou.env.gym_wrappers
Env env: gym environment with continuous action space. ... MultiDiscrete) nvec = env.action_space.nvec assert nvec.ndim == 1 self.bases = np.ones_like(nvec) ...
Read more >
Action Space Shaping in Deep Reinforcement Learning - arXiv
A set of keyboard buttons and mouse control could be represented as a combination of MultiDiscrete and two Continuous actions, one continuous action...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found