question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItĀ collects links to all the places you might be looking at while hunting down a tough bug.

And, if youā€™re still stuck at the end, weā€™re happy to hop on a call to see how we can help out.

Dataset slow during model training

See original GitHub issue

Describe the bug

While migrating towards šŸ¤— Datasets, I encountered an odd performance degradation: training suddenly slows down dramatically. I train with an image dataset using Keras and execute a to_tf_dataset just before training.

First, I have optimized my dataset following https://discuss.huggingface.co/t/solved-image-dataset-seems-slow-for-larger-image-size/10960/6, which actually improved the situation from what I had before but did not completely solve it.

Second, I saved and loaded my dataset using tf.data.experimental.save and tf.data.experimental.load before training (for which I would have expected no performance change). However, I ended up with the performance I had before tinkering with šŸ¤— Datasets.

Any idea whatā€™s the reason for this and how to speed-up training with šŸ¤— Datasets?

Steps to reproduce the bug

# Sample code to reproduce the bug

from datasets import load_dataset
import os

dataset_dir = "./dataset"
prep_dataset_dir = "./prepdataset"
model_dir = "./model"

# Load Data
dataset = load_dataset("Lehrig/Monkey-Species-Collection", "downsized")
def read_image_file(example):
    with open(example["image"].filename, "rb") as f:
        example["image"] = {"bytes": f.read()}
        return example
dataset = dataset.map(read_image_file)
dataset.save_to_disk(dataset_dir)

# Preprocess
from datasets import (
    Array3D,
    DatasetDict,
    Features,
    load_from_disk,
    Sequence,
    Value
)
import numpy as np
from transformers import ImageFeatureExtractionMixin

dataset = load_from_disk(dataset_dir)

num_classes = dataset["train"].features["label"].num_classes
one_hot_matrix = np.eye(num_classes)
feature_extractor = ImageFeatureExtractionMixin()

def to_pixels(image):
    image = feature_extractor.resize(image, size=size)
    image = feature_extractor.to_numpy_array(image, channel_first=False)
    image = image / 255.0
    return image

def process(examples):
    examples["pixel_values"] = [
        to_pixels(image) for image in examples["image"]
    ]
    examples["label"] = [
        one_hot_matrix[label] for label in examples["label"]
    ]
    return examples

features = Features({
    "pixel_values": Array3D(dtype="float32", shape=(size, size, 3)),
    "label": Sequence(feature=Value(dtype="int32"), length=num_classes)
})

prep_dataset = dataset.map(
    process,
    remove_columns=["image"],
    batched=True,
    batch_size=batch_size,
    num_proc=2,
    features=features,
)

prep_dataset = prep_dataset.with_format("numpy")

# Split
train_dev_dataset = prep_dataset['test'].train_test_split(
    test_size=test_size,
    shuffle=True,
    seed=seed
)

train_dev_test_dataset = DatasetDict({
    'train': train_dev_dataset['train'],
    'dev': train_dev_dataset['test'],
    'test': prep_dataset['test'],
})

train_dev_test_dataset.save_to_disk(prep_dataset_dir)

# Train Model
import datetime
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, BatchNormalization
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from transformers import DefaultDataCollator

dataset = load_from_disk(prep_data_dir)

data_collator = DefaultDataCollator(return_tensors="tf")

train_dataset = dataset["train"].to_tf_dataset(
    columns=['pixel_values'],
    label_cols=['label'],
    shuffle=True,
    batch_size=batch_size,
    collate_fn=data_collator
)

validation_dataset = dataset["dev"].to_tf_dataset(
    columns=['pixel_values'],
    label_cols=['label'],
    shuffle=False,
    batch_size=batch_size,
    collate_fn=data_collator
)

print(f'{datetime.datetime.now()} - Saving Data')
tf.data.experimental.save(train_dataset, model_dir+"/train")
tf.data.experimental.save(validation_dataset, model_dir+"/val")

