loadModel of a simple LSTM model, with weights on the order of 3MB uses 5GB memory.
See original GitHub issueCode to reproduce:
const model = await tf.loadModel('https://timotheebernard.github.io/models/model.json');
// This yields:
// {
// unreliable: false
// numTensors: 8991
// numDataBuffers: 8986
// numBytes: 5050462344
// }
console.log(tf.memory());
// If you use this code here you can get the total size of the weights of all layers,
// which gives 766601. With 4 bytes per weight, that's 3066404 bytes ~= 3MB.
const totalWeightsSize =
tf.util.flatten(model.layers.map(l => l.weights).filter(x => x.length > 0))
.map(x => x.val)
.reduce((accumulator, weight) => accumulator += weight.size, 0);
Note that the weights files are small, < 3MB: https://github.com/timotheebernard/timotheebernard.github.io/tree/master/models
Is there something that recurrent cells are doing that cause this much memory blowout? @caisq, @ericdnielsen, @bileschi can you take a look?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:3
- Comments:5
Top Results From Across the Web
How to Make Predictions with Long Short-Term Memory ...
The goal of developing an LSTM model is a final model that you can use on your sequence prediction problem. In this post,...
Read more >Long Short-Term Memory Networks - MATLAB & Simulink
This topic explains how to work with sequence and time series data for classification and regression tasks using long short-term memory (LSTM) networks....
Read more >LSTM RNN in Keras: Examples of One-to-Many ... - WandB
In this report, I explain long short-term memory (LSTM) recurrent neural networks (RNN) and how to build them with Keras.
Read more >Save and Load your RNN model - Code A Star
Before we use a pre-trained model, we need to train a mode. Let's use the toxic comment classification project that we did last...
Read more >LSTM by Example using Tensorflow - Towards Data Science
In Deep Learning, Recurrent Neural Networks (RNN) are a family of ... that has found practical applications is Long Short-Term Memory (LSTM) ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
My guess is that this has to do with the Orthogonal initializer of the LSTM. When the model JSON is loaded, the Orthogonal initializer calls the QR decomposition under the hood. Owing to the relatively large size of the matrix, this slows things done. Maybe we can optimize the memory usage of the QR decomposition a little. We can also consider adding the logic to skip any initialization if weights are going to be loaded afterwards.
FYI, I plan to speed up orthogonal initializers by replacing the QR decomposition with the Gram-Schmidt process.