Add 32 bit support to neural_network module
See original GitHub issueRelated to https://github.com/scikit-learn/scikit-learn/issues/11000 it would be good to add support for 32 bit computations for estimators in the neural_network
module. This was done for BernoulliRBM
in https://github.com/scikit-learn/scikit-learn/pull/16352 Because performance is bound by the dot product this is going to have a large impact (cf https://github.com/scikit-learn/scikit-learn/pull/17641#issuecomment-646879638)
I would even argue that unlike other models, it could make sense to add a dtype=np.float32
parameter and make calculations in 32 bit by default regardless of X.dtypes
. We could also consider supporting dtype=np.float16
.
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
How to deploy a deep learning model in a windows 32bit ...
I have trained a crnn model for ocr with tensorflow. Then I deployed it with tensorflow c++ API on a windows 64 bit...
Read more >Train With Mixed Precision - NVIDIA Documentation Center
Single precision (also known as 32-bit) is a common floating point format ( float in C-derived programming languages), and 64-bit, known as ...
Read more >Neural Networks API Drivers - Android Open Source Project
This page provides an overview of how to implement a Neural Networks API ... and tts_float models are recommended for measuring performance for...
Read more >Quantization and Deployment of Deep Neural Networks on ...
This work focuses on quantization and deployment of deep neural networks onto low-power 32-bit microcontrollers. The quantization methods, relevant in the ...
Read more >Mixed precision | TensorFlow Core
Mixed precision is the use of both 16-bit and 32-bit floating-point types in a model during training to make it run faster and...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Sure @rth, will open a PR in a bit.
Great! I was hoping that it would be a bit more significant, I but I guess it also depends on the CPU vectorization support.
Maybe let’s not add the dtype parameter after all, but rather do as in https://github.com/scikit-learn/scikit-learn/pull/16352: run,
first and then use
X.dtype
for the creation of new arrays. That would probably be more consistent with other estimators. Would you like to make a PR?