question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

After Flatten() both batch and output shape of a tensor become none

See original GitHub issue

If I use Flatten() using functional API, both batch and output shape become none, where as in Sequential model its print correct output shape. I need both batch and output size later on and I have to use functional api because of my model complexity, is this is a issue of Keras? @fchollet @farizrahman4u @Dref360

`input = Input(batch_shape=[64,224,224,3])

test = Conv2D(96,kernel_size=(11,11),strides=(4,4), activation=‘relu’,name=‘conv1’)(input) print 'b4 flatten shape ', test.get_shape()

test = Flatten()(test)

print 'after flateen shape ',test.get_shape()

alexnet = Sequential()

alexnet.add(InputLayer(batch_input_shape=img_input))

alexnet.add(Conv2D(96, kernel_size=(11, 11), strides=(4, 4),
                   activation='relu', name='conv1'))

alexnet.add(Flatten())

print 'Sequential model shape ',alexnet.output_shape`

The output I got :

b4 flatten shape (64, 30, 30, 96) after flatten shape (?, ?) Sequential model shape (64, 86400)

Issue Analytics

  • State:closed
  • Created 6 years ago
  • Reactions:8
  • Comments:16 (1 by maintainers)

github_iconTop GitHub Comments

4reactions
thomasmoooncommented, Jun 21, 2018

My current workaround is to query the shape before Flatten() is done this way:

# imports
import numpy as np
import keras
from keras.layers import *

# toy net
x = Input(shape = (12,100,10))
x = Dense(32)(x)
f = Flatten()(x)

# shape inference
shape_before_flatten = x.shape.as_list()[1:] # [1:] to skip None
shape_flatten = np.prod(shape_before_flatten) # value of shape in the non-batch dimension

# print
print("x = "+ str(x))
print("shape before Flatten() = " + str(shape_before_flatten))
print("shape after Flatten() = " + str(f.shape.as_list()))
print("shape_flatten in the non-batch dimension: " + str(shape_flatten))


x = Tensor(“dense_17/BiasAdd:0”, shape=(?, 12, 100, 32), dtype=float32) shape before Flatten() = [12, 100, 32] shape after Flatten() = [None, None] shape_flatten in the non-batch dimension: 38400

4reactions
dembacommented, Apr 18, 2018

Hi, I am using Keras 2.1.5 and am getting similar issues with Flatten() and Reshape not behaving as expected. The output of the same commands as above are

Tensor("input_5:0", shape=(16, 10, 10), dtype=float32)
Tensor("flatten_15/Reshape:0", shape=(?, ?), dtype=float32)
Tensor("reshape_27/Reshape:0", shape=(?, ?), dtype=float32)

The work around suggested in the previous comment doesn’t work. My current workaround is

x = Input(batch_shape=(16, 10, 10))
x = keras.layers.Reshape((100,))(x)
print(x)

which outputs

Tensor("reshape_29/Reshape:0", shape=(16, 100), dtype=float32)
Read more comments on GitHub >

github_iconTop Results From Across the Web

keras Flatten giving wrong output shape - Stack Overflow
I am using keras Flatten() layer after a dropout layer whose output shape is (?,35,50). The output of Flatten() is (?,?) whereas it...
Read more >
What is tensorflow flatten with Examples? - eduCBA
Flatten() function will be (size of the batch, 12). ... Flatten(set of input, collection of output: name of class = None, scope: name...
Read more >
Flatten layer - Keras
Note: If inputs are shaped (batch,) without a feature axis, then flattening adds an extra channel dimension and output shape is (batch, 1)...
Read more >
tf.data.Dataset | TensorFlow v2.11.0
Represents a potentially large set of elements.
Read more >
Beginners Guide to Debugging TensorFlow Models - KDnuggets
This is because we are passing the input shape of (28,28) and 1 extra dimension added by TensorFlow for Batch size, so the...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found