question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

PyTorch Dataloading doesn't work with >0 workers

See original GitHub issue

Hi!

I’m new to the JAX ecosystem, have used PyTorch and TensorFlow extensively for over 5 years.

My issue is that I can’t get PyTorch data loading to work with jax/flax with num_workers>0. Following is a minimal example to reproduce my issues

import argparse
from typing import Sequence
from functools import partial
import flax
from typing import Any
import optax
from flax.training import train_state
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import tqdm
from torchvision.datasets import CIFAR10
from flax.training import common_utils
import torch
import torchvision.transforms as transforms
import torch.multiprocessing as multiprocessing
multiprocessing.set_start_method('spawn')


NUM_CLASSES = 10
NUM_EPOCHS = 50
BATCH_SIZE = 512


parser = argparse.ArgumentParser()
parser.add_argument("--num_workers", default=0, type=int)


def collate_fn(batch):
    inputs_np = []
    targets_np = []
    for item in batch:
        inp_np = item[0].permute(1, 2, 0).detach().numpy()
        tgts_np = item[1]
        inputs_np.append(inp_np)
        targets_np.append(tgts_np)
    inputs_np = np.asarray(inputs_np)
    targets_np = np.asarray(targets_np)
    return inputs_np, targets_np


class CNN(nn.Module):
    @nn.compact
    def __call__(self, inputs, train=False):
        conv = partial(nn.Conv, kernel_size=(3, 3), strides=(2, 2), 
                       use_bias=False, kernel_init=jax.nn.initializers.kaiming_normal())
        bn = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9,
                   epsilon=1e-5)
        x = conv(features=32)(inputs)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=64)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = conv(features=128)(x)
        x = bn()(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(4, 4), strides=(1, 1))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(NUM_CLASSES)(x)
        return x


def initialize(key, inp_shape, model):
  input_shape = (1,) + inp_shape
  @jax.jit
  def init(*args):
    return model.init(*args)
  variables = init({'params': key}, jnp.ones(input_shape))
  return variables['params'], variables['batch_stats']


@jax.jit
def cross_entropy_loss(logits, labels):
    one_hot_labels = common_utils.onehot(labels, num_classes=NUM_CLASSES)
    xentropy = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
    return jnp.mean(xentropy)

@jax.jit
def calculate_accuracy(logits, labels):
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return accuracy


@jax.jit
def train_step(state, images, labels):
    step = state.step
    @jax.jit
    def cost_fn(params):
        logits, new_model_state = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            images,
            mutable=['batch_stats'],
            train=True
        )
        loss = cross_entropy_loss(logits, labels)
        weight_penalty_params = jax.tree_leaves(params)
        weight_l2 = sum([jnp.sum(x ** 2)
                        for x in weight_penalty_params
                        if x.ndim > 1])
        weight_decay=0.0001
        weight_penalty = weight_decay * 0.5 * weight_l2
        loss = loss + weight_penalty
        return loss, (new_model_state, logits)
    grad_fn = jax.value_and_grad(cost_fn, has_aux=True)
    aux, grads = grad_fn(state.params)
    new_model_state, logits = aux[1]
    acc = calculate_accuracy(logits, labels)
    new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
    return new_state, aux[0], acc

@jax.jit
def eval_step(state, images, labels):
    logits = state.apply_fn(
        {"params": state.params, 
        "batch_stats": state.batch_stats}, 
        images, train=False, mutable=False)
    return calculate_accuracy(logits, labels)


class TrainState(train_state.TrainState):
    batch_stats: Any


