General questions to the algorithmic understanding
See original GitHub issueBeen trying to get a grasp of the DALLE code recently. However, there are a couple of things, I cant quite wrap my head around and since the paper is not published yet, I was wondering, if we can maybe clarify them here.
So there is the VAE training which basically features the codebook in the bottleneck and is trained a priori.
Next, Dalle receives text and image pairs, embeds them and adds positional encodings individually to both modalities. However, the image data is not embedded like e.g. in ViT but by downsampling it via the Encoder of the VAE (without accumulating gradients), argmax search within the feature dimension across the downsampled image patches and finally indexing into the previously trained codebook.
The resulting representations of both modalities are then concatenated along the token dimension. And while every word of the text input is one token, the height and width of the VAE-encoded image yields the number of image tokens.
The combined embedding is then passed into a single transformer which calculates self-attention not only intra-modal but also across both modalities if I am not mistaken.
A masking of the form
mask = torch.ones_like(text).bool()
results in unmasked attention calculation, right?
A final Mlp maps the transformer output to all potential token possibilities (both text and image).
Then I dont understand the masking
logits_mask = (
((seq_range >= (text_seq_len - 1)) & (logits_range < num_text_tokens)) |
((seq_range < (text_seq_len - 1)) & (logits_range >= num_text_tokens)) |
((seq_range != (seq_len - 1)) & (logits_range >= (total_tokens - 1)))
)
shouldnt there be one more row concerned with the text input and one less row for the image input?
For the following config with 3 text input tokens
vae = DiscreteVAE(
image_size = 64,
num_layers = 5,
num_tokens = 10,
codebook_dim = 256,
hidden_dim = 64,
num_resnet_blocks = 1,
temperature = 0.9
).cuda()
dalle = DALLE(
dim = 256,
vae = vae, # automatically infer (1) image sequence length and (2) number of image tokens
num_text_tokens = 4, # vocab size for text
text_seq_len = 3, # text sequence length
depth = 6, # should aim to be 64
heads = 8, # attention heads
dim_head = 64, # attention head dimension
attn_dropout = 0.1, # attention dropout
ff_dropout = 0.1 # feedforward dropout
).cuda()
text = torch.randint(0, 4, (1, 3)).cuda()
images = torch.randn(1, 3, 64, 64).cuda()
mask = torch.ones_like(text).bool().cuda()
the mask looks like this
tensor([[[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True],
[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, False]]])
shouldt it be?
tensor([[[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True],
[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True],
[False, False, False, False, True, True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, True],
[ True, True, True, True, False, False, False, False, False, False, False, False, False, False, False]]])
The purpose of the masking is so that image tokens dont contribute to the predictions of text and vice versa. The code proceeds by constructing labels from the text integer tokens and the VAE image embedding pixels by using the codebook indices.
But what is it we are actually trying to predict with this classification task here? It is a 2d CrossEntropyLoss where for each token (either text or image) we are trying to predict … exactly what? Some I am missing the intuition here I guess…
And then, why is the label vector neglecting the very first label entry but using the EOS enty?
**loss = F.cross_entropy(rearrange(logits, 'b n c -> b c n'), labels[:, 1:])**
Maybe someone can help me (and others) in understanding better whats going on here. Thank you in advance
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (7 by maintainers)
Top GitHub Comments
yea, someone else actually brought that up. I don’t believe so, because if you read the iGPT paper, they clustered the pixel space into 512 values and then simply retrained on those 512 values as unique embeddings, and it still worked. however, I have a branch in this repository named ‘end-to-end’ that contains what you are describing and you are free to try it out
so the attention is only from future to past because the causal flag is turned on https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/transformer.py#L86 https://github.com/lucidrains/DALLE-pytorch/blob/main/dalle_pytorch/dalle_pytorch.py#L294
that’s a bug on my part, fixed in the latest commit! 🙏
thank you for your effort and time Phil. I will have a look tomorrow and get back to you.
Again so many thanks!