Dueling network with atari
See original GitHub issueI’m using the dueling dqn with atari environnment. After creating the neural network, I have an error :
`model = Sequential() model.add(Lambda(lambda a: a / 255.0,input_shape=(minecraft_resolution[0],minecraft_resolution[1],3))) model.add(Permute((3, 1, 2))) model.add(Conv2D(32, (8, 8), strides=(2, 2), activation=activation)) model.add(Conv2D(32, (4, 4), strides=(2, 2), activation=activation)) model.add(Conv2D(32, (3, 3), strides=(2, 2), activation=activation)) model.add(Conv2D(32, (2, 2), strides=(1, 1), activation=activation)) model.add(TimeDistributed(Flatten())) model.add(LSTM(128)) for i in xrange(nb_layers): model.add(Dense(hidden_size, activation=activation)) model.add(Dense(env.action_space.n + 1)) model.add(Lambda(lambda a: K.expand_dims(a[:, 0], axis=-1) + a[:, 1:], output_shape=(env.action_space.n,)))
…
memory = SequentialMemory(limit=10000000, window_length=1) agent = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10, enable_dueling_network=True, dueling_type=‘avg’, target_model_update=1e-3, policy=policy,processor=processor) agent.compile(optimizer, metrics=[‘mse’]) …
`
ValueError: Error when checking : expected input_1 to have 4 dimensions, but got array with shape (1, 1, 200, 200, 3) I think it’s because the neural network is trained with a batch.
def process_observation(self, observation): return(imresize(observation,(200,200)).shape) the shape is (200, 200, 3)
Issue Analytics
- State:
- Created 6 years ago
- Comments:11
Top GitHub Comments
You can remove the unnecessary dimension(s) by using your own processor.
The problem is that the emulator returns observation of shape (1,1,224,224,3) which when passed to your input (shape ?,224,224,3) causes a value error due to dimension mismatch try changing input shape to
now your input must expect a shape of (?,1,224,224,3) note that the extra dimension we add here represents the window of sequence observed for storing in the memory which we set using “window_length=1” parameter in SequentialMemory