Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Feb 12, 2024
2 parents 47ab9c1 + 2bf89bf commit 3014a87
Show file tree
Hide file tree
Showing 25 changed files with 228 additions and 38 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
hooks:
- id: flake8
args:
- '--per-file-ignores=*/__init__.py:F401'
- "--per-file-ignores=*/__init__.py:F401"
- --ignore=E203,W503,E741
- --max-complexity=30
- --max-line-length=456
Expand Down Expand Up @@ -64,6 +64,6 @@ repos:
language: node
pass_filenames: false
types: [python]
additional_dependencies: ["pyright"]
additional_dependencies: ["pyright@1.1.347"]
args:
- --project=pyproject.toml
Binary file modified docs/_static/videos/breakable-bottles.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/deep-sea-treasure-concave.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/deep-sea-treasure.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/four-room.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/_static/videos/fruit-tree.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/minecart-deterministic.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/minecart.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-halfcheetah.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-hopper.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-lunar-lander.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-mountaincar.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-mountaincarcontinuous.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-reacher.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/mo-supermario.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/resource-gathering.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/videos/water-reservoir.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
39 changes: 20 additions & 19 deletions docs/environments/all-environments.md

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion mo_gymnasium/envs/deep_sea_treasure/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from gymnasium.envs.registration import register

from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import CONCAVE_MAP
from mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure import (
CONCAVE_MAP,
MIRRORED_MAP,
)


register(
Expand All @@ -15,3 +18,10 @@
max_episode_steps=100,
kwargs={"dst_map": CONCAVE_MAP},
)

register(
id="deep-sea-treasure-mirrored-v0",
entry_point="mo_gymnasium.envs.deep_sea_treasure.deep_sea_treasure:DeepSeaTreasure",
max_episode_steps=100,
kwargs={"dst_map": MIRRORED_MAP},
)
52 changes: 44 additions & 8 deletions mo_gymnasium/envs/deep_sea_treasure/deep_sea_treasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@
np.array([124.0, -19]),
]

# As in Felten et al. 2022, same PF as concave, just harder map
MIRRORED_MAP = np.array(
[
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, -10, -10, 2.0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, 3.0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, -10, -10, -10, -10, -10, -10, 5.0, 8.0, 16.0, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0, 0, 0],
[0, 0, 0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 24.0, 50.0, 0, 0],
[0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 0, 0],
[0, 0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 74.0, 0],
[0, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, -10, 124.0],
]
)


