Skip to content

Commit

Permalink
Lazily import locomotion envs to prevent ModuleNotFoundError when lab…
Browse files Browse the repository at this point in the history
…maze is not installed (#125)
  • Loading branch information
GaetanLepage authored Oct 8, 2024
1 parent 780d86b commit 095d576
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions shimmy/registration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Registers environments within gymnasium for optional modules."""

from __future__ import annotations

from functools import partial
Expand Down Expand Up @@ -68,6 +69,41 @@ def _make_dm_control_generic_env(env, **render_kwargs):
# Register all suite environments
import dm_control.suite

def _register_locomotion_envs():
try:
from dm_control import composer
from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020
except ImportError:
print(
"Warning, registration of `dm_control` locomotion envs has failed due to an ImportError"
)
return

def _make_dm_control_example_locomotion_env(
env_fn: Callable[[np.random.RandomState | None], composer.Environment],
random_state: np.random.RandomState | None = None,
**render_kwargs,
):
return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs)

for locomotion_env, nondeterministic in (
(basic_cmu_2019.cmu_humanoid_run_walls, False),
(basic_cmu_2019.cmu_humanoid_run_gaps, False),
(basic_cmu_2019.cmu_humanoid_go_to_target, False),
(basic_cmu_2019.cmu_humanoid_maze_forage, True),
(basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True),
(basic_rodent_2020.rodent_escape_bowl, False),
(basic_rodent_2020.rodent_run_gaps, False),
(basic_rodent_2020.rodent_maze_forage, True),
(basic_rodent_2020.rodent_two_touch, True),
# (cmu_2020_tracking.cmu_humanoid_tracking, False),
):
register(
f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0",
partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env),
nondeterministic=nondeterministic,
)

def _make_dm_control_suite_env(
domain_name: str,
task_name: str,
Expand Down Expand Up @@ -98,33 +134,7 @@ def _make_dm_control_suite_env(

# Register all example locomotion environments
# Listed in https://github.com/deepmind/dm_control/blob/main/dm_control/locomotion/examples/examples_test.py
from dm_control import composer
from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020

def _make_dm_control_example_locomotion_env(
env_fn: Callable[[np.random.RandomState | None], composer.Environment],
random_state: np.random.RandomState | None = None,
**render_kwargs,
):
return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs)

for locomotion_env, nondeterministic in (
(basic_cmu_2019.cmu_humanoid_run_walls, False),
(basic_cmu_2019.cmu_humanoid_run_gaps, False),
(basic_cmu_2019.cmu_humanoid_go_to_target, False),
(basic_cmu_2019.cmu_humanoid_maze_forage, True),
(basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True),
(basic_rodent_2020.rodent_escape_bowl, False),
(basic_rodent_2020.rodent_run_gaps, False),
(basic_rodent_2020.rodent_maze_forage, True),
(basic_rodent_2020.rodent_two_touch, True),
# (cmu_2020_tracking.cmu_humanoid_tracking, False),
):
register(
f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0",
partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env),
nondeterministic=nondeterministic,
)
_register_locomotion_envs()

# Register all manipulation environments
import dm_control.manipulation
Expand Down

0 comments on commit 095d576

Please sign in to comment.