tf2.4.1 + Keras model + Large Dataset = Memory leak?
See original GitHub issueHey guys.
I work at a relatively large scale company, and during my training loop I am eventually running OOM on my vms. This is unexpected and I was hoping someone could take a look at what I have here and maybe point me in the right direction to solving this problem?
I am using tensorflow_ranking for a large scale recommender at my company. Having upgraded to tensorflow==2.4.1 and tensorflow-ranking=0.3.3 I believe I am running into some kind of unexpected memory leak. I am running in a kubeflow cluster, with gcr.io/deeplearning-platform-release/tf2-gpu.2-4:latest as the container.
This is what I consistently see during training:

Notice both the CPU and GPU memory are slowly creeping upwards, eventually causing an OOM error.
My data is TFRecords, generated from beam. They are encoded in ELWC style. Their list size is maximum 240, but varies based on session. Generally I have 230gb of train data, with a 0.01 test/eval split.
I am loading the data using the following function:
def load_dataset(kind='train', list_size_feature='list_size', limit=limit):
context_spec, example_spec = get_features_dataset_spec(with_label=True)
file_pattern = f"{data_dir}/{kind}/*.tfrecord"
log.info(f"loading dataset from file_pattern: {file_pattern}")
dataset = tfr.data.build_ranking_dataset(
file_pattern=file_pattern,
data_format=tfr.data.ELWC,
batch_size=batch_size,
context_feature_spec=context_spec,
example_feature_spec=example_spec,
reader=tf.data.TFRecordDataset,
prefetch_buffer_size=0,
sloppy_ordering=True,
shuffle=True if kind == 'train' else False,
num_epochs=num_epochs if kind == 'train' else 1,
size_feature_name=list_size_feature,
list_size=run.static_list_size,
)
@tf.function
def _separate_features_and_label(features):
features_without_labels = {
feature: value
for feature, value in features.items()
if feature != 'labels'
}
labels = tf.cast(
x=tf.squeeze(features[_LABEL_FEATURE_NAME], axis=2),
dtype=tf.float32,
name='labels'
)
return features_without_labels, labels
if limit:
dataset = dataset.take(limit)
return dataset
train_data = load_dataset('train')
eval_data = load_dataset('eval')
test_data = load_dataset('test')
Here I am building the network/ranker:
context_columns, example_columns = get_feature_columns()
network = tfr.keras.canned.DNNRankingNetwork(
context_feature_columns=context_columns,
example_feature_columns=example_columns,
hidden_layer_dims=[1024, 512, 256, 128, 64],
activation=tf.nn.relu,
dropout=0.3,
use_batch_norm=True,
)
metrics = [
*[
dict(key="ndcg", topn=k, name=f"metric/ndcg_{k}")
for k in (1, 2, 5, 10, 50, 100, 200)
],
dict(key="ordered_pair_accuracy", name="metric/ordered_pair_accuracy"),
]
ranker: tf.keras.models.Model = tfr.keras.model.create_keras_model(
network=network,
loss='pairwise_logistic_loss',
optimizer='adam',
metrics=[tfr.keras.metrics.get(**x) for x in metrics],
list_size=run.static_list_size,
size_feature_name=_LIST_SIZE_FEATURE_NAME
)
Finally I fit with the following function:
ranker.fit(
train_data,
verbose=1,
steps_per_epoch=max(1, train_samples // batch_size),
epochs=run.epochs,
validation_freq=1,
validation_data=eval_data,
callbacks=[],
)
From this code I’ve noticed the continual increase in memory on CPU and GPU until an eventual out of memory occurs as training progresses.
Any thoughts or ideas towards solutions would be greatly appreciated.
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (1 by maintainers)

Top Related StackOverflow Question
@xuanhuiwang I have since upgraded to 0.4.0 and tensorflow==2.5.0
I am pleased to report that the GPU memory is now steady on both A100s and V100s.
However I am seeing the CPU memory jump up in spikes on V100s only (for some reason?). It jumps then plateaus, then jumps, and plateaus. Very unusual.
However I have not seen a run hit OOM yet.
So I can confirm the tensorflow-ranking==0.4.0 release seems to have resolved some issues, with more testing needed.
Let me run one without a validation_data set and get the results back. It takes me a while to run out of memory. I will report back