class DeepSeaTreasure(gym.Env, EzPickle):
"""
Expand Down Expand Up @@ -96,7 +113,7 @@ class DeepSeaTreasure(gym.Env, EzPickle):
The episode terminates when the agent reaches a treasure.
## Arguments
- dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP).`
- dst_map: the map of the deep sea treasure. Default is the convex map from Yang et al. (2019). To change, use `mo_gymnasium.make("DeepSeaTreasure-v0", dst_map=CONCAVE_MAP | MIRRORED_MAP).`
- float_state: if True, the state is a 2D continuous box with values in [0.0, 1.0] for the x and y coordinates of the submarine.
## Credits
Expand All @@ -115,8 +132,18 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float

# The map of the deep sea treasure (convex version)
self.sea_map = dst_map
self._pareto_front = CONVEX_FRONT if np.all(dst_map == DEFAULT_MAP) else CONCAVE_FRONT
assert self.sea_map.shape == DEFAULT_MAP.shape, "The map's shape must be 11x11"
if dst_map.shape[0] == DEFAULT_MAP.shape[0] and dst_map.shape[1] == DEFAULT_MAP.shape[1]:
if np.all(dst_map == DEFAULT_MAP):
self.map_name = "convex"
elif np.all(dst_map == CONCAVE_MAP):
self.map_name = "concave"
else:
raise ValueError("Invalid map")
elif np.all(dst_map == MIRRORED_MAP):
self.map_name = "mirrored"
else:
raise ValueError("Invalid map")
self._pareto_front = CONVEX_FRONT if self.map_name == "convex" else CONCAVE_FRONT

self.dir = {
0: np.array([-1, 0], dtype=np.int32), # up
Expand All @@ -130,7 +157,7 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
if self.float_state:
self.observation_space = Box(low=0.0, high=1.0, shape=(2,), dtype=obs_type)
else:
self.observation_space = Box(low=0, high=10, shape=(2,), dtype=obs_type)
self.observation_space = Box(low=0, high=len(self.sea_map[0]), shape=(2,), dtype=obs_type)

# action space specification: 1 dimension, 0 up, 1 down, 2 left, 3 right
self.action_space = Discrete(4)
Expand All @@ -144,11 +171,15 @@ def __init__(self, render_mode: Optional[str] = None, dst_map=DEFAULT_MAP, float
self.current_state = np.array([0, 0], dtype=np.int32)

# pygame
self.window_size = (min(64 * self.sea_map.shape[1], 512), min(64 * self.sea_map.shape[0], 512))
ratio = self.sea_map.shape[1] / self.sea_map.shape[0]
padding = 10
self.pix_inside = (min(64 * self.sea_map.shape[1], 512) * ratio, min(64 * self.sea_map.shape[0], 512))
# adding some padding on the sides
self.window_size = (self.pix_inside[0] + 2 * padding, self.pix_inside[1])
# The size of a single grid square in pixels
self.pix_square_size = (
self.window_size[1] // self.sea_map.shape[1] + 1,
self.window_size[0] // self.sea_map.shape[0] + 1,
self.pix_inside[0] // self.sea_map.shape[1] + 1,
self.pix_inside[1] // self.sea_map.shape[0] + 1, # watch out for axis inversions here
)
self.window = None
self.clock = None
Expand Down Expand Up @@ -257,7 +288,12 @@ def _get_state(self):
def reset(self, seed=None, **kwargs):
super().reset(seed=seed)

self.current_state = np.array([0, 0], dtype=np.int32)
if self.map_name == "convex" or self.map_name == "concave":
self.current_state = np.array([0, 0], dtype=np.int32)
elif self.map_name == "mirrored":
self.current_state = np.array([0, 10], dtype=np.int32)
else:
raise ValueError("Invalid map")
self.step_count = 0.0
state = self._get_state()
if self.render_mode == "human":
Expand Down
Binary file added mo_gymnasium/envs/fruit_tree/assets/agent.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added mo_gymnasium/envs/fruit_tree/assets/node_blue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
155 changes: 149 additions & 6 deletions mo_gymnasium/envs/fruit_tree/fruit_tree.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Environment from https://github.com/RunzheYang/MORL/blob/master/synthetic/envs/fruit_tree.py
from typing import List
from os import path
from typing import List, Optional

import gymnasium as gym
import numpy as np
import pygame
from gymnasium import spaces
from gymnasium.utils import EzPickle

Expand Down Expand Up @@ -264,16 +266,16 @@ class FruitTreeEnv(gym.Env, EzPickle):
The episode terminates when the agent reaches a leaf node.
"""

def __init__(self, depth=6):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}

def __init__(self, depth=6, render_mode: Optional[str] = None):
assert depth in [5, 6, 7], "Depth must be 5, 6 or 7."
EzPickle.__init__(self, depth)

self.render_mode = render_mode
self.reward_dim = 6
self.tree_depth = depth # zero based depth
branches = np.zeros((int(2**self.tree_depth - 1), self.reward_dim))
# fruits = np.random.randn(2**self.tree_depth, self.reward_dim)
# fruits = np.abs(fruits) / np.linalg.norm(fruits, 2, 1, True)
# print(fruits*10)
fruits = np.array(FRUITS[str(depth)])
self.tree = np.concatenate([branches, fruits])

Expand All @@ -288,9 +290,35 @@ def __init__(self, depth=6):
self.current_state = np.array([0, 0], dtype=np.int32)
self.terminal = False

# pygame
self.row_height = 20
self.top_margin = 15

# Add margin at the bottom to account for the node rewards
self.window_size = (1200, self.row_height * self.tree_depth + 150)
self.window_padding = 15 # padding on the left and right of the window
self.node_square_size = np.array([10, 10], dtype=np.int32)
self.font_size = 12
pygame.font.init()
self.font = pygame.font.SysFont(None, self.font_size)

self.window = None
self.clock = None
self.node_img = None
self.agent_img = None

def get_ind(self, pos):
"""Given the pos = current_state = [row_ind, pos_in_row]
return the index of the node in the tree array"""
return int(2 ** pos[0] - 1) + pos[1]

def ind_to_state(self, ind):
"""Given the index of the node in the tree array return the
current_state = [row_ind, pos_in_row]"""
x = int(np.log2(ind + 1))
y = ind - 2**x + 1
return np.array([x, y], dtype=np.int32)

def get_tree_value(self, pos):
return np.array(self.tree[self.get_ind(pos)], dtype=np.float32)

