question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RelaxedBernoulliStraightThrough seems to give continuous samples when used in conjunction AutoHierarchicalNormalMessenger/AutoNormalMessenger

See original GitHub issue

Hi,

I have used the RelaxedBernoulliStraightThrough distribution in my model in a code block like this:

        with obs_plate:
            I_cm = pyro.sample('I_cm',
                               RelaxedBernoulliStraightThrough(probs = p_m,
                                                                     temperature = self.one/1000.
                                                                     ).expand([batch_size, 1, self.n_modules]))

However, I do not see discrete samples in either:

1.) the posterior samples for I_cm 2.) the posterior samples for I_cm_tracking when I add

I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)

3.) in the printed output during training when I add:

print(I_cm)

So I wonder did you test that RelaxedBernoulliStraightThrough indeed gives discrete samples in the foward pass during training?

Thanks!

Alexander

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:15 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
martinjankowiakcommented, Aug 16, 2022

@vitkl i’m not very familiar with the internals of AutoNormalMessenger but something like that (but returning a distribution in the if branch) may be ok too. if it runs and if you trace the guide and get the samples you expect it’s probably ok.

@fritzo would have a better idea.

1reaction
AlexanderAivazidiscommented, Aug 8, 2022

Hi,

after removing the expand I still get the same problems and don’t see discrete samples anywhere during training:

        from pyro.distributions import RelaxedBernoulliStraightThrough
        RelaxedBernoulliStraightThrough.mean = property(lambda self: self.probs)

        p_m = pyro.sample('p_m', dist.Beta(self.activation_probability_alpha,
                                           self.activation_probability_beta
                                          ).expand([1,1,self.n_modules]).to_event(3))
        with obs_plate:
            I_cm = pyro.sample('I_cm',
                               RelaxedBernoulliStraightThrough(probs = p_m,
                                                               temperature = self.one/1000.))
        
        print('I_cm', I_cm)
        
        I_cm_tracking = pyro.deterministic('I_cm_tracking', I_cm)

I have also made this minimum example with the RelaxedBernoulliStraightThrough used in a Gaussian Mixture Model:

https://github.com/AlexanderAivazidis/Minimum-Example/blob/main/RelaxedBernoulliMinimalExample.ipynb

Best wishes,

Alexander

Read more comments on GitHub >

github_iconTop Results From Across the Web

Issues · pyro-ppl/pyro - GitHub
RelaxedBernoulliStraightThrough seems to give continuous samples when used in conjunction AutoHierarchicalNormalMessenger/AutoNormalMessenger bug.
Read more >
Distributions - Pyro Documentation
In a typical use case this parameter will be adapted concurrently with the loc and scale_tril that define the distribution. Example usage: control_var...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found