Using tf.data.Dataset API
See original GitHub issueI was trying to use the tf Dataset API
with keras but I am getting weird errors. Here is my code:
def data_gen(X=None, y=None, batch_size=32, nb_epochs=1, sess=None):
def _parse_function(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.cast(tf.image.decode_jpeg(image_string), tf.float32)
image_decoded = (image_decoded / tf.constant(127.5)) - tf.constant(1.)
image_resized = tf.image.resize_images(image_decoded, [224, 224])
return image_resized, label
dataset = tf.data.Dataset.from_tensor_slices((X,y))
dataset = dataset.map(_parse_function)
dataset = dataset.batch(batch_size).repeat(nb_epochs)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
for i in range(nb_epochs):
sess.run(iterator.initializer)
while True:
try:
nxb, nxl = sess.run(next_element)
nxl = to_categorical(nxl, num_classes=10)
yield nxb, nxl
except tf.errors.OutOfRangeError:
break
train_images = tf.constant(train_df['image'].values)
train_labels = tf.constant([labels_dict[l] for l in train_df['label'].values])
valid_images = tf.constant(valid_df['image'].values)
valid_labels = tf.constant([labels_dict[l] for l in valid_df['label'].values])
sess = K.get_session()
model = get_model()
train_gen = data_gen(X=train_images, y=train_labels, nb_epochs=10, sess=sess)
valid_gen = data_gen(X=valid_images, y=valid_labels, nb_epochs=10, sess=sess)
batch_size=32
nb_train_steps = train_images.shape.num_elements() // batch_size
nb_valid_steps = valid_images.shape.num_elements() // batch_size
# Fit the model
model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)
The last line throws this error:
model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)
Epoch 1/1
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-32-be01c033aa94> in <module>()
----> 1 model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)
/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
89 warnings.warn('Update your `' + object_name +
90 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91 return func(*args, **kwargs)
92 wrapper._original_function = func
93 return wrapper
/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
1415 use_multiprocessing=use_multiprocessing,
1416 shuffle=shuffle,
-> 1417 initial_epoch=initial_epoch)
1418
1419 @interfaces.legacy_generator_methods_support
/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
153 batch_index = 0
154 while steps_done < steps_per_epoch:
--> 155 generator_output = next(output_generator)
156
157 if not hasattr(generator_output, '__len__'):
/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in get(self)
791 success, value = self.queue.get()
792 if not success:
--> 793 six.reraise(value.__class__, value, value.__traceback__)
/opt/conda/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
691 if value.__traceback__ is not tb:
692 raise value.with_traceback(tb)
--> 693 raise value
694 finally:
695 value = None
/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in _data_generator_task(self)
656 # => Serialize calls to
657 # infinite iterator/generator's next() function
--> 658 generator_output = next(self._generator)
659 self.queue.put((True, generator_output))
660 else:
<ipython-input-18-a1077a5b59f8> in data_gen(X, y, batch_size, nb_epochs, sess)
11 dataset = dataset.map(_parse_function)
12 dataset = dataset.batch(batch_size).repeat(nb_epochs)
---> 13 iterator = dataset.make_initializable_iterator()
14 next_element = iterator.get_next()
15
/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in make_initializable_iterator(self, shared_name)
106 sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
107 with ops.colocate_with(iterator_resource):
--> 108 initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
109 iterator_resource)
110 return iterator_ops.Iterator(iterator_resource, initializer,
/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
1402 def _as_variant_tensor(self):
1403 return gen_dataset_ops.repeat_dataset(
-> 1404 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
1405 count=self._count,
1406 output_shapes=nest.flatten(
/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
1646 sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
1647 output_types=nest.flatten(
-> 1648 sparse.as_dense_types(self.output_types, self.output_classes)))
1649
1650 @property
/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py in batch_dataset(input_dataset, batch_size, output_types, output_shapes, name)
54 _, _, _op = _op_def_lib._apply_op_helper(
55 "BatchDataset", input_dataset=input_dataset, batch_size=batch_size,
---> 56 output_types=output_types, output_shapes=output_shapes, name=name)
57 _result = _op.outputs[:]
58 _inputs_flat = _op.inputs
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
348 # Need to flatten all the arguments into a list.
349 # pylint: disable=protected-access
--> 350 g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
351 # pylint: enable=protected-access
352 except AssertionError as e:
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph)
5651 graph = graph_element.graph
5652 elif original_graph_element is not None:
-> 5653 _assert_same_graph(original_graph_element, graph_element)
5654 elif graph_element.graph is not graph:
5655 raise ValueError("%s is not from the passed-in graph." % graph_element)
/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item)
5587 if original_item.graph is not item.graph:
5588 raise ValueError("%s must be from the same graph as %s." % (item,
-> 5589 original_item))
5590
5591
ValueError: Tensor("batch_size:0", shape=(), dtype=int64) must be from the same graph as Tensor("MapDataset_3:0", shape=(), dtype=variant).
Issue Analytics
- State:
- Created 5 years ago
- Reactions:1
- Comments:13 (3 by maintainers)
Top Results From Across the Web
A Gentle Introduction to the tensorflow.data API
Yet another way of providing data is to use tf.data dataset. In this tutorial, you will see how you can use the tf.data...
Read more >Building a data pipeline - CS230 Deep Learning
# An overview of tf.data ... The Dataset API allows you to build an asynchronous, highly optimized data pipeline to prevent your GPU...
Read more >How to use Dataset in TensorFlow - Towards Data Science
I can now easily create a Dataset from it by calling tf.contrib.data.make_csv_dataset . Be aware that the iterator will create a dictionary with...
Read more >Intro to Data Input Pipelines with tf.data - Kaggle
data is indeed an API to build an ETL (Extract, Transform, and Load) pipeline to feed data to a tensorflow model. If you're...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
@fchollet
When I try this way, got error as:
which is caused by
iterator.get_next()
return tensors with shape (?,?). But I don’t understand why it’s batch size is None, even withdataset.bacth(BATCH_SIZE)
.What do you mean?
fit_generator
is no longer necessary if you are fitting from a TF dataset.Please open a new issue.
Use it if your data is already in Dataset format. One reason to use Dataset is that it may offer better performance than multi-process Python generators in some cases.