Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Super suit migration example #1091

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
95 changes: 95 additions & 0 deletions pettingzoo/utils/wrappers/supersuit/basic_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from reward_lambda import reward_lambda_v0, AecRewardLambda
from observation_lambda import observation_lambda_v0, AecObservationLambda
from utils.basic_transforms import color_reduction
from typing import Literal, Any
from types import ModuleType
from pettingzoo import AECEnv, ParallelEnv
from gymnasium.spaces import Space
import numpy as np


def basic_obs_wrapper(env: AECEnv | ParallelEnv, module: ModuleType, param: Any) -> AecObservationLambda:
"""
Wrap an environment to modify its observation space and observations using a specified module and parameter.

This function takes an environment, a module, and a parameter, and creates a new environment with an observation
space and observations modified based on the provided module and parameter.

Parameters:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could have the docstrings be more similar to the syntax in conversions.py and other wrappers in PettingZoo that would be good. Also I don't know the wrapper super well but I think my_module and my_param may not be the correct names. In general using GPT to write this stuff is pretty risky because it could very well just be completely made up and I don't have the time to look through all of the specifics to ensure it's correct. If you can do so yourself and double check then that's great but I'm hesitant to include too much details because it could be incorrect. Look elsewhere throughout the repo to see how the formatting is done for the docstrings.

And for the example format, see chess.py as we have an example using the >>> format which gets tested under the doctests with pytest.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks for the advice, I'll fix the docstring format and write them from scratch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, definitely not use the output of CGPT directly for docstrings.

- env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped.
- module: The module responsible for modifying the observation space and observations.
- param: The parameter used to modify the observation space and observations.

Returns:
- AecObservationLambda: A wrapped environment that applies the observation space and observation modifications. #TODO fix this line

Example:
```python
modified_env = basic_obs_wrapper(original_env, my_module, my_param)
```
In the above example, `modified_env` is a new environment that has its observation space and observations modified
according to the `my_module` and `my_param`.
"""

def change_space(space: Space): # Box?
module.check_param(space, param)
space = module.change_obs_space(space, param)
return space

def change_obs(obs: np.ndarray, obs_space: Space): # not sure about ndarray
return module.change_observation(obs, obs_space, param)

return observation_lambda_v0(env, change_obs, change_space)


def color_reduction_v0(env: AECEnv | ParallelEnv, mode: Literal["full", "R", "G", "B"] = "full") -> AecObservationLambda:
"""
Wrap an environment to perform color reduction on its observations.

This function takes an environment and an optional mode to specify the color reduction technique. It then creates
a new environment that performs color reduction on the observations based on the specified mode.

Parameters:
- env (Generic[AgentID, ObsType, ActionType]): The environment to be wrapped.
- mode (Union[str, color_reduction.COLOR_RED_LIST], optional): The color reduction mode to apply (default is "full").
Valid modes are defined in the color_reduction module.

Returns:
- AecObservationLambda: A wrapped environment that applies color reduction to its observations. #TODO fix this line

Example:
```python
reduced_color_env = color_reduction_v0(original_env, mode="grayscale")
```
In the above example, `reduced_color_env` is a new environment that performs grayscale color reduction on its
observations.
"""

return basic_obs_wrapper(env, color_reduction, mode)


def clip_reward_v0(env: AECEnv | ParallelEnv, lower_bound: float = -1, upper_bound: float = 1) -> AecRewardLambda:
"""
Clip rewards in an environment using the specified lower and upper bounds.

This function applies a reward clipping transformation to an environment's rewards. It takes an environment and
two optional bounds: `lower_bound` and `upper_bound`. Any reward in the environment that falls below the
`lower_bound` will be set to `lower_bound`, and any reward that exceeds the `upper_bound` will be set to
`upper_bound`. Rewards within the specified range are left unchanged.

Parameters:
- env (Generic[AgentID, ObsType, ActionType]): The environment on which to apply the reward clipping.
- lower_bound (float, optional): The lower bound for clipping rewards (default is -1).
- upper_bound (float, optional): The upper bound for clipping rewards (default is 1).

Returns:
- AecRewardLambda: A reward transformation function that applies the specified reward clipping when called. #TODO fix this line

Example:
```python
clipped_env = clip_reward_v0(my_environment, lower_bound=-0.5, upper_bound=0.5)
```
In the above example, the rewards in `my_environment` will be clipped to the range [-0.5, 0.5].
"""

