question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Delay log density computation in pyro.factor

See original GitHub issue

In several places (e.g. NeuTra reparam) that we used pyro.factor or dist.Delta, log_density is computed explicitly. This computation is not needed, e.g. when making predictions. It would be nice to add a mechanism to defer this computation, only until some inference algorithms that need the log probability.

Proposed solution

Allow log_factor to be a callable in pyro.factor. In dist.Unit.log_prob, if log_factor is a callable, we will call it to get the required value (similar to the behavior in pyro.param).

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
fritzocommented, Feb 4, 2021

Looking at more real-world models, I see there is often other code surrounding pyro.factor, including plates and poutines:

def model():
    ...
    # This whole section could be omitted during prediction:
    tree_scale = pyro.sample("tree_scale", dist.LogNormal(0, 5))
    with pyro.plate("edges", len(edges), dim=-1):
        u, v = log_rate[edges].unbind(-1)
        pyro.sample(
            "rate_change",
            dist.Laplace(0, tree_scale),
            obs=u - v,
        )
    ...

One advantage of a check like is_masked() is that we could use it around entire sections of model code. In particular plates and poutines would not be able to be embedded inside a lambda passed to pyro.factor.

1reaction
fritzocommented, Jan 21, 2021

Here’s a possible solution that avoids the need for non-tensors in distributions (which would complicate typing):

What if we use poutine.mask(mask=False) to control whether various computations are needed? During prediction we can install a mask(False) handler to disable computation of log density. A brute force way to consult the mask would be to add a primitive is_mask_false() or something. Or instead of a new poutine, we could in NeuTra add tracing logic like this:

- z_unconstrained = pyro.sample("{}_shared_latent".format(name),
-                               self.guide.get_base_dist().mask(False))
+ with poutine.trace() as trace:
+     z_name = f"{name}_shared_latent"
+     z_unconstrained = pyro.sample(z_name, self.guide.get_base_dist().mask(False))
+ z_mask = trace.nodes[z_name]["mask"]
  
  # Differentiably transform.
  x_unconstrained = self.transform(z_unconstrained)
- log_density = self.transform.log_abs_det_jacobian(z_unconstrained, x_unconstrained)
+ if z_mask is False:
+     log_density = z_unconstrained.new_zeros(())
+ else:
+     log_density = self.transform.log_abs_det_jacobian(z_unconstrained, x_unconstrained)
  self.x_unconstrained = list(reversed(list(self.guide._unpack_latent(x_unconstrained))))
Read more comments on GitHub >

github_iconTop Results From Across the Web

Distributions - Pyro Documentation
Evaluates log probability densities for each of a batch of samples. Parameters. x (torch.Tensor) – A single value or a batch of values...
Read more >
Pyro Solver 2.0 dynamics node - SideFX
A scaling factor for time inside this solver. 1 is normal speed, greater than 1 makes the pyro sim appear sped up, less...
Read more >
Analysis of 6.4 million SARS-CoV-2 genomes identifies ...
The spread of the virus into human populations in late 2019 and early 2022 was marked by periods of rapid evolution in fitness...
Read more >
(PDF) High-Resolution Porosity-Permeability Logs Driven by ...
Matrix permeability was directly computed from the 3D FIB-SEM images using ... DE density and ZEFF logs and Core Gamma logs provide early ......
Read more >
Pyro‐Phyllobilins: Elusive Chlorophyll Catabolites Lacking a ...
Retention of configuration at the C10‐position was verified by the basically similar CD spectra of 4 and of the YCC in MeOH. The...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found