[Feature Request] Support NNC/torchinductor usage in gpytorch + GlobalNUTS integration
See original GitHub issueIssue Description
We have a prototype integration of GlobalNUTS with gpytorch which I would like to accelerate with NNC.
Steps to Reproduce
Execute the cell with nnc_compile=True
. It raises the following
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
[<ipython-input-136-2558f70fe3a8>](https://localhost:8080/#) in <module>
7 num_samples=10,
8 num_adaptive_samples=10,
----> 9 num_chains=1,
10 )
9 frames
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/base_inference.py](https://localhost:8080/#) in infer(self, queries, observations, num_samples, num_chains, num_adaptive_samples, show_progress_bar, initialize_fn, max_init_retries, run_in_parallel, mp_context, verbose)
210 chain_results = p.starmap(single_chain_infer, enumerate(seeds))
211
--> 212 all_samples, all_log_liklihoods = zip(*chain_results)
213 # the hash of RVIdentifier can change when it is being sent to another process,
214 # so we have to rely on the order of the returned list to determine which samples
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/base_inference.py](https://localhost:8080/#) in _single_chain_infer(self, queries, observations, num_samples, num_adaptive_samples, show_progress_bar, initialize_fn, max_init_retries, chain_id, seed)
109 desc="Samples collected",
110 disable=not show_progress_bar,
--> 111 position=chain_id,
112 ):
113 for idx, obs in enumerate(observations):
[/usr/local/lib/python3.7/dist-packages/tqdm/notebook.py](https://localhost:8080/#) in __iter__(self)
257 try:
258 it = super(tqdm_notebook, self).__iter__()
--> 259 for obj in it:
260 # return super(tqdm...) will not catch exception
261 yield obj
[/usr/local/lib/python3.7/dist-packages/tqdm/std.py](https://localhost:8080/#) in __iter__(self)
1193
1194 try:
-> 1195 for obj in iterable:
1196 yield obj
1197 # Update and possibly print the progressbar.
[/usr/lib/python3.7/_collections_abc.py](https://localhost:8080/#) in __next__(self)
315 When exhausted, raise StopIteration.
316 """
--> 317 return self.send(None)
318
319 @abstractmethod
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/sampler.py](https://localhost:8080/#) in send(self, world)
73 for proposer in proposers:
74 try:
---> 75 new_world, accept_log_prob = proposer.propose(world)
76 accept_log_prob = accept_log_prob.clamp(max=0.0)
77 accepted = torch.rand_like(accept_log_prob).log() < accept_log_prob
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/proposer/nuts_proposer.py](https://localhost:8080/#) in propose(self, world)
307 )
308 if direction == -1:
--> 309 new_tree = self._build_tree(tree.left, j, tree_args)
310 else:
311 new_tree = self._build_tree(tree.right, j, tree_args)
[/usr/local/lib/python3.7/dist-packages/beanmachine/ppl/inference/proposer/nuts_proposer.py](https://localhost:8080/#) in _build_tree(self, root, tree_depth, args)
157 combine the two."""
158 if tree_depth == 0:
--> 159 return self._build_tree_base_case(root, args)
160
161 # build the first half of the tree
[/usr/local/lib/python3.7/dist-packages/functorch/_src/aot_autograd.py](https://localhost:8080/#) in returned_function(*args, **kwargs)
426 # Now flatten the tensor args
427 if HAS_TREE:
--> 428 flat_tensor_args = tree.flatten((tensor_args, kwargs))
429 else:
430 flat_tensor_args, _ = pytree.tree_flatten((tensor_args, kwargs))
[/usr/local/lib/python3.7/dist-packages/tree/__init__.py](https://localhost:8080/#) in flatten(structure)
225 TypeError: If `structure` is or contains a mapping with non-sortable keys.
226 """
--> 227 return _tree.flatten(structure)
228
229
TypeError: '<' not supported between instances of 'RVIdentifier' and 'RVIdentifier'
Expected Behavior
A faster run of NNC
System Info
Please provide information about your setup
- PyTorch Version (run
print(torch.__version__)
1.12.1+cu113 - Python version: 3.9
- Beanmachine version: 0.2.0
Additional Context
Issue Analytics
- State:
- Created a year ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
No results found
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yes, this works 😃
and it appears to be significantly faster than Pyro for the same model / results
CC @eytan @Balandat @dme65, and I will evaluate and document the comparison in an internal doc as well.
Do we think it’ll be OK to override ordering on RVIdentifier until functorch cuts a new release? If so, we can close this issue.
Glad to see the fast performance of BM + NNC 😄.
Sure, I’m working on D39486353 that add
__lt__
toRVIentifier
for now and will create a PR once the tests pass. You might be able to use BM+NNC even without this patch internally, though, since Buck always uses the latest source code. 😃