print(f'{datetime.datetime.now()} - Loading Data')
train_dataset = tf.data.experimental.load(model_dir+"/train")
validation_dataset = tf.data.experimental.load(model_dir+"/val")

shape = np.shape(dataset["train"][0]["pixel_values"])
backbone = InceptionV3(
    include_top=False,
    weights='imagenet',
    input_shape=shape
)

for layer in backbone.layers:
    layer.trainable = False

model = Sequential()
model.add(backbone)
model.add(GlobalAveragePooling2D())
model.add(Dense(128, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(Dense(64, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.3))
model.add(Dense(10, activation='softmax'))

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print(model.summary())

earlyStopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    verbose=0,
    mode='min'
)

mcp_save = ModelCheckpoint(
    f'{model_dir}/best_model.hdf5',
    save_best_only=True,
    monitor='val_loss',
    mode='min'
)

reduce_lr_loss = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.1,
    patience=7,
    verbose=1,
    min_delta=0.0001,
    mode='min'
)

hist = model.fit(
    train_dataset,
    epochs=epochs,
    validation_data=validation_dataset,
    callbacks=[earlyStopping, mcp_save, reduce_lr_loss]
)

Expected results

Same performance when training without my ā€œsave/load hackā€ or a good explanation/recommendation about the issue.

Actual results

Performance slower without my ā€œsave/load hackā€.

Epoch Breakdown (without my ā€œsave/load hackā€):

  • Epoch 1/10 41s 2s/step - loss: 1.6302 - accuracy: 0.5048 - val_loss: 1.4713 - val_accuracy: 0.3273 - lr: 0.0010
  • Epoch 2/10 32s 2s/step - loss: 0.5357 - accuracy: 0.8510 - val_loss: 1.0447 - val_accuracy: 0.5818 - lr: 0.0010
  • Epoch 3/10 36s 3s/step - loss: 0.3547 - accuracy: 0.9231 - val_loss: 0.6245 - val_accuracy: 0.7091 - lr: 0.0010
  • Epoch 4/10 36s 3s/step - loss: 0.2721 - accuracy: 0.9231 - val_loss: 0.3395 - val_accuracy: 0.9091 - lr: 0.0010
  • Epoch 5/10 32s 2s/step - loss: 0.1676 - accuracy: 0.9856 - val_loss: 0.2187 - val_accuracy: 0.9636 - lr: 0.0010
  • Epoch 6/10 42s 3s/step - loss: 0.2066 - accuracy: 0.9615 - val_loss: 0.1635 - val_accuracy: 0.9636 - lr: 0.0010
  • Epoch 7/10 32s 2s/step - loss: 0.1814 - accuracy: 0.9423 - val_loss: 0.1418 - val_accuracy: 0.9636 - lr: 0.0010
  • Epoch 8/10 32s 2s/step - loss: 0.1301 - accuracy: 0.9856 - val_loss: 0.1388 - val_accuracy: 0.9818 - lr: 0.0010
  • Epoch 9/10 loss: 0.1102 - accuracy: 0.9856 - val_loss: 0.1185 - val_accuracy: 0.9818 - lr: 0.0010
  • Epoch 10/10 32s 2s/step - loss: 0.1013 - accuracy: 0.9808 - val_loss: 0.0978 - val_accuracy: 0.9818 - lr: 0.0010

