question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Text Classification on GLUE on TPU using Jax/Flax : BigBird

See original GitHub issue

The notebook for Text Classification on GLUE tasks, provided here can do with some updates:

  1. It has a missing from flax import traverse_utils which can be added.
  2. Replace gradient_transformation with adamw(1e-7) in TrainState.create().
  3. Probably a note to pip install sentencepiece since RoBERTA etc. models which are available in Flax at this point, use it.

Also, I couldn’t get the notebook to run for google/bigbird-roberta-base with batch_size=1 and the default task(cola). I got the following error:

Shapes of inputs: (8, 1, 128) (8, 1, 128) (8, 1)
---------------------------------------------------------------------------
UnfilteredStackTrace                      Traceback (most recent call last)
<ipython-input-32-cb9e92a30675> in <module>()
      7         print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['labels'].shape)
----> 8         state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
      9         progress_bar_train.update(1)

67 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    142     try:
--> 143       return fun(*args, **kwargs)
    144     except Exception as e:

/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
   1658         name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1659         global_arg_shapes=tuple(global_arg_shapes_flat))
   1660     return tree_unflatten(out_tree(), out)

/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
   1623     assert len(params['in_axes']) == len(args)
-> 1624     return call_bind(self, fun, *args, **params)
   1625 

/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1555   with maybe_new_sublevel(top_trace):
-> 1556     outs = primitive.process(top_trace, fun, tracers, params)
   1557   return map(full_lower, apply_todos(env_trace_todo(), outs))

/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1626   def process(self, trace, fun, tracers, params):
-> 1627     return trace.process_map(self, fun, tracers, params)
   1628 

/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    608   def process_call(self, primitive, f, tracers, params):
--> 609     return primitive.impl(f, *tracers, **params)
    610   process_map = process_call

/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
    622                                    donated_invars, global_arg_shapes,
--> 623                                    *abstract_args)
    624   # Don't re-abstractify args unless logging is enabled for performance.

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)

/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in parallel_callable(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *avals)
    698   with core.extend_axis_env(axis_name, global_axis_size, None):  # type: ignore
--> 699     jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
    700   jaxpr = xla.apply_outfeed_rewriter(jaxpr)

/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, transform_name)
   1208     main.jaxpr_stack = ()  # type: ignore
-> 1209     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1210     del fun, main

/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
   1187     in_tracers = map(trace.new_arg, in_avals)
-> 1188     ans = fun.call_wrapped(*in_tracers)
   1189     out_tracers = map(trace.full_raise, ans)

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

<ipython-input-24-f522714f3451> in train_step(state, batch, dropout_rng)
     10     grad_function = jax.value_and_grad(loss_function)
---> 11     loss, grad = grad_function(state.params)
     12     grad = jax.lax.pmean(grad, "batch")

/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    142     try:
--> 143       return fun(*args, **kwargs)
    144     except Exception as e:

/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
    886     if not has_aux:
--> 887       ans, vjp_py = _vjp(f_partial, *dyn_args)
    888     else:

/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
   1965     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1966     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1967     out_tree = out_tree()

/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    113   if not has_aux:
--> 114     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    115   else:

/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
    100   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
    102   out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)

/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
    497     fun = trace_to_subjaxpr(fun, main, instantiate)
--> 498     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    499     assert not env

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    165     try:
--> 166       ans = self.f(*args, **dict(self.params, **kwargs))
    167     except:

<ipython-input-24-f522714f3451> in loss_function(params)
      5     def loss_function(params):
----> 6         logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
      7         loss = state.loss_function(logits, targets)

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, params, dropout_rng, train, output_attentions, output_hidden_states, return_dict)
   1420             return_dict,
-> 1421             rngs=rngs,
   1422         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
    938         mutable=mutable, capture_intermediates=capture_intermediates
--> 939     )(variables, *args, **kwargs, rngs=rngs)
    940 

/usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs)
    686     with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 687       y = fn(root, *args, **kwargs)
    688     if mutable is not False:

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs)
   1177     try:
-> 1178       return fn(module.clone(parent=scope), *args, **kwargs)
   1179     finally:

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic, output_attentions, output_hidden_states, return_dict)
   1697             output_hidden_states=output_hidden_states,
-> 1698             return_dict=return_dict,
   1699         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic, output_attentions, output_hidden_states, return_dict)
   1458             output_hidden_states=output_hidden_states,
-> 1459             return_dict=return_dict,
   1460         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions, output_hidden_states, return_dict)
   1265             output_hidden_states=output_hidden_states,
-> 1266             return_dict=return_dict,
   1267         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions, output_hidden_states, return_dict)
   1221             layer_outputs = layer(
-> 1222                 hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
   1223             )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
   1179         attention_outputs = self.attention(
-> 1180             hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
   1181         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
   1113         attn_outputs = self.self(
-> 1114             hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
   1115         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
    360             plan_num_rand_blocks=None,
--> 361             output_attentions=output_attentions,
    362         )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in bigbird_block_sparse_attention(self, query_layer, key_layer, value_layer, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, n_heads, head_size, plan_from_length, plan_num_rand_blocks, output_attentions)
    476                 plan_from_length=plan_from_length,
--> 477                 plan_num_rand_blocks=plan_num_rand_blocks,
    478             )

/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
    274     try:
--> 275       y = fun(self, *args, **kwargs)
    276       if _context.capture_stack:

/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in _bigbird_block_rand_mask_with_head(self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_heads, plan_from_length, plan_num_rand_blocks, window_block_left, window_block_right, global_block_top, global_block_bottom, global_block_left, global_block_right)
   1004                         global_block_left=global_block_left,
-> 1005                         global_block_right=global_block_right,
   1006                     )

UnfilteredStackTrace: ValueError: could not broadcast input array from shape (0) into shape (3)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

Who can help

@patrickvonplaten @vasudevgupta7 @patil-suraj

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:11 (10 by maintainers)

github_iconTop GitHub Comments

3reactions
Jeevesh8commented, Jul 3, 2021

I think it is probably not valid to use block_sparse attention, with num_random_blocks=3 and block_size=64 with a sequence length of just 128(which is the length for cola).

The notebook works fine upon changing the attention_type to original_full, as it should be for such short sequence length. Sorry, my bad.

Probably can raise a better error though; when trying to run the model with invalid combination of smaller sequence lengths, with the default num_random_blocks and block_size.

Shall I open a pr? For the notebook or better error message, in case you guys don’t have time for one?

1reaction
Jeevesh8commented, Jul 8, 2021

Or probably use dynamic batching and padding with similar sequence length elements batched together.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fine-tuning a Transformers model on TPU with Flax/JAX
The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are: CoLA, MNLI, MRPC, QNLI, QQP,...
Read more >
Big Bird Text Classification Tutorial - Jesus Leal
This flavor of attention uses a combination of global attention (on selected tokens), window attention (just like Longformer) and random ...
Read more >
Solve GLUE tasks using BERT on TPU | Text - TensorFlow
On the Classify text with BERT colab the preprocessing model is used directly embedded with the BERT encoder. This tutorial demonstrates how to...
Read more >
arXiv:2105.03824v4 [cs.CL] 26 May 2022
FNet: Mixing Tokens with Fourier Transforms ... Table 2: GLUE Validation results on TPUs, after finetuning on respective tasks.
Read more >
A Nyström-based Algorithm for Approximating Self-Attention
Brockett 2005), text classification (Howard and Ruder ... ple downstream tasks in the GLUE benchmark (Wang et al. ... BIGBIRD (Zaheer et al....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found