Textual Inversion Broken: it updates entire `embeddings` weights, loss diverges
See original GitHub issueDescribe the bug
The weights of the entire text_encoder
evolve over the course of training, thus breaking the text_encoder
. I’m not sure why yet, but this in turn breaks Inversion.
To demonstrate it,
1.) save a random token id and it’s embedding, outside the main loop:
token_embed_w_copy = text_encoder.get_input_embeddings().weight.data.clone().detach().requires_grad_(False).to(accelerator.device)
# never seen
test_tok_id = tokenizer.convert_tokens_to_ids('alligator')
test_tok = token_embed_w_copy[test_tok_id]
2.) Inside the loop, assert that it’s not changing:
test_tok_actual = text_encoder.get_input_embeddings().weight.data[test_tok_id]
assert(torch.allclose(test_tok, test_tok_actual))
# BREAKS!
The assertion passes until an entire batch completes, at which time the embeddings diverge.
The code currently tries to solve this by zeroing all the non-placeholder_token
gradients to zero, but this (or something else) fails to keep the weights from updating.
I’ve confirmed that this breaks TI by manually copying back the entire set of non-placeholder
weights after every batch, and this fixes TI. But it’s ducttape, really, and I’m hoping someone has a better idea.
EDIT: this does not actually solve it. It solves it a little, it seems, but the loss still random-walks / diverges. I can even 0 out all the gradient each step and it still behaves strangely.
System Info
Debian, Python 3.9.2, revision b2b3b1a8ab83b020ecaf32f45de3ef23644331cf
Issue Analytics
- State:
- Created a year ago
- Comments:12 (9 by maintainers)
Top GitHub Comments
Working on the fix, should ready by end of the week, sorry to get back to this only now!
You are right @JunnYu ! The weight_decay indeed updates the whole
embeddings
. Will send a fix soon.