Assertion assert_equal_shape failed for MultiDiscrete action space
See original GitHub issueIssue Description
First of all, thank you for developing this package and I really like the modular design. I am a bit new to RL and the JAX ecosystem so my question my be a bit naive. I am currently doing a baseline study with my customized gym
environment and VanillaPG
but I encounter this bug shown below and I could not figure it out. My understanding is that it is complaining that the shape of log_pi
should not be (4,)
. But I do have a MultiDiscrete
action space and its corresponding log_pi
should be something like (4,)
or (1, 4)
. I also attached the output when I call coax.Policy.example_data(env)
and my policy function definition below if that helps explain the situation.
So my questions are:
- Do you think this error is related to the fact that I have a MultiDiscrete action space?
- Did I declare my policy function properly?
- Any general ideas on how to debug JAX functions?
I would appreciate any feedback. Thank you!
Error message
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Input In [25], in <cell line: 5>()
13 transition_batch = tracer.pop()
14 Gn = transition_batch.Rn
---> 15 metrics = vanilla_pg.update(transition_batch, Adv=Gn)
16 env.record_metrics(metrics)
17 if done:
File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:149, in PolicyObjective.update(self, transition_batch, Adv)
127 def update(self, transition_batch, Adv):
128 r"""
129
130 Update the model parameters (weights) of the underlying function approximator.
(...)
147
148 """
--> 149 grads, function_state, metrics = self.grads_and_metrics(transition_batch, Adv)
150 if any(jnp.any(jnp.isnan(g)) for g in jax.tree_leaves(grads)):
151 raise RuntimeError(f"found nan's in grads: {grads}")
File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:218, in PolicyObjective.grads_and_metrics(self, transition_batch, Adv)
212 if self.REQUIRES_PROPENSITIES and jnp.all(transition_batch.logP == 0):
213 warnings.warn(
214 f"In order for {self.__class__.__name__} to work properly, transition_batch.logP "
215 "should be non-zero. Please sample actions with their propensities: "
216 "a, logp = pi(s, return_logp=True) and then add logp to your reward tracer, "
217 "e.g. nstep_tracer.add(s, a, r, done, logp)")
--> 218 return self._grad_and_metrics_func(
219 self._pi.params, self._pi.function_state, self.hyperparams, self._pi.rng,
220 transition_batch, Adv)
File ~/opt/python3.9/site-packages/coax/utils/_jit.py:59, in JittedFunc.__call__(self, *args, **kwargs)
58 def __call__(self, *args, **kwargs):
---> 59 return self._jitted_func(*args, **kwargs)
[... skipping hidden 14 frame]
File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:80, in PolicyObjective.__init__.<locals>.grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv)
77 def grads_and_metrics_func(params, state, hyperparams, rng, transition_batch, Adv):
78 grads_func = jax.grad(loss_func, has_aux=True)
79 grads, (metrics, state_new) = \
---> 80 grads_func(params, state, hyperparams, rng, transition_batch, Adv)
82 # add some diagnostics of the gradients
83 metrics.update(get_grads_diagnostics(grads, f'{self.__class__.__name__}/grads_'))
[... skipping hidden 10 frame]
File ~/opt/python3.9/site-packages/coax/policy_objectives/_base.py:47, in PolicyObjective.__init__.<locals>.loss_func(params, state, hyperparams, rng, transition_batch, Adv)
45 def loss_func(params, state, hyperparams, rng, transition_batch, Adv):
46 objective, (dist_params, log_pi, state_new) = \
---> 47 self.objective_func(params, state, hyperparams, rng, transition_batch, Adv)
49 # flip sign to turn objective into loss
50 loss = -objective
File ~/opt/python3.9/site-packages/coax/policy_objectives/_vanilla_pg.py:52, in VanillaPG.objective_func(self, params, state, hyperparams, rng, transition_batch, Adv)
49 W = jnp.clip(transition_batch.W, 0.1, 10.)
51 # some consistency checks
---> 52 chex.assert_equal_shape([W, Adv, log_pi])
53 chex.assert_rank([W, Adv, log_pi], 1)
54 objective = W * Adv * log_pi
File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:197, in chex_assertion.<locals>._chex_assert_fn(*args, **kwargs)
195 else:
196 try:
--> 197 host_assertion(*args, **kwargs)
198 except jax.errors.ConcretizationTypeError as exc:
199 msg = ("Chex assertion detected `ConcretizationTypeError`: it is very "
200 "likely that it tried to access tensors' values during tracing. "
201 "Make sure that you defined a jittable version of this Chex "
202 "assertion.")
File ~/opt/python3.9/site-packages/chex/_src/asserts_internal.py:157, in make_static_assertion.<locals>._static_assert(custom_message, custom_message_format_vars, include_default_message, exception_type, *args, **kwargs)
154 custom_message = custom_message.format(*custom_message_format_vars)
155 error_msg = f"{error_msg} [{custom_message}]"
--> 157 raise exception_type(error_msg)
AssertionError: [Chex] Assertion assert_equal_shape failed: Arrays have different shapes: [(1,), (1,), (4,)].
Example data
ExampleData(
inputs=Inputs(
args=ArgsType2(
S={
'features': array(shape=(1, 1000), dtype=float32, min=0.008, median=2.13, max=2.77)
is_training=True)
static_argnums=(
1))
output=(
{
'logits': array(shape=(1, 10), dtype=float32, min=-2.31, median=0.152, max=0.732)},
{
'logits': array(shape=(1, 10), dtype=float32, min=-1.54, median=-0.138, max=0.994)},
{
'logits': array(shape=(1, 10), dtype=float32, min=-0.984, median=0.0808, max=1.73)},
{
'logits': array(shape=(1, 10), dtype=float32, min=-2.74, median=-0.289, max=1.74)}))
Policy function
def pi(S, is_training):
module = CustomizedModule()
res = tuple([{"logits": item} for item in module(S["features"])])
return res
Issue Analytics
- State:
- Created 7 months ago
- Comments:5 (2 by maintainers)
Thanks for reporting this!
I can confirm that this is a proper bug. It has to do with the way variates are pre/post-processed.
I’ll have a closer look at it as soon as I have time, which is either tonight or tomorrow.
Hi @xiangyuy, thanks for your patience. PR #22 is merged. I bumped the version number, because I also fixed a few other little things in the same PR.