question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Model multiple parameters on TPU

See original GitHub issue

🐛 Bug

load_from_checkpoint fails for model with additional required parameters (besides hparams) in model constructor on TPU with more than 1 core.

To Reproduce

Steps to reproduce the behavior:

  1. Add additional required parameter (besides hparams) in model constructor e.g. dataset
  2. Run training on TPU with more than 1 core
  3. See error
Traceback (most recent call last):
  File "train.py", line 83, in <module>
    trainer.fit(model)   
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py", line 721, in fit
    self.load_spawn_weights(model)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 372, in load_spawn_weights
    loaded_model = original_model.__class__.load_from_checkpoint(path)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/lightning.py", line 1512, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/lightning.py", line 1543, in _load_model_state
    model = cls(*model_args)
TypeError: __init__() missing 1 required positional argument: 'dataset'

Code sample

Google Colab Notebook

from pytorch_lightning import Trainer
from argparse import Namespace

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl

class CoolSystem(pl.LightningModule):

    def __init__(self, hparams, dataset):
        super(CoolSystem, self).__init__()
        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, 10)
        self.hparams = hparams

    def forward(self, x):
        # called with self(x)
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        # REQUIRED
        x, y = batch
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        # OPTIONAL
        x, y = batch
        y_hat = self.forward(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_end(self, outputs):
        # OPTIONAL
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0004)

    def prepare_data(self):
        self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        self.mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def train_dataloader(self):
        loader = DataLoader(self.mnist_train, batch_size=32, num_workers=2)
        return loader

    def val_dataloader(self):
        loader = DataLoader(self.mnist_test, batch_size=32)
        return loader

class Dataset():
  pass

model = CoolSystem({ "test_param": 2 }, Dataset())

trainer = Trainer(num_tpu_cores=8, train_percent_check=0.02, val_percent_check=0.1, max_epochs=1)
trainer.fit(model)   

Expected behavior

Model parameters are saved and loaded correctly.

Environment

  • PyTorch Version (e.g., 1.0): 1.6.0a0+3e5d25f
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip
  • Build command you used (if compiling from source): -
  • Python version: 3.6
  • CUDA/cuDNN version: -
  • GPU models and configuration: TPU
  • Any other relevant information: PyTorch Lightning from master branch

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

3reactions
williamFalconcommented, Apr 7, 2020

oh i see. yeah, the dataset argument in your constructor is breaking the load.

For the trainer to autoload you have to only use hparams (put the dataset in the hparams object which can be a dict as well). Or second option is to submit a PR to enable loading other params as well

2reactions
Bordacommented, Jun 11, 2020

this shall be fixed with #2047

Read more comments on GitHub >

github_iconTop Results From Across the Web

Model multiple parameters on TPU · Issue #1400
Steps to reproduce the behavior: Add additional required parameter (besides hparams ) in model constructor e.g. dataset; Run training on TPU ...
Read more >
Run multiple models with multiple Edge TPUs
Co-compiling allows the Edge TPU to store the parameter data for multiple models in RAM together, which means it typically works well only...
Read more >
Use TPUs | TensorFlow Core
Usually, you run your model on multiple TPUs in a data-parallel way. To distribute your model on multiple TPUs (as well as multiple...
Read more >
Cloud TPU performance guide
Model is input bound. TPUs perform calculations very fast. · Padded tensors under-utilize the TPU core. · The total batch size should be...
Read more >
How to perform Keras hyperparameter optimization x3 ...
After reading this post, you will be able to configure your Keras model for hyperparameter optimization experiments x3 faster and yield state-of-the-art on...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found