question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

TrainingArguments does not support `mps` device (Mac M1 GPU)

See original GitHub issue

System Info

  • transformers version: 4.21.0.dev0
  • Platform: macOS-12.4-arm64-arm-64bit
  • Python version: 3.8.9
  • Huggingface_hub version: 0.8.1
  • PyTorch version (GPU?): 1.12.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

@sgugger

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, …)
  • My own task or dataset (give details below)

Reproduction

export TASK_NAME=wnli
python run_glue.py \
  --model_name_or_path bert-base-cased \
  --task_name $TASK_NAME \
  --do_train \
  --do_eval \
  --max_seq_length 128 \
  --per_device_train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3 \
  --output_dir /tmp/$TASK_NAME/

Expected behavior

When running the Trainer.train on a machine with an MPS GPU, it still just uses the CPU. I expected it to use the MPS GPU. This is supported by torch in the newest version 1.12.0, and we can check if the MPS GPU is available using torch.backends.mps.is_available().

It seems like the issue lies in the TrainingArguments._setup_devices method, which doesn’t appear to allow for the case where device = "mps".

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:12 (10 by maintainers)

github_iconTop GitHub Comments

7reactions
saattrupdancommented, Jun 30, 2022

A simple hack fixed the issue, by simply overwriting the device attribute of TrainingArguments:

import torch
from transformers import TrainingArguments


class TrainingArgumentsWithMPSSupport(TrainingArguments):

    @property
    def device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device("cuda")
        elif torch.backends.mps.is_available():
            return torch.device("mps")
        else:
            return torch.device("cpu")

This at least shows that it might just be the aforementioned _setup_devices that needs changing.

2reactions
sguggercommented, Jun 30, 2022

This is not supported yet, as this has been introduced by PyTorch 1.12, which also breaks all speech models due to a regression there. We will look into the support for Mac M1 GPUs once we officially support PyTorch 1.12 (probably won’t be before they do a patch 1.12.1).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Mac M1 gpu not detected - Beginners
I have a gpu on my pc, but it does not show up when I use, import torch print([torch.cuda.device(i) for i in ...
Read more >
Fastai on Apple M1 - Deep Learning - fast.ai Course Forums
my tests shows gpu acceleration being turned off on mac, can anybody confirm? edit: They reverted initial mps support: revert auto-enable of mac...
Read more >
Training doesn't converge when running on M1 pro GPU ...
Hi, I'm trying to train a network model on Macbook M1 pro GPU by using the MPS device, but for some reason the...
Read more >
Enable Training on Apple Silicon Processors in PyTorch
How to Enable GPU-Accelerated Training on Apple Silicon in PyTorch ... You can use the MPS device in PyTorch like so: ...
Read more >
Release 0.7.2 Carsen Stringer & Marius Pachitariu
From the command line you can choose the Mac device with python -m cellpose --dir path --gpu_device mps --use_gpu.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found