Problem with trax/models/research/bert.py
See original GitHub issueDescription
There are some problems when using the BERT class from trax/models/research/bert.py. The method new_weights(self, input_signature) in PretrainedBERT class is using super().new_weights(input_signature) to set the weights when it should be super().init_weights_and_state(input_signature) and self.weights = weights instead of return weights. After that, I’m still having issues.
BERT’s input_signature requires a tuple of length (3) as I thinkIt should be the (batch_size, sequence_length, d_model), but It has got the PaddingMask Layer that needs to receive a rank 2 tensor with shape (batch_size, sequence_length)
Environment information
OS: Ubuntu 18.04
$ pip freeze | grep trax
# 
trax==1.3.5
$ pip freeze | grep tensor
#
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-addons==0.11.2
tensorflow-datasets==4.0.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.9.0
tensorflow-metadata==0.24.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0
$ pip freeze | grep jax
# 
jax==0.2.1
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl
$ python -V
# 
Python 3.6.9
For bugs: reproduction and error logs
Steps to reproduce:
model = BERT(init_checkpoint=BERT_MODEL_PATH + 'bert_model.ckpt')
model.new_weights(input_signature=((24,128, 768)))
# Error logs:
...
LayerError: Exception passing through layer Parallel (in init):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 34
  layer input shapes: (ShapeDtype{shape:(), dtype:int32}, ShapeDtype{shape:(), dtype:int32}, ShapeDtype{shape:(), dtype:int32})
  File [...]/trax/layers/combinators.py, line 223, in init_weights_and_state
    in zip(self.sublayers, sublayer_signatures)]
  File [...]/trax/layers/combinators.py, line 222, in <listcomp>
    for layer, signature
LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 34
  layer input shapes: ShapeDtype{shape:(), dtype:int32}
  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)
LayerError: Exception passing through layer PaddingMask(0) (in _forward_abstract):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 33
  layer input shapes: ShapeDtype{shape:(), dtype:int32}
  File [...]/jax/interpreters/partial_eval.py, line 304, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)
  File [...]/jax/interpreters/partial_eval.py, line 1009, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File [...]/jax/interpreters/partial_eval.py, line 1019, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File [...]/dist-packages/jax/linear_util.py, line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File [...]/dist-packages/jax/linear_util.py, line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
LayerError: Exception passing through layer PaddingMask(0) (in pure_fn):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 33
  layer input shapes: ShapeDtype{shape:(), dtype:int32}
  File [...]/trax/layers/base.py, line 660, in forward
    raw_output = self._forward_fn(inputs)
  File [...]/trax/layers/base.py, line 702, in _forward
    return f(*xs)
  File [...]/trax/layers/attention.py, line 383, in f
    f'Input to PaddingMask must be a rank 2 tensor with shape '
ValueError: Input to PaddingMask must be a rank 2 tensor with shape (batch_size, sequence_length); instead got shape ().
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (5 by maintainers)
 Top Results From Across the Web
Top Results From Across the Web
trax-ml/community - Gitter
Or is it more likely I have made some error in my set-up that lead to this ... Hi a trax newbie here,...
Read more >The trax from google - Giter VIP
# Steps to reproduce: Just run the trainer.py in trax/trax using the configuration reformer_enwiki8.gin. # Error logs: [[[!!!! I ...
Read more > Top Related Medium Post
Top Related Medium Post
No results found
 Top Related StackOverflow Question
Top Related StackOverflow Question
No results found
 Troubleshoot Live Code
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free Top Related Reddit Thread
Top Related Reddit Thread
No results found
 Top Related Hackernoon Post
Top Related Hackernoon Post
No results found
 Top Related Tweet
Top Related Tweet
No results found
 Top Related Dev.to Post
Top Related Dev.to Post
No results found
 Top Related Hashnode Post
Top Related Hashnode Post
No results found

I don’t know if it’s correct, but I got it initialized with following signature (after fixing the new_weights thing)
Thanks @hepaajan. I was using tuples for the signature instead of trax.shapes.ShapeDtype instances. For me Its now working correctly.