diff --git a/shimmy/registration.py b/shimmy/registration.py index 492ddbd4..4501000f 100644 --- a/shimmy/registration.py +++ b/shimmy/registration.py @@ -1,4 +1,5 @@ """Registers environments within gymnasium for optional modules.""" + from __future__ import annotations from functools import partial @@ -68,6 +69,41 @@ def _make_dm_control_generic_env(env, **render_kwargs): # Register all suite environments import dm_control.suite + def _register_locomotion_envs(): + try: + from dm_control import composer + from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020 + except ImportError: + print( + "Warning, registration of `dm_control` locomotion envs has failed due to an ImportError" + ) + return + + def _make_dm_control_example_locomotion_env( + env_fn: Callable[[np.random.RandomState | None], composer.Environment], + random_state: np.random.RandomState | None = None, + **render_kwargs, + ): + return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs) + + for locomotion_env, nondeterministic in ( + (basic_cmu_2019.cmu_humanoid_run_walls, False), + (basic_cmu_2019.cmu_humanoid_run_gaps, False), + (basic_cmu_2019.cmu_humanoid_go_to_target, False), + (basic_cmu_2019.cmu_humanoid_maze_forage, True), + (basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True), + (basic_rodent_2020.rodent_escape_bowl, False), + (basic_rodent_2020.rodent_run_gaps, False), + (basic_rodent_2020.rodent_maze_forage, True), + (basic_rodent_2020.rodent_two_touch, True), + # (cmu_2020_tracking.cmu_humanoid_tracking, False), + ): + register( + f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0", + partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env), + nondeterministic=nondeterministic, + ) + def _make_dm_control_suite_env( domain_name: str, task_name: str, @@ -98,33 +134,7 @@ def _make_dm_control_suite_env( # Register all example locomotion environments # Listed in https://github.com/deepmind/dm_control/blob/main/dm_control/locomotion/examples/examples_test.py - from dm_control import composer - from dm_control.locomotion.examples import basic_cmu_2019, basic_rodent_2020 - - def _make_dm_control_example_locomotion_env( - env_fn: Callable[[np.random.RandomState | None], composer.Environment], - random_state: np.random.RandomState | None = None, - **render_kwargs, - ): - return DmControlCompatibilityV0(env_fn(random_state), **render_kwargs) - - for locomotion_env, nondeterministic in ( - (basic_cmu_2019.cmu_humanoid_run_walls, False), - (basic_cmu_2019.cmu_humanoid_run_gaps, False), - (basic_cmu_2019.cmu_humanoid_go_to_target, False), - (basic_cmu_2019.cmu_humanoid_maze_forage, True), - (basic_cmu_2019.cmu_humanoid_heterogeneous_forage, True), - (basic_rodent_2020.rodent_escape_bowl, False), - (basic_rodent_2020.rodent_run_gaps, False), - (basic_rodent_2020.rodent_maze_forage, True), - (basic_rodent_2020.rodent_two_touch, True), - # (cmu_2020_tracking.cmu_humanoid_tracking, False), - ): - register( - f"dm_control/{locomotion_env.__name__.title().replace('_', '')}-v0", - partial(_make_dm_control_example_locomotion_env, env_fn=locomotion_env), - nondeterministic=nondeterministic, - ) + _register_locomotion_envs() # Register all manipulation environments import dm_control.manipulation