question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[QUESTION][REQUEST] Way to move model to a specific device

See original GitHub issue

In torch you can move model to run on a specific device via my_model.to(device) Is there a way currently to do that with Darts?

It can be especially useful for m1 macs, as it does not use CUDA and you need to specify an mps backend. In pure torch it’s very easy to do, but I couldn’t find a way to do that with Darts.

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:11 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
skeenancommented, Aug 27, 2022

Got this working. Solution to get num_loader_workers working was to wrap the code execution logic into an
if name == ‘main’ guard.

However, even with num_workers > 0 still very slow on GPU. 2seconds on CPU - I killed the still-running process after 30mins on GPU. I can only conclude there is some sort of issue with current state of M1 PyTorch implementation. Not worth bothering with this in the current state of support IMO. Hope this helps someone in future.

Complete code below.

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel
from darts.metrics import mape
from darts.datasets import AirPassengersDataset

def main():
    # Read data:
    series = AirPassengersDataset().load()
    series = series.astype(np.float32)

    # Create training and validation sets:
    train, val = series.split_after(pd.Timestamp("19590101"))

    # Normalize the time series (note: we avoid fitting the transformer on the validation set)
    transformer = Scaler()
    train_transformed = transformer.fit_transform(train)
    val_transformed = transformer.transform(val)
    series_transformed = transformer.transform(series)


    my_model = RNNModel(
        model="RNN",
        hidden_dim=20,
        dropout=0,
        batch_size=40,
        n_epochs=300,
        optimizer_kwargs={"lr": 1e-3},
        model_name="Air_RNN",
        log_tensorboard=True,
        random_state=42,
        training_length=20,
        input_chunk_length=14,
        force_reset=True,
         pl_trainer_kwargs={
          "accelerator": "gpu",
          "devices": [0],
        },


    )

    my_model.fit(train_transformed, val_series=val_transformed, verbose=False, num_loader_workers=4)

if __name__ == '__main__':
    main()

1reaction
skeenancommented, Aug 27, 2022

Using the latest nightly build for PyTorch. Have tried diving deeper, but no luck. If I ever resolve I’ll post.

Wondering if this is related https://github.com/Lightning-AI/lightning/issues/4289

Read more comments on GitHub >

github_iconTop Results From Across the Web

Saving and loading models across devices in PyTorch
Specify a path to save to PATH = "model.pt" # Save ... Be sure to call model.to(torch.device('cuda')) to convert the model's parameter tensors...
Read more >
How to get the device type of a pytorch module conveniently?
You access the device through model.device as for parameters. This solution does not work when you have no parameter inside the model. Share....
Read more >
FAQ for survey creators—ArcGIS Survey123 | Documentation
Can I move the items for my survey to a different location in my ArcGIS organization? Can I transfer my survey files to...
Read more >
Switch from Android to iPhone - Apple
Move to iOS app​​ in a few simple steps. Open to read more about the move to iOS App. Start by downloading the...
Read more >
Shared Responsibility Model - Amazon Web Services (AWS)
The customer assumes responsibility and management of the guest operating system (including updates and security patches), other associated application software ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found