Building and training a RoutingTransformerEncDec from pre-trained RoutingTransformerLMs
See original GitHub issueI am trying to build and train an encoder-decoder from pretrained routing transformer LMs. The way I approached it was to replace the encoder and decoder in a RoutingTransformerEncDec with the pre-trained RoutingTransformerLMs as follows:
enc_dec.enc=pretrained_lm
enc_dec.dec=AutoregressiveWrapper(pretrained_lm)
and then try to train the enc_dec as normal when I get the following error:
RuntimeError Traceback (most recent call last)
<ipython-input-9-681d3315d6dc> in <module>
147 grad_accum_steps=1,
148 temperature=1,
--> 149 model_suffix=''
150
151 )
~/projects/trlabs_routing_transformer/routing_sum/train_and_eval.py in train_routing_single(epoch, model, tokenizer, train_chunk_bucket, val_data_bucket, model_dir, optimizer, lr, max_seq_len, pred_target_len, src_pad_len, tgt_pad_len, max_src_len, max_tgt_len, log_interval, eval_interval, save_interval, train_logger, global_step, grad_accum_steps, temperature, model_suffix)
469 train_seq_out = padded_target[:, :max_seq_len].to(device)
470 loss, aux_loss = model(train_seq_in, train_seq_out, return_loss=True)
--> 471 loss.backward()
472 aux_loss.backward()
473 train_loss += loss.item()
~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
196 products. Defaults to ``False``.
197 """
--> 198 torch.autograd.backward(self, gradient, retain_graph, create_graph)
199
200 def register_hook(self, hook):
~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
98 Variable._execution_engine.run_backward(
99 tensors, grad_tensors, retain_graph, create_graph,
--> 100 allow_unreachable=True) # allow_unreachable flag
101
102
RuntimeError: new kmeans has not been supplied
I would appreciate any feedback on what may be the problem or what is the best way to build an enc_dec from pretrained LM checkpoints.
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
Fine-tune a pretrained model - Hugging Face
When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly...
Read more >Hugging Face Pre-trained Models: Find the Best One for Your ...
Transformer models are complex to build as they would require fine-tuning of tens of billions of parameters and intense training. The hugging Face...
Read more >Training with Custom Pretrained Models Using the NVIDIA ...
The TLT is a Python-based AI toolkit for creating highly optimized and accurate AI apps using transfer learning and pretrained models. The TLT ......
Read more >Transfer learning from pre-trained models | by Pedro Marcelino
Transfer learning is a popular method in computer vision because it allows us to build accurate models in a timesaving way (Rawat &...
Read more >How to Leverage Pre-Trained Layers in Image Classification
Challenges associated with limited data can be overcome if you leverage pre-trained layers, like Keras in Python, which is published online.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
@AliOskooeiTR also, as shown https://github.com/lucidrains/routing-transformer/blob/master/routing_transformer/encoder_decoder.py#L59
Just closing this issue as I have figured out why it wasn’t possible to use the same LM for both the encoder and the decoder. The decoder LM must receive context and be causal. This results in the encoder having a different architecture and state dictionary and not exchangeable with the encoder LM.