Potential memory leakage of TensorFlow Swin model on kaggle!
See original GitHub issueSystem Info
Info:
Framework: TensorFlow 2 (Keras)
Version: 2.6
OS: Kaggle
Who can help?
Swin Model Card @amyeroberts TensorFlow: @Rocketknight1 Vision: @NielsRogge, @sgugger
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
A recent kaggle competition (hosted by Google), I tried to use pretrained tf
swin transformer model from hugging face but even with the base model, I consistently received out of memory error. Below is the submission status with a base_tf_swin
model.
Some note:
- Other framework like pytorch works fine here.
- Other than this model, much larger model like
tf_convnext_xlarge
is able to run without OOM.
So, I’m assuming there might be some potential memory leakage in tf_swin
implementation. Below is the code I use to build the complete model.
id = "microsoft/swin-base-patch4-window7-224-in22k"
from transformers import AutoFeatureExtractor, TFSwinModel
feature_extractor = AutoFeatureExtractor.from_pretrained(id)
inputs = keras.Input(shape=(None, None, 3), dtype='uint8')
mode_inputs = tf.cast(inputs, tf.float32)
mode_inputs = keras.layers.Resizing(*INPUT_SHAPE)(mode_inputs)
mode_inputs = keras.layers.Rescaling(scale=1.0 / 255)(mode_inputs)
mode_inputs = keras.layers.Normalization(
mean=feature_extractor.image_mean,
variance=[x ** 2 for x in feature_extractor.image_std ],
axis=3
)(mode_inputs)
mode_inputs = keras.layers.Permute(dims=(3, 1, 2))(mode_inputs)
tf_huggingface_module = TFSwinModel.from_pretrained(id)
tf_huggingface_module.trainable = False
logits = tf_huggingface_module(mode_inputs)
adv_logits = keras.Dense(64)(logits.pooler_output)
outputs = keras.layers.Lambda(
lambda x: tf.math.l2_normalize(x, axis=-1), name='embedding_norm'
)(adv_logits)
tf_huggingface_classifier = keras.Model(inputs, outputs)
Expected behavior
It should work like other model. To reproduce the issue exactly, (in the worst case), you may need to run it on kaggle platform. Kaggle submission status (as shown in the above diagram) is not very descriptive other than just showing submission status 😦. Mainly, I like to know what could be the cause of it and any possible solution.
Issue Analytics
- State:
- Created a year ago
- Comments:14 (4 by maintainers)
Top GitHub Comments
Randomly jumping in this thread 😃
Hi @innat. As mentioned above it’s quite hard to debug without know what’s happening during submission and logs from the kaggle notebook. My current best guess is it’s due to the size of the saved Swin model.
Using your script to create and save out a model, I looked at the sizes across different checkpoints:
I haven’t dug much into why the model is so much larger. A cursory glance at the model graphs didn’t reveal anything particularly surprising.