question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

General questions to the algorithmic understanding

See original GitHub issue

Been trying to get a grasp of the DALLE code recently. However, there are a couple of things, I cant quite wrap my head around and since the paper is not published yet, I was wondering, if we can maybe clarify them here.

So there is the VAE training which basically features the codebook in the bottleneck and is trained a priori.

Next, Dalle receives text and image pairs, embeds them and adds positional encodings individually to both modalities. However, the image data is not embedded like e.g. in ViT but by downsampling it via the Encoder of the VAE (without accumulating gradients), argmax search within the feature dimension across the downsampled image patches and finally indexing into the previously trained codebook.

The resulting representations of both modalities are then concatenated along the token dimension. And while every word of the text input is one token, the height and width of the VAE-encoded image yields the number of image tokens.

The combined embedding is then passed into a single transformer which calculates self-attention not only intra-modal but also across both modalities if I am not mistaken.

A masking of the form

mask = torch.ones_like(text).bool()

results in unmasked attention calculation, right?

A final Mlp maps the transformer output to all potential token possibilities (both text and image).

Then I dont understand the masking

  logits_mask = (
         ((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
         ((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
         ((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
     )

shouldnt there be one more row concerned with the text input and one less row for the image input?

For the following config with 3 text input tokens

vae = DiscreteVAE(
    image_size = 64,
    num_layers = 5,
    num_tokens = 10,
    codebook_dim = 256,
    hidden_dim = 64,
    num_resnet_blocks = 1,
    temperature = 0.9
).cuda()

dalle = DALLE(
    dim = 256,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 4,    # vocab size for text
    text_seq_len = 3,         # text sequence length
    depth = 6,                 # should aim to be 64
    heads = 8,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
).cuda()

text = torch.randint(0, 4, (1, 3)).cuda()
images = torch.randn(1, 3, 64, 64).cuda()
mask = torch.ones_like(text).bool().cuda()

the mask looks like this

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

shouldt it be?

tensor([[[False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [False, False, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False,  True],
         [ True,  True,  True,  True, False, False, False, False, False, False, False, False, False, False, False]]])

The purpose of the masking is so that image tokens dont contribute to the predictions of text and vice versa. The code proceeds by constructing labels from the text integer tokens and the VAE image embedding pixels by using the codebook indices.

But what is it we are actually trying to predict with this classification task here? It is a 2d CrossEntropyLoss where for each token (either text or image) we are trying to predict … exactly what? Some I am missing the intuition here I guess…

And then, why is the label vector neglecting the very first label entry but using the EOS enty?

    **loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])**

Maybe someone can help me (and others) in understanding better whats going on here. Thank you in advance

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:11 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
lucidrainscommented, Jan 21, 2021

thank you once again for your answers and the possibility to discuss matters here.

I believe the codebook is pretrained only in the VAE, and the DALL-E trains its own embeddings for the visual tokens

mh but right now, imho the codebook is also adjusted during DALL-E training…

yea, someone else actually brought that up. I don’t believe so, because if you read the iGPT paper, they clustered the pixel space into 512 values and then simply retrained on those 512 values as unique embeddings, and it still worked. however, I have a branch in this repository named ‘end-to-end’ that contains what you are describing and you are free to try it out

So in this specific setup, because text tokens precede image tokens, the image tokens can attend to all the text tokens (but not the other way around). However, one can imagine a future system where you don’t have such restrictions and just mix all tokens from all modalities together in any order

I cannot follow. The attention of the transformer receives the input which is text and image tokens, concatenated along the token dimension. But since the mask during training has TRUE set everywhere, this is full fletched attention from every token to every other, isnt it? maybe I didnt understand properly.

so the attention is only from future to past because the causal flag is turned on https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/transformer.py#L86 https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L294

Another thing which I dont get is, why there are two BOS tokens prepended. Once in the generate_images function

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L340

and then again here

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L379

because of this

https://github.com/lucidrains/DALLE-pytorch/blob/6a50564387bb518d1e4c2ad9988e0e1cd09225ef/dalle_pytorch/dalle_pytorch.py#L400

doesnt this cause the first image token to be predicted not by the last but by the second to the last text token?

that’s a bug on my part, fixed in the latest commit! 🙏

1reaction
CDitzelcommented, Jan 19, 2021

thank you for your effort and time Phil. I will have a look tomorrow and get back to you.

Again so many thanks!

Read more comments on GitHub >

github_iconTop Results From Across the Web

The insider's guide to algorithm interview questions
3 steps to prepare for you interview; What are algorithmic paradigms? Measuring efficiency; Big O complexity; Sample Problems; Next steps ...
Read more >
19 Essential Algorithm Interview Questions and Answers - Toptal
What are Divide and Conquer algorithms? Describe how they work. Can you give any common examples of the types of problems where this...
Read more >
Top 18 Algorithm Interview Questions and Answers (2022)
1) Explain what is an algorithm in computing? · 2) Explain what is Quick Sort algorithm? · 3) Explain what is time complexity...
Read more >
Top Algorithm Interview Questions (2023) - InterviewBit
The following are some of the benefits of using algorithms in real-world problems. ... What do you understand about the Dynamic Programming (DP)...
Read more >
Top 25 Algorithm Interview Questions (2023) - Javatpoint
Algorithm Interview Questions and Answers · 1) What is an algorithm? · 2) What is the Complexity of Algorithm? · 3) Write an...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found