question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Ray Train / Ray Tune fails with Matplotlib

See original GitHub issue

This simple script finishes training, but then the actor coordinating the training dies at the end, causing the fault tolerance engine in Ray to re-run, creating an endless loop.

It may be caused by the inclusion of matplotlib, but that doesn’t matter – data scientists will often have such code, and Ray Train shouldn’t cause it to explode.

System info

% which python
/Users/will/opt/anaconda3/envs/torch_to_train/bin/python
% which ipython
/Users/will/opt/anaconda3/envs/torch_to_train/bin/ipython
% python --version
Python 3.8.12
% ipython
In [1]: import ray
In [2]: ray.__version__
Out[2]: '1.9.2'

Code:

# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
from ray import train
import ray.train.torch
from ray.train import Trainer
from ray.train.torch import TorchConfig
 
 
########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
 
 
########################################################################
# Let us show some of the training images, for fun.
 
 
def imshow(img):
   img = img / 2 + 0.5     # unnormalize
   npimg = img.numpy()
   plt.imshow(np.transpose(npimg, (1, 2, 0)))
   plt.show()
 
 
 
class Net(nn.Module):
   def __init__(self):
       super().__init__()
       self.conv1 = nn.Conv2d(3, 6, 5)
       self.pool = nn.MaxPool2d(2, 2)
       self.conv2 = nn.Conv2d(6, 16, 5)
       self.fc1 = nn.Linear(16 * 5 * 5, 120)
       self.fc2 = nn.Linear(120, 84)
       self.fc3 = nn.Linear(84, 10)
 
   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = torch.flatten(x, 1) # flatten all dimensions except batch
       x = F.relu(self.fc1(x))
       x = F.relu(self.fc2(x))
       x = self.fc3(x)
       return x
 
 
def my_train_func(config):
   # setup data and transofmrers
   transform = transforms.Compose(
       [transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
   )
 
   batch_size = 4
 
   trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                           download=True, transform=transform)
   trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                           shuffle=True, num_workers=2)
 
   testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
   testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                           shuffle=False, num_workers=2)
 
   classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
  
   # setup model
   net = Net()
   net = train.torch.prepare_model(net)
 
   # setup loss/optimizer
   criterion = nn.CrossEntropyLoss()
   optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
 
       # get some random training images
   dataiter = iter(trainloader)
   images, labels = dataiter.next()
   # # show images
   # imshow(torchvision.utils.make_grid(images))
   # # print labels
   # print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
 
   for epoch in range(2):  # loop over the dataset multiple times
 
       running_loss = 0.0
       for i, data in enumerate(trainloader, 0):
           # get the inputs; data is a list of [inputs, labels]
           inputs, labels = data
 
           # zero the parameter gradients
           optimizer.zero_grad()
 
           # forward + backward + optimize
           outputs = net(inputs)
           loss = criterion(outputs, labels)
           loss.backward()
           optimizer.step()
 
           # print statistics
           running_loss += loss.item()
           if i % 2000 == 1999:    # print every 2000 mini-batches
               print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
               running_loss = 0.0
 
   print('Finished Training')
 
   ########################################################################
   # Let's quickly save our trained model:
 
   PATH = './cifar_net.pth'
   torch.save(net.state_dict(), PATH)
 
   ########################################################################
   # See `here <https://pytorch.org/docs/stable/notes/serialization.html>`_
   # for more details on saving PyTorch models.
   #
   # 5. Test the network on the test data
   # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
   #
   # We have trained the network for 2 passes over the training dataset.
   # But we need to check if the network has learnt anything at all.
   #
   # We will check this by predicting the class label that the neural network
   # outputs, and checking it against the ground-truth. If the prediction is
   # correct, we add the sample to the list of correct predictions.
   #
   # Okay, first step. Let us display an image from the test set to get familiar.
 
   dataiter = iter(testloader)
   images, labels = dataiter.next()
 
   # print images
   imshow(torchvision.utils.make_grid(images))
   print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
 
   ########################################################################
   # Next, let's load back in our saved model (note: saving and re-loading the model
   # wasn't necessary here, we only did it to illustrate how to do so):
 
   net = Net()
   net.load_state_dict(torch.load(PATH))
 
   ########################################################################
   # Okay, now let us see what the neural network thinks these examples above are:
 
   outputs = net(images)
 
   ########################################################################
   # The outputs are energies for the 10 classes.
   # The higher the energy for a class, the more the network
   # thinks that the image is of the particular class.
   # So, let's get the index of the highest energy:
   _, predicted = torch.max(outputs, 1)
 
   print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                               for j in range(4)))
 
   ########################################################################
   # The results seem pretty good.
   #
   # Let us look at how the network performs on the whole dataset.
 
   correct = 0
   total = 0
   # since we're not training, we don't need to calculate the gradients for our outputs
   with torch.no_grad():
       for data in testloader:
           images, labels = data
           # calculate outputs by running images through the network
           outputs = net(images)
           # the class with the highest energy is what we choose as prediction
           _, predicted = torch.max(outputs.data, 1)
           total += labels.size(0)
           correct += (predicted == labels).sum().item()
 
   print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
 
   ########################################################################
   # That looks way better than chance, which is 10% accuracy (randomly picking
   # a class out of 10 classes).
   # Seems like the network learnt something.
   #
   # Hmmm, what are the classes that performed well, and the classes that did
   # not perform well:
 
   # prepare to count predictions for each class
   correct_pred = {classname: 0 for classname in classes}
   total_pred = {classname: 0 for classname in classes}
 
   # again no gradients needed
   with torch.no_grad():
       for data in testloader:
           images, labels = data
           outputs = net(images)
           _, predictions = torch.max(outputs, 1)
           # collect the correct predictions for each class
           for label, prediction in zip(labels, predictions):
               if label == prediction:
                   correct_pred[classes[label]] += 1
               total_pred[classes[label]] += 1
 
 
   # print accuracy for each class
   for classname, correct_count in correct_pred.items():
       accuracy = 100 * float(correct_count) / total_pred[classname]
       print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
 
 
 
