Modification of network: How to handle multiple inputs
See original GitHub issueHi, I am modifying this code to write a new function which conditions on both action and state to reduce variance(I know this may introduce bias but can be corrected later like in the Q-prop paper).
I have two batches data, action
and observ
, with shape [batch_size, act_dim]
and [batch_size, obs_dim]
, respectively, and I want to feed them into tf.nn.dynamic_rnn
.
Since tf.nn.dynamic_rnn
expect input with shape as [batch_size, max_time, input_size]
, so we can input action[:, None]
and observ[:, None]
instead to match shape.
What I want is to inherit tf.contrib.rnn.RNNCell
and process action
and observ
inside __call__(self, input, state)
, so I really need to input both observ
and action
instead of merge them first and then input.
However, I do not know how to handle two inputs for tf.nn.dynamic_rnn
.
documentation says that it accepts tuple of tensor, so I input tuple_input = [action, observ]
and hope to get action
and observ
inside __call__
through tuple_input[0]
and tuple_input[1]
. However, an error occurs:
File "/opt/anaconda/envs/rl/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 547, in dynamic_rnn
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
File "/opt/anaconda/envs/rl/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 547, in <genexpr>
flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
File "/opt/anaconda/envs/rl/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 67, in _transpose_batch_time
(x, x_static_shape))
ValueError: Expected input tensor Tensor("main/CheckNumerics_1:0", shape=(2,), dtype=float32, device=/device:CPU:0) to have rank at least 2, but saw shape: (2,)
It seems that I can not input a tuple, could you please suggest how to handle multiple inputs for tf.nn.dynamic_rnn
.
The actual code using tf.nn.dynamic_rnn()
is follows,
tuple_input = [action[:, None], observ[:, None]]
cell = self._config.network(self._batch_env.action.shape[1].value)
(mean, logstd, value), state = tf.nn.dynamic_rnn(
cell, tuple_input, length, state, tf.float32, swap_memory=True)
And this how we inherit tf.contrib.rnn.RNNCell
,
class NewNetwork(tf.contrib.rnn.RNNCell):
""" Inherited RNN Network
"""
def __init__(
self, layers, action_size,
mean_weights_initializer=_MEAN_WEIGHTS_INITIALIZER,
logstd_initializer=_LOGSTD_INITIALIZER):
self._layers = layers
self._action_size = action_size
self._mean_weights_initializer = mean_weights_initializer
self._logstd_initializer = logstd_initializer
@property
def state_size(self):
unused_state_size = 1
return unused_state_size
@property
def output_size(self):
return tf.TensorShape([])
def __call__(self, obsact, state):
with tf.variable_scope('network'):
observation = obsact[0]
action = obsact[1]
x = tf.contrib.layers.flatten(observation)
y = tf.contrib.layers.flatten(action)
for size in self._stein_layers:
x = tf.contrib.layers.fully_connected(x, size, tf.nn.relu)
y = tf.contrib.layers.fully_connected(y, size, tf.nn.relu)
xy = tf.concat(x, y, axis=0)
value = tf.contrib.layers.fully_connected(xy, 1, None)[:, 0]
return (value), state
Issue Analytics
- State:
- Created 6 years ago
- Comments:7 (2 by maintainers)
@danijar Hi, when i try tuple input (batch x time x observation, batch x time x action) in tf.nn.dynamic_rnn(), it gets error: Shape (2, 11, 256) must have rank 2… is this the correct way?
Did you try the last release? @danijar changed it and it no longer directly inherits RNNCell, instead, it uses function.