question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[Bug] [rllib] Attention and FrameStackingModel work poorly

See original GitHub issue

Search before asking

  • I searched the issues and found no similar issues.

Ray Component

RLlib

What happened + What you expected to happen

I’ve been experimenting with Ray RLlib’s StatelessCartPole environment, where some observations are hidden, and with different options for how to deal with these partial observations. I noticed two problems:

  • Frame stacking inside the model using the FrameStackingCartPoleModel works a lot worse than frame stacking in the environment with the FrameStack wrapper. I expected them to do the same thing and work roughly similarly well.
  • Using attention with default parameters does not improve learning at all. I expected attention to help a lot here - as is suggested in the attention net example

See reproduction scripts for details.

There’s also a discussion on discourse: https://discuss.ray.io/t/lstm-and-attention-on-stateless-cartpole/4293/3

Here’s my blog post where with a notebook and details: https://stefanbschneider.github.io/blog/rl-partial-observability

Versions / Dependencies

Python 3.8 on Windows 10 Gym 0.21.0 (Had errors with frame stacking on lower versions!) Ray 2.0.0.dev0: Latest wheels from Dec 1, 2021; commit on master 0467bc9

Reproduction script

Details for reproducibility: I’m experimenting with PPO, the default config, 10 training iterations, and 3 CPUs

Default PPO on StatelessCartPole: Reward 51

By default (no frame stacking, no attention), PPO does not work so well on StatelessCartPole - as expected. For me, it leads to a reward of 51 after 10 train iterations; code below:

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry


registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())



ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

PPO + Attention on StatelessCartPole: Reward 39 -> Bug?

Now, the same thing with attention enabled and otherwise default params. I expected a much higher reward than without attention, but it’s almost the same - even slightly worse! Why?

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry


registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["model"] = {
    # attention
    "_use_default_native_models": True,
    "use_attention": True,
}

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

Even adding extra model params from the attention net example doesn’t help:

# extra model config from attention_net.py example
    "max_seq_len": 10,
    "attention_num_transformer_units": 1,
    "attention_dim": 32,
    "attention_memory_inference": 10,
    "attention_memory_training": 10,
    "attention_num_heads": 1,
    "attention_head_dim": 32,
    "attention_position_wise_mlp_dim": 32,

Stacked Frames in Env: Reward 202

Now, stacking frames in the environment with gym.wrappers.FrameStack works really well - without attention. For me, the reward is at 202 after 10 train iters (with 10 stacked frames). Code is identical except for creating a StackedStatelessCartPole env with stacked frames:

from gym.wrappers import FrameStack

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry


NUM_FRAMES = 10

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())
registry.register_env("StackedStatelessCartPole",
                      lambda _: FrameStack(StatelessCartPole(), NUM_FRAMES))


ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StackedStatelessCartPole"

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

Stacked Frames in Model: Reward 105 --> Bug?

Now, the same thing - stacking 10 frames - within the model rather than the environment. Surprisingly, this leads to a reward that’s just half as high: 95 after 10 iters: Why? Isn’t this supposed to do the same thing?

import ray
from ray import tune
from ray.rllib.agents import ppo
from ray.rllib.examples.env.stateless_cartpole import StatelessCartPole
from ray.tune import registry
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.examples.models.trajectory_view_utilizing_models import FrameStackingCartPoleModel


NUM_FRAMES = 10

registry.register_env("StatelessCartPole", lambda _: StatelessCartPole())

ModelCatalog.register_custom_model("stacking_model", FrameStackingCartPoleModel)


ray.init(num_cpus=3)

config = ppo.DEFAULT_CONFIG
config["env"] = "StatelessCartPole"
config["model"] = {
     "custom_model": "stacking_model",
     "custom_model_config": {
         "num_frames": NUM_FRAMES,
     },
}

stop = {"training_iteration": 10}
results = tune.run("PPO", config=config, stop=stop)

TL;DR: Why does attention not help? Why does frame stacking within the model lead to much worse results than within the env?

Anything else

I think these are really two issues (or misunderstandings/misconfigurations from my side):

  • Why is frame stacking in the model worse than in the environment and how to fix it?
  • Why does attention not lead to better results here?

Maybe this is just an issue with the default configuration in this scenario…

I’m also willing to submit a PR, but am not sure where to start looking. I’ll post any new insights in the comments.

Are you willing to submit a PR?

  • Yes I am willing to submit a PR!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (10 by maintainers)

github_iconTop GitHub Comments

2reactions
sven1977commented, Jan 18, 2022

I’m closing this issue. Feel free to re-open should you think that the above answers are not sufficient 😃

@stefanbschneider @mickelliu

2reactions
mickelliucommented, Jan 18, 2022

Thanks @sven1977, it could be because that attention net and lstm are notoriously harder to train. Especially in our custom CV environment, we found that the LSTM stagnant performance problem was mainly attributed to the shared layers between value and actor networks, separating the two branches help the convergence of LSTM models… And as for GTrXL the original paper used 12 transformer layers and trained billions of timesteps to make it work in their environments…

Read more comments on GitHub >

github_iconTop Results From Across the Web

[RLlib Trajectory View API] Trajectory View API can replace ...
The Trajectory View API is able to replace the need for the frame-stacking option in the models through the shift parameter in the ......
Read more >
LSTM and Attention on Stateless CartPole - RLlib - Ray
I opened an issue with reproduction script here: [Bug] [rllib] Attention and FrameStackingModel work poorly · Issue #20827 · ray-project/ray · GitHub.
Read more >
Attention Nets and More with RLlib's Trajectory View API
In this post, we're announcing two new features now stable in RLlib: Support for Attention networks as custom models, and the “trajectory ...
Read more >
[P] Attention Nets and More with RLlib's Trajectory View API
Hey Everyone, I wanted to share two new features now stable in RLlib : Support for Attention networks as custom models, and the...
Read more >
RLlib - Error with Custom env - continuous action space - DDPG
Estimated DQN with this experience data and it ran through. Changed environment action space to be continuous (Box(,1)) and DDPG did not work....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found