Assertion assert_equal_shape failed for MultiDiscrete action space

Issue Description

First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym environment and VanillaPG but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi should not be (4,). But I do have a MultiDiscrete action space and its corresponding log_pi should be something like (4,) or (1, 4). I also attached the output when I call coax.Policy.example_data(env) and my policy function definition below if that helps explain the situation.

So my questions are:

  1. Do you think this error is related to the fact that I have a MultiDiscrete action space?
  2. Did I declare my policy function properly?
  3. Any general ideas on how to debug JAX functions?

I would appreciate any feedback. Thank you!

Error message

AssertionError                            Traceback (most recent call last)
Input In [25], in <cell line: 5>()
     13     transition_batch = tracer.pop()
     14     Gn = transition_batch.Rn
---> 15     metrics = vanilla_pg.update(transition_batch, Adv=Gn)
     16     env.record_metrics(metrics)
     17 if done:

File ~/opt/python3.9/site-packages/coax/policy_objectives/, in PolicyObjective.update(self, transition_batch, Adv)
    127 def update(self, transition_batch, Adv):
    128     r"""
    130     Update the model parameters (weights) of the underlying function approximator.
    148     """
--> 149     grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
    150     if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
    151         raise RuntimeError(f"found nan's in grads: {grads}")

File ~/opt/python3.9/site-packages/coax/policy_objectives/, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
    212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
    213     warnings.warn(
    214         f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
    215         "should be non-zero. Please sample actions with their propensities: "
    216         "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
    217         "e.g. nstep_tracer.add(s, a, r, done, logp)")
--> 218 return self._grad_and_metrics_func(
    219     self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
    220     transition_batch, Adv)

File ~/opt/python3.9/site-packages/coax/utils/, in JittedFunc.__call__(self, *args, **kwargs)
     58 def __call__(self, *args, **kwargs):
---> 59     return self._jitted_func(*args, **kwargs)

    [... skipping hidden 14 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
     77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
     78     grads_func = jax.grad(loss_func, has_aux=True)
     79     grads, (metrics, state_new) = \
---> 80         grads_func(params, state, hyperparams, rng, transition_batch, Adv)
     82     # add some diagnostics of the gradients
     83     metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))

    [... skipping hidden 10 frame]

File ~/opt/python3.9/site-packages/coax/policy_objectives/, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
     45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
     46     objective, (dist_params, log_pi, state_new) = \
---> 47         self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
     49     # flip sign to turn objective into loss
     50     loss = -objective

File ~/opt/python3.9/site-packages/coax/policy_objectives/, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
     49 W = jnp.clip(transition_batch.W, 0.1, 10.)
     51 # some consistency checks
---> 52 chex.assert_equal_shape([W, Adv, log_pi])
     53 chex.assert_rank([W, Adv, log_pi], 1)
     54 objective = W * Adv * log_pi

File ~/opt/python3.9/site-packages/chex/_src/, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
    195 else:
    196   try:
--> 197     host_assertion(*args, **kwargs)
    198   except jax.errors.ConcretizationTypeError as exc:
    199     msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
    200            "likely that it tried to access tensors' values during tracing. "
    201            "Make sure that you defined a jittable version of this Chex "
    202            "assertion.")

File ~/opt/python3.9/site-packages/chex/_src/, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
    154     custom_message = custom_message.format(*custom_message_format_vars)
    155   error_msg = f"{error_msg} [{custom_message}]"
--> 157 raise exception_type(error_msg)

AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].

Example data

        'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
      'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
      'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
      'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
      'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))

Policy function

def pi(S, is_training):
    module = CustomizedModule()
    res = tuple([{"logits": item} for item in module(S["features"])])
    return res

KristianHolsheimercommented, Jul 21, 2022

Thanks for reporting this!

I can confirm that this is a proper bug. It has to do with the way variates are pre/post-processed.

I’ll have a closer look at it as soon as I have time, which is either tonight or tomorrow.

KristianHolsheimercommented, Jul 22, 2022

Hi @xiangyuy, thanks for your patience. PR #22 is merged. I bumped the version number, because I also fixed a few other little things in the same PR.

