question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Supporting PyTorch GPU compatibility on Apple Silicon chips

See original GitHub issue

🚀 Feature

PyTorch recently released support for GPU acceleration using the Apple Silicon chips. This should be supported in stable-baselines3 by the "mps" device (I believe).

Minimal Example

from stable_baselines3 import PPO
import gym

env = gym.make("QbertNoFrameskip-v0")
ppo = PPO("CnnPolicy", env, device = "mps")

ppo.learn(total_timesteps = 1000000)

The Mac Silicon GPU device is not automatically recognized by stable-baselines at the moment, so it defaults to "cpu". If you try to force it to use the "mps" device, this stack trace appears.

A.L.E: Arcade Learning Environment (version 0.7.4+069f8bd)
[Powered by Stella]
Traceback (most recent call last):
  File "/Users/ryanrudes/Desktop/pydt/train_min.py", line 7, in <module>
    ppo.learn(total_timesteps = 1000000)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/ppo/ppo.py", line 310, in learn
    return super().learn(
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 247, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts
    actions, values, log_probs = self.policy(obs_tensor)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 592, in forward
    distribution = self._get_action_dist_from_latent(latent_pi)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/policies.py", line 610, in _get_action_dist_from_latent
    return self.action_dist.proba_distribution(action_logits=mean_actions)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/stable_baselines3-1.5.1a6-py3.9.egg/stable_baselines3/common/distributions.py", line 274, in proba_distribution
    self.distribution = Categorical(logits=action_logits)
  File "/Users/ryanrudes/miniforge3/lib/python3.9/site-packages/torch/distributions/categorical.py", line 60, in __init__
    self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
NotImplementedError: Could not run 'aten::amax.out' with arguments from the 'MPS' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::amax.out' is only available for these backends: [Dense, Conjugate, UNKNOWN_TENSOR_TYPE_ID, QuantizedXPU, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseCPU, SparseCUDA, SparseHIP, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, SparseXPU, UNKNOWN_TENSOR_TYPE_ID, SparseVE, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, NestedTensorCUDA, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID, UNKNOWN_TENSOR_TYPE_ID].

CPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterCPU.cpp:37386 [kernel]
Meta: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterMeta.cpp:31637 [kernel]
BackendSelect: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:133 [backend fallback]
Named: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/ADInplaceOrViewType_1.cpp:3288 [kernel]
AutogradOther: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradCUDA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXLA: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradMPS: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradIPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradXPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradHPU: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradLazy: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse1: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse2: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
AutogradPrivateUse3: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/VariableType_1.cpp:11242 [autograd kernel]
Tracer: registered at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/generated/TraceType_1.cpp:11951 [kernel]
AutocastCPU: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:481 [backend fallback]
Autocast: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/autocast_mode.cpp:324 [backend fallback]
Batched: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/BatchingRegistrations.cpp:1064 [backend fallback]
VmapMode: fallthrough registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
Functionalize: registered at /Users/runner/work/pytorch/pytorch/pytorch/build/aten/src/ATen/RegisterFunctionalization_3.cpp:12118 [kernel]
PythonTLSSnapshot: registered at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:137 [backend fallback

Issue Analytics

  • State:open
  • Created a year ago
  • Reactions:2
  • Comments:8 (2 by maintainers)

github_iconTop GitHub Comments

2reactions
traderpedrosocommented, Aug 19, 2022

For the moment, consider that SB3 is not compatible with MPS. But we are working on it: #951

Have you seen something in the documentation about float64 and MPS? (You can answer in the PR conversation.)

yes MPS framework

For a more extensive list of which data types do and don’t run:

Avoid Float64 on all Apple devices. Even if the hardware supports Double physically (AMD or Intel), the Metal API doesn’t let you access it. Avoid BFloat16. That is natively supported by the latest Nvidia GPUs, but not supported in Metal. Also don’t try to use TF18/TF32 or Int4. All standard integer types (UInt8, UInt16, UInt32, UInt64) and their signed counterparts work natively on Apple devices. Not exactly 8-bit integers, which are cast to 16-bit integers before being stored into registers, but those aren’t going to harm performance. Yes, they run 64-bit integers on the Apple GPU and not 64-bit floats. Metal allows you to use 64-bit integers in shaders on AMD and Intel, but the arithmetic there might just happen through emulation (slow). I think that’s where I experienced the crash in MPSGraph previously - trying to run an operation on UInt64 on my Intel Mac mini.

https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf

2reactions
Miffylicommented, May 20, 2022

Ideally yes, SB3 should support that device too (not a big change), but seems like it would, at the moment, require some operation-call changes to fully support. Those need to be addressed first (or wait till torch has equal functions for all platforms), but the changes should not interfere with existing code at all; this could spur up lots of hidden changes otherwise.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Installing PyTorch on Apple M1 chip with GPU Acceleration
Step 2: Install PyTorch packages · Check the download helper first because the installation command may change in the future. · Expect the...
Read more >
Introducing Accelerated PyTorch Training on Mac
In collaboration with the Metal engineering team at Apple, we are excited to announce support for GPU-accelerated PyTorch training on Mac.
Read more >
Running PyTorch on the M1 GPU - Sebastian Raschka
Today, the PyTorch Team has finally announced M1 GPU support, and I was excited to try it. Along with the announcement, their benchmark...
Read more >
Enable Training on Apple Silicon Processors in PyTorch
This tutorial shows you how to enable GPU-accelerated training on Apple Silicon's processors in PyTorch with Lightning.
Read more >
Deep Learning on Mac - M1 Chips | Apple Developer Forums
Both TF and PyTorch allow inference and training on CPUs in python code during development. However, only TF has GPU support at the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found