Ray Train / Ray Tune fails with Matplotlib
See original GitHub issueThis simple script finishes training, but then the actor coordinating the training dies at the end, causing the fault tolerance engine in Ray to re-run, creating an endless loop.
It may be caused by the inclusion of matplotlib, but that doesn’t matter – data scientists will often have such code, and Ray Train shouldn’t cause it to explode.
System info
% which python
/Users/will/opt/anaconda3/envs/torch_to_train/bin/python
% which ipython
/Users/will/opt/anaconda3/envs/torch_to_train/bin/ipython
% python --version
Python 3.8.12
% ipython
In [1]: import ray
In [2]: ray.__version__
Out[2]: '1.9.2'
Code:
# -*- coding: utf-8 -*-
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ray import train
import ray.train.torch
from ray.train import Trainer
from ray.train.torch import TorchConfig
########################################################################
# The output of torchvision datasets are PILImage images of range [0, 1].
# We transform them to Tensors of normalized range [-1, 1].
########################################################################
# Let us show some of the training images, for fun.
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def my_train_func(config):
# setup data and transofmrers
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# setup model
net = Net()
net = train.torch.prepare_model(net)
# setup loss/optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# # show images
# imshow(torchvision.utils.make_grid(images))
# # print labels
# print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
########################################################################
# Let's quickly save our trained model:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
########################################################################
# See `here <https://pytorch.org/docs/stable/notes/serialization.html>`_
# for more details on saving PyTorch models.
#
# 5. Test the network on the test data
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# We have trained the network for 2 passes over the training dataset.
# But we need to check if the network has learnt anything at all.
#
# We will check this by predicting the class label that the neural network
# outputs, and checking it against the ground-truth. If the prediction is
# correct, we add the sample to the list of correct predictions.
#
# Okay, first step. Let us display an image from the test set to get familiar.
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
########################################################################
# Next, let's load back in our saved model (note: saving and re-loading the model
# wasn't necessary here, we only did it to illustrate how to do so):
net = Net()
net.load_state_dict(torch.load(PATH))
########################################################################
# Okay, now let us see what the neural network thinks these examples above are:
outputs = net(images)
########################################################################
# The outputs are energies for the 10 classes.
# The higher the energy for a class, the more the network
# thinks that the image is of the particular class.
# So, let's get the index of the highest energy:
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
for j in range(4)))
########################################################################
# The results seem pretty good.
#
# Let us look at how the network performs on the whole dataset.
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
########################################################################
# That looks way better than chance, which is 10% accuracy (randomly picking
# a class out of 10 classes).
# Seems like the network learnt something.
#
# Hmmm, what are the classes that performed well, and the classes that did
# not perform well:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
if __name__ == '__main__':
config = {}
trainer = Trainer(backend="torch", num_workers=4)
trainer.start() # set up resources
trainer.run(my_train_func, config=config)
trainer.shutdown()
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (7 by maintainers)
Top Results From Across the Web
Training (tune.Trainable, session.report) — Ray 2.2.0
This API is the canonical way to report metrics from Tune and Train, and replaces the legacy tune.report , with tune.checkpoint_dir , train.report...
Read more >Execution (Tuner, tune.Experiment) — Ray 2.2.0
Tuner is the recommended way of launching hyperparameter tuning jobs with Ray Tune. Parameters. trainable – The trainable to be tuned. param_space –...
Read more >Ray Train Examples — Ray 2.2.0 - the Ray documentation
Ray Train Examples#. Below are examples for using Ray Train with a variety of models, frameworks, and use cases. General Examples#. PyTorch#.
Read more >Source code for ray.tune.execution.ray_trial_executor
ERROR = 7 # This is to signal to TrialRunner that there is an error. ... To resolve this issue, add the "...
Read more >Ray Tune FAQ — Ray 2.2.0 - the Ray documentation
Is it a small or large problem (how long does it take to train? ... run pip install -r ray/python/ray/tune/requirements-dev.txt to install all...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
I took a quick look, here’s a simple repro:
The root cause indeed seems to be because Ray Train executes the training function in a separate thread. This is because the main thread for each process is currently used to allow the
Trainer
(more specifically theWorkerGroup
) to callray.get
to get the workers’ reported/checkpointed data.Fixing this likely requires us to rearchitect the backend pretty significantly. While I do agree that this isn’t a great user experience, I’m not sure if we want to allocate 20+ hours to fix it at this moment.
As a short term solution, would it be reasonable to add an FAQ to the documentation and explain how to work around this - specifically to just run the plot outside of the
Trainer
? The user experience isn’t terrible as everything can still run in a single Python script (and by doing this you don’t end up creatingnum_worker
plots).@worldveil I created a PR to add this to the docs. Please take a look!