return reward_lambda_v0(env, lambda rew: max(min(rew, upper_bound), lower_bound))
140 changes: 140 additions & 0 deletions pettingzoo/utils/wrappers/supersuit/observation_lambda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import functools
import numpy as np
from gymnasium.spaces import Box, Discrete
from utils.base_aec_wrapper import BaseWrapper
from typing import Callable
from pettingzoo import AECEnv, ParallelEnv
from pettingzoo.utils.env import ActionType, AgentID


class AecObservationLambda(BaseWrapper):
"""
A wrapper for AEC environments that allows the modification of observation spaces and observations.

Args:
env (AECEnv | ParallelEnv): The environment to be wrapped.
change_observation_fn (Callable): A function that modifies observations.
change_obs_space_fn (Callable, optional): A function that modifies observation spaces. Default is None.

Raises:
AssertionError: If `change_observation_fn` is not callable, or if `change_obs_space_fn` is provided and is not callable.

Note:
- The `change_observation_fn` should be a function that accepts observation data and optionally the observation space and agent ID as arguments and returns a modified observation.
- The `change_obs_space_fn` should be a function that accepts an old observation space and optionally the agent ID as arguments and returns a modified observation space.

Attributes:
change_observation_fn (Callable): The function used to modify observations.
change_obs_space_fn (Callable, optional): The function used to modify observation spaces.

Methods:
_modify_action(agent: str, action: Discrete) -> Discrete:
Modify the action.

_check_wrapper_params() -> None:
Check wrapper parameters for consistency.

observation_space(agent: str) -> Box:
Get the modified observation space for a specific agent.

_modify_observation(agent: str, observation: Box) -> Box:
Modify the observation.

"""
def __init__(self, env: AECEnv | ParallelEnv, change_observation_fn: Callable, change_obs_space_fn: Callable = None):
assert callable(
change_observation_fn
), "change_observation_fn needs to be a function. It is {}".format(
change_observation_fn
)
assert change_obs_space_fn is None or callable(
change_obs_space_fn
), "change_obs_space_fn needs to be a function. It is {}".format(
change_obs_space_fn
)

self.change_observation_fn = change_observation_fn
self.change_obs_space_fn = change_obs_space_fn

super().__init__(env)

if hasattr(self, "possible_agents"):
for agent in self.possible_agents:
# call any validation logic in this function
self.observation_space(agent)

def _modify_action(self, agent: AgentID, action: ActionType) -> ActionType:
"""
Modify the action.

Args:
agent (str): The agent for which to modify the action.
action (Discrete): The original action.

Returns:
Discrete: The modified action.
"""
return action

def _check_wrapper_params(self) -> None:
"""
Check wrapper parameters for consistency.

Raises:
AssertionError: If the provided parameters are inconsistent.
"""
if self.change_obs_space_fn is None and hasattr(self, "possible_agents"):
for agent in self.possible_agents:
assert isinstance(
self.observation_space(agent), Box
), "the observation_lambda_wrapper only allows the change_obs_space_fn argument to be optional for Box observation spaces"

@functools.lru_cache(maxsize=None)
def observation_space(self, agent: AgentID) -> Box:
"""
Get the modified observation space for a specific agent.

Args:
agent (str): The agent for which to retrieve the observation space.

Returns:
Box: The modified observation space.
"""
if self.change_obs_space_fn is None:
space = self.env.observation_space(agent)
try:
trans_low = self.change_observation_fn(space.low, space, agent)
trans_high = self.change_observation_fn(space.high, space, agent)
except TypeError:
trans_low = self.change_observation_fn(space.low, space)
trans_high = self.change_observation_fn(space.high, space)
new_low = np.minimum(trans_low, trans_high)
new_high = np.maximum(trans_low, trans_high)

return Box(low=new_low, high=new_high, dtype=new_low.dtype)
else:
old_obs_space = self.env.observation_space(agent)
try:
return self.change_obs_space_fn(old_obs_space, agent)
except TypeError:
return self.change_obs_space_fn(old_obs_space)

def _modify_observation(self, agent: AgentID, observation: Box) -> Box:
"""
Modify the observation.

Args:
agent (str): The agent for which to modify the observation.
observation (Box): The original observation.

Returns:
Box: The modified observation.
"""
old_obs_space = self.env.observation_space(agent)
try:
return self.change_observation_fn(observation, old_obs_space, agent)
except TypeError:
return self.change_observation_fn(observation, old_obs_space)


