device support
See original GitHub issueFor array creation functions, device support will be needed, unless we intend to only support operations on the default device. Otherwise what will happen if any function that creates a new array (e.g. create the output array with empty()
before filling it with the results of some computation) is that the new array will be on the default device, and an exception will be raised if an input array is on a non-default device.
We discussed this in the Aug 27th call, and the preference was to do something PyTorch-like, perhaps a simplified version to start with (we may not need the context manager part), as the most robust option. Summary of some points that were made:
- TensorFlow has an issue where its
.shape
attribute is also a tensor, and that interacts badly with its context manager approach to specifying devices - because metadata like.shape
typically should live on the host, not on an accelerator. - PyTorch uses a mix of a default device, a context manager, and
device=
keywords - JAX also has a context manager-like approach; it has a global default that can be set, and then
pmap
s can be decorated to override that. The different with other libraries that use a context is that JAX is fairly (too) liberal about implicit device copies. - It’d be best for operations where data is not all on the same device to raise an exception. Implicit device transfers are making it very hard to get a good performance story.
- Propagating device assignments through operations is important.
- Control over where operations get executed is important; trying to be fully implicit doesn’t scale to situation with multiple GPUs
- It may not make sense to add syntax for device support for libraries that only support a single device (i.e., CPU).
Links to the relevant docs for each library:
- PyTorch: https://pytorch.org/docs/stable/notes/cuda.html
- TensorFlow: https://www.tensorflow.org/api_docs/python/tf/device
- CuPy: https://docs.cupy.dev/en/stable/tutorial/basic.html#current-device
- JAX: https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices
- MXNet: https://mxnet.apache.org/versions/1.6/api/python/docs/api/mxnet/context/index.html
Next step should be to write up a proposal for something PyTorch-like.
Issue Analytics
- State:
- Created 3 years ago
- Comments:8 (7 by maintainers)
Top GitHub Comments
With SYCL, one writes a kernel once, compile it with a SYCL compiler to an IR, and then you can submit it to different queues targeting different devices (i.e. CPU, GPU, FPGA, etc.).
This example constructs a Python extension, compiled with Intel’s DPCPP compiler, to compute column-wise sums of an array.
Running it on CPU/GPU is a matter of changing a queue to submit the work to:
Array consuming library author need not be aware of this, I thought, just as he/she need not be aware of which array implementation is powering the application.
That’s good to know. In that case I’ll remove the note on that, no point in mentioning it if it’s being phased out.
The “mutating global state” points at the exact problem with context managers. Having global state generally makes it harder to write correct code. For the person writing that code it may be fine to keep that all in their head, but it affects any library call that gets invoked. Which is probably still fine in single-device situations (e.g. switch between CPU and one GPU), but beyond that it gets tricky.
The consensus of our conversation in September was that a context manager isn’t always enough, and that the PyTorch model was more powerful. That still left open whether we should also add a context manager though.
Re cost - do you mean cost in verbosity? Passing through a keyword shouldn’t have significant performance cost.
I think the typical pattern would be to either use the default, or obtain it from the local context. E.g.
And only in more complex situations would the actual device need to be known explicitly.
That is a good question, should it be enforced or just recommended? Having device transfers be explicit is usually better (implicit transfers can make for hard to track down performance issues), but perhaps not always.
Interesting, I’m not familiar with this hard/soft distinction, will look at the TF docs.
That should not be a problem if shape and size aren’t arrays, but either custom objects or tuples/ints?
That may be a good idea. Would be great to discuss in more detail later today.