question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Feature Request] Support NNC/torchinductor usage in gpytorch + GlobalNUTS integration

See original GitHub issue

Issue Description

We have a prototype integration of GlobalNUTS with gpytorch which I would like to accelerate with NNC.

Steps to Reproduce

Execute the cell with nnc_compile=True. It raises the following

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-136-2558f70fe3a8>](https://localhost:8080/#) in <module>
      7     num_samples=10,
      8     num_adaptive_samples=10,
----> 9     num_chains=1,
     10 )

9 frames
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/base_inference.py](https://localhost:8080/#) in infer(self, queries, observations, num_samples, num_chains, num_adaptive_samples, show_progress_bar, initialize_fn, max_init_retries, run_in_parallel, mp_context, verbose)
    210                 chain_results = p.starmap(single_chain_infer, enumerate(seeds))
    211 
--> 212         all_samples, all_log_liklihoods = zip(*chain_results)
    213         # the hash of RVIdentifier can change when it is being sent to another process,
    214         # so we have to rely on the order of the returned list to determine which samples

[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/base_inference.py](https://localhost:8080/#) in _single_chain_infer(self, queries, observations, num_samples, num_adaptive_samples, show_progress_bar, initialize_fn, max_init_retries, chain_id, seed)
    109             desc="Samples collected",
    110             disable=not show_progress_bar,
--> 111             position=chain_id,
    112         ):
    113             for idx, obs in enumerate(observations):

[/usr/local/lib/python3.7/dist-packages/tqdm/notebook.py](https://localhost:8080/#) in __iter__(self)
    257         try:
    258             it = super(tqdm_notebook, self).__iter__()
--> 259             for obj in it:
    260                 # return super(tqdm...) will not catch exception
    261                 yield obj

[/usr/local/lib/python3.7/dist-packages/tqdm/std.py](https://localhost:8080/#) in __iter__(self)
   1193 
   1194         try:
-> 1195             for obj in iterable:
   1196                 yield obj
   1197                 # Update and possibly print the progressbar.

[/usr/lib/python3.7/_collections_abc.py](https://localhost:8080/#) in __next__(self)
    315         When exhausted, raise StopIteration.
    316         """
--> 317         return self.send(None)
    318 
    319     @abstractmethod

[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/sampler.py](https://localhost:8080/#) in send(self, world)
     73         for proposer in proposers:
     74             try:
---> 75                 new_world, accept_log_prob = proposer.propose(world)
     76                 accept_log_prob = accept_log_prob.clamp(max=0.0)
     77                 accepted = torch.rand_like(accept_log_prob).log() < accept_log_prob

[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/proposer/nuts_proposer.py](https://localhost:8080/#) in propose(self, world)
    307             )
    308             if direction == -1:
--> 309                 new_tree = self._build_tree(tree.left, j, tree_args)
    310             else:
    311                 new_tree = self._build_tree(tree.right, j, tree_args)

[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/proposer/nuts_proposer.py](https://localhost:8080/#) in _build_tree(self, root, tree_depth, args)
    157         combine the two."""
    158         if tree_depth == 0:
--> 159             return self._build_tree_base_case(root, args)
    160 
    161         # build the first half of the tree

[/usr/local/lib/python3.7/dist-packages/functorch/_src/aot_autograd.py](https://localhost:8080/#) in returned_function(*args, **kwargs)
    426         # Now flatten the tensor args
    427         if HAS_TREE:
--> 428             flat_tensor_args = tree.flatten((tensor_args, kwargs))
    429         else:
    430             flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))

[/usr/local/lib/python3.7/dist-packages/tree/__init__.py](https://localhost:8080/#) in flatten(structure)
    225     TypeError: If `structure` is or contains a mapping with non-sortable keys.
    226   """
--> 227   return _tree.flatten(structure)
    228 
    229 

TypeError: '<' not supported between instances of 'RVIdentifier' and 'RVIdentifier'

Expected Behavior

A faster run of NNC

System Info

Please provide information about your setup

  • PyTorch Version (run print(torch.__version__) 1.12.1+cu113
  • Python version: 3.9
  • Beanmachine version: 0.2.0

Additional Context

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
feynmanliangcommented, Sep 13, 2022

Yes, this works 😃 image image

and it appears to be significantly faster than Pyro for the same model / results image image

CC @eytan @Balandat @dme65, and I will evaluate and document the comparison in an internal doc as well.

Do we think it’ll be OK to override ordering on RVIdentifier until functorch cuts a new release? If so, we can close this issue.

0reactions
horizon-bluecommented, Sep 13, 2022

Glad to see the fast performance of BM + NNC 😄.

Do we think it’ll be OK to override ordering on RVIdentifier until functorch cuts a new release?

Sure, I’m working on D39486353 that add __lt__ to RVIentifier for now and will create a PR once the tests pass. You might be able to use BM+NNC even without this patch internally, though, since Buck always uses the latest source code. 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found