question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Using tf.data.Dataset API

See original GitHub issue

I was trying to use the tf Dataset API with keras but I am getting weird errors. Here is my code:


def data_gen(X=None, y=None, batch_size=32, nb_epochs=1, sess=None):
    def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.cast(tf.image.decode_jpeg(image_string), tf.float32)
        image_decoded = (image_decoded / tf.constant(127.5)) - tf.constant(1.)
        image_resized = tf.image.resize_images(image_decoded, [224, 224])
        
        return image_resized, label
    
    dataset = tf.data.Dataset.from_tensor_slices((X,y))
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(batch_size).repeat(nb_epochs)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    
    for i in range(nb_epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                nxb, nxl = sess.run(next_element)
                nxl = to_categorical(nxl, num_classes=10)
                yield nxb, nxl
            except tf.errors.OutOfRangeError:
                break


train_images = tf.constant(train_df['image'].values)  
train_labels = tf.constant([labels_dict[l] for l in train_df['label'].values])

valid_images = tf.constant(valid_df['image'].values)
valid_labels = tf.constant([labels_dict[l] for l in valid_df['label'].values])

sess = K.get_session()
model = get_model()

train_gen = data_gen(X=train_images, y=train_labels, nb_epochs=10, sess=sess)
valid_gen = data_gen(X=valid_images, y=valid_labels, nb_epochs=10, sess=sess)

batch_size=32
nb_train_steps = train_images.shape.num_elements() // batch_size
nb_valid_steps = valid_images.shape.num_elements() // batch_size

# Fit the model
model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

The last line throws this error:

model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

Epoch 1/1

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-32-be01c033aa94> in <module>()
----> 1 model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1415             use_multiprocessing=use_multiprocessing,
   1416             shuffle=shuffle,
-> 1417             initial_epoch=initial_epoch)
   1418 
   1419     @interfaces.legacy_generator_methods_support

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    153             batch_index = 0
    154             while steps_done < steps_per_epoch:
--> 155                 generator_output = next(output_generator)
    156 
    157                 if not hasattr(generator_output, '__len__'):

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

/opt/conda/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-18-a1077a5b59f8> in data_gen(X, y, batch_size, nb_epochs, sess)
     11     dataset = dataset.map(_parse_function)
     12     dataset = dataset.batch(batch_size).repeat(nb_epochs)
---> 13     iterator = dataset.make_initializable_iterator()
     14     next_element = iterator.get_next()
     15 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in make_initializable_iterator(self, shared_name)
    106             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    107     with ops.colocate_with(iterator_resource):
--> 108       initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
    109                                                   iterator_resource)
    110     return iterator_ops.Iterator(iterator_resource, initializer,

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
   1402   def _as_variant_tensor(self):
   1403     return gen_dataset_ops.repeat_dataset(
-> 1404         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
   1405         count=self._count,
   1406         output_shapes=nest.flatten(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
   1646             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
   1647         output_types=nest.flatten(
-> 1648             sparse.as_dense_types(self.output_types, self.output_classes)))
   1649 
   1650   @property

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py in batch_dataset(input_dataset, batch_size, output_types, output_shapes, name)
     54     _, _, _op = _op_def_lib._apply_op_helper(
     55         "BatchDataset", input_dataset=input_dataset, batch_size=batch_size,
---> 56         output_types=output_types, output_shapes=output_shapes, name=name)
     57     _result = _op.outputs[:]
     58     _inputs_flat = _op.inputs

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    348       # Need to flatten all the arguments into a list.
    349       # pylint: disable=protected-access
--> 350       g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    351       # pylint: enable=protected-access
    352     except AssertionError as e:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph)
   5651         graph = graph_element.graph
   5652       elif original_graph_element is not None:
-> 5653         _assert_same_graph(original_graph_element, graph_element)
   5654       elif graph_element.graph is not graph:
   5655         raise ValueError("%s is not from the passed-in graph." % graph_element)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item)
   5587   if original_item.graph is not item.graph:
   5588     raise ValueError("%s must be from the same graph as %s." % (item,
-> 5589                                                                 original_item))
   5590 
   5591 

ValueError: Tensor("batch_size:0", shape=(), dtype=int64) must be from the same graph as Tensor("MapDataset_3:0", shape=(), dtype=variant).


Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:1
  • Comments:13 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
zhangyi02commented, Jan 17, 2019

@fchollet

x, y = iterator.get_next()
model.fit(x, y, steps_per_epoch=steps_per_epoch, epochs=epochs)

When I try this way, got error as:

When feeding symbolic tensors to a model, we expect the tensors to have a static batch size. Got tensor with shape: (None, None)

which is caused by iterator.get_next() return tensors with shape (?,?). But I don’t understand why it’s batch size is None, even with dataset.bacth(BATCH_SIZE) .

1reaction
fcholletcommented, May 16, 2018

Will this work with fit_generator() too?

What do you mean? fit_generator is no longer necessary if you are fitting from a TF dataset.

I was using Eager and with Keras and so many things break in that. Should I open another issue or post the issue in this thread only?

Please open a new issue.

Why would I need to use the Dataset API with Keras? Does it provide any functionality that fit_generator on its own does not? Thank you!

Use it if your data is already in Dataset format. One reason to use Dataset is that it may offer better performance than multi-process Python generators in some cases.

Read more comments on GitHub >

github_iconTop Results From Across the Web

tf.data.Dataset | TensorFlow v2.11.0
Represents a potentially large set of elements.
Read more >
A Gentle Introduction to the tensorflow.data API
Yet another way of providing data is to use tf.data dataset. In this tutorial, you will see how you can use the tf.data...
Read more >
Building a data pipeline - CS230 Deep Learning
# An overview of tf.data ... The Dataset API allows you to build an asynchronous, highly optimized data pipeline to prevent your GPU...
Read more >
How to use Dataset in TensorFlow - Towards Data Science
I can now easily create a Dataset from it by calling tf.contrib.data.make_csv_dataset . Be aware that the iterator will create a dictionary with...
Read more >
Intro to Data Input Pipelines with tf.data - Kaggle
data is indeed an API to build an ETL (Extract, Transform, and Load) pipeline to feed data to a tensorflow model. If you're...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found