question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Problem with trax/models/research/bert.py

See original GitHub issue

Description

There are some problems when using the BERT class from trax/models/research/bert.py. The method new_weights(self, input_signature) in PretrainedBERT class is using super().new_weights(input_signature) to set the weights when it should be super().init_weights_and_state(input_signature) and self.weights = weights instead of return weights. After that, I’m still having issues. BERT’s input_signature requires a tuple of length (3) as I thinkIt should be the (batch_size, sequence_length, d_model), but It has got the PaddingMask Layer that needs to receive a rank 2 tensor with shape (batch_size, sequence_length)

Environment information

OS: Ubuntu 18.04

$ pip freeze | grep trax
# 
trax==1.3.5

$ pip freeze | grep tensor
#
mesh-tensorflow==0.1.16
tensor2tensor==1.15.7
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-addons==0.11.2
tensorflow-datasets==4.0.1
tensorflow-estimator==2.3.0
tensorflow-gan==2.0.0
tensorflow-hub==0.9.0
tensorflow-metadata==0.24.0
tensorflow-probability==0.7.0
tensorflow-text==2.3.0

$ pip freeze | grep jax
# 
jax==0.2.1
jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.55-cp36-none-manylinux2010_x86_64.whl

$ python -V
# 
Python 3.6.9

For bugs: reproduction and error logs

Steps to reproduce:

model = BERT(init_checkpoint=BERT_MODEL_PATH + 'bert_model.ckpt')
model.new_weights(input_signature=((24,128, 768)))
# Error logs:
...
LayerError: Exception passing through layer Parallel (in init):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 34
  layer input shapes: (ShapeDtype{shape:(), dtype:int32}, ShapeDtype{shape:(), dtype:int32}, ShapeDtype{shape:(), dtype:int32})

  File [...]/trax/layers/combinators.py, line 223, in init_weights_and_state
    in zip(self.sublayers, sublayer_signatures)]

  File [...]/trax/layers/combinators.py, line 222, in <listcomp>
    for layer, signature

LayerError: Exception passing through layer Serial (in init):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 34
  layer input shapes: ShapeDtype{shape:(), dtype:int32}

  File [...]/trax/layers/combinators.py, line 105, in init_weights_and_state
    outputs, _ = sublayer._forward_abstract(inputs)

LayerError: Exception passing through layer PaddingMask(0) (in _forward_abstract):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 33
  layer input shapes: ShapeDtype{shape:(), dtype:int32}

  File [...]/jax/interpreters/partial_eval.py, line 304, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals)

  File [...]/jax/interpreters/partial_eval.py, line 1009, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)

  File [...]/jax/interpreters/partial_eval.py, line 1019, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)

  File [...]/dist-packages/jax/linear_util.py, line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/dist-packages/jax/linear_util.py, line 151, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

LayerError: Exception passing through layer PaddingMask(0) (in pure_fn):
  layer created in file [...]/<ipython-input-73-2fa2cbdc2efd>, line 33
  layer input shapes: ShapeDtype{shape:(), dtype:int32}

  File [...]/trax/layers/base.py, line 660, in forward
    raw_output = self._forward_fn(inputs)

  File [...]/trax/layers/base.py, line 702, in _forward
    return f(*xs)

  File [...]/trax/layers/attention.py, line 383, in f
    f'Input to PaddingMask must be a rank 2 tensor with shape '

ValueError: Input to PaddingMask must be a rank 2 tensor with shape (batch_size, sequence_length); instead got shape ().

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
hepaajancommented, Oct 21, 2020

I don’t know if it’s correct, but I got it initialized with following signature (after fixing the new_weights thing)

input_signature = (
    trax.shapes.ShapeDtype(shape=(1, 512), dtype=np.int32),
    trax.shapes.ShapeDtype(shape=(1, 512), dtype=np.int32),
    trax.shapes.ShapeDtype(shape=(1, 512), dtype=np.float32),
)
0reactions
kujaomegacommented, Oct 21, 2020

Thanks @hepaajan. I was using tuples for the signature instead of trax.shapes.ShapeDtype instances. For me Its now working correctly.

Read more comments on GitHub >

github_iconTop Results From Across the Web

trax-ml/community - Gitter
Or is it more likely I have made some error in my set-up that lead to this ... Hi a trax newbie here,...
Read more >
The trax from google - Giter VIP
# Steps to reproduce: Just run the trainer.py in trax/trax using the configuration reformer_enwiki8.gin. # Error logs: [[[!!!! I ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found