question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

'linear/w' does not match shape

See original GitHub issue

I’ve been starting to learn about RL and have been trying to get coax up and running, but have run into an issue that I’m not sure how to resolve. I’m doing Q-learning on a custom gym environment, and I can run the following pieces successfully:

q = coax.Q(func_q, env)
pi = coax.Policy(func_pi, env)

qlearning = coax.td_learning.QLearning(q, pi_targ=pi, optimizer=optax.adam(0.001))
cache = coax.reward_tracing.NStep(n=1, gamma=0.9)

Additionally, my setup passes the simple checks of:

data = coax.Q.example_data(env) # Looks good
...
s = env.observation_space.sample()
a = env.action_space.sample()
print(q(s,a)) # 0.0
...
a = pi(s)
print(a) # [0, 0, 0, 0, 0] as I have a MultiDiscrete action space

However, once I get to actually running the training loop:

for ep in range(50):
  pi.epsilon = 0.1
  s = env.reset()

  for t in range(env.maxGuesses):
    a = pi(s)
    s_next, r, done, info = env.step(a)

    # update
    cache.add(s, a, r, done)

    while cache:
      transition_batch = cache.pop()
      metrics = qlearning.update(transition_batch)
      env.record_metrics(metrics)

    if done:
      break

    s = s_next

    # early stopping
    if env.avg_G > env.reward_threshold:
      break

I get a bunch of errors with the most human-readable of them saying:

ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')

By adjusting the parameters of the environment, I can adjust what the numbers that are mismatched are. I can’t get them to match and either way that seems like the wrong solution as something more fundamental seems to be the issue.

For reference, here are my functions for q and pi:

def func_pi(S, is_training):
  logits = hk.Sequential((
    hk.Linear(30), jax.nn.relu, 
    hk.Linear(30), jax.nn.relu, 
    hk.Linear(30), jax.nn.relu,
    hk.Linear(Wordle.wordLength*len(alphabet), w_init=jnp.zeros) # This many possible actions
  ))
  # First, convert to a vector:
  sVec = state_to_vec(S)

  # Now get the output:
  logitVec = logits(sVec)

  # Now chunk the output into alphabet-sized pieces (definitionally an integral
  # number of them). There will be Wordle.wordLength chunks of this length
  chunks = jnp.split(logitVec, Wordle.wordLength)

  # Now format our output array:
  ret = []
  for chunk in chunks:
    ret.append({'logits': jnp.reshape(chunk,(1,len(alphabet)))})

  return tuple(ret)

# and for actual state:
def func_q(S, A, is_training):
  value = hk.Sequential((
    hk.Linear(30), jax.nn.relu, 
    hk.Linear(30), jax.nn.relu,
    hk.Linear(30), jax.nn.relu,
    hk.Linear(1, w_init=jnp.zeros), jnp.ravel
  ))

  sVec = state_to_vec(S)
  aVec = action_to_vec(A)

  X = jnp.concatenate((sVec, aVec))
  return value(X)

Note that state_to_vec(S) and action_to_vec(A) just convert from my internal types to jnp.array’s for use with Haiku.

I’m quite new to coax/JAX/Haiku so it’s entirely possible I’ve set something up wrong. For completeness here’s the full text of the error:

Traceback (most recent call last):
  File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
    metrics = qlearning.update(transition_batch)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
    grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
    return self._grads_and_metrics_func(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
    out_flat = xla.xla_call(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 596, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 142, in _xla_call_impl
    compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
    ans = call(fun, *args)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 169, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars,
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/dispatch.py", line 197, in lower_xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1623, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
    grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 996, in grad_f_aux
    (_, aux), g = value_and_grad_f(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 1067, in value_and_grad_f
    ans, vjp_py, aux = _vjp(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 2478, in _vjp
    out_primal, out_vjp, aux = ad.vjp(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 118, in vjp
    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 103, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 520, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
    Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/_src/api.py", line 426, in cache_miss
    out_flat = xla.xla_call(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/ad.py", line 324, in process_call
    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 204, in process_call
    jaxpr, out_pvals, consts, env_tracers = self.partial_eval(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 317, in partial_eval
    out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1671, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/core.py", line 1683, in call_bind
    outs = top_trace.process_call(primitive, fun, tracers, params)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1364, in process_call
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 1594, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
    out = f(*args, **kwargs)
  File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
    return value(X)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 314, in <module>
    metrics = qlearning.update(transition_batch)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 87, in update
    grads, function_state, metrics, td_error = self.grads_and_metrics(transition_batch)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 149, in grads_and_metrics
    return self._grads_and_metrics_func(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 462, in grads_and_metrics_func
    grads, (td_error, state_new, metrics) = jax.grad(loss_func, has_aux=True)(
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/td_learning/_base.py", line 436, in loss_func
    Q, state_new = self.q.function_type1(params, state, next(rngs), S, A, True)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/coax/utils/_jit.py", line 59, in __call__
    return self._jitted_func(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/transform.py", line 383, in apply_fn
    out = f(*args, **kwargs)
  File "/home/user/wordle/wordle_ai/wordle-game-rl.py", line 264, in func_1
    return value(X)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 125, in __call__
    out = layer(out, *args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 428, in wrapped
    out = f(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/module.py", line 279, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/basic.py", line 178, in __call__
    w = hk.get_parameter("w", [input_size, output_size], dtype, init=w_init)
  File "/home/user/miniconda3/envs/wordle/lib/python3.10/site-packages/haiku/_src/base.py", line 319, in get_parameter
    raise ValueError(
ValueError: 'linear/w' with retrieved shape (420, 30) does not match shape=[940, 30] dtype=dtype('float32')

Please let me know if other information would be useful or relevant (or let me know if this isn’t actually a coax issue…).

Thanks for your help and the neat package.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
bcerjancommented, Jan 31, 2022

Kristian, thanks for your help that totally resolved it.

0reactions
KristianHolsheimercommented, Jan 30, 2022

The new default preprocessor is now updated in the main branch. If you want to try it out, you can install coax from the main branch:

pip install git+https://github.com/coax-dev/coax.git@main
Read more comments on GitHub >

github_iconTop Results From Across the Web

Dimensions do not match in linear regression - Stack Overflow
my X and Y are a set of coordinates in 3D (x,y,z); i want to train a model by using X as input...
Read more >
Why do the shapes below not match? - FEniCS Q&A
I am trying to implement a variational problem involving pressure and one-dimensional displacement. I get the following error from the code:  ...
Read more >
Problem of shapes not matching Edit: Solved - Linear Algebra
Edit: Problem solved, I was looking at the wrong part of the equation. Hello, I am trying to model an AM process using...
Read more >
Curve Fitting using Linear and Nonlinear Regression
It all depends on how well your model fits the data. Sometimes linear models can adequately fit the curvature and there are no...
Read more >
Graphing With Excel - Linear Regression - LabWrite
A straight line depicts a linear trend in the data (i.e., the equation describing the line is of first order. For example, y...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found