Expand Down Expand Up @@ -325,5 +353,120 @@ def step(self, action):
reward = self.get_tree_value(self.current_state)
if self.current_state[0] == self.tree_depth:
self.terminal = True

return self.current_state.copy(), reward, self.terminal, False, {}

def get_pos_in_window(self, row, index_in_row):
"""Given the row and index_in_row of the node
calculate its position in the window in pixels"""
window_width = self.window_size[0] - 2 * self.window_padding
distance_between_nodes = window_width / (2 ** (row))
pos_x = self.window_padding + (index_in_row + 0.5) * distance_between_nodes
pos_y = row * self.row_height
return np.array([pos_x, pos_y])

def render(self):
if self.render_mode is None:
assert self.spec is not None
gym.logger.warn(
"You are calling render method without specifying render mode."
"You can specify the render_mode at initialization, "
f'e.g. mo_gym.make("{self.spec.id}", render_mode="rgb_array")'
)
return

if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

if self.window is None:
pygame.init()

if self.render_mode == "human":
pygame.display.init()
pygame.display.set_caption("Fruit Tree")
self.window = pygame.display.set_mode(self.window_size)
self.clock.tick(self.metadata["render_fps"])
else:
self.window = pygame.Surface(self.window_size)

if self.node_img is None:
filename = path.join(path.dirname(__file__), "assets", "node_blue.png")
self.node_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)
self.node_img = pygame.transform.flip(self.node_img, flip_x=True, flip_y=False)

if self.agent_img is None:
filename = path.join(path.dirname(__file__), "assets", "agent.png")
self.agent_img = pygame.transform.scale(pygame.image.load(filename), self.node_square_size)

canvas = pygame.Surface(self.window_size)
canvas.fill((255, 255, 255)) # White

# draw branches
for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
node_pos = self.get_pos_in_window(row, index_in_row)
if row < self.tree_depth:
# Get childerns' positions and draw branches
child1_pos = self.get_pos_in_window(row + 1, 2 * index_in_row)
child2_pos = self.get_pos_in_window(row + 1, 2 * index_in_row + 1)
half_square = self.node_square_size / 2
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child1_pos + half_square, 1)
pygame.draw.line(canvas, (90, 82, 85), node_pos + half_square, child2_pos + half_square, 1)

for ind, node in enumerate(self.tree):
row, index_in_row = self.ind_to_state(ind)
if (row, index_in_row) == tuple(self.current_state):
img = self.agent_img
font_color = (164, 0, 0) # Red digits for agent node
else:
img = self.node_img
if ind % 2:
font_color = (250, 128, 114) # Green
else:
font_color = (45, 72, 101) # Dark Blue

node_pos = self.get_pos_in_window(row, index_in_row)

canvas.blit(img, np.array(node_pos))

# Print node values at the bottom of the tree
if row == self.tree_depth:
odd_nodes_values_offset = 0.5 * (ind % 2)
values_imgs = [self.font.render(f"{val:.2f}", True, font_color) for val in node]
for i, val_img in enumerate(values_imgs):
canvas.blit(val_img, node_pos + np.array([-5, (i + 1 + odd_nodes_values_offset) * 1.5 * self.font_size]))

background = pygame.Surface(self.window_size)
background.fill((255, 255, 255)) # White
background.blit(canvas, (0, self.top_margin))

self.window.blit(background, (0, 0))

if self.render_mode == "human":
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
elif self.render_mode == "rgb_array":
return np.transpose(np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2))

background = pygame.Surface(self.window_size)
background.fill((255, 255, 255)) # White

background.blit(canvas, (0, self.top_margin))

self.window.blit(background, (0, 0))


if __name__ == "__main__":
import time

import mo_gymnasium as mo_gym

env = mo_gym.make("fruit-tree", depth=6, render_mode="human")
env.reset()
while True:
env.render()
obs, r, terminal, truncated, info = env.step(env.action_space.sample())
if terminal or truncated:
env.render()
time.sleep(2)
env.reset()
4 changes: 2 additions & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_env_determinism_rollout(env_spec: EnvSpec):
if env_spec.nondeterministic is True:
return

env_1 = env_spec.make(disable_env_checker=True)
env_2 = env_spec.make(disable_env_checker=True)
env_1 = mo_gym.make(env_spec.id)
env_2 = mo_gym.make(env_spec.id)
env_1 = mo_gym.LinearReward(env_1)
env_2 = mo_gym.LinearReward(env_2)

Expand Down

0 comments on commit 3014a87

Please sign in to comment.