observation_lambda_v0 = AecObservationLambda
92 changes: 92 additions & 0 deletions pettingzoo/utils/wrappers/supersuit/reward_lambda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from utils.base_aec_wrapper import PettingzooWrap
from utils.make_defaultdict import make_defaultdict
from typing import Callable
from pettingzoo import AECEnv, ParallelEnv
from pettingzoo.utils.env import ActionType


class AecRewardLambda(PettingzooWrap):
"""
A wrapper for AEC environments that allows the modification of rewards.

Args:
env (AECEnv | ParallelEnv): The environment to be wrapped.
change_reward_fn (Callable): A function that modifies rewards.

Raises:
AssertionError: If `change_reward_fn` is not callable.

Attributes:
_change_reward_fn (Callable): The function used to modify rewards.

Methods:
reset(seed: int = None, options: dict = None) -> None:
Reset the environment, applying the reward modification to initial rewards.

step(action: ActionType) -> None:
Take a step in the environment, applying the reward modification to the received rewards.

"""
def __init__(self, env: AECEnv | ParallelEnv, change_reward_fn: Callable):
assert callable(
change_reward_fn
), f"change_reward_fn needs to be a function. It is {change_reward_fn}"
self._change_reward_fn = change_reward_fn

super().__init__(env)

def _check_wrapper_params(self) -> None:
"""
Check wrapper parameters for consistency.

This method is currently empty and does not perform any checks.
"""
pass

def _modify_spaces(self) -> None:
"""
Modify the spaces of the wrapped environment.

This method is currently empty and does not modify the spaces.
"""
pass

def reset(self, seed: int = None, options: dict = None) -> None:
"""
Reset the environment, applying the reward modification to initial rewards.

Args:
seed (int, optional): A seed for environment randomization. Default is None.
options (dict, optional): Additional options for environment initialization. Default is None.
"""
super().reset(seed=seed, options=options)
self.rewards = {
agent: self._change_reward_fn(reward)
for agent, reward in self.rewards.items()
}
self.__cumulative_rewards = make_defaultdict({a: 0 for a in self.agents})
self._accumulate_rewards()

def step(self, action: ActionType) -> None:
"""
Take a step in the environment, applying the reward modification to the received rewards.

Args:
action (ActionType): The action to be taken in the environment.
"""
agent = self.env.agent_selection
super().step(action)
self.rewards = {
agent: self._change_reward_fn(reward)
for agent, reward in self.rewards.items()
}
self.__cumulative_rewards[agent] = 0
self._cumulative_rewards = self.__cumulative_rewards
self._accumulate_rewards()


reward_lambda_v0 = AecRewardLambda
""" example:
reward_lambda_v0 = WrapperChooser(
aec_wrapper=AecRewardLambda, par_wrapper=ParRewardLambda
)"""
Empty file.
57 changes: 57 additions & 0 deletions pettingzoo/utils/wrappers/supersuit/test/dummy_aec_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pettingzoo import AECEnv
from pettingzoo.utils.agent_selector import agent_selector


class DummyEnv(AECEnv):
metadata = {"render_modes": ["human"], "is_parallelizable": True}

def __init__(self, observations, observation_spaces, action_spaces):
super().__init__()
self._observations = observations
self._observation_spaces = observation_spaces

self.agents = sorted([x for x in observation_spaces.keys()])
self.possible_agents = self.agents[:]
self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()
self._action_spaces = action_spaces

self.steps = 0

def observation_space(self, agent):
return self._observation_spaces[agent]

def action_space(self, agent):
return self._action_spaces[agent]

def observe(self, agent):
return self._observations[agent]

def step(self, action, observe=True):
if (
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
return self._was_dead_step(action)
self._cumulative_rewards[self.agent_selection] = 0
self.agent_selection = self._agent_selector.next()
self.steps += 1
if self.steps >= 5 * len(self.agents):
self.truncations = {a: True for a in self.agents}

self._accumulate_rewards()
self._deads_step_first()

def reset(self, seed=None, options=None):
self.agents = self.possible_agents[:]
self._agent_selector = agent_selector(self.agents)
self.agent_selection = self._agent_selector.reset()
self.rewards = {a: 1 for a in self.agents}
self._cumulative_rewards = {a: 0 for a in self.agents}
self.terminations = {a: False for a in self.agents}
self.truncations = {a: False for a in self.agents}
self.infos = {a: {} for a in self.agents}
self.steps = 0

def close(self):
pass
Loading
Loading