xmap regression
See original GitHub issuePossibly related to #6959 , but there seems to be additional regressions in xmap unrelated to the multi-host case. One simplified example is provided in the colab below: https://colab.research.google.com/drive/18nU5rz8CF7YDYPSuPqJH38Dc7HkroYOn?usp=sharing
The example works with jax == 0.2.12, but not with the latest version.
Full stack trace:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-2-2f2533018f84> in <module>()
65 axis_resources={'shard': 'mp', 'batch': 'dp'})
66
---> 67 run_xmap(params, x)
8 frames
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in fun_mapped(*args)
577 backend=backend,
578 spmd_in_axes=None,
--> 579 spmd_out_axes_thunk=None)
580 if has_output_rank_assertions:
581 for out, spec in zip(out_flat, out_axes_thunk()):
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in bind(self, fun, *args, **params)
764 def bind(self, fun, *args, **params):
765 assert len(params['in_axes']) == len(args)
--> 766 return core.call_bind(self, fun, *args, **params) # type: ignore
767
768 def process(self, trace, fun, tracers, params):
/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1549 params_tuple, out_axes_transforms)
1550 tracers = map(top_trace.full_raise, args)
-> 1551 outs = primitive.process(top_trace, fun, tracers, params)
1552 return map(full_lower, apply_todos(env_trace_todo(), outs))
1553
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in process(self, trace, fun, tracers, params)
767
768 def process(self, trace, fun, tracers, params):
--> 769 return trace.process_xmap(self, fun, tracers, params)
770
771 def post_process(self, trace, out_tracers, params):
/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
604
605 def process_call(self, primitive, f, tracers, params):
--> 606 return primitive.impl(f, *tracers, **params)
607 process_map = process_call
608
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in xmap_impl(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *args)
596 axis_resources, resource_env, backend,
597 spmd_in_axes, spmd_out_axes_thunk,
--> 598 *in_avals)
599 distributed_debug_log(("Running xmapped function", name),
600 ("python function", fun.f),
/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
260 fun.populate_stores(stores)
261 else:
--> 262 ans = call(fun, *args)
263 cache[key] = (ans, fun.stores)
264
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *in_avals)
619 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
620 out_axes = out_axes_thunk()
--> 621 _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
622 # NOTE: We don't use avals and all params, so only pass in the relevant parts (too lazy...)
623 _resource_typing_xmap([], dict(axis_resources=axis_resources,
/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
1401 if undeclared_axes:
1402 undeclared_axes_str = sorted([str(axis) for axis in undeclared_axes])
-> 1403 raise TypeError(f"One of xmap results has an out_axes specification of "
1404 f"{axes.user_repr}, but is actually mapped along more axes "
1405 f"defined by this xmap call: {', '.join(undeclared_axes_str)}")
TypeError: One of xmap results has an out_axes specification of ['shard', ...], but is actually mapped along more axes defined by this xmap call: batch
Issue Analytics
- State:
- Created 2 years ago
- Reactions:1
- Comments:19 (8 by maintainers)
Top Results From Across the Web
jax.experimental.maps.xmap - JAX documentation
xmap () extends this model by adding support for named axes. In particular, each array used in a function wrapped by xmap() can...
Read more >XMAP: eXplainable mapping analytical process | SpringerLink
Theoretically, although XMAP can be used for both regression and classification, we only focus on classification tasks in this paper to ...
Read more >Comparison of xMAP and ELISA assays for detecting CSF ...
Reliable conversions between xMAP and ELISA measurements are possible by linear regressions. We analyzed twelve linear regression models that estimate the ELISA ...
Read more >xPONENT® Software for xMAP® Instruments
... control of the MAGPIX, Luminex 100/200, and FLEXMAP 3D xMAP instruments. ... Intra-well Normalization Analysis, Real Time Regression and Data Analysis.
Read more >Significance analysis of xMap cytokine bead arrays - PNAS
The methodology, which we refer to as Statistical Analysis of xMap Cytokine Beads (SAxCyB), is a linear regression model designed to find ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi! Wonderful comment, and you’re not alone. I’m sending this mostly as a reminder to myself not to let you fall through the cracks – it’s 11pm right now, so I won’t be able to get to this until some other time. But if it goes too long without you getting a reply, please ping me or DM me on twitter (https://twitter.com/theshawwn).
The JAX team really is awesome, and I suspect they’ll be along with an answer, but this darn issue is a tricky one. Basically, it was a situation where there was a breaking API change (yuck!) for good reason (oh…) and it happened to break the only popular thing using xmap (GPT-J).
I actually complained a lot about the breaking change, but was dragged kicking and screaming to the conclusion that it was a good thing. I hate when developers take the attitude of “correctness matters more than dev experience,” but in this case the correctness was sort of important.
Anyway, I’ll try to make sure someone comes along with a detailed answer about what to do next. Maybe it’ll be me.
On Fri, Sep 17, 2021 at 8:02 PM Matthew Piziak @.***> wrote:
I was running now into the same issue. I would like to fine-tune GPT-J using mesh-transformer-jax on Google Cloud TPU. It does not work as mesh-transformer-jax requires jax 0.2.12 which does no more work on Google Cloud TPU.
Does anybody has a solution?