Epoch Breakdown (with my ā€œsave/load hackā€):

  • Epoch 1/10 13s 625ms/step - loss: 3.0478 - accuracy: 0.1146 - val_loss: 2.3061 - val_accuracy: 0.0727 - lr: 0.0010
  • Epoch 2/10 0s 80ms/step - loss: 2.3105 - accuracy: 0.2656 - val_loss: 2.3085 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 3/10 0s 77ms/step - loss: 1.8608 - accuracy: 0.3542 - val_loss: 2.3130 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 4/10 1s 98ms/step - loss: 1.8677 - accuracy: 0.3750 - val_loss: 2.3157 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 5/10 1s 204ms/step - loss: 1.5561 - accuracy: 0.4583 - val_loss: 2.3049 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 6/10 1s 210ms/step - loss: 1.4657 - accuracy: 0.4896 - val_loss: 2.2944 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 7/10 1s 205ms/step - loss: 1.4018 - accuracy: 0.5312 - val_loss: 2.2917 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 8/10 1s 207ms/step - loss: 1.2370 - accuracy: 0.5729 - val_loss: 2.2814 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 9/10 1s 214ms/step - loss: 1.1190 - accuracy: 0.6250 - val_loss: 2.2733 - val_accuracy: 0.0909 - lr: 0.0010
  • Epoch 10/10 1s 207ms/step - loss: 1.1484 - accuracy: 0.6302 - val_loss: 2.2624 - val_accuracy: 0.0909 - lr: 0.0010

Environment info

  • datasets version: 2.2.2
  • Platform: Linux-4.18.0-305.45.1.el8_4.ppc64le-ppc64le-with-glibc2.17
  • Python version: 3.8.13
  • PyArrow version: 7.0.0
  • Pandas version: 1.4.2
  • TensorFlow: 2.8.0
  • GPU (used during training): Tesla V100-SXM2-32GB

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

3reactions
Rocketknight1commented, Jun 13, 2022

Hi @lehrig, I suspect whatā€™s happening here is that our to_tf_dataset() method has some performance issues when streaming samples. This is usually not a problem, but they become apparent when streaming a vision dataset into a very small vision model, which will need a lot of sample throughput to saturate the GPU.

When you save a tf.data.Dataset with tf.data.experimental.save, all of the samples from the dataset (which are, in this case, batches of images), are saved to disk. When you load this saved dataset, youā€™re effectively bypassing to_tf_dataset() entirely, which alleviates this performance bottleneck.

to_tf_dataset() is something weā€™re actively working on overhauling right now - particularly for image datasets, we want to make it possible to access the underlying images with tf.data without going through the current layer of indirection with Arrow, which should massively improve simplicity and performance.

However, if you just want this to work quickly but without needing your save/load hack, my advice would be to simply load the dataset into memory if itā€™s small enough to fit. Since all your samples have the same dimensions, you can do this simply with:

dataset = load_from_disk(prep_data_dir)
dataset = dataset.with_format("numpy")
data_in_memory = dataset[:]

Then you can simply do something like:

model.fit(data_in_memory["pixel_values"], data_in_memory["labels"])
0reactions
Rocketknight1commented, Jun 14, 2022

That 5 minute wait is quite surprising! I donā€™t have a good explanation for why itā€™s happening, but it canā€™t be an issue with datasets or tf.data because youā€™re just fitting directly on Numpy arrays at this point. All I can suggest is seeing if you can isolate the issue - for example, does fitting on a smaller dataset containing only 10% of the original data reduce the wait? This might indicate the delay is caused by your data being copied or converted somehow. Alternatively, you could try removing things like callbacks and seeing if you could isolate the issue there.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Dataset slow during model training - Hugging Face Forums
Datasets, I encountered an odd performance degradation: training suddenly slows down dramatically. I train with an image dataset using KerasĀ ...
Read more >
Slow model training is very frustrating - Fast.ai forums
Hello. I am using kaggle notebook with a GPU. It takes forever to train the model described in chapter 10 of the book....
Read more >
Does your model train too slow? A guide to solve this famous ...
A quick guide to exploring various ways to encounter Vanishing or Exploding Gradient problem while training Neural Network using Gradient Descent Algorithm.
Read more >
Iterating Quickly On Large Data and Slow Models - Medium
So we ended up optimizing this process heavily by separating the creation of our dataset from the training of our model. In this...
Read more >
How to Reduce Training Time for a Deep Learning Model ...
Learn to create an input pipeline for images to efficiently use CPU and GPU resources to process the image dataset and reduce the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found