a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity
See original GitHub issue(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)
We are talking about FSMT - ported fairseq transformers model.
If I understand correctly their SinusoidalPositionalEmbedding
was designed so that it won’t be part of the model params
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25
most likely so that it won’t be part of the state_dict
, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.
I had to copy their implementation, and not use Bart’s version, since the pretrained weights rely on it, and the positions it produces are different.
So their SinusoidalPositionalEmbedding
’s self.weights
is a normal variable (not a buffer and not a nn.parameter.Parameter
). They create a dummy buffer self._float_tensor
to hold the device
. So when model.to()
is called, self._float_tensor
gets the right device
. During forward
self.weights
gets to(self._float_tensor)
and all is good. So self.weights
is kind of a ghost variable. Now you see me and now you don’t.
This approach works just fine until we get to torchscript
- in particular 2 common tests:
def test_torchscript_output_attentions(self):
def test_torchscript_output_hidden_state(self):
which blow up under USE_CUDA=1
, with:
Comparison exception: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
Everything is on cuda:0
but SinusoidalPositionalEmbedding
’s self.weights
are on cpu
still at this point.
The first time it encounters self.weights
inside forward
, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device before forward
.
Solution 1
So, I solved this problem with the following hack:
class FSMTForConditionalGeneration(PretrainedFSMTModel):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.base_model.to(*args, **kwargs)
return self
class FSMTModel(PretrainedFSMTModel):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.encoder.embed_positions.to(*args, **kwargs)
self.decoder.embed_positions.to(*args, **kwargs)
return self
class SinusoidalPositionalEmbedding(nn.Module):
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.weights = self.weights.to(*args, **kwargs)
return self
It’s absolutely crazy, but it works.
Basically it forwards model.to()
call to SinusoidalPositionalEmbedding
’s self.weights
, via 3 “bridges”.
I thought that each torch module got to()
called but that doesn’t seem to be the case, I think it traverses the model structure instead and doesn’t call to
for each module. Hence the 2 classes are involved to bridge it on.
(and there is also half()
that needs to be dealt with too, since model.half()
won’t get forwarded to this non-parameter variable either.)
Solution 2
The second solution is to make SinusoidalPositionalEmbedding
’s self.weights
a parameter, but then we have to hack save/load to not save/ignore-on-load model.encoder.embed_positions.*
and model.decoder.embed_positions.*
keys.
Solution 3
The third solution is to save the useless weights (useless as they aren’t trained and get calculated deterministically).
Perhaps you can think of other solutions.
Thank you.
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (11 by maintainers)
Thank you for the idea, @sgugger.
Alas, if I understood your suggestion correctly, I already tried it and it doesn’t work. torchscript wants the vars to be on the same device before any
forward
call.Indeed, the buffers have been around since a long time, but the need here is different. We want a non-persistent buffer, a functionality which was added just a few months ago and it doesn’t yet work with torchscript, so it doesn’t help to solve the problem at hand.