Skip to content

Commit

Permalink
Fix minecart rgb observation
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasAlegre committed Feb 13, 2024
1 parent 3014a87 commit d6b4d16
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
1 change: 1 addition & 0 deletions mo_gymnasium/envs/minecart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
id="minecart-rgb-v0",
entry_point="mo_gymnasium.envs.minecart.minecart:Minecart",
kwargs={"image_observation": True},
nondeterministic=True, # This is a nondeterministic environment due to the random placement of the mines
max_episode_steps=1000,
)

Expand Down
63 changes: 41 additions & 22 deletions mo_gymnasium/envs/minecart/minecart.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(

self.render_mode = render_mode
self.screen = None
self.canvas = None
self.clock = None
self.last_render_mode_used = None
self.config = config
self.frame_skip = frame_skip
Expand Down Expand Up @@ -422,9 +424,17 @@ def initialize_mines(self):
self.mine_rects = []
for mine in self.mines:
mine_sprite = pygame.sprite.Sprite()
mine_sprite.image = pygame.transform.rotozoom(
pygame.image.load(MINE_IMG), mine.rotation, MINE_SCALE
).convert_alpha()
# mine_sprite.image = pygame.transform.rotozoom(
# pygame.image.load(MINE_IMG), mine.rotation, MINE_SCALE,
# ).convert_alpha()
mine_sprite.image = pygame.image.load(MINE_IMG)
mine_sprite.image = pygame.transform.scale(
mine_sprite.image,
(int(mine_sprite.image.get_width() * MINE_SCALE), int(mine_sprite.image.get_height() * MINE_SCALE)),
)
mine_sprite.image = pygame.transform.rotate(mine_sprite.image, mine.rotation)
if self.render_mode == "human":
mine_sprite.image = mine_sprite.image.convert_alpha()
self.mine_sprites.add(mine_sprite)
mine_sprite.rect = mine_sprite.image.get_rect()
mine_sprite.rect.centerx = (mine.pos[0] * (1 - 2 * MARGIN)) * WIDTH + MARGIN * WIDTH
Expand Down Expand Up @@ -517,7 +527,7 @@ def get_pixels(self, update=True):
np.array -- array of pixels, with shape (width, height, channels)
"""
if update:
self.pixels = pygame.surfarray.array3d(self.screen)
self.pixels = np.transpose(np.array(pygame.surfarray.pixels3d(self.canvas)), axes=(1, 0, 2))

return self.pixels

Expand Down Expand Up @@ -564,7 +574,7 @@ def reset(self, seed=None, **kwargs):
"""
super().reset(seed=seed)

if self.screen is None and self.image_observation:
if self.canvas is None and self.image_observation:
self.render() # init pygame

if self.image_observation:
Expand All @@ -590,52 +600,61 @@ def __str__(self):
return string

def render(self):
if self.screen is None or self.last_render_mode_used != self.render_mode:
if self.canvas is None or self.last_render_mode_used != self.render_mode:
self.last_render_mode_used = self.render_mode
pygame.init()
self.screen = pygame.display.set_mode(
(WIDTH, HEIGHT),
flags=pygame.HIDDEN if self.render_mode == "rgb_array" else 0,
)
self.clock = pygame.time.Clock()
self.canvas = pygame.Surface((WIDTH, HEIGHT))
if self.render_mode == "human":
pygame.display.init()
self.screen = pygame.display.set_mode(
(WIDTH, HEIGHT),
)

if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()

self.initialize_mines()

self.cart_sprite = pygame.sprite.Sprite()
self.cart_sprites = pygame.sprite.Group()
self.cart_sprites.add(self.cart_sprite)
self.cart_image = pygame.transform.rotozoom(pygame.image.load(CART_IMG).convert_alpha(), 0, CART_SCALE)
self.cart_image = pygame.image.load(CART_IMG)
if self.render_mode == "human":
self.cart_image = self.cart_image.convert_alpha()
self.cart_image = pygame.transform.scale(
self.cart_image,
(int(self.cart_image.get_width() * CART_SCALE), int(self.cart_image.get_height() * CART_SCALE)),
)

if not self.image_observation:
self.render_pygame() # if the obs is not an image, then step would not have rendered the screen

if self.render_mode == "human":
self.clock.tick(FPS)
self.screen.blit(self.canvas, self.canvas.get_rect())
pygame.event.pump()
pygame.display.update()
self.clock.tick(FPS)
elif self.render_mode == "rgb_array":
string_image = pygame.image.tostring(self.screen, "RGB")
temp_surf = pygame.image.fromstring(string_image, (WIDTH, HEIGHT), "RGB")
tmp_arr = pygame.surfarray.array3d(temp_surf)
return tmp_arr
return np.transpose(np.array(pygame.surfarray.pixels3d(self.canvas)), axes=(1, 0, 2))

def render_pygame(self):
pygame.event.get()

self.mine_sprites.update()

# Clear canvas
self.screen.fill(GRAY)
self.canvas.fill(GRAY)

# Draw Home
pygame.draw.circle(
self.screen,
self.canvas,
RED,
(int(WIDTH * HOME_X), int(HEIGHT * HOME_Y)),
int(WIDTH / 3 * BASE_SCALE),
)

# Draw Mines
self.mine_sprites.draw(self.screen)
self.mine_sprites.draw(self.canvas)

# Draw cart
self.cart_sprite.image = rot_center(self.cart_image, -self.cart.angle).copy()
Expand All @@ -647,7 +666,7 @@ def render_pygame(self):

self.cart_sprites.update()

self.cart_sprites.draw(self.screen)
self.cart_sprites.draw(self.canvas)

# Draw cart content
width = self.cart_sprite.rect.width / (2 * self.ore_cnt)
Expand All @@ -659,7 +678,7 @@ def render_pygame(self):

if rect_height >= 1:
pygame.draw.rect(
self.screen,
self.canvas,
self.ore_colors[i],
(
self.cart_sprite.rect.left + offset + i * (width + 1),
Expand Down

0 comments on commit d6b4d16

Please sign in to comment.