Simplify CrossAttention to run on Apple Neural Engine
See original GitHub issueIs your feature request related to a problem? Please describe. I’m trying to convert portions of unet into CoreML. However, CrossAttention fails to compile to the Apple Neural Engine.
Describe the solution you’d like My best guess after a lot of experimentation on converting CrossAttention in many respects is that there are too many reshapes and transposes.
Is there a way to simplify
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
Using an einsum? Or some other method?
Thanks!
Issue Analytics
- State:
- Created a year ago
- Comments:6 (4 by maintainers)
Top Results From Across the Web
Deploying Transformers on the Apple Neural Engine
For deployment of trained models on Apple devices, they use coremltools, Apple's open-source unified conversion tool, to convert their favorite PyTorch and ...
Read more >Accelerate machine learning with Metal - WWDC22 - Videos
Discover how you can use Metal to accelerate your PyTorch model training on macOS. We'll take you through updates to TensorFlow training...
Read more >Multi-task Learning with Cross Attention for Keyword Spotting
In this study, we investigate training a single network to perform both tasks jointly. We train the network in a supervised multi-task learning...
Read more >Objects that Simplify the Creation of Neural Networks
Overview. Graphs in Metal Performance Shaders offer a higher level graph API, intended to simplify the creation of neural networks. The graph is...
Read more >Auditory Attention Detection via Cross-Modal Attention
(2017) studied a non-linear neural network for mapping EEG signals to speech envelopes in a cocktail party scenario and showed that it ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
yeah, giving it a shot now. I want to make sure the einsum I’m using will convert with coremltools though. Will check back in!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.