Unexpected interaction between between dm_control and JAX
See original GitHub issueThe order of importing jax
and dm_control
has a large effect on FPS.
I’m using dm-control==1.0.3
, jax==0.3.1
and jaxlib==0.3.0+cuda11.cudnn82
The script below reproduces the issue, with code adapted from https://github.com/ikostrikov/jaxrl/tree/main/jaxrl/wrappers
What should be the correct order of importing the libraries?
fast = True
if fast:
from dm_control import suite
import jax
else:
import jax
from dm_control import suite
from dm_env import specs
import numpy as np
from typing import Dict, Optional, OrderedDict
import copy
import gym
from gym import core, spaces
def dmc_spec2gym_space(spec):
if isinstance(spec, OrderedDict):
spec = copy.copy(spec)
for k, v in spec.items():
spec[k] = dmc_spec2gym_space(v)
return spaces.Dict(spec)
elif isinstance(spec, specs.BoundedArray):
return spaces.Box(low=spec.minimum,
high=spec.maximum,
shape=spec.shape,
dtype=spec.dtype)
elif isinstance(spec, specs.Array):
return spaces.Box(low=-float('inf'),
high=float('inf'),
shape=spec.shape,
dtype=spec.dtype)
else:
raise NotImplementedError
class DMCEnv(core.Env):
def __init__(self,
domain_name: str,
task_name: str,
task_kwargs: Optional[Dict] = {},
environment_kwargs=None):
assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour'
self._env = suite.load(domain_name=domain_name,
task_name=task_name,
task_kwargs=task_kwargs,
environment_kwargs=environment_kwargs)
self.action_space = dmc_spec2gym_space(self._env.action_spec())
self.observation_space = dmc_spec2gym_space(
self._env.observation_spec())
self.seed(seed=task_kwargs['random'])
def __getattr__(self, name):
return getattr(self._env, name)
def step(self, action):
assert self.action_space.contains(action)
time_step = self._env.step(action)
reward = time_step.reward or 0
done = time_step.last()
obs = time_step.observation
info = {}
if done and time_step.discount == 1.0:
info['TimeLimit.truncated'] = True
return obs, reward, done, info
def reset(self):
time_step = self._env.reset()
return time_step.observation
def make_env(env, seed):
domain_name, task_name = env.split('-')
env = DMCEnv(
domain_name = domain_name,
task_name = task_name,
task_kwargs = {'random': seed}
)
if isinstance(env.observation_space, gym.spaces.Dict):
env = gym.wrappers.FlattenObservation(env)
return env
env = make_env('humanoid-run', 42)
for t in range(10000):
action = env.action_space.sample()
next_state, reward, done, info = env.step(action)
Issue Analytics
- State:
- Created a year ago
- Comments:11 (6 by maintainers)
Top Results From Across the Web
Datasets and Benchmarks - NeurIPS 2022
ConfLab: A Data Collection Concept, Dataset, and Benchmark for Machine Analysis of Free-Standing Social Interactions in the Wild.
Read more >Cooling efficiency of vests with different cooling concepts over ...
Practitioner summary: The study assessed the cooling capacity of commercially available vests, using a thermal manikin. The vests present an affordable solution ...
Read more >A Walk in the Park: Learning to Walk in 20 Minutes With Model ...
Reinforcement learning offers a promising alternative, acquiring effective control strategies directly through interaction with the real ...
Read more >A PALB2-interacting domain in RNF168 couples homologous ...
The interaction between BRCA1 and PALB2 is tightly regulated in a ... In this study, we reveal an unexpected role of RNF168 in...
Read more >REPORT DOCUMENTATION PAGE - DTIC
We are also investigating a potential causal relationship between enhanced ... Consequently, an unexpected benefit of sustained low-.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks.
I ran a stripped version of the script above with cProfile. It appears that this list comprehension is where the slowdown comes from. I have no idea why.
The comprehension is over a list of pybind11 structs, so maybe JAX affects pybind11 bindings in some way.
The simplified script, with cProfile:
We still don’t know the cause for this.
It’s not a high priority issue for us, so it’s unlikely to be fixed soon. For now, could you import in alphabetical order? 😝