question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Possibly related to #6959 , but there seems to be additional regressions in xmap unrelated to the multi-host case. One simplified example is provided in the colab below: https://colab.research.google.com/drive/18nU5rz8CF7YDYPSuPqJH38Dc7HkroYOn?usp=sharing

The example works with jax == 0.2.12, but not with the latest version.

Full stack trace:

---------------------------------------------------------------------------

TypeError                                 Traceback (most recent call last)

<ipython-input-2-2f2533018f84> in <module>()
     65                                             axis_resources={'shard': 'mp', 'batch': 'dp'})
     66 
---> 67     run_xmap(params, x)

8 frames

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in fun_mapped(*args)
    577       backend=backend,
    578       spmd_in_axes=None,
--> 579       spmd_out_axes_thunk=None)
    580     if has_output_rank_assertions:
    581       for out, spec in zip(out_flat, out_axes_thunk()):

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in bind(self, fun, *args, **params)
    764   def bind(self, fun, *args, **params):
    765     assert len(params['in_axes']) == len(args)
--> 766     return core.call_bind(self, fun, *args, **params)  # type: ignore
    767 
    768   def process(self, trace, fun, tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1549       params_tuple, out_axes_transforms)
   1550   tracers = map(top_trace.full_raise, args)
-> 1551   outs = primitive.process(top_trace, fun, tracers, params)
   1552   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1553 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in process(self, trace, fun, tracers, params)
    767 
    768   def process(self, trace, fun, tracers, params):
--> 769     return trace.process_xmap(self, fun, tracers, params)
    770 
    771   def post_process(self, trace, out_tracers, params):

/usr/local/lib/python3.7/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    604 
    605   def process_call(self, primitive, f, tracers, params):
--> 606     return primitive.impl(f, *tracers, **params)
    607   process_map = process_call
    608 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in xmap_impl(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *args)
    596       axis_resources, resource_env, backend,
    597       spmd_in_axes, spmd_out_axes_thunk,
