Text Classification on GLUE on TPU using Jax/Flax : BigBird
See original GitHub issueThe notebook for Text Classification on GLUE tasks, provided here can do with some updates:
- It has a missing
from flax import traverse_utils
which can be added. - Replace
gradient_transformation
withadamw(1e-7)
inTrainState.create()
. - Probably a note to
pip install sentencepiece
sinceRoBERTA
etc. models which are available in Flax at this point, use it.
Also, I couldn’t get the notebook to run for google/bigbird-roberta-base
with batch_size=1
and the default task(cola
).
I got the following error:
Shapes of inputs: (8, 1, 128) (8, 1, 128) (8, 1)
---------------------------------------------------------------------------
UnfilteredStackTrace Traceback (most recent call last)
<ipython-input-32-cb9e92a30675> in <module>()
7 print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['labels'].shape)
----> 8 state, train_metrics, dropout_rngs = parallel_train_step(state, batch, dropout_rngs)
9 progress_bar_train.update(1)
67 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
142 try:
--> 143 return fun(*args, **kwargs)
144 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in f_pmapped(*args, **kwargs)
1658 name=flat_fun.__name__, donated_invars=tuple(donated_invars),
-> 1659 global_arg_shapes=tuple(global_arg_shapes_flat))
1660 return tree_unflatten(out_tree(), out)
/usr/local/lib/python3.7/dist-packages/jax/core.py in bind(self, fun, *args, **params)
1623 assert len(params['in_axes']) == len(args)
-> 1624 return call_bind(self, fun, *args, **params)
1625
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1555 with maybe_new_sublevel(top_trace):
-> 1556 outs = primitive.process(top_trace, fun, tracers, params)
1557 return map(full_lower, apply_todos(env_trace_todo(), outs))
/usr/local/lib/python3.7/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
1626 def process(self, trace, fun, tracers, params):
-> 1627 return trace.process_map(self, fun, tracers, params)
1628
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
608 def process_call(self, primitive, f, tracers, params):
--> 609 return primitive.impl(f, *tracers, **params)
610 process_map = process_call
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in xla_pmap_impl(fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *args)
622 donated_invars, global_arg_shapes,
--> 623 *abstract_args)
624 # Don't re-abstractify args unless logging is enabled for performance.
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/pxla.py in parallel_callable(fun, backend_name, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, global_arg_shapes, *avals)
698 with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
--> 699 jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
700 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals, transform_name)
1208 main.jaxpr_stack = () # type: ignore
-> 1209 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1210 del fun, main
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1187 in_tracers = map(trace.new_arg, in_avals)
-> 1188 ans = fun.call_wrapped(*in_tracers)
1189 out_tracers = map(trace.full_raise, ans)
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
<ipython-input-24-f522714f3451> in train_step(state, batch, dropout_rng)
10 grad_function = jax.value_and_grad(loss_function)
---> 11 loss, grad = grad_function(state.params)
12 grad = jax.lax.pmean(grad, "batch")
/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
142 try:
--> 143 return fun(*args, **kwargs)
144 except Exception as e:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in value_and_grad_f(*args, **kwargs)
886 if not has_aux:
--> 887 ans, vjp_py = _vjp(f_partial, *dyn_args)
888 else:
/usr/local/lib/python3.7/dist-packages/jax/_src/api.py in _vjp(fun, has_aux, *primals)
1965 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1966 out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
1967 out_tree = out_tree()
/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
113 if not has_aux:
--> 114 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
115 else:
/usr/local/lib/python3.7/dist-packages/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, instantiate)
497 fun = trace_to_subjaxpr(fun, main, instantiate)
--> 498 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
499 assert not env
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
165 try:
--> 166 ans = self.f(*args, **dict(self.params, **kwargs))
167 except:
<ipython-input-24-f522714f3451> in loss_function(params)
5 def loss_function(params):
----> 6 logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
7 loss = state.loss_function(logits, targets)
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, params, dropout_rng, train, output_attentions, output_hidden_states, return_dict)
1420 return_dict,
-> 1421 rngs=rngs,
1422 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in apply(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)
938 mutable=mutable, capture_intermediates=capture_intermediates
--> 939 )(variables, *args, **kwargs, rngs=rngs)
940
/usr/local/lib/python3.7/dist-packages/flax/core/scope.py in wrapper(variables, rngs, *args, **kwargs)
686 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
--> 687 y = fn(root, *args, **kwargs)
688 if mutable is not False:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in scope_fn(scope, *args, **kwargs)
1177 try:
-> 1178 return fn(module.clone(parent=scope), *args, **kwargs)
1179 finally:
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic, output_attentions, output_hidden_states, return_dict)
1697 output_hidden_states=output_hidden_states,
-> 1698 return_dict=return_dict,
1699 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, input_ids, attention_mask, token_type_ids, position_ids, deterministic, output_attentions, output_hidden_states, return_dict)
1458 output_hidden_states=output_hidden_states,
-> 1459 return_dict=return_dict,
1460 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions, output_hidden_states, return_dict)
1265 output_hidden_states=output_hidden_states,
-> 1266 return_dict=return_dict,
1267 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions, output_hidden_states, return_dict)
1221 layer_outputs = layer(
-> 1222 hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
1223 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
1179 attention_outputs = self.attention(
-> 1180 hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
1181 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
1113 attn_outputs = self.self(
-> 1114 hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions
1115 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in __call__(self, hidden_states, attention_mask, deterministic, output_attentions)
360 plan_num_rand_blocks=None,
--> 361 output_attentions=output_attentions,
362 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in bigbird_block_sparse_attention(self, query_layer, key_layer, value_layer, band_mask, from_mask, to_mask, from_blocked_mask, to_blocked_mask, n_heads, head_size, plan_from_length, plan_num_rand_blocks, output_attentions)
476 plan_from_length=plan_from_length,
--> 477 plan_num_rand_blocks=plan_num_rand_blocks,
478 )
/usr/local/lib/python3.7/dist-packages/flax/linen/module.py in wrapped_module_method(*args, **kwargs)
274 try:
--> 275 y = fun(self, *args, **kwargs)
276 if _context.capture_stack:
/usr/local/lib/python3.7/dist-packages/transformers/models/big_bird/modeling_flax_big_bird.py in _bigbird_block_rand_mask_with_head(self, from_seq_length, to_seq_length, from_block_size, to_block_size, num_heads, plan_from_length, plan_num_rand_blocks, window_block_left, window_block_right, global_block_top, global_block_bottom, global_block_left, global_block_right)
1004 global_block_left=global_block_left,
-> 1005 global_block_right=global_block_right,
1006 )
UnfilteredStackTrace: ValueError: could not broadcast input array from shape (0) into shape (3)
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
Who can help
Issue Analytics
- State:
- Created 2 years ago
- Comments:11 (10 by maintainers)
Top Results From Across the Web
Fine-tuning a Transformers model on TPU with Flax/JAX
The GLUE Benchmark is a group of nine classification tasks on sentences or pairs of sentences which are: CoLA, MNLI, MRPC, QNLI, QQP,...
Read more >Big Bird Text Classification Tutorial - Jesus Leal
This flavor of attention uses a combination of global attention (on selected tokens), window attention (just like Longformer) and random ...
Read more >Solve GLUE tasks using BERT on TPU | Text - TensorFlow
On the Classify text with BERT colab the preprocessing model is used directly embedded with the BERT encoder. This tutorial demonstrates how to...
Read more >arXiv:2105.03824v4 [cs.CL] 26 May 2022
FNet: Mixing Tokens with Fourier Transforms ... Table 2: GLUE Validation results on TPUs, after finetuning on respective tasks.
Read more >A Nyström-based Algorithm for Approximating Self-Attention
Brockett 2005), text classification (Howard and Ruder ... ple downstream tasks in the GLUE benchmark (Wang et al. ... BIGBIRD (Zaheer et al....
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I think it is probably not valid to use
block_sparse
attention, withnum_random_blocks=3
andblock_size=64
with a sequence length of just128
(which is the length forcola
).The notebook works fine upon changing the
attention_type
tooriginal_full
, as it should be for such short sequence length. Sorry, my bad.Probably can raise a better error though; when trying to run the model with invalid combination of smaller sequence lengths, with the default
num_random_blocks
andblock_size
.Shall I open a pr? For the notebook or better error message, in case you guys don’t have time for one?
Or probably use dynamic batching and padding with similar sequence length elements batched together.