Large differences between T5 weight initialization in TF and torch
See original GitHub issuetransformers
version: 4.18.0, master branch
Who can help
I found some significant differences in weight init between the PT and TF implementations of T5.
The embeddings (model.shared):
-
In PT, according to
T5PreTrainedModel._init_weights
, they are initialized with random normal with std=1.0:module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
-
In TF (TFT5Model), the embeddings are initialized as such:
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")
Since initializer_range is not being provided, it is using the default, which ishidden_size**-0.5
(see TFSharedEmbeddings).
This means that in the base model (d=768), the weights in PT are being initialized with stdev=1.0, and in TF they are being initialized with stdev=0.036.
The LM head (model.lm_head):
-
In PT, the initializer is not specified, meaning it is being initialized with a uniform distribution in [-sqrt(1/d_model), sqrt(1/d_model)] (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html). The weights don’t seem to be initialized in _init_weights either.
lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
-
In TF, the initializer is explicitly provided (TFT5ForConditionalGeneration):
lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)
So, in the base model, the weights in PT are initialized with a uniform distribution of [-0.036, 0.036], and in TF they are initialized with a random normal with stdev=1.0.
I’m not entirely sure about the actual implications of this in model training. But at least the lm_head weights will have a huge impact in loss values initially.
Based on other transformer models I’ve seen, the “correct” answer seems to be that both weights should be initialised with stdev=1.0. But none of the implementations actually does this.
Issue Analytics
- State:
- Created a year ago
- Comments:10 (6 by maintainers)
Yep, MTF initializes embeddings as a standard Gaussian. https://github.com/tensorflow/mesh/blob/a32810e32709e0eaad3b241475d3be0957409adc/mesh_tensorflow/layers.py#L2096
Please take over the issue @patrickvonplaten . This got pretty muddy and I’m not sure what is the right approach here.