Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add __setattr__ to wrappers. Fixes #1176 #1180

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion pettingzoo/utils/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,49 @@ class BaseWrapper(AECEnv[AgentID, ObsType, ActionType]):
All AECEnv wrappers should inherit from this base class
"""

# This is a list of object variables (as strings), used by THIS wrapper,
# which should be stored by the wrapper object and not by the underlying
# environment. They are used to store information that the wrapper needs
# to behave correctly. The list is used by __setattr__() to determine where
# to store variables. It is very important that this list is correct to
# prevent confusing bugs.
# Wrappers inheriting from this class should include their own _local_vars
# list with object variables used by that class. Note that 'env' is hardcoded
# as part of the __setattr__ function so should not be included.
_local_vars = []

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
super().__init__()
self.env = env

def __getattr__(self, name: str) -> Any:
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_") and name != "_cumulative_rewards":
if name.startswith("_") and name not in [
"_cumulative_rewards",
"_skip_agent_selection",
]:
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

def __setattr__(self, name: str, value: Any) -> None:
"""Set attribute ``name`` if it is this class's value, otherwise send to env."""
# these are the attributes that can be set on this wrapper directly
if name == "env" or name in self._local_vars:
self.__dict__[name] = value
else:
# If this is being raised by your wrapper while you are trying to access
# a variable that is owned by the wrapper and NOT part of the env, you
# may have forgotten to add the variable to the _local_vars list.
if name.startswith("_") and name not in [
"_cumulative_rewards",
"_skip_agent_selection",
]:
raise AttributeError(
f"setting private attribute '{name}' is prohibited"
)
# send to the underlying environment to handle
setattr(self.__dict__["env"], name, value)

@property
def unwrapped(self) -> AECEnv:
return self.env.unwrapped
Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/multi_episode_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class MultiEpisodeEnv(BaseWrapper):
The result of this wrapper is that the environment is no longer Markovian around the environment reset.
"""

_local_vars = ["_num_episodes", "_episodes_elapsed", "_seed", "_options"]

def __init__(self, env: AECEnv, num_episodes: int):
"""__init__.

Expand Down
4 changes: 3 additions & 1 deletion pettingzoo/utils/wrappers/order_enforcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
* warn on calling step after environment is terminated or truncated
"""

_local_vars = ["_has_reset", "_has_rendered", "_has_updated"]

def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]):
assert isinstance(
env, AECEnv
), "OrderEnforcingWrapper is only compatible with AEC environments"
super().__init__(env)
self._has_reset = False
self._has_rendered = False
self._has_updated = False
super().__init__(env)

def __getattr__(self, value: str) -> Any:
"""Raises an error message when data is gotten from the env.
Expand Down
2 changes: 2 additions & 0 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class TerminateIllegalWrapper(BaseWrapper[AgentID, ObsType, ActionType]):
illegal_reward: number that is the value of the player making an illegal move.
"""

_local_vars = ["_prev_obs", "_prev_info", "_terminated", "_illegal_value"]

def __init__(
self, env: AECEnv[AgentID, ObsType, ActionType], illegal_reward: float
):
Expand Down
138 changes: 137 additions & 1 deletion test/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv
from pettingzoo.utils.env import AECEnv
from pettingzoo.utils.wrappers import (
BaseWrapper,
MultiEpisodeEnv,
MultiEpisodeParallelEnv,
)


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
Expand Down Expand Up @@ -67,3 +72,134 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."


class FakeEnv(AECEnv):
"""Fake environment used by the getattr and setattr tests."""

def __init__(self):
self.public_value: int = 123
self._private_value: int = 456
self.agents = ["a1, a2"]
self.terminations = {agent: True for agent in self.agents}
self.agent_selection = self.agents[0]
self._name = "env" # should never be used

def compare_private(self, value: int) -> bool:
"""Return comparison of value with _private_value."""
return self._private_value == value


class FakeWrapper(BaseWrapper):
"""Fake wrapper used by the getattr and setattr tests."""

# these variables should be settable
_local_vars = ["wrapper_variable", "_private_wrapper_variable"]

def __init__(self, env: FakeEnv):
super().__init__(env)
# bypass __setattr__ so we have a private variable that is not in
# the _local_vars list. We should be able to access this.
self.__dict__["_name"] = "wrapper"


def test_wrapper_getattr() -> None:
"""Test that the base wrapper's __getattr__ works correctly.

