question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add a sampling_transform callback to generation for arbitrary probability-warps

See original GitHub issue

🚀 Feature request

I’d like to add sampling_transform callback argument to all generate functions in modeling_utils that allows arbitrary sampling from the probability distribution during sequence generation. The signature of this function would be (input_ids, next_probs, next_token) -> next_token.

Motivation

The modeling_utils’s generation function is getting pretty hairy – including parameters for top p, top k, bad tokens, temperature, repetition penalty, etc… Every new way of sampling means that we have to add more parameters and it further complicate the function. I believe the right way of solving this is to provide a function-argument that allows users to express arbitrary rules for the next sample. In the long run, we could replace all the other parameters with a set of pre-baked sampling_transform functions that can be composed at will.

This method scales better to strange warps – for example, in a project I’m working on (https://github.com/turtlesoupy/this-word-does-not-exist) I need to early-terminate sequences if they generate from a large set of bad tokens and need to continue generating if an EOS token is sampled too early. An example is here https://github.com/turtlesoupy/this-word-does-not-exist/blob/260e33a8f420b9be8b1e7260cb03c74d6231686e/title_maker_pro/datasets.py#L386

Your contribution

I’ll attach a sample PR that makes this work for pytorch and non-beam samples. If people like the idea, it should be easy to refine into a full PR that generalize to beam search and tensorflow.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:2
  • Comments:7 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
turtlesoupycommented, May 12, 2020

@yjernite good call; the interface would have to be modified slightly to (input_ids, next_probs, next_token) -> (next_probs). You might be able to fold the sampling procedure into the transforms. I know in my case I want to operate after a sample has been chosen

Sampler.Compose([
   one_hot_draw(),
   my_crazy_custom_transform(),
])
1reaction
patrickvonplatencommented, May 12, 2020

Very interesting idea! I think we eventually have to make the generation function more general anyways.

Maybe it’s time to move this whole code:

            # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
            if repetition_penalty != 1.0:
                self.enforce_repetition_penalty_(next_token_logits, batch_size, 1, input_ids, repetition_penalty)
            if no_repeat_ngram_size > 0:
                # calculate a list of banned tokens to prevent repetitively generating the same ngrams
                # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
                banned_tokens = calc_banned_ngram_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
                for batch_idx in range(batch_size):
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
            if bad_words_ids is not None:
                # calculate a list of banned tokens according to bad words
                banned_tokens = calc_banned_bad_words_ids(input_ids, bad_words_ids)
                for batch_idx in range(batch_size):
                    next_token_logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
            # set eos token prob to zero if min_length is not reached
            if eos_token_id is not None and cur_len < min_length:
                next_token_logits[:, eos_token_id] = -float("inf")
            if do_sample:
                # Temperature (higher temperature => more likely to sample low probability tokens)
                if temperature != 1.0:
                    next_token_logits = next_token_logits / temperature
                # Top-p/top-k filtering
                next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
                # Sample
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).squeeze(1)

                if sampling_transform:
                    next_token = sampling_transform(input_ids, probs, next_token)
            else:
                # Greedy decoding
                next_token = torch.argmax(next_token_logits, dim=-1)

even to a Sampler class with a sampler.sample(input_ids, next_token_logits) which can also include a generic function as proposed.

What are your thoughts on this @yjernite @thomwolf @sshleifer @LysandreJik ?

Read more comments on GitHub >

github_iconTop Results From Across the Web

Turn functions with a callback into Python generators?
@brice the main problem I see of creating an infinite generator from an arbitrary function is how to signal it when you're done...
Read more >
Inference — PyMC3 3.11.5 documentation
A function which gets called for every sample from the trace of a chain. The function is called with the trace and the...
Read more >
Generating Data from Arbitrary Distribution - Cross Validated
In inverse transform sampling, we sample uniformly from this image, i.e., U[0,1]. These are the dots on the y axis. We then go...
Read more >
How to use Keras fit and fit_generator (a hands-on tutorial)
In this tutorial you will learn how the Keras .fit and .fit_generator functions work, including the differences between them.
Read more >
UltraNest 3.5.7 documentation
ultranest.dychmc.step_or_reflect(theta, v, epsilon, transform, loglike, ... In nested sampling, we need to sample the prior subject to the likelihood ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found