question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

generation problem in a toy task

See original GitHub issue

Here is the full script for my toy task (x -> xx like “abc” to “abcabc”)

from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import gzip
import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 100
ENC_SEQ_LEN=16
DEC_SEQ_LEN=40
NUM_TOKENS = 256 + 2
BUCKET_SIZE = 8

# helpers

def top_k(logits, thres = 0.9):
    k = int((1 - thres) * logits.shape[-1])
    val, ind = torch.topk(logits, k)
    probs = torch.full_like(logits, float('-inf'))
    probs.scatter_(1, ind, val)
    return probs


def cycle():
    while True:
        source = torch.randint(2, 258, (BATCH_SIZE, ENC_SEQ_LEN)).long().cuda()

        target = torch.cat((source, source), 1)
        prefix = torch.ones((BATCH_SIZE, 1)).long().cuda()
        target = torch.cat((prefix, target), axis=1)

        x_mask = torch.ones(BATCH_SIZE, ENC_SEQ_LEN).bool().cuda()
        y_mask = torch.ones(BATCH_SIZE, target.shape[1]).bool().cuda()


        yield (source, target, x_mask, y_mask)

# instantiate model

class MySinkhornTransformer(nn.Module):
    def __init__(self, num_tokens, dim, depth, heads, bucket_size, enc_max_seq_len, dec_max_seq_len):
        super().__init__()
        
        self.pad_token = 0
        self.sos_token = 1

        self.enc = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, bucket_size=bucket_size, max_seq_len=enc_max_seq_len,
                                         reversible=True, return_embeddings=True)
        self.dec = SinkhornTransformerLM(num_tokens=num_tokens, dim=dim, depth=depth, heads=heads, causal=True, bucket_size=bucket_size, max_seq_len=dec_max_seq_len, 
                                         receives_context=True, context_bucket_size=bucket_size, reversible=True)
        self.dec = AutoregressiveWrapper(self.dec, pad_value=num_tokens-2)
    
    @torch.no_grad()
    def generate(self, x, x_mask):
        context = self.enc(x, input_mask=x_mask)
        start_tokens = (torch.ones((x.shape[0],1)) * self.sos_token).long().cuda()

        return self.dec.generate(start_tokens, 32, context=context, context_mask=x_mask)

    def forward(self, x, y, x_mask, y_mask, return_loss):
        context = self.enc(x, input_mask=x_mask)
        return self.dec(y, context=context, input_mask=y_mask, context_mask=x_mask, return_loss=True)


model = MySinkhornTransformer(num_tokens=NUM_TOKENS, dim=512, depth=1, heads=1, bucket_size=BUCKET_SIZE, enc_max_seq_len=ENC_SEQ_LEN, dec_max_seq_len=DEC_SEQ_LEN)
model.cuda()
# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        source, target, x_mask, y_mask = next(cycle())
        loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
        loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            source, target, x_mask, y_mask = next(cycle())
            loss = model(x=source, y=target, x_mask=x_mask, y_mask=y_mask, return_loss=True)
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        
        source, target, x_mask, y_mask = next(cycle())
        
        sample = model.generate(x=source, x_mask=x_mask)
        print("input:  ", source[0])
        print("model output:  ", sample[0])

After a few steps the loss becomes practically zero. I checked the logits during the training and they seem to be OK. but during generation phase, the model outputs this pattern: “x,x,x,x,x,y,y,y,y,y” like “aaaabbbb” instead of “abcdabcd”. I was wondering what might be the underlying issue. Do you got any idea?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:18 (13 by maintainers)

github_iconTop GitHub Comments

1reaction
lucidrainscommented, May 5, 2020

thanks! it’s been a learning experience

1reaction
py4commented, May 5, 2020

You have very good repos (this and the reformer one) but I strongly suggest running toy tasks when implementing a paper. They catch bugs very well specially for seq2seq. The rule of thumb is that a seq2seq should be able to learn “x -> x” or “x -> xx” perfectly.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Constraining effects of examples in a creative generation task
" In the toy-generation task, the subjects were told, •'Imagine that you are employed by a toy company that is in need of...
Read more >
Generation AI: What happens when your child's friend is an AI ...
The problem, though, with governance of smart toys is that the AI is learning and changing with each interaction with the child. This...
Read more >
Children's only profession: Playing with toys - PMC - NCBI
In his research on play, Piaget defined the need to create order within people as a balance impulse. Humans have a biological tendency...
Read more >
Build a Toy Workshop - Activity - TeachEngineering
NGSS: Next Generation Science Standards - Science ; Generate and compare multiple solutions to a problem based on how well they meet the...
Read more >
Revisiting Self-Training for Neural Sequence Generation
... studied on classification problems, in complex sequence generation tasks ... ablation tests, and experiments on the toy task can only bring indirect ......
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found