Public variables of the env can be accessed from the wrapper.
Private variables cannot and will raise an AttributeError.
"""
wrapped = FakeWrapper(FakeEnv())

# Public values: fall through the the base env
expected_public = wrapped.env.public_value # can access directly from env
assert (
wrapped.public_value == expected_public
), "Wrapper can't access public env value"

# Private values: trying to access should trigger an AttributeError
expected_private = wrapped.env._private_value # can access directly from env
with pytest.raises(AttributeError):
result = wrapped._private_value == expected_private

# Meanwhile, calling an env function that does the same thing
# should be fine because the the function is delegated to the env.
result = wrapped.compare_private(expected_private)

# Wrapper should not set any default value when trying to access a variable
# that is not defined in the env or wrapper. It should trigger an AttributeError
with pytest.raises(AttributeError):
result = wrapped.nonexistant_value

# However, should be able to intentionally assign a default value when
# using getattr, even with a private variable.
# Note: this works because the attempt to access _private_value
# raises a new AttributeError from __getattr__ that causes getattr
# to return the given default value.
default = wrapped.env._private_value + 1 # ensure default is different
result = getattr(wrapped, "_private_value", default)
assert result == default, "Default value not set correctly"

# Should be able to get any private variables owned by the wrapper,
# even if not defined in _local_vars.
# Note: This is not a design choice, it's a consequence the implementation.
# FakeWrapper has _name defined on itself, but not listed in _local_vars.
assert wrapped._name == "wrapper"


def test_wrapper_setattr() -> None:
"""Test that wrapper's setattr works properly.

It should pass everything that isn't in _local_vars through to the
base environment. Everything in _local vars should be stored in the
wrapper object and not be part of the base environment.
"""
wrapped = FakeWrapper(FakeEnv())

# Having the wrapper directly set an env's public variable should:
# 1) change the value in the env and 2) not set it in the wrapper.
target_value = wrapped.public_value + 1 # ensure new value is different
wrapped.public_value = target_value
assert (
wrapped.env.public_value == target_value
), "Wrapper didn't correctly set env value"
assert "public_value" not in wrapped.__dict__, "Wrapper set value in wrong place"

# Setting env's private value should only be allowed by the env.
# Trying to directly do so from the wrapper should raise an AttributeError
with pytest.raises(AttributeError):
wrapped._private_value = target_value

# Should work normally when accessed from the env
wrapped.env._private_value = target_value

# AECEnv._deads_step_first() currently sets _skip_agent_selection and
# agent_selection. These should both be dispatched to the env, not set
# on the wrapper.
wrapped._deads_step_first()
assert "_skip_agent_selection" in wrapped.env.__dict__, "Value not set on env"
assert "agent_selection" in wrapped.env.__dict__, "Value not set on env"
assert (
"_skip_agent_selection" not in wrapped.__dict__
), "Wrapper set value in wrong place"
assert "agent_selection" not in wrapped.__dict__, "Wrapper set value in wrong place"

# All values in _local_vars that are set should go to the wrapper and
# not the env, regardless of whether they are private or not
for name in wrapped._local_vars:
# should not be in either before being set
assert (
name not in wrapped.__dict__
), "test logic failure: variable should not be set"
assert (
name not in wrapped.env.__dict__
), "test logic failure: variable should not be set"
setattr(wrapped, name, 1)
assert name in wrapped.__dict__, "local wrapper value not set"
assert name not in wrapped.env.__dict__, "local wrapper value set on env"

# Not able to set any private variables, even if owned by the
# wrapper, unless they are listed in _local_vars.
# FakeWrapper has _name defined on itself, but not listed in _local_vars.
with pytest.raises(AttributeError):
wrapped._name = "changed wrapper"
Loading