--> 598       *in_avals)
    599   distributed_debug_log(("Running xmapped function", name),
    600                         ("python function", fun.f),

/usr/local/lib/python3.7/dist-packages/jax/linear_util.py in memoized_fun(fun, *args)
    260       fun.populate_stores(stores)
    261     else:
--> 262       ans = call(fun, *args)
    263       cache[key] = (ans, fun.stores)
    264 

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in make_xmap_callable(fun, name, in_axes, out_axes_thunk, donated_invars, global_axis_sizes, axis_resources, resource_env, backend, spmd_in_axes, spmd_out_axes_thunk, *in_avals)
    619     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
    620   out_axes = out_axes_thunk()
--> 621   _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
    622   # NOTE: We don't use avals and all params, so only pass in the relevant parts (too lazy...)
    623   _resource_typing_xmap([], dict(axis_resources=axis_resources,

/usr/local/lib/python3.7/dist-packages/jax/experimental/maps.py in _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
   1401     if undeclared_axes:
   1402       undeclared_axes_str = sorted([str(axis) for axis in undeclared_axes])
-> 1403       raise TypeError(f"One of xmap results has an out_axes specification of "
   1404                       f"{axes.user_repr}, but is actually mapped along more axes "
   1405                       f"defined by this xmap call: {', '.join(undeclared_axes_str)}")

TypeError: One of xmap results has an out_axes specification of ['shard', ...], but is actually mapped along more axes defined by this xmap call: batch

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:19 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
shawwncommented, Sep 18, 2021

Hi! Wonderful comment, and you’re not alone. I’m sending this mostly as a reminder to myself not to let you fall through the cracks – it’s 11pm right now, so I won’t be able to get to this until some other time. But if it goes too long without you getting a reply, please ping me or DM me on twitter (https://twitter.com/theshawwn).

The JAX team really is awesome, and I suspect they’ll be along with an answer, but this darn issue is a tricky one. Basically, it was a situation where there was a breaking API change (yuck!) for good reason (oh…) and it happened to break the only popular thing using xmap (GPT-J).

I actually complained a lot about the breaking change, but was dragged kicking and screaming to the conclusion that it was a good thing. I hate when developers take the attitude of “correctness matters more than dev experience,” but in this case the correctness was sort of important.

Anyway, I’ll try to make sure someone comes along with a detailed answer about what to do next. Maybe it’ll be me.

On Fri, Sep 17, 2021 at 8:02 PM Matthew Piziak @.***> wrote:

Hej @apaszke https://github.com/apaszke (and others 👋),

I’m interested in experimenting with text-generation models and to that end I’m trying to run GPT-J-6B inference on a TPU Cloud Research machine. I’m a bit out of my depth, but if you don’t mind here are my observations regarding how this issue impacts the GPT-J-6B Inference Demo https://colab.research.google.com/github/kingoflolz/mesh-transformer-jax/blob/master/colab_demo.ipynb hosted on Colab. Observations

First, I see a reference to a so-called “regression” with xmap in jax version 0.2.13.

jax 0.2.12 is required due to a regression with xmap in 0.2.13

!pip install mesh-transformer-jax/ jax==0.2.12 tensorflow==2.5.0

Second, I see a seeming incompatibility between the 0.2.12 version and the TPU environment:

TpuTransferManager_ReadDynamicShapes not available in this library. Aborted (core dumped)

That issue is also reported in #7334 https://github.com/google/jax/issues/7334. Catch-22

Based on my reading of this issue, xmap has stopped letting one return a piece of a sharded tensor, and the old behavior should not be depended upon. That makes sense and indeed I wouldn’t expect any further mitigation from the jax maintenance team, since xmap warns the user that it is experimental. This is fully an issue of downstream usage, but nevertheless if you’d offer me any wisdom here I’d appreciate it.

The issue:

  • As far as I know GPT-J-6B is the most powerful public autoregressive text-generation weight set.
  • GPT-J-6B only works on model v1 of mesh-transformer-jax.
  • v1 only works with JAX 0.2.12.
  • Dependency set mesh-transformer-jax+jax==0.2.12+tensorflow==2.5.0 aborts (as above, tested in a TPU v3 environment).
  • When used with jax>=0.2.13, it fails as above.

In summary, the model I planned to use is semi-deprecated. This comment https://github.com/kingoflolz/mesh-transformer-jax/issues/67#issuecomment-886395688 says it’s not likely to be fixed by the mesh-transformer-jax team.

Gosh, I don’t know if this is a mesh-transformer-jax issue or not. I feel a bit embarrassed because I feel like there isn’t a right place to post this. I suppose that I’m really looking for big-picture advice, this is where it all began, and you have the big-picture context.

Ultimately, what would you recommend for an experimenter who does not have access to proprietary models?

Options:

  • Find and fix the xmap usage in model v1 of mesh-transformer-jax, restoring GPT-J-6B.
  • Use an alternative model.
  • Attempt to train a new model on v2 of mesh-transformer-jax.
  • Something else?

I appreciate your patience with my pre-JAX-familiarity notes. You have good documentation; I ought to read through it and I will do that shortly. Thanks and best wishes.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/jax/issues/6962#issuecomment-922151574, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAOR4FJXUMUTW3D2UH7RUDUCPQKZANCNFSM46UTLNFQ .

0reactions
Eichhofcommented, Nov 22, 2022

I was running now into the same issue. I would like to fine-tune GPT-J using mesh-transformer-jax on Google Cloud TPU. It does not work as mesh-transformer-jax requires jax 0.2.12 which does no more work on Google Cloud TPU.

Does anybody has a solution?

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.experimental.maps.xmap - JAX documentation
xmap () extends this model by adding support for named axes. In particular, each array used in a function wrapped by xmap() can...
Read more >
XMAP: eXplainable mapping analytical process | SpringerLink
Theoretically, although XMAP can be used for both regression and classification, we only focus on classification tasks in this paper to ...
Read more >
Comparison of xMAP and ELISA assays for detecting CSF ...
Reliable conversions between xMAP and ELISA measurements are possible by linear regressions. We analyzed twelve linear regression models that estimate the ELISA ...
Read more >
xPONENT® Software for xMAP® Instruments
... control of the MAGPIX, Luminex 100/200, and FLEXMAP 3D xMAP instruments. ... Intra-well Normalization Analysis, Real Time Regression and Data Analysis.
Read more >
Significance analysis of xMap cytokine bead arrays - PNAS
The methodology, which we refer to as Statistical Analysis of xMap Cytokine Beads (SAxCyB), is a linear regression model designed to find ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found