optionally attach user source line info to tracers / jaxpr eqns
See original GitHub issueWe’d like to be able to associate the code of transformed functions, and in particular the XLA HLO generated by jit
, with the user source code that generated it. We can do that by using the inspect
module to get the line info, and then attaching that information to Tracer instances and/or to jaxpr eqns generated (for the specific case of partial_eval.py).
This is another instance of reviving a feature from an older version of JAX. Here’s a snippet we had to get the user line info, though stuffing it into the OpMetaData
instance was something specific to plumbing it through to XLA HLO (whereas now we might want to attach it to JAX’s data structures first). We’d pass skip_files
as a list of JAX source files.
There might be some overhead to inspecting stack frames, so we could hide this behind a flag.
@j-towns want to take a look at this?
import inspect
import os
def _user_source_info(skip_files=()):
"""Retrieves the user source information as an OpMetadata instance.
Args:
skip_files: iterable of filenames to skip over, moving further up the stack.
Returns:
XLA OpMetaData representing the source information for the caller stack
frame, or None if no "user" source code location can be found.
"""
frame = inspect.currentframe()
while frame:
filename = os.path.basename(frame.f_code.co_filename)
if filename not in skip_files:
lineno = frame.f_lineno
return xla_bridge.get_xla_client().OpMetadata(
op_type=None, op_name=None, source_file=filename, source_line=lineno)
else:
frame = frame.f_back
return None
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (5 by maintainers)
This is related to this old Autograd issue. One thing that the suggestion above is lacking is the ability to inspect the frame as you might from a Python debugger (something I find extremely useful).
I’m wondering if it might be possible to reach an equivalent frame from the debugger by re-running the function and raising an exception when the frame corresponding to a problematic jaxpr is reached… Sorry this is a little vague, just putting down thoughts.
Thanks for your comments here, @j-towns.
Unfortunately we explored this a bit and found that retrieving source info was too slow in Python (requiring a filesystem access for each line!). I think we should just close this issue.