question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

How to batch train with fit_generator()?

See original GitHub issue

Apologies if this is the wrong place to raise my issue (please help me out with where best to raise it if that’s the case). I’m a novice with Keras and Python so hope responses have that in mind.

I’m trying to train a CNN steering model that takes images as input. It’s a fairly large dataset, so I created a data generator to work with fit_generator(). It’s not clear to me how to make this method trains on batches, so I assumed that the generator has to return batches to fit_generator(). The generator looks like this:

def gen(file_name, batchsz = 64):
    csvfile = open(file_name)
    reader = csv.reader(csvfile)
    batchCount = 0
    while True:
        for line in reader:
            inputs = []
            targets = []
            temp_image = cv2.imread(line[1]) # line[1] is path to image
            measurement = line[3] # steering angle
            inputs.append(temp_image)
            targets.append(measurement)
            batchCount += 1
            if batchCount >= batchsz:
                batchCount = 0
                X = np.array(inputs)
                y = np.array(targets)
                yield X, y
        csvfile.seek(0)

It reads a csv file containing telemetry data (steering angle etc) and paths to image samples, and returns arrays of size: batchsz The call to fit_generator() looks like this:

    tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator
    vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator
    try:
        #model.fit(X_all, y_all, validation_split=0.2, shuffle=True, nb_epoch=epochs)
        model.fit_generator(
            tgen,
            samples_per_epoch=113526,
            nb_epoch=6,
            validation_data=vgen,
            nb_val_samples=20001
        )

The dataset contains 113526 sample points yet the model training update output reads like this (for example):

  1020/113526 [..............................] - ETA: 27737s - loss: 0.0080
  1021/113526 [..............................] - ETA: 27723s - loss: 0.0080
  1022/113526 [..............................] - ETA: 27709s - loss: 0.0080
  1023/113526 [..............................] - ETA: 27696s - loss: 0.0080

Which appears to be training sample by sample (stochastically?). The resultant model is useless. I previously trained on a much smaller dataset using .fit() with the whole dataset loaded into memory, and that produced a model that at least works even if poorly. Clearly something is wrong with my fit_generator() approach. Will be very grateful for some help with this.

Issue Analytics

  • State:closed
  • Created 6 years ago
  • Reactions:5
  • Comments:6 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
StripedBananacommented, Aug 24, 2017

I don’t think you should use for loop in your generator. The reason for that is Keras will spawn multiple threads when using fit_generator, each calling your generator trying to fetch examples in advance. This helps parallelizing data fetching on the CPU.

From your code I understand you want to go through your whole dataset on one epoch of your fit_generator. This makes sense, but unfortunately the method wasn’t really designed like that if I got it right. You have two ways of doing it:

  • fetch random batches in your while True: loop indefinitely
  • fetch batches by indexing your dataset, and playing with steps_per_epoch to make it stop exactly at the end of your data

I opted for the latter, and it works well, though be careful of the threading nature of the method (it may try to fetch data outside your range, hence my condition in the example below:)

def my_generator(data, labels, indices, batch_size, steps):
    """Generator used by `keras.models.Sequential.fit_generator` to yield batches
    of pairs.

    Such a generator is required by the parallel nature of the aforementioned
    Keras function. It can theoretically feed batches of pairs indefinitely
    (looping over the dataset). Ideally, it would be called so that an epoch ends
    exactly with the last batch of the dataset.
    """
    i = 1
    while 1:
        (batch_pairs, batch_labels) = fetch_batch(i, data, labels,
                                                  indices, batch_size)
        if i == steps:
            i = 1 # avoids going too far in the data
            # will preload the first batches for the next epoch
        else:
            i += 1 # go for the next batch
        yield [batch_pairs[:, 0], batch_pairs[:, 1]], batch_labels

Don’t mind my fetch_batch function, it basically index a batch of data with the index i.

0reactions
stale[bot]commented, Dec 12, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to use Keras fit and fit_generator (a hands-on tutorial)
In this tutorial, you will learn how the Keras .fit and .fit_generator functions work, including the differences between them.
Read more >
Keras fit, fit_generator, train_on_batch
The Keras.fit_generator() train the model on data generated batch-by-batch by a Python generator. Keras' fit_generator method is a dynamic method that takes ...
Read more >
Train model in batches using fit_generator - Stack Overflow
Here's the proper way of using generators: Make a generator that yields individual datums. Create a Dataset from it and use batch method...
Read more >
keras.fit() and keras.fit_generator() - GeeksforGeeks
Here we are training our network for 10 epochs along with the default batch size of 32. For small and less complex datasets...
Read more >
How to use Keras fit and fit_generator in Python - Value ML
Keras .fit_generator() function ; # initialize the number of epochs and batch size ; epochs = 100 ; bs = 32 ; #...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found