Initializing attention weights in T5
See original GitHub issue@patrickvonplaten @patil-suraj @craffel Excuse me if this question is repeated but I did not find an answer for it
In these lines
elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
d_model = self.config.d_model
key_value_proj_dim = self.config.d_kv
n_heads = self.config.num_heads
module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
if isinstance(module, LongT5TransientGlobalAttention):
module.global_relative_attention_bias.weight.data.normal_(
mean=0.0, std=factor * ((d_model) ** -0.5)
)
from t5 implementation https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/longt5/modeling_longt5.py#L1291
-
we notice that the factor is multiplied by ((d_model * key_value_proj_dim) ** -0.5) for just the query and the output , and with * (d_model**-0.5) for key and value, why? Is there a detailed explanation of that? and still the initial value of the factor is 1.0?
-
Also today I found this issue https://github.com/huggingface/transformers/issues/16749
According to my understanding to this issue and correct me if I am wrong :
@patrickvonplaten corrects the initialization but still vague for me is the relation between tying word embedding initialization and language model head initialization in this line https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L766 and why this condition in not included in longt5 implementation?
Issue Analytics
- State:
- Created a year ago
- Comments:6 (2 by maintainers)
Top GitHub Comments
It also confuses me.
@patrickvonplaten @patil-suraj @craffel
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.