question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

opt.train_sample_n is 16?

See original GitHub issue

Hi @ruotianluo. May i ask why you are using train_sample_n = 16 for gen_result and n=1 for greedy in the code below? In transformer_sc.yml, you dont define train_sample_n so it should follow the value in opts which is 16. Shouldn’t this be >1 only when using new self-critical when we want to generate multiple samples?

          self.model.eval()
            with torch.no_grad():
                greedy_res, _ = self.model(fc_feats, att_feats, att_masks,
                    mode='sample',
                    opt={'sample_method': opt.sc_sample_method,
                         'beam_size': opt.sc_beam_size})
            self.model.train()
            gen_result, sample_logprobs = self.model(fc_feats, att_feats, att_masks,
                    opt={'sample_method':opt.train_sample_method,
                        'beam_size':opt.train_beam_size,
                        'sample_n': opt.train_sample_n},
                    mode='sample')

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
ruotianluocommented, Apr 17, 2020

No. It’s because some other loss may need the full probability.

0reactions
homelifescommented, Apr 17, 2020

@ruotianluo alright. Thanks a lot. I have one more question. I realized that in gathering the logprobs:

           it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature)

            # stop when all finished
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)
            seq[:,t] = it
            seqLogprobs[:,t] = logprobs

You are saving the logprobs and not the sampled ones sampleLogprobs. And then in RewardCriterion you gather them according to it:
input = input.gather(2, seq.unsqueeze(2)).squeeze(2)

May i also know if that contributes to efficient computation performance? Or is it just another way?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Optional Practical Training (OPT) for F-1 Students | USCIS
Optional Practical Training (OPT) is temporary employment that is directly related to an F-1 student's major area of study.
Read more >
Predict Default | Kaggle
Default_Binary = df.loan_status.isin([ 'Default', 'Charged Off', 'Late (31-120 days)', 'Late (16-30 days)', 'Does not meet the credit policy.
Read more >
add play-scala serve example for recommendation model ...
val trainSample = train.map(x => x.sample). println("Sample is created, ... val prediction = localPredictor.predict(trainSample) ... }).opt.get.root.
Read more >
Layout Inference and Table Detection in Spreadsheet ...
Elvis Koci, Maik Thiele, Oscar Romero, and Wolfgang Lehner. A machine learning approach for layout inference in spreadsheets.
Read more >
Membership Inference Attacks via Adversarial Examples
now compete with humans regarding vision [28, 30, 38, 61], object recognition [16] or medical image segmentation [39]. Unsupervised and semi-supervised ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found