question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Bayesian GPs with pyro (NUTS) - example notebook crashes when jit_compile=True

See original GitHub issue

I’m trying to run your fully Bayesian GP example.

The notebook runs OK as-is. As you may expect, sampling is much slower when i increase the size of the training dataset. I’ve tried to enable jit compilation in the pyro NUTS sampler:

nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)

After this change, the NUTS sampler crashes:

Warmup:   0%|          | 0/300 [00:00, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-56cc00113944> in <module>()
     26 nuts_kernel = NUTS(pyro_model, adapt_step_size=True, jit_compile=True)
     27 mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps, disable_progbar=smoke_test)
---> 28 mcmc_run.run(train_x, train_y)

24 frames
/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    378         with optional(pyro.validation_enabled(not self.disable_validation),
    379                       self.disable_validation is not None):
--> 380             for x, chain_id in self.sampler.run(*args, **kwargs):
    381                 if num_samples[chain_id] == 0:
    382                     num_samples[chain_id] += 1

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in run(self, *args, **kwargs)
    167             for sample in _gen_samples(self.kernel, self.warmup_steps, self.num_samples, hook_w_logging,
    168                                        i if self.num_chains > 1 else None,
--> 169                                        *args, **kwargs):
    170                 yield sample, i  # sample, chain_id
    171             self.kernel.cleanup()

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/api.py in _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs)
    109 
    110 def _gen_samples(kernel, warmup_steps, num_samples, hook, chain_id, *args, **kwargs):
--> 111     kernel.setup(warmup_steps, *args, **kwargs)
    112     params = kernel.initial_params
    113     # yield structure (key, value.shape) of params

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/hmc.py in setup(self, warmup_steps, *args, **kwargs)
    304         if self.initial_params:
    305             z = {k: v.detach() for k, v in self.initial_params.items()}
--> 306             z_grads, potential_energy = potential_grad(self.potential_fn, z)
    307         else:
    308             z_grads, potential_energy = {}, self.potential_fn(self.initial_params)

/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
     80             return grads, z_nodes[0].new_tensor(float('nan'))
     81         else:
---> 82             raise e
     83 
     84     grads = grad(potential_energy, z_nodes)

/usr/local/lib/python3.6/dist-packages/pyro/ops/integrator.py in potential_grad(potential_fn, z)
     73         node.requires_grad_(True)
     74     try:
---> 75         potential_energy = potential_fn(z)
     76     # deal with singular matrices
     77     except RuntimeError as e:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    287             if skip_jit_warnings:
    288                 _pe_jit = ignore_jit_warnings()(_pe_jit)
--> 289             self._compiled_fn = torch.jit.trace(_pe_jit, vals, **jit_options)
    290 
    291             result = self._compiled_fn(*vals)

/usr/local/lib/python3.6/dist-packages/torch/jit/__init__.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)
    978                                                   var_lookup_fn,
    979                                                   strict,
--> 980                                                   _force_outplace)
    981 
    982     # Check the trace against new traces created from user-specified inputs

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _pe_jit(*zi)
    283             def _pe_jit(*zi):
    284                 params = dict(zip(names, zi))
--> 285                 return self._potential_fn(params)
    286 
    287             if skip_jit_warnings:

/usr/local/lib/python3.6/dist-packages/pyro/infer/mcmc/util.py in _potential_fn(self, params)
    259         cond_model = poutine.condition(self.model, params_constrained)
    260         model_trace = poutine.trace(cond_model).get_trace(*self.model_args,
--> 261                                                           **self.model_kwargs)
    262         log_joint = self.trace_prob_evaluator.log_prob(model_trace)
    263         for name, t in self.transforms.items():

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    185         Calls this poutine and returns its trace instead of the function's return value.
    186         """
--> 187         self(*args, **kwargs)
    188         return self.msngr.get_trace()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    169                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    170                 exc = exc.with_traceback(traceback)
--> 171                 raise exc from None
    172             self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
    173         return ret

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    163                                       args=args, kwargs=kwargs)
    164             try:
--> 165                 ret = self.fn(*args, **kwargs)
    166             except (ValueError, RuntimeError):
    167                 exc_type, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
      9 def _context_wrap(context, fn, *args, **kwargs):
     10     with context:
---> 11         return fn(*args, **kwargs)
     12 
     13 

<ipython-input-4-56cc00113944> in pyro_model(x, y)
     19 
     20 def pyro_model(x, y):
---> 21   model.pyro_sample_from_prior()
     22   output = model(x)
     23   loss = mll.pyro_factor(output, y)

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in pyro_sample_from_prior(self)
    318         parameters of the model that have GPyTorch priors registered to them.
    319         """
