Bayesian GPs with pyro (NUTS) - example notebook crashes when jit_compile=True
See original GitHub issueI’m trying to run your fully Bayesian GP example.
The notebook runs OK as-is. As you may expect, sampling is much slower when i increase the size of the training dataset. I’ve tried to enable jit
compilation in the pyro
NUTS sampler:
nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
After this change, the NUTS sampler crashes:
Warmup: 0%| | 0/300 [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-4-56cc00113944> in <module>()
26 nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
27 mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
---> 28 mcmc_run.run(train_x, train_y)
24 frames
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
378 with optional(pyro.validation_enabled(not self.disable_validation),
379 self.disable_validation is not None):
--> 380 for x, chain_id in self.sampler.run(*args, **kwargs):
381 if num_samples[chain_id] == 0:
382 num_samples[chain_id] += 1
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
167 for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
168 i if self.num_chains > 1 else None,
--> 169 *args, **kwargs):
170 yield sample, i # sample, chain_id
171 self.kernel.cleanup()
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
109
110 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 111 kernel.setup(warmup_steps, *args, **kwargs)
112 params = kernel.initial_params
113 # yield structure (key, value.shape) of params
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
304 if self.initial_params:
305 z = {k: v.detach() for k, v in self.initial_params.items()}
--> 306 z_grads, potential_energy = potential_grad(self.potential_fn, z)
307 else:
308 z_grads, potential_energy = {}, self.potential_fn(self.initial_params)
/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
80 return grads, z_nodes[0].new_tensor(float('nan'))
81 else:
---> 82 raise e
83
84 grads = grad(potential_energy, z_nodes)
/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
73 node.requires_grad_(True)
74 try:
---> 75 potential_energy = potential_fn(z)
76 # deal with singular matrices
77 except RuntimeError as e:
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
287 if skip_jit_warnings:
288 _pe_jit = ignore_jit_warnings()(_pe_jit)
--> 289 self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
290
291 result = self._compiled_fn(*vals)
/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
978 var_lookup_fn,
979 strict,
--> 980 _force_outplace)
981
982 # Check the trace against new traces created from user-specified inputs
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _pe_jit(*zi)
283 def _pe_jit(*zi):
284 params = dict(zip(names, zi))
--> 285 return self._potential_fn(params)
286
287 if skip_jit_warnings:
/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn(self, params)
259 cond_model = poutine.condition(self.model, params_constrained)
260 model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
--> 261 **self.model_kwargs)
262 log_joint = self.trace_prob_evaluator.log_prob(model_trace)
263 for name, t in self.transforms.items():
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from None
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError):
167 exc_type, exc_value, traceback = sys.exc_info()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
9 def _context_wrap(context, fn, *args, **kwargs):
10 with context:
---> 11 return fn(*args, **kwargs)
12
13
<ipython-input-4-56cc00113944> in pyro_model(x, y)
19
20 def pyro_model(x, y):
---> 21 model.pyro_sample_from_prior()
22 output = model(x)
23 loss = mll.pyro_factor(output, y)
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in pyro_sample_from_prior(self)
318 parameters of the model that have GPyTorch priors registered to them.
319 """
--> 320 return _pyro_sample_from_prior(module=self, memo=None, prefix="")
321
322 def local_load_samples(self, samples_dict, memo, prefix):
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
427 for mname, module_ in module.named_children():
428 submodule_prefix = prefix + ("." if prefix else "") + mname
--> 429 _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
430
431
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
421 )
422 memo.add(prior)
--> 423 prior = prior.expand(closure().shape)
424 value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
425 setting_closure(value)
/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in closure()
226
227 def closure():
--> 228 return getattr(self, param_or_closure)
229
230 if setting_closure is not None:
/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/gaussian_likelihood.py in noise(self)
83 @property
84 def noise(self) -> Tensor:
---> 85 return self.noise_covar.noise
86
87 @noise.setter
/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/noise_models.py in noise(self)
33 @property
34 def noise(self):
---> 35 return self.raw_noise_constraint.transform(self.raw_noise)
36
37 @noise.setter
/usr/local/lib/python3.6/dist-packages/gpytorch/constraints/constraints.py in transform(self, tensor)
174
175 def transform(self, tensor):
--> 176 transformed_tensor = self._transform(tensor) if self.enforced else tensor
177 return transformed_tensor
178
RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-1.2059
[ torch.FloatTensor{1} ]
Trace Shapes:
Param Sites:
Sample Sites:
I’ve done some googling and found https://github.com/pyro-ppl/pyro/issues/2292 - this seems to indicate that i failed to properly register a prior, perhaps for the noise_covar.noise
of my Gaussian likelihood? Is this true? In your example, I do see a noise prior being registered, namely
likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")
If so, how do I register the missing prior? Or am I looking at this the wrong way? Thanks!
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (3 by maintainers)
Top Results From Across the Web
Fully Bayesian GPs - Sampling Hyperparamters with NUTS
In this notebook, we'll demonstrate how to integrate GPyTorch and NUTS to sample GP hyperparameters and perform GP inference in a fully Bayesian...
Read more >I don't understand why NUTS code is not working. bayesian ...
I just only replace tfp code into pyro. def model(data): alpha = (1. / data.mean()) lambda1 = pyro.sample("lambda1", dist.Exponential( ...
Read more >MCMC — Pyro documentation - Read the Docs
Gets some diagnostics statistics such as effective sample size, ... Refer to the baseball example to see how to do Bayesian inference in...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
+1. It would be great to get to get some more clarity on the issue here. Using jit should greatly speed things up
This issue is basically superseded by #1578, as the bugs are caused by the same problem, and a fix to #1578 will resolve this issue as well. I’ve confirmed that jit_compile works with at least the fix I have so far.