diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3067c3a8..cbbea960 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: hooks: - id: flake8 args: - - '--per-file-ignores=*/__init__.py:F401' + - "--per-file-ignores=*/__init__.py:F401" - --ignore=E203,W503,E741 - --max-complexity=30 - --max-line-length=456 @@ -64,6 +64,6 @@ repos: language: node pass_filenames: false types: [python] - additional_dependencies: ["pyright"] + additional_dependencies: ["pyright@1.1.347"] args: - --project=pyproject.toml diff --git a/docs/_static/videos/breakable-bottles.gif b/docs/_static/videos/breakable-bottles.gif index 4f901d79..ddca154e 100644 Binary files a/docs/_static/videos/breakable-bottles.gif and b/docs/_static/videos/breakable-bottles.gif differ diff --git a/docs/_static/videos/deep-sea-treasure-concave.gif b/docs/_static/videos/deep-sea-treasure-concave.gif index 4bad0688..a0a76216 100644 Binary files a/docs/_static/videos/deep-sea-treasure-concave.gif and b/docs/_static/videos/deep-sea-treasure-concave.gif differ diff --git a/docs/_static/videos/deep-sea-treasure-mirrored.gif b/docs/_static/videos/deep-sea-treasure-mirrored.gif new file mode 100644 index 00000000..dc5c4994 Binary files /dev/null and b/docs/_static/videos/deep-sea-treasure-mirrored.gif differ diff --git a/docs/_static/videos/deep-sea-treasure.gif b/docs/_static/videos/deep-sea-treasure.gif index 6f5b5d81..b06f7471 100644 Binary files a/docs/_static/videos/deep-sea-treasure.gif and b/docs/_static/videos/deep-sea-treasure.gif differ diff --git a/docs/_static/videos/four-room.gif b/docs/_static/videos/four-room.gif index 24282928..8641a8f3 100644 Binary files a/docs/_static/videos/four-room.gif and b/docs/_static/videos/four-room.gif differ diff --git a/docs/_static/videos/fruit-tree.gif b/docs/_static/videos/fruit-tree.gif new file mode 100644 index 00000000..483af1e2 Binary files /dev/null and b/docs/_static/videos/fruit-tree.gif differ diff --git a/docs/_static/videos/minecart-deterministic.gif b/docs/_static/videos/minecart-deterministic.gif index 3559f6ba..39ad4172 100644 Binary files a/docs/_static/videos/minecart-deterministic.gif and b/docs/_static/videos/minecart-deterministic.gif differ diff --git a/docs/_static/videos/minecart.gif b/docs/_static/videos/minecart.gif index 6242074e..0c3e99dc 100644 Binary files a/docs/_static/videos/minecart.gif and b/docs/_static/videos/minecart.gif differ diff --git a/docs/_static/videos/mo-halfcheetah.gif b/docs/_static/videos/mo-halfcheetah.gif index bf4cf6be..3fe0efc6 100644 Binary files a/docs/_static/videos/mo-halfcheetah.gif and b/docs/_static/videos/mo-halfcheetah.gif differ diff --git a/docs/_static/videos/mo-hopper.gif b/docs/_static/videos/mo-hopper.gif index 402fc20a..8677eecf 100644 Binary files a/docs/_static/videos/mo-hopper.gif and b/docs/_static/videos/mo-hopper.gif differ diff --git a/docs/_static/videos/mo-lunar-lander.gif b/docs/_static/videos/mo-lunar-lander.gif index abb23939..2051d754 100644 Binary files a/docs/_static/videos/mo-lunar-lander.gif and b/docs/_static/videos/mo-lunar-lander.gif differ diff --git a/docs/_static/videos/mo-mountaincar.gif b/docs/_static/videos/mo-mountaincar.gif index d0b7db84..c9a8d4ba 100644 Binary files a/docs/_static/videos/mo-mountaincar.gif and b/docs/_static/videos/mo-mountaincar.gif differ diff --git a/docs/_static/videos/mo-mountaincarcontinuous.gif b/docs/_static/videos/mo-mountaincarcontinuous.gif index 9af3f1e4..81d96abb 100644 Binary files a/docs/_static/videos/mo-mountaincarcontinuous.gif and b/docs/_static/videos/mo-mountaincarcontinuous.gif differ diff --git a/docs/_static/videos/mo-reacher.gif b/docs/_static/videos/mo-reacher.gif index 4f2d878d..cd2d6086 100644 Binary files a/docs/_static/videos/mo-reacher.gif and b/docs/_static/videos/mo-reacher.gif differ diff --git a/docs/_static/videos/mo-supermario.gif b/docs/_static/videos/mo-supermario.gif index b60f379d..f87efc03 100644 Binary files a/docs/_static/videos/mo-supermario.gif and b/docs/_static/videos/mo-supermario.gif differ diff --git a/docs/_static/videos/resource-gathering.gif b/docs/_static/videos/resource-gathering.gif index 2187043b..916630bc 100644 Binary files a/docs/_static/videos/resource-gathering.gif and b/docs/_static/videos/resource-gathering.gif differ diff --git a/docs/_static/videos/water-reservoir.gif b/docs/_static/videos/water-reservoir.gif index c0709c40..5ebbbcdf 100644 Binary files a/docs/_static/videos/water-reservoir.gif and b/docs/_static/videos/water-reservoir.gif differ diff --git a/docs/environments/all-environments.md b/docs/environments/all-environments.md index 4d2f011a..0975216a 100644 --- a/docs/environments/all-environments.md +++ b/docs/environments/all-environments.md @@ -7,25 +7,26 @@ title: "Environments" MO-Gymnasium includes environments taken from the MORL literature, as well as multi-objective version of classical environments, such as Mujoco. -| Env | Obs/Action spaces | Objectives | Description | -|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------|---------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [`deep-sea-treasure-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | -| [`deep-sea-treasure-concave-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-concave/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Vamplew et al. 2010](https://link.springer.com/article/10.1007/s10994-010-5232-5). | -| [`resource-gathering-v0`](https://mo-gymnasium.farama.org/environments/resource-gathering/)
| Discrete / Discrete | `[enemy, gold, gem]` | Agent must collect gold or gem. Enemies have a 10% chance of killing the agent. From [Barret & Narayanan 2008](https://dl.acm.org/doi/10.1145/1390156.1390162). | -| [`fishwood-v0`](https://mo-gymnasium.farama.org/environments/fishwood/)
| Discrete / Discrete | `[fish_amount, wood_amount]` | ESR environment, the agent must collect fish and wood to light a fire and eat. From [Roijers et al. 2018](https://www.researchgate.net/publication/328718263_Multi-objective_Reinforcement_Learning_for_the_Expected_Utility_of_the_Return). | -| [`breakable-bottles-v0`](https://mo-gymnasium.farama.org/environments/breakable-bottles/)
| Discrete (Dictionary) / Discrete | `[time_penalty, bottles_delivered, potential]` | Gridworld with 5 cells. The agents must collect bottles from the source location and deliver to the destination. From [Vamplew et al. 2021](https://www.sciencedirect.com/science/article/pii/S0952197621000336). | -| [`fruit-tree-v0`](https://mo-gymnasium.farama.org/environments/fruit-tree/)
| Discrete / Discrete | `[nutri1, ..., nutri6]` | Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with a value for the nutrients Protein, Carbs, Fats, Vitamins, Minerals and Water. From [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | -| [`water-reservoir-v0`](https://mo-gymnasium.farama.org/environments/water-reservoir/)
| Continuous / Continuous | `[cost_flooding, deficit_water]` | A Water reservoir environment. The agent executes a continuous action, corresponding to the amount of water released by the dam. From [Pianosi et al. 2013](https://iwaponline.com/jh/article/15/2/258/3425/Tree-based-fitted-Q-iteration-for-multi-objective). | -| [`four-room-v0`](https://mo-gymnasium.farama.org/environments/four-room/)
| Discrete / Discrete | `[item1, item2, item3]` | Agent must collect three different types of items in the map and reach the goal. From [Alegre et al. 2022](https://proceedings.mlr.press/v162/alegre22a.html). | -| [`mo-mountaincar-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincar/)
| Continuous / Discrete | `[time_penalty, reverse_penalty, forward_penalty]` | Classic Mountain Car env, but with extra penalties for the forward and reverse actions. From [Vamplew et al. 2011](https://www.researchgate.net/publication/220343783_Empirical_evaluation_methods_for_multiobjective_reinforcement_learning_algorithms). | -| [`mo-mountaincarcontinuous-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincarcontinuous/)
| Continuous / Continuous | `[time_penalty, fuel_consumption_penalty]` | Continuous Mountain Car env, but with penalties for fuel consumption. | -| [`mo-lunar-lander-v2`](https://mo-gymnasium.farama.org/environments/mo-lunar-lander/)
| Continuous / Discrete or Continuous | `[landed, shaped_reward, main_engine_fuel, side_engine_fuel]` | MO version of the `LunarLander-v2` [environment](https://gymnasium.farama.org/environments/box2d/lunar_lander/). Objectives defined similarly as in [Hung et al. 2022](https://openreview.net/forum?id=AwWaBXLIJE). | -| [`minecart-v0`](https://mo-gymnasium.farama.org/environments/minecart/)
| Continuous or Image / Discrete | `[ore1, ore2, fuel]` | Agent must collect two types of ores and minimize fuel consumption. From [Abels et al. 2019](https://arxiv.org/abs/1809.07803v2). | -| [`mo-highway-v0`](https://mo-gymnasium.farama.org/environments/mo-highway/) and `mo-highway-fast-v0`
| Continuous / Discrete | `[speed, right_lane, collision]` | The agent's objective is to reach a high speed while avoiding collisions with neighbouring vehicles and staying on the rightest lane. From [highway-env](https://github.com/eleurent/highway-env). | -| [`mo-supermario-v0`](https://mo-gymnasium.farama.org/environments/mo-supermario/)
| Image / Discrete | `[x_pos, time, death, coin, enemy]` | [:warning: SuperMarioBrosEnv support is limited.] Multi-objective version of [SuperMarioBrosEnv](https://github.com/Kautenja/gym-super-mario-bros). Objectives are defined similarly as in [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | -| [`mo-reacher-v4`](https://mo-gymnasium.farama.org/environments/mo-reacher/)
| Continuous / Discrete | `[target_1, target_2, target_3, target_4]` | Mujoco version of `mo-reacher-v0`, based on `Reacher-v4` [environment](https://gymnasium.farama.org/environments/mujoco/reacher/). | -| [`mo-hopper-v4`](https://mo-gymnasium.farama.org/environments/mo-hopper/)
| Continuous / Continuous | `[velocity, height, energy]` | Multi-objective version of [Hopper-v4](https://gymnasium.farama.org/environments/mujoco/hopper/) env. | -| [`mo-halfcheetah-v4`](https://mo-gymnasium.farama.org/environments/mo-halfcheetah/)
| Continuous / Continuous | `[velocity, energy]` | Multi-objective version of [HalfCheetah-v4](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) env. Similar to [Xu et al. 2020](https://github.com/mit-gfx/PGMORL). | +| Env | Obs/Action spaces | Objectives | Description | +|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------------------|---------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [`deep-sea-treasure-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | +| [`deep-sea-treasure-concave-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-concave/)
| Discrete / Discrete | `[treasure, time_penalty]` | Agent is a submarine that must collect a treasure while taking into account a time penalty. Treasures values taken from [Vamplew et al. 2010](https://link.springer.com/article/10.1007/s10994-010-5232-5). | +| [`deep-sea-treasure-mirrored-v0`](https://mo-gymnasium.farama.org/environments/deep-sea-treasure-mirrored/)
| Discrete / Discrete | `[treasure, time_penalty]` | Harder version of the concave DST [Felten et al. 2022](https://www.scitepress.org/Papers/2022/109891/109891.pdf). | +| [`resource-gathering-v0`](https://mo-gymnasium.farama.org/environments/resource-gathering/)
| Discrete / Discrete | `[enemy, gold, gem]` | Agent must collect gold or gem. Enemies have a 10% chance of killing the agent. From [Barret & Narayanan 2008](https://dl.acm.org/doi/10.1145/1390156.1390162). | +| [`fishwood-v0`](https://mo-gymnasium.farama.org/environments/fishwood/)
| Discrete / Discrete | `[fish_amount, wood_amount]` | ESR environment, the agent must collect fish and wood to light a fire and eat. From [Roijers et al. 2018](https://www.researchgate.net/publication/328718263_Multi-objective_Reinforcement_Learning_for_the_Expected_Utility_of_the_Return). | +| [`breakable-bottles-v0`](https://mo-gymnasium.farama.org/environments/breakable-bottles/)
| Discrete (Dictionary) / Discrete | `[time_penalty, bottles_delivered, potential]` | Gridworld with 5 cells. The agents must collect bottles from the source location and deliver to the destination. From [Vamplew et al. 2021](https://www.sciencedirect.com/science/article/pii/S0952197621000336). | +| [`fruit-tree-v0`](https://mo-gymnasium.farama.org/environments/fruit-tree/)
| Discrete / Discrete | `[nutri1, ..., nutri6]` | Full binary tree of depth d=5,6 or 7. Every leaf contains a fruit with a value for the nutrients Protein, Carbs, Fats, Vitamins, Minerals and Water. From [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | +| [`water-reservoir-v0`](https://mo-gymnasium.farama.org/environments/water-reservoir/)
| Continuous / Continuous | `[cost_flooding, deficit_water]` | A Water reservoir environment. The agent executes a continuous action, corresponding to the amount of water released by the dam. From [Pianosi et al. 2013](https://iwaponline.com/jh/article/15/2/258/3425/Tree-based-fitted-Q-iteration-for-multi-objective). | +| [`four-room-v0`](https://mo-gymnasium.farama.org/environments/four-room/)
| Discrete / Discrete | `[item1, item2, item3]` | Agent must collect three different types of items in the map and reach the goal. From [Alegre et al. 2022](https://proceedings.mlr.press/v162/alegre22a.html). | +| [`mo-mountaincar-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincar/)
| Continuous / Discrete | `[time_penalty, reverse_penalty, forward_penalty]` | Classic Mountain Car env, but with extra penalties for the forward and reverse actions. From [Vamplew et al. 2011](https://www.researchgate.net/publication/220343783_Empirical_evaluation_methods_for_multiobjective_reinforcement_learning_algorithms). | +| [`mo-mountaincarcontinuous-v0`](https://mo-gymnasium.farama.org/environments/mo-mountaincarcontinuous/)
| Continuous / Continuous | `[time_penalty, fuel_consumption_penalty]` | Continuous Mountain Car env, but with penalties for fuel consumption. | +| [`mo-lunar-lander-v2`](https://mo-gymnasium.farama.org/environments/mo-lunar-lander/)
| Continuous / Discrete or Continuous | `[landed, shaped_reward, main_engine_fuel, side_engine_fuel]` | MO version of the `LunarLander-v2` [environment](https://gymnasium.farama.org/environments/box2d/lunar_lander/). Objectives defined similarly as in [Hung et al. 2022](https://openreview.net/forum?id=AwWaBXLIJE). | +| [`minecart-v0`](https://mo-gymnasium.farama.org/environments/minecart/)
| Continuous or Image / Discrete | `[ore1, ore2, fuel]` | Agent must collect two types of ores and minimize fuel consumption. From [Abels et al. 2019](https://arxiv.org/abs/1809.07803v2). | +| [`mo-highway-v0`](https://mo-gymnasium.farama.org/environments/mo-highway/) and `mo-highway-fast-v0`
| Continuous / Discrete | `[speed, right_lane, collision]` | The agent's objective is to reach a high speed while avoiding collisions with neighbouring vehicles and staying on the rightest lane. From [highway-env](https://github.com/eleurent/highway-env). | +| [`mo-supermario-v0`](https://mo-gymnasium.farama.org/environments/mo-supermario/)
| Image / Discrete | `[x_pos, time, death, coin, enemy]` | [:warning: SuperMarioBrosEnv support is limited.] Multi-objective version of [SuperMarioBrosEnv](https://github.com/Kautenja/gym-super-mario-bros). Objectives are defined similarly as in [Yang et al. 2019](https://arxiv.org/pdf/1908.08342.pdf). | +| [`mo-reacher-v4`](https://mo-gymnasium.farama.org/environments/mo-reacher/)
| Continuous / Discrete | `[target_1, target_2, target_3, target_4]` | Mujoco version of `mo-reacher-v0`, based on `Reacher-v4` [environment](https://gymnasium.farama.org/environments/mujoco/reacher/). | +| [`mo-hopper-v4`](https://mo-gymnasium.farama.org/environments/mo-hopper/)
| Continuous / Continuous | `[velocity, height, energy]` | Multi-objective version of [Hopper-v4](https://gymnasium.farama.org/environments/mujoco/hopper/) env. | +| [`mo-halfcheetah-v4`](https://mo-gymnasium.farama.org/environments/mo-halfcheetah/)
| Continuous / Continuous | `[velocity, energy]` | Multi-objective version of [HalfCheetah-v4](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) env. Similar to [Xu et al. 2020](https://github.com/mit-gfx/PGMORL). | ```{toctree} diff --git a/mo_gymnasium/envs/deep_sea_treasure/__init__.py b/mo_gymnasium/envs/deep_sea_treasure/__init__.py index 65799cdc..152d4f54 100644 --- a/mo_gymnasium/envs/deep_sea_treasure/__init__.py +++ b/mo_gymnasium/envs/deep_sea_treasure/__init__.py @@ -1,6 +1,9 @@ from gymnasium.envs.registration import register -from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import CONCAVE_MAP +from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import ( + CONCAVE_MAP, + MIRRORED_MAP, +) register( @@ -15,3 +18,10 @@ max_episode_steps=100, kwargs={"dst_map": CONCAVE_MAP}, ) + +register( + id="deep-sea-treasure-mirrored-v0", + entry_point="mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure:DeepSeaTreasure", + max_episode_steps=100, + kwargs={"dst_map": MIRRORED_MAP}, +) diff --git a/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py b/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py index efd95750..374c15a0 100644 --- a/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py +++ b/mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py @@ -68,6 +68,23 @@ np.array([124.0, -19]), ] +# As in Felten et al. 2022, same PF as concave, just harder map +MIRRORED_MAP = np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, -10, -10, 2.0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, 3.0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, -10, -10, 5.0, 8.0, 16.0, 0, 0, 0, 0], + [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0], + [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0], + [0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 24.0, 50.0, 0, 0], + [0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0], + [0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 74.0, 0], + [0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 124.0], + ] +) + class DeepSeaTreasure(gym.Env, EzPickle): """ @@ -96,7 +113,7 @@ class DeepSeaTreasure(gym.Env, EzPickle): The episode terminates when the agent reaches a treasure. ## Arguments - - dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP).` + - dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP | MIRRORED_MAP).` - float_state: if True, the state is a 2D continuous box with values in [0.0, 1.0] for the x and y coordinates of the submarine. ## Credits @@ -115,8 +132,18 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float # The map of the deep sea treasure (convex version) self.sea_map = dst_map - self._pareto_front = CONVEX_FRONT if np.all(dst_map == DEFAULT_MAP) else CONCAVE_FRONT - assert self.sea_map.shape == DEFAULT_MAP.shape, "The map's shape must be 11x11" + if dst_map.shape[0] == DEFAULT_MAP.shape[0] and dst_map.shape[1] == DEFAULT_MAP.shape[1]: + if np.all(dst_map == DEFAULT_MAP): + self.map_name = "convex" + elif np.all(dst_map == CONCAVE_MAP): + self.map_name = "concave" + else: + raise ValueError("Invalid map") + elif np.all(dst_map == MIRRORED_MAP): + self.map_name = "mirrored" + else: + raise ValueError("Invalid map") + self._pareto_front = CONVEX_FRONT if self.map_name == "convex" else CONCAVE_FRONT self.dir = { 0: np.array([-1, 0], dtype=np.int32), # up @@ -130,7 +157,7 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float if self.float_state: self.observation_space = Box(low=0.0, high=1.0, shape=(2,), dtype=obs_type) else: - self.observation_space = Box(low=0, high=10, shape=(2,), dtype=obs_type) + self.observation_space = Box(low=0, high=len(self.sea_map[0]), shape=(2,), dtype=obs_type) # action space specification: 1 dimension, 0 up, 1 down, 2 left, 3 right self.action_space = Discrete(4) @@ -144,11 +171,15 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float self.current_state = np.array([0, 0], dtype=np.int32) # pygame - self.window_size = (min(64 * self.sea_map.shape[1], 512), min(64 * self.sea_map.shape[0], 512)) + ratio = self.sea_map.shape[1] / self.sea_map.shape[0] + padding = 10 + self.pix_inside = (min(64 * self.sea_map.shape[1], 512) * ratio, min(64 * self.sea_map.shape[0], 512)) + # adding some padding on the sides + self.window_size = (self.pix_inside[0] + 2 * padding, self.pix_inside[1]) # The size of a single grid square in pixels self.pix_square_size = ( - self.window_size[1] // self.sea_map.shape[1] + 1, - self.window_size[0] // self.sea_map.shape[0] + 1, + self.pix_inside[0] // self.sea_map.shape[1] + 1, + self.pix_inside[1] // self.sea_map.shape[0] + 1, # watch out for axis inversions here ) self.window = None self.clock = None @@ -257,7 +288,12 @@ def _get_state(self): def reset(self, seed=None, **kwargs): super().reset(seed=seed) - self.current_state = np.array([0, 0], dtype=np.int32) + if self.map_name == "convex" or self.map_name == "concave": + self.current_state = np.array([0, 0], dtype=np.int32) + elif self.map_name == "mirrored": + self.current_state = np.array([0, 10], dtype=np.int32) + else: + raise ValueError("Invalid map") self.step_count = 0.0 state = self._get_state() if self.render_mode == "human": diff --git a/mo_gymnasium/envs/fruit_tree/assets/agent.png b/mo_gymnasium/envs/fruit_tree/assets/agent.png new file mode 100644 index 00000000..8027fcb1 Binary files /dev/null and b/mo_gymnasium/envs/fruit_tree/assets/agent.png differ diff --git a/mo_gymnasium/envs/fruit_tree/assets/node_blue.png b/mo_gymnasium/envs/fruit_tree/assets/node_blue.png new file mode 100644 index 00000000..17645780 Binary files /dev/null and b/mo_gymnasium/envs/fruit_tree/assets/node_blue.png differ diff --git a/mo_gymnasium/envs/fruit_tree/fruit_tree.py b/mo_gymnasium/envs/fruit_tree/fruit_tree.py index fac03c06..7d978c38 100644 --- a/mo_gymnasium/envs/fruit_tree/fruit_tree.py +++ b/mo_gymnasium/envs/fruit_tree/fruit_tree.py @@ -1,8 +1,10 @@ # Environment from https://github.com/RunzheYang/MORL/blob/master/synthetic/envs/fruit_tree.py -from typing import List +from os import path +from typing import List, Optional import gymnasium as gym import numpy as np +import pygame from gymnasium import spaces from gymnasium.utils import EzPickle @@ -264,16 +266,16 @@ class FruitTreeEnv(gym.Env, EzPickle): The episode terminates when the agent reaches a leaf node. """ - def __init__(self, depth=6): + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} + + def __init__(self, depth=6, render_mode: Optional[str] = None): assert depth in [5, 6, 7], "Depth must be 5, 6 or 7." EzPickle.__init__(self, depth) + self.render_mode = render_mode self.reward_dim = 6 self.tree_depth = depth # zero based depth branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim)) - # fruits = np.random.randn(2**self.tree_depth, self.reward_dim) - # fruits = np.abs(fruits) / np.linalg.norm(fruits, 2, 1, True) - # print(fruits*10) fruits = np.array(FRUITS[str(depth)]) self.tree = np.concatenate([branches, fruits]) @@ -288,9 +290,35 @@ def __init__(self, depth=6): self.current_state = np.array([0, 0], dtype=np.int32) self.terminal = False + # pygame + self.row_height = 20 + self.top_margin = 15 + + # Add margin at the bottom to account for the node rewards + self.window_size = (1200, self.row_height * self.tree_depth + 150) + self.window_padding = 15 # padding on the left and right of the window + self.node_square_size = np.array([10, 10], dtype=np.int32) + self.font_size = 12 + pygame.font.init() + self.font = pygame.font.SysFont(None, self.font_size) + + self.window = None + self.clock = None + self.node_img = None + self.agent_img = None + def get_ind(self, pos): + """Given the pos = current_state = [row_ind, pos_in_row] + return the index of the node in the tree array""" return int(2 ** pos[0] - 1) + pos[1] + def ind_to_state(self, ind): + """Given the index of the node in the tree array return the + current_state = [row_ind, pos_in_row]""" + x = int(np.log2(ind + 1)) + y = ind - 2**x + 1 + return np.array([x, y], dtype=np.int32) + def get_tree_value(self, pos): return np.array(self.tree[self.get_ind(pos)], dtype=np.float32) @@ -325,5 +353,120 @@ def step(self, action): reward = self.get_tree_value(self.current_state) if self.current_state[0] == self.tree_depth: self.terminal = True - return self.current_state.copy(), reward, self.terminal, False, {} + + def get_pos_in_window(self, row, index_in_row): + """Given the row and index_in_row of the node + calculate its position in the window in pixels""" + window_width = self.window_size[0] - 2 * self.window_padding + distance_between_nodes = window_width / (2 ** (row)) + pos_x = self.window_padding + (index_in_row + 0.5) * distance_between_nodes + pos_y = row * self.row_height + return np.array([pos_x, pos_y]) + + def render(self): + if self.render_mode is None: + assert self.spec is not None + gym.logger.warn( + "You are calling render method without specifying render mode." + "You can specify the render_mode at initialization, " + f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")' + ) + return + + if self.clock is None and self.render_mode == "human": + self.clock = pygame.time.Clock() + + if self.window is None: + pygame.init() + + if self.render_mode == "human": + pygame.display.init() + pygame.display.set_caption("Fruit Tree") + self.window = pygame.display.set_mode(self.window_size) + self.clock.tick(self.metadata["render_fps"]) + else: + self.window = pygame.Surface(self.window_size) + + if self.node_img is None: + filename = path.join(path.dirname(__file__), "assets", "node_blue.png") + self.node_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size) + self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False) + + if self.agent_img is None: + filename = path.join(path.dirname(__file__), "assets", "agent.png") + self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size) + + canvas = pygame.Surface(self.window_size) + canvas.fill((255, 255, 255)) # White + + # draw branches + for ind, node in enumerate(self.tree): + row, index_in_row = self.ind_to_state(ind) + node_pos = self.get_pos_in_window(row, index_in_row) + if row < self.tree_depth: + # Get childerns' positions and draw branches + child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row) + child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1) + half_square = self.node_square_size / 2 + pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child1_pos + half_square, 1) + pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child2_pos + half_square, 1) + + for ind, node in enumerate(self.tree): + row, index_in_row = self.ind_to_state(ind) + if (row, index_in_row) == tuple(self.current_state): + img = self.agent_img + font_color = (164, 0, 0) # Red digits for agent node + else: + img = self.node_img + if ind % 2: + font_color = (250, 128, 114) # Green + else: + font_color = (45, 72, 101) # Dark Blue + + node_pos = self.get_pos_in_window(row, index_in_row) + + canvas.blit(img, np.array(node_pos)) + + # Print node values at the bottom of the tree + if row == self.tree_depth: + odd_nodes_values_offset = 0.5 * (ind % 2) + values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node] + for i, val_img in enumerate(values_imgs): + canvas.blit(val_img, node_pos + np.array([-5, (i + 1 + odd_nodes_values_offset) * 1.5 * self.font_size])) + + background = pygame.Surface(self.window_size) + background.fill((255, 255, 255)) # White + background.blit(canvas, (0, self.top_margin)) + + self.window.blit(background, (0, 0)) + + if self.render_mode == "human": + pygame.event.pump() + pygame.display.update() + self.clock.tick(self.metadata["render_fps"]) + elif self.render_mode == "rgb_array": + return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)) + + background = pygame.Surface(self.window_size) + background.fill((255, 255, 255)) # White + + background.blit(canvas, (0, self.top_margin)) + + self.window.blit(background, (0, 0)) + + +if __name__ == "__main__": + import time + + import mo_gymnasium as mo_gym + + env = mo_gym.make("fruit-tree", depth=6, render_mode="human") + env.reset() + while True: + env.render() + obs, r, terminal, truncated, info = env.step(env.action_space.sample()) + if terminal or truncated: + env.render() + time.sleep(2) + env.reset() diff --git a/tests/test_envs.py b/tests/test_envs.py index 4443789c..7e338be4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -64,8 +64,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec): if env_spec.nondeterministic is True: return - env_1 = env_spec.make(disable_env_checker=True) - env_2 = env_spec.make(disable_env_checker=True) + env_1 = mo_gym.make(env_spec.id) + env_2 = mo_gym.make(env_spec.id) env_1 = mo_gym.LinearReward(env_1) env_2 = mo_gym.LinearReward(env_2)