question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity

See original GitHub issue

(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)

We are talking about FSMT - ported fairseq transformers model.

If I understand correctly their SinusoidalPositionalEmbedding was designed so that it won’t be part of the model params https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25 most likely so that it won’t be part of the state_dict, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.

I had to copy their implementation, and not use Bart’s version, since the pretrained weights rely on it, and the positions it produces are different.

So their SinusoidalPositionalEmbedding’s self.weights is a normal variable (not a buffer and not a nn.parameter.Parameter). They create a dummy buffer self._float_tensor to hold the device. So when model.to() is called, self._float_tensor gets the right device. During forward self.weights gets to(self._float_tensor) and all is good. So self.weights is kind of a ghost variable. Now you see me and now you don’t.

This approach works just fine until we get to torchscript - in particular 2 common tests:

    def test_torchscript_output_attentions(self):
    def test_torchscript_output_hidden_state(self):

which blow up under USE_CUDA=1, with:

Comparison exception:   Expected all tensors to be on the same device, 
but found at least two devices, cuda:0 and cpu!

Everything is on cuda:0 but SinusoidalPositionalEmbedding’s self.weights are on cpu still at this point.

The first time it encounters self.weightsinside forward, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device before forward.

Solution 1

So, I solved this problem with the following hack:

class FSMTForConditionalGeneration(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.base_model.to(*args, **kwargs)
        return self

class FSMTModel(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.encoder.embed_positions.to(*args, **kwargs)
        self.decoder.embed_positions.to(*args, **kwargs)
        return self

class SinusoidalPositionalEmbedding(nn.Module):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.weights = self.weights.to(*args, **kwargs)
        return self

It’s absolutely crazy, but it works.

Basically it forwards model.to() call to SinusoidalPositionalEmbedding’s self.weights, via 3 “bridges”.

I thought that each torch module got to() called but that doesn’t seem to be the case, I think it traverses the model structure instead and doesn’t call to for each module. Hence the 2 classes are involved to bridge it on.

(and there is also half() that needs to be dealt with too, since model.half() won’t get forwarded to this non-parameter variable either.)

Solution 2

The second solution is to make SinusoidalPositionalEmbedding’s self.weights a parameter, but then we have to hack save/load to not save/ignore-on-load model.encoder.embed_positions.* and model.decoder.embed_positions.* keys.

Solution 3

The third solution is to save the useless weights (useless as they aren’t trained and get calculated deterministically).

Perhaps you can think of other solutions.

Thank you.

@sgugger, @patrickvonplaten, @sshleifer, @LysandreJik

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:11 (11 by maintainers)

github_iconTop GitHub Comments

3reactions
stas00commented, Sep 18, 2020

Another hack different than solution 1 is to check the device of the inputs in the forward pass and maybe move the matrices to the right device.

Thank you for the idea, @sgugger.

Alas, if I understood your suggestion correctly, I already tried it and it doesn’t work. torchscript wants the vars to be on the same device before any forward call.

1reaction
stas00commented, Sep 21, 2020

Indeed, the buffers have been around since a long time, but the need here is different. We want a non-persistent buffer, a functionality which was added just a few months ago and it doesn’t yet work with torchscript, so it doesn’t help to solve the problem at hand.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Untitled
1 insert one additional value into its correct position in a list that's already in ascending order. 1 1 de:ignoriere _ ca:ignora _...
Read more >
2019 ACC Program | Thursday July 11, 2019 - PaperPlaza
Paley, Derek A. University of Maryland. Keywords: Autonomous robots, Maritime control, Constrained control. Abstract: This paper addresses the swimming dynamics ...
Read more >
IMACS '91 - DTIC
Abstact. The authors present- a modelling. The different stars, electrically method allowing the study of voltage inve ter.
Read more >
AmusingPythonCodes/vocabulary at master - GitHub
Interesting python codes to tackle simple machine/deep learning tasks - AmusingPythonCodes/vocabulary at master · IsaacChanghau/AmusingPythonCodes.
Read more >
mn 0 01 05_1 1 10 100 10th 11 11_d0003 12 13 14 141a
... H H. Habberstad Haberman Habib Habib1 Hachkowski Hack Hacrat Hadjimichael ... embattlement embattling embay embayment embed embedded embedder embedding ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found