question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unexpected interaction between between dm_control and JAX

See original GitHub issue

The order of importing jax and dm_control has a large effect on FPS. I’m using dm-control==1.0.3, jax==0.3.1 and jaxlib==0.3.0+cuda11.cudnn82

The script below reproduces the issue, with code adapted from https://github.com/ikostrikov/jaxrl/tree/main/jaxrl/wrappers

What should be the correct order of importing the libraries?

fast = True
if fast:
    from dm_control import suite
    import jax
else:
    import jax
    from dm_control import suite

from dm_env import specs
import numpy as np
from typing import Dict, Optional, OrderedDict
import copy
import gym
from gym import core, spaces



def dmc_spec2gym_space(spec):
    if isinstance(spec, OrderedDict):
        spec = copy.copy(spec)
        for k, v in spec.items():
            spec[k] = dmc_spec2gym_space(v)
        return spaces.Dict(spec)
    elif isinstance(spec, specs.BoundedArray):
        return spaces.Box(low=spec.minimum,
                          high=spec.maximum,
                          shape=spec.shape,
                          dtype=spec.dtype)
    elif isinstance(spec, specs.Array):
        return spaces.Box(low=-float('inf'),
                          high=float('inf'),
                          shape=spec.shape,
                          dtype=spec.dtype)
    else:
        raise NotImplementedError


class DMCEnv(core.Env):
    def __init__(self,
                 domain_name: str,
                 task_name: str,
                 task_kwargs: Optional[Dict] = {},
                 environment_kwargs=None):
        assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'

        self._env = suite.load(domain_name=domain_name,
                               task_name=task_name,
                               task_kwargs=task_kwargs,
                               environment_kwargs=environment_kwargs)
        self.action_space = dmc_spec2gym_space(self._env.action_spec())

        self.observation_space = dmc_spec2gym_space(
            self._env.observation_spec())

        self.seed(seed=task_kwargs['random'])

    def __getattr__(self, name):
        return getattr(self._env, name)

    def step(self, action):
        assert self.action_space.contains(action)

        time_step = self._env.step(action)
        reward = time_step.reward or 0
        done = time_step.last()
        obs = time_step.observation

        info = {}
        if done and time_step.discount == 1.0:
            info['TimeLimit.truncated'] = True

        return obs, reward, done, info

    def reset(self):
        time_step = self._env.reset()
        return time_step.observation


def make_env(env, seed):
    domain_name, task_name = env.split('-')
    env = DMCEnv(
        domain_name = domain_name,
        task_name   = task_name,
        task_kwargs = {'random': seed}
    )
    if isinstance(env.observation_space, gym.spaces.Dict):
        env = gym.wrappers.FlattenObservation(env)
    return env


env = make_env('humanoid-run', 42)
for t in range(10000):
    action = env.action_space.sample()
    next_state, reward, done, info = env.step(action)

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:11 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
nimrod-gileadicommented, Jul 14, 2022

Thanks.

I ran a stripped version of the script above with cProfile. It appears that this list comprehension is where the slowdown comes from. I have no idea why.

The comprehension is over a list of pybind11 structs, so maybe JAX affects pybind11 bindings in some way.

The simplified script, with cProfile:

fast = True
if fast:
    from dm_control import suite
    import jax
else:
    import jax
    from dm_control import suite

import cProfile
from dm_env import specs
import numpy as np

def make_env(env, seed):
    domain_name, task_name = env.split('-')
    env = suite.load(domain_name=domain_name,
                     task_name=task_name,
                     task_kwargs={'random': seed},
                     environment_kwargs=None)
    return env


env = make_env('humanoid-run', 42)
action = np.zeros(env.action_spec().shape)

def loop(env, action):
  for t in range(10000):
      timestep = env.step(action)

cProfile.run("loop(env, action)")
1reaction
nimrod-gileadicommented, Jul 18, 2022

We still don’t know the cause for this.

It’s not a high priority issue for us, so it’s unlikely to be fixed soon. For now, could you import in alphabetical order? 😝

Read more comments on GitHub >

github_iconTop Results From Across the Web

Datasets and Benchmarks - NeurIPS 2022
ConfLab: A Data Collection Concept, Dataset, and Benchmark for Machine Analysis of Free-Standing Social Interactions in the Wild.
Read more >
Cooling efficiency of vests with different cooling concepts over ...
Practitioner summary: The study assessed the cooling capacity of commercially available vests, using a thermal manikin. The vests present an affordable solution ...
Read more >
A Walk in the Park: Learning to Walk in 20 Minutes With Model ...
Reinforcement learning offers a promising alternative, acquiring effective control strategies directly through interaction with the real ...
Read more >
A PALB2-interacting domain in RNF168 couples homologous ...
The interaction between BRCA1 and PALB2 is tightly regulated in a ... In this study, we reveal an unexpected role of RNF168 in...
Read more >
REPORT DOCUMENTATION PAGE - DTIC
We are also investigating a potential causal relationship between enhanced ... Consequently, an unexpected benefit of sustained low-.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found