if __name__ == "__main__":
  args = parser.parse_args()
  cnn = CNN()
  key = jax.random.PRNGKey(0)
  key, *subkeys = jax.random.split(key, 4)
  params, batch_stats = initialize(subkeys[0], (32, 32, 3), cnn)
  tx = optax.adam(
    1e-3
  )
  state = TrainState.create(
      apply_fn=cnn.apply,
      params=params,
      tx=tx,
      batch_stats=batch_stats
  )
  transform = transforms.Compose(
  [transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

  batch_size = BATCH_SIZE
  trainset = CIFAR10(root='./data', train=True,
                                          download=True, transform=transform)
  trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, drop_last=True,
                                          shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn)
  num_tr_steps = len(trainloader)
  testset = CIFAR10(root='./data', train=False,
                                      download=True, transform=transform)
  testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, drop_last=True,
                                          shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn)
  num_test_steps = len(testloader)
  
  for epoch in range(1, NUM_EPOCHS+1):
    print("Starting epoch {}".format(epoch))
    train_loss = []
    train_acc = []
    itercnt = 0
    for batch in trainloader:
      images, labels = batch
      state, loss, acc = train_step(state, images, labels)
      if itercnt == 0:
        print("Input shape:", images.shape)
        print("labels shape:", labels.shape)
      if itercnt % 25 == 0:
        print("[{:03d}] | Step: [{:04d}/{:04d}] | Loss: {:.04f} | Acc: {:.04f}".format(
          epoch, itercnt, num_tr_steps, loss, acc
        ))
      train_loss.append(jax.device_get(loss))
      train_acc.append(jax.device_get(acc))
      itercnt += 1
    print("Validating...")
    val_accs = []
    for batch in testloader:
      images, labels = batch
      acc = eval_step(state, images, labels)
      val_accs.append(jax.device_get(acc))

    print("Epoch {:03d} done...".format(epoch))
    print("\t Train loss: {:.04f} | Train Acc: {:.04f}".format(
      np.mean(train_loss), np.mean(train_acc)))
    print("\t Val Acc: {:.04f}".format(np.mean(val_accs)))

Problem encountered:

I’ve tried running the script on both TPU and GPU: it works fine when num_workers = 0, but doesn’t work with num_workers > 0.

An earlier issue from 2020 recommended setting torch.multiprocessing.set_start_method('spawn'), but that didn’t fix the issue for me. Unlike the author of that issue, I’m not using jax primitives in the data loading pipeline at all (as can be seen in the collate_fn() function)

With num_workers>0, I get the following errors:

On GPU

  • With torch.multiprocessing.set_start_method('spawn') throws RuntimeError: context has already been set
  • With torch.multiprocessing.set_start_method('fork') throws Failed setting context: CUDA_ERROR_NOT_INITIALIZED: initialization error

On TPUv2-8 VM

  • With torch.multiprocessing.set_start_method('spawn') throws libtpu.so already in use by another process, followed by RuntimeError: context has already been set later in the stack trace.
  • With torch.multiprocessing.set_start_method('fork') I get no error, the dataloader hangs indefinitely.

Following are the packages being used:

torch==1.9.0+cu111
jax==0.2.26
jaxlib==0.1.75       #+cuda11.cudnn82 for GPU

Any help is appreciated!

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:11

github_iconTop GitHub Comments

4reactions
nikitakitcommented, Feb 10, 2022

This issue appears to be a regression compared to one year ago. I was using multi-worker data loaders in sabertooth and they worked fine at the time, but no longer work with newly started TPU VMs. I want to emphasize that the data workers are not using JAX nor accessing the TPUs in any way, just doing pure numpy computation.

torch.multiprocessing.set_start_method('spawn') sort of works as a work around. I’ve managed to avoid the error RuntimeError: context has already been set with the idiom if __name__ == '__main__': torch.multiprocessing.set_start_method('spawn') – I had to wrap it so that each spawned worker don’t itself attempt to set the start method. However this workaround still has issues: each worker takes a really long time to spawn, and generates a bunch of libtpu.so already in use by another process messages. Setting persistent_workers=True helps cut down on these but it’s still annoying.

Given that this is a regression, is it really the case that it can’t be fixed? None of the child processes are actually doing anything with the TPU.

3reactions
levskayacommented, Jan 15, 2022

Hi! I think that it’s going to be really hard to make this work. We generally don’t try to support python-multiprocessing: all the internal C++ libs we use aren’t written to be fork-safe, and I’m not sure that TPU libtpu.so can be used with multiprocessing at all.

Usually we recommend that people use TFDS / tf.data based dataloaders as they’re far more CPU efficient for feeding multiple GPUs or TPUs than torch dataloaders with multiprocessing.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Torch dataloader num_workers>0 not spawning workers
I'm currently working on porting code from Keras to PyTorch. I'm working with many GPUs and CPUs so it's important to have batch...
Read more >
DataLoader, when num_worker >0, there is bug
Since PyTorch seems to adopt lazy way of initializing workers, this means that the actual file opening has to happen inside of the...
Read more >
Pytorch DataLoader freezes when num_workers > 0 - vision
i am facing exactly this same issue : DataLoader freezes randomly when num_workers > 0 (Multiple threads train models on different GPUs in...
Read more >
Network doesn't learn with num_workers=0, works fine ...
... no real need for multiple workers since there is no intensive loading/data augmentation for now. After that change, training didn't work ......
Read more >
Num_workers>0 Windows 10 ERROR - PyTorch Forums
... I try to run on windows I can only run if num-workers = 0 but that makes the ... type=int, help='Number of...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found