if __name__ == '__main__':
   config = {}
   trainer = Trainer(backend="torch", num_workers=4)
  
   trainer.start() # set up resources
   trainer.run(my_train_func, config=config)
   trainer.shutdown()

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (7 by maintainers)

github_iconTop GitHub Comments

1reaction
matthewdengcommented, Feb 11, 2022

I took a quick look, here’s a simple repro:

import numpy as np
import matplotlib.pyplot as plt
from ray.train import Trainer

def plot_func():
    arr = np.array([[1]])
    plt.imshow(arr)
    plt.show()

trainer = Trainer(backend="torch", num_workers=1)
trainer.start()
trainer.run(plot_func)

The root cause indeed seems to be because Ray Train executes the training function in a separate thread. This is because the main thread for each process is currently used to allow the Trainer (more specifically the WorkerGroup) to call ray.get to get the workers’ reported/checkpointed data.

Fixing this likely requires us to rearchitect the backend pretty significantly. While I do agree that this isn’t a great user experience, I’m not sure if we want to allocate 20+ hours to fix it at this moment.

As a short term solution, would it be reasonable to add an FAQ to the documentation and explain how to work around this - specifically to just run the plot outside of the Trainer? The user experience isn’t terrible as everything can still run in a single Python script (and by doing this you don’t end up creating num_worker plots).

0reactions
matthewdengcommented, Mar 2, 2022

@worldveil I created a PR to add this to the docs. Please take a look!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Training (tune.Trainable, session.report) — Ray 2.2.0
This API is the canonical way to report metrics from Tune and Train, and replaces the legacy tune.report , with tune.checkpoint_dir , train.report...
Read more >
Execution (Tuner, tune.Experiment) — Ray 2.2.0
Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune. Parameters. trainable – The trainable to be tuned. param_space –...
Read more >
Ray Train Examples — Ray 2.2.0 - the Ray documentation
Ray Train Examples#. Below are examples for using Ray Train with a variety of models, frameworks, and use cases. General Examples#. PyTorch#.
Read more >
Source code for ray.tune.execution.ray_trial_executor
ERROR = 7 # This is to signal to TrialRunner that there is an error. ... To resolve this issue, add the "...
Read more >
Ray Tune FAQ — Ray 2.2.0 - the Ray documentation
Is it a small or large problem (how long does it take to train? ... run pip install -r ray/python/ray/tune/requirements-dev.txt to install all...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found