question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Performance on CartPole-v0

See original GitHub issue

Hi Ku. I try to use discrete sac on CartPole-v0 with some modifications on the env.py and use fc in model.py. However, I found it failed to achieve reasonable results like getting ~200 reward. Do you have any idea why dsac failed on my code?

I list all changes here: (mark changes with #)

# in env.py 
def make_gym(env_id):
    env = gym.make(env_id)
    return env

def make_pytorch_env(env_id, episode_life=True, clip_rewards=True,
                     frame_stack=True, scale=False):
    env = make_gym(env_id)
    return env
# in sacd.py 
    def explore(self, state):
        # Act with randomness.
        state = torch.ByteTensor(
            state[None, ...]).to(self.device).float()#
        with torch.no_grad():
            action, _, _ = self.policy.sample(state)
        return action.item()

    def exploit(self, state):
        # Act without randomness.
        state = torch.ByteTensor(
            state[None, ...]).to(self.device).float()#
        with torch.no_grad():
            action = self.policy.act(state)
        return action.item()
# in model.py 
class DQNBase(BaseNetwork):

    def __init__(self, num_channels):
        super(DQNBase, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(num_channels, 64),  #
            nn.ReLU(),
            nn.Linear(64, 64), #
            nn.ReLU(),
            nn.Linear(64, 32), #
            nn.ReLU(),
            Flatten(),
        ).apply(initialize_weights_he)

    def forward(self, states):
        return self.net(states)


class QNetwork(BaseNetwork):

    def __init__(self, num_channels, num_actions, shared=False,
                 dueling_net=False):
        super().__init__()

        if not shared:
            self.fc = DQNBase(num_channels)

        if not dueling_net:
            self.head = nn.Sequential(
                nn.Linear(32, 32),#
                nn.ReLU(inplace=True),
                nn.Linear(32, num_actions))#
        else:
            self.a_head = nn.Sequential(
                nn.Linear(32, 32),#
                nn.ReLU(inplace=True),
                nn.Linear(32, num_actions))#
            self.v_head = nn.Sequential(
                nn.Linear(32, 32),#
                nn.ReLU(inplace=True),
                nn.Linear(32, 1))#

        self.shared = shared
        self.dueling_net = dueling_net

    def forward(self, states):
        if not self.shared:
            states = self.fc(states)#

        if not self.dueling_net:
            return self.head(states)
        else:
            a = self.a_head(states)
            v = self.v_head(states)
            return v + a - a.mean(1, keepdim=True)


class CateoricalPolicy(BaseNetwork):

    def __init__(self, num_channels, num_actions, shared=False):
        super().__init__()
        if not shared:
            self.fc = DQNBase(num_channels)

        self.head = nn.Sequential(
            nn.Linear(32, 32),#
            nn.ReLU(inplace=True),
            nn.Linear(32, num_actions))#

        self.shared = shared

    def act(self, states):
        if not self.shared:
            states = self.fc(states)#

        action_logits = self.head(states)
        greedy_actions = torch.argmax(
            action_logits, dim=1, keepdim=True)
        return greedy_actions

    def sample(self, states):
        if not self.shared:
            states = self.fc(states)#

        action_probs = F.softmax(self.head(states), dim=1)
        action_dist = Categorical(action_probs)
        actions = action_dist.sample().view(-1, 1)

        # Avoid numerical instability.
        z = (action_probs == 0.0).float() * 1e-8
        log_action_probs = torch.log(action_probs + z)

        return actions, action_probs, log_action_probs

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (4 by maintainers)

github_iconTop GitHub Comments

3reactions
toshikwacommented, Jun 30, 2020

Hi @YongyiTang92

I don’t think so. I implemented SAC-Discrete for CartPole for you here.

I assume that target_entropy_ratio=0.98 is too high and it results in low train return. However, even with target_entropy_ratio=0.98, test return gets around 200 at 30k steps.

Sorry for cluttered codes, but I hope it can help you.

スクリーンショット 2020-06-30 16 54 37

1reaction
toshikwacommented, Dec 23, 2020

@VCasecnikovs

Thank you for your suggestion 😉 Actually, I did it on purpose.

It reduces the latency when transferring tensors between GPU and RAM. In other words, when tensors are big (e.g. images), the cost of redundant casting (convert into “uint8”, transfer to GPU, then convert into “float32”) is much smaller than the cost of transferring “float32” tensors to GPU. (Note that it doesn’t lose any information only if tensors are images. )

This repository only supports Image inputs (Atari) now, however, I will make it compatible with vector inputs (e.g. CartPole) when I have time 😃

Thanks 😃

Read more comments on GitHub >

github_iconTop Results From Across the Web

Performance on CartPole-v0 · Issue #10 · toshikwa/sac ...
Hi Ku. I try to use discrete sac on CartPole-v0 with some modifications on the env.py and use fc in model.py. However, I...
Read more >
Solving Open AI's CartPole using Reinforcement Learning ...
A CartPole-v0 is a simple playground provided by OpenAI to train and test Reinforcement Learning algorithms. The agent is the cart, controlled by...
Read more >
Deep Q Learning for the CartPole - Towards Data Science
The pendulum starts upright, and the goal is to prevent it from falling over. The state space is represented by four values: cart...
Read more >
Why the learned DQN agent of gym CartPole-v0 is not ... - Reddit
I trained the agent to play the simplest `CartPole-v0` built in OpenAI Gym. I found that the trained agent was not always worked...
Read more >
Openai gym CartPole-v0 :Solved using DQN - YouTube
In this video i talked about how to use CartPole environment also talk about various approaches to solve this problem .
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found