--> 320         return _pyro_sample_from_prior(module=self, memo=None, prefix="")
    321 
    322     def local_load_samples(self, samples_dict, memo, prefix):

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    427     for mname, module_ in module.named_children():
    428         submodule_prefix = prefix + ("." if prefix else "") + mname
--> 429         _pyro_sample_from_prior(module=module_, memo=memo, prefix=submodule_prefix)
    430 
    431 

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in _pyro_sample_from_prior(module, memo, prefix)
    421                     )
    422                 memo.add(prior)
--> 423                 prior = prior.expand(closure().shape)
    424                 value = pyro.sample(prefix + ("." if prefix else "") + prior_name, prior)
    425                 setting_closure(value)

/usr/local/lib/python3.6/dist-packages/gpytorch/module.py in closure()
    226 
    227             def closure():
--> 228                 return getattr(self, param_or_closure)
    229 
    230             if setting_closure is not None:

/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/gaussian_likelihood.py in noise(self)
     83     @property
     84     def noise(self) -> Tensor:
---> 85         return self.noise_covar.noise
     86 
     87     @noise.setter

/usr/local/lib/python3.6/dist-packages/gpytorch/likelihoods/noise_models.py in noise(self)
     33     @property
     34     def noise(self):
---> 35         return self.raw_noise_constraint.transform(self.raw_noise)
     36 
     37     @noise.setter

/usr/local/lib/python3.6/dist-packages/gpytorch/constraints/constraints.py in transform(self, tensor)
    174 
    175     def transform(self, tensor):
--> 176         transformed_tensor = self._transform(tensor) if self.enforced else tensor
    177         return transformed_tensor
    178 

RuntimeError: Cannot insert a Tensor that requires grad as a constant. Consider making it a parameter or input, or detaching the gradient
Tensor:
-1.2059
[ torch.FloatTensor{1} ]
Trace Shapes:
 Param Sites:
Sample Sites:

I’ve done some googling and found https://github.com/pyro-ppl/pyro/issues/2292 - this seems to indicate that i failed to properly register a prior, perhaps for the noise_covar.noise of my Gaussian likelihood? Is this true? In your example, I do see a noise prior being registered, namely

likelihood.register_prior("noise_prior", UniformPrior(0.05, 0.3), "noise")

If so, how do I register the missing prior? Or am I looking at this the wrong way? Thanks!

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:10 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
sdaultoncommented, Apr 16, 2021

+1. It would be great to get to get some more clarity on the issue here. Using jit should greatly speed things up

2reactions
jacobrgardnercommented, Apr 17, 2021

This issue is basically superseded by #1578, as the bugs are caused by the same problem, and a fix to #1578 will resolve this issue as well. I’ve confirmed that jit_compile works with at least the fix I have so far.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fully Bayesian GPs - Sampling Hyperparamters with NUTS
In this notebook, we'll demonstrate how to integrate GPyTorch and NUTS to sample GP hyperparameters and perform GP inference in a fully Bayesian...
Read more >
I don't understand why NUTS code is not working. bayesian ...
I just only replace tfp code into pyro. def model(data): alpha = (1. / data.mean()) lambda1 = pyro.sample("lambda1", dist.Exponential( ...
Read more >
MCMC — Pyro documentation - Read the Docs
Gets some diagnostics statistics such as effective sample size, ... Refer to the baseball example to see how to do Bayesian inference in...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found