question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Model Parallelism for Bert Models

See original GitHub issue

Hi,

I’m trying to implement Model parallelism for BERT models by splitting and assigning layers across GPUs. I took DeBERTa as an example for this. For DeBERTa, I’m able to split entire model into ‘embedding’, ‘encoder’, ‘pooler’, ‘classifier’ and ‘dropout’ layers as shown in below pic.

Capture

With this approach, I trained on IMDB classification task by assigning ‘encoder’ to second GPU and others to first ‘GPU’. At the end of the training, second GPU consumed lot of memory when compared to first GPU and this resulted in 20-80 split of the entire model.

So, I tried splitting encoder layers also as shown below but getting this error - “TypeError: forward() takes 1 positional argument but 2 were given”

embed = dberta.deberta.embeddings.to('cuda:0')

f6e = dberta.deberta.encoder.layer[:6].to('cuda:0')

l6e = dberta.deberta.encoder.layer[6:].to('cuda:1')

pooler = dberta.pooler.to('cuda:0')

classifier = dberta.classifier.to('cuda:0')

dropout = dberta.dropout.to('cuda:0')

test = "this is to test deberta"

inp_ids = tok_dberta(test, return_tensors='pt').input_ids
att_mask = tok_dberta(test, return_tensors='pt').attention_mask

emb_out = embed(inp_ids.to('cuda:0'))

first_6_enc_lay_out = f6e(emb_out)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-15-379d948e5ba5> in <module>
----> 1 first_6_enc_lay_out = f6e(emb_out)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

TypeError: forward() takes 1 positional argument but 2 were given

Plz suggest how to proceed further…

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:22 (10 by maintainers)

github_iconTop GitHub Comments

2reactions
stas00commented, Feb 19, 2021

Yay, so glad to hear you found a solution, @saichandrapandraju!

Thank you for updating the notebook too!

If the issue has been fully resolved for you please don’t hesitate to close this Issue.

If some new problem occurs, please open a new dedicated issue. Thank you.

1reaction
saichandrapandrajucommented, Feb 22, 2021

Tested DeepSpeed on multi-GPU as well and it worked !!

By setting NCCL_SOCKET_IFNAME=lo, everything worked as expected.

Thanks a lot @stas00

Read more comments on GitHub >

github_iconTop Results From Across the Web

Model Parallelism - Hugging Face
Naive Model Parallel (MP) is where one spreads groups of model layers across multiple GPUs. The mechanism is relatively simple - switch the...
Read more >
Model Parallelism using Transformers and PyTorch - Medium
This tutorial will help you implement Model Parallelism (splitting the model layers into multiple GPUs) to help train larger models over ...
Read more >
Megatron-LM: Training Multi-Billion Parameter ... - arXiv
We ef- ficiently trained transformer based models up to 8.3 bil- lion parameter on 512 NVIDIA V100 GPUs with 8-way model parallelism and...
Read more >
SageMaker's Model Parallelism Library - AWS Documentation
Use Amazon SageMaker's model parallel library to train large deep learning (DL) models that are difficult to train due to GPU memory limitations....
Read more >
Use Amazon Sagemaker Distributed Model Parallel to Launch ...
bert_example /bert_config.json : This allows for additional configuration of the model and is used by modeling.py . Additional configuration includes dropout ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found