Bug in masking of output in K.rnn(..., unroll=False) (for tensorflow and cntk)
See original GitHub issueSummary
Outputs are not masked correctly in tensorflow_backend.rnn(..., unroll=False)
. The issue is that states[0]
is assumed to be equal to the output
of the step_function
in this line (not so in other backends or with unroll=True
). This holds for the built-in RNNCells, which is the reason the bug has gone undetected. Especially since the introduction of output_size
in the RNNCell it is clear that this should not generally be assumed.
Implications
RNN
returns the wrong output when mask is used and theoutput
is not equal tostates[0]
but has same size - i.e. a quiet error:
class Cell(keras.layers.Layer):
def __init__(self):
self.state_size = None
self.output_size = None
super(Cell, self).__init__()
def build(self, input_shape):
self.state_size = input_shape[-1]
self.output_size = input_shape[-1]
def call(self, inputs, states):
return inputs, [s + 1 for s in states]
x = Input((3, 1), name="x")
x_masked = Masking()(x)
s_0 = Input((1,), name="s_0")
y, s = recurrent.RNN(Cell(),
return_state=True,
unroll=False)(x_masked, initial_state=s_0)
model = Model([x, s_0], [y, s])
model.compile(optimizer='sgd', loss='mse')
# last time step masked
x_arr = np.array([[[1.],[2.],[0.]]])
s_0_arr = np.array([[10.]])
y_arr, s_arr = model.predict([x_arr, s_0_arr])
# 1 is added to initial state two times
assert_allclose(s_arr, s_0_arr + 2)
# expect last output to be the same as last output before masking
assert_allclose(y_arr, x_arr[:, 1, :]) # Fails!
Gives:
AssertionError:
Not equal to tolerance rtol=1e-07, atol=0
(mismatch 100.0%)
x: array([12.], dtype=float32)
y: array([2.])
- Exception is raised when trying to apply an RNN with a cell which
output_size != state_size[0]
class Cell(keras.layers.Layer):
def __init__(self):
self.state_size = None
self.output_size = None
super(Cell, self).__init__()
def build(self, input_shape):
self.state_size = input_shape[-1]
self.output_size = input_shape[-1] * 2
def call(self, inputs, states):
return keras.layers.concatenate([inputs]*2), [s + 1 for s in states]
x = Input((3, 1), name="x")
x_masked = Masking()(x)
s_0 = Input((1,), name="s_0")
y, s = recurrent.RNN(Cell(),
return_state=True,
unroll=False)(x_masked, initial_state=s_0) # Fails!
Gives:
ValueError: Dimension 1 in both shapes must be equal, but are 2 and 1. Shapes are [?,2] and [?,1]. for 'rnn_1/while/Select' (op: 'Select') with input shapes: [?,?], [?,2], [?,1].
Issue Analytics
- State:
- Created 5 years ago
- Comments:17 (17 by maintainers)
Top Results From Across the Web
Masking and padding with Keras | TensorFlow Core
Padding is a special form of masking where the masked steps are at the ... output = self.lstm(x, mask=mask) # The layer will...
Read more >TF2 theano CNTK - USC Bytes
TensorFlow 2.0. The 2.3.0 release will be the last major release of multi-backend Keras. Multi-backend. Keras is superseded by tf.keras. Bugs present in ......
Read more >keras.pdf - The Comprehensive R Archive Network
'Keras' was developed with a focus on enabling fast experimentation, supports both convolution based networks and recurrent networks (as well as.
Read more >Package 'keras' - CRAN
The model and the weights are compatible with TensorFlow, Theano, and CNTK. The data format convention used by the model is the one...
Read more >How does masking work in an RNN (and variants) and why
The model takes the mini-batch as input and outputs the RNN outputs ... but please read the tensorflow guide to see what else...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I’ve put minor because you’ve already made a PR for this bug. But of course this is not something to overlook. Let me change the tag. Also thanks for your quick reaction and PR on this bug!
Yes, thanks, that was the fix I made in the PR I opened together with this issue 😉