question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Building and training a RoutingTransformerEncDec from pre-trained RoutingTransformerLMs

See original GitHub issue

I am trying to build and train an encoder-decoder from pretrained routing transformer LMs. The way I approached it was to replace the encoder and decoder in a RoutingTransformerEncDec with the pre-trained RoutingTransformerLMs as follows:

enc_dec.enc=pretrained_lm
enc_dec.dec=AutoregressiveWrapper(pretrained_lm)

and then try to train the enc_dec as normal when I get the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-9-681d3315d6dc> in <module>
    147         grad_accum_steps=1,
    148         temperature=1,
--> 149         model_suffix=''
    150 
    151     )

~/projects/trlabs_routing_transformer/routing_sum/train_and_eval.py in train_routing_single(epoch, model, tokenizer, train_chunk_bucket, val_data_bucket, model_dir, optimizer, lr, max_seq_len, pred_target_len, src_pad_len, tgt_pad_len, max_src_len, max_tgt_len, log_interval, eval_interval, save_interval, train_logger, global_step, grad_accum_steps, temperature, model_suffix)
    469         train_seq_out = padded_target[:, :max_seq_len].to(device)
    470         loss, aux_loss = model(train_seq_in, train_seq_out, return_loss=True)
--> 471         loss.backward()
    472         aux_loss.backward()
    473         train_loss += loss.item()

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    196                 products. Defaults to ``False``.
    197         """
--> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    199 
    200     def register_hook(self, hook):

~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     98     Variable._execution_engine.run_backward(
     99         tensors, grad_tensors, retain_graph, create_graph,
--> 100         allow_unreachable=True)  # allow_unreachable flag
    101 
    102 

RuntimeError: new kmeans has not been supplied

I would appreciate any feedback on what may be the problem or what is the best way to build an enc_dec from pretrained LM checkpoints.

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

0reactions
AliOskooeiTRcommented, May 1, 2021

Just closing this issue as I have figured out why it wasn’t possible to use the same LM for both the encoder and the decoder. The decoder LM must receive context and be causal. This results in the encoder having a different architecture and state dictionary and not exchangeable with the encoder LM.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fine-tune a pretrained model - Hugging Face
When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly...
Read more >
Hugging Face Pre-trained Models: Find the Best One for Your ...
Transformer models are complex to build as they would require fine-tuning of tens of billions of parameters and intense training. The hugging Face...
Read more >
Training with Custom Pretrained Models Using the NVIDIA ...
The TLT is a Python-based AI toolkit for creating highly optimized and accurate AI apps using transfer learning and pretrained models. The TLT ......
Read more >
Transfer learning from pre-trained models | by Pedro Marcelino
Transfer learning is a popular method in computer vision because it allows us to build accurate models in a timesaving way (Rawat &...
Read more >
How to Leverage Pre-Trained Layers in Image Classification
Challenges associated with limited data can be overcome if you leverage pre-trained layers, like Keras in Python, which is published online.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found