From c082b23e69a3289b529c69cdb71be4b303d8fa34 Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Sat, 10 Feb 2024 15:18:54 +0000 Subject: [PATCH 1/2] Add __setattr__ to wrappers. Fixes #1176 Adding this ensures that variables get set to the appropriate location. By default, any public value set by the wrapper is sent to the env to be set there instead of on the wrapper. If a variable is meant to be set by the wrapper, it should be listed in the _local_vars class variable of the wrapper. This is not ideal, but seems to be the most reasonable design. An example of needing to specify which vars to keep locally is here: https://python-patterns.guide/gang-of-four/decorator-pattern/#implementing-dynamic-wrapper The solution is to list which vars should be in the wrapper and check them when setting a value. That is the approach used in this commit, but more generalized. In line with __getattr__, private values cannot be set on underlying envs. There are two exceptions: _cumulative_rewards was previously exempted in __getattr__ because it is used by many envs. _skip_agent_selection is added because is used byt the dead step handling. If a wrapper can't set this, that functionality will break. --- pettingzoo/utils/wrappers/base.py | 35 ++++++++++++++++++- .../utils/wrappers/multi_episode_env.py | 2 ++ pettingzoo/utils/wrappers/order_enforcing.py | 4 ++- .../utils/wrappers/terminate_illegal.py | 2 ++ 4 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pettingzoo/utils/wrappers/base.py b/pettingzoo/utils/wrappers/base.py index cea324189..41769bc8c 100644 --- a/pettingzoo/utils/wrappers/base.py +++ b/pettingzoo/utils/wrappers/base.py @@ -14,16 +14,49 @@ class BaseWrapper(AECEnv[AgentID, ObsType, ActionType]): All AECEnv wrappers should inherit from this base class """ + # This is a list of object variables (as strings), used by THIS wrapper, + # which should be stored by the wrapper object and not by the underlying + # environment. They are used to store information that the wrapper needs + # to behave correctly. The list is used by __setattr__() to determine where + # to store variables. It is very important that this list is correct to + # prevent confusing bugs. + # Wrappers inheriting from this class should include their own _local_vars + # list with object variables used by that class. Note that 'env' is hardcoded + # as part of the __setattr__ function so should not be included. + _local_vars = [] + def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): super().__init__() self.env = env def __getattr__(self, name: str) -> Any: """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" - if name.startswith("_") and name != "_cumulative_rewards": + if name.startswith("_") and name not in [ + "_cumulative_rewards", + "_skip_agent_selection", + ]: raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute ``name`` if it is this class's value, otherwise send to env.""" + # these are the attributes that can be set on this wrapper directly + if name == "env" or name in self._local_vars: + self.__dict__[name] = value + else: + # If this is being raised by your wrapper while you are trying to access + # a variable that is owned by the wrapper and NOT part of the env, you + # may have forgotten to add the variable to the _local_vars list. + if name.startswith("_") and name not in [ + "_cumulative_rewards", + "_skip_agent_selection", + ]: + raise AttributeError( + f"setting private attribute '{name}' is prohibited" + ) + # send to the underlying environment to handle + setattr(self.__dict__["env"], name, value) + @property def unwrapped(self) -> AECEnv: return self.env.unwrapped diff --git a/pettingzoo/utils/wrappers/multi_episode_env.py b/pettingzoo/utils/wrappers/multi_episode_env.py index 9c233f0c9..a00c5ba00 100644 --- a/pettingzoo/utils/wrappers/multi_episode_env.py +++ b/pettingzoo/utils/wrappers/multi_episode_env.py @@ -15,6 +15,8 @@ class MultiEpisodeEnv(BaseWrapper): The result of this wrapper is that the environment is no longer Markovian around the environment reset. """ + _local_vars = ["_num_episodes", "_episodes_elapsed", "_seed", "_options"] + def __init__(self, env: AECEnv, num_episodes: int): """__init__. diff --git a/pettingzoo/utils/wrappers/order_enforcing.py b/pettingzoo/utils/wrappers/order_enforcing.py index 649c23caa..016ce536f 100644 --- a/pettingzoo/utils/wrappers/order_enforcing.py +++ b/pettingzoo/utils/wrappers/order_enforcing.py @@ -26,14 +26,16 @@ class OrderEnforcingWrapper(BaseWrapper[AgentID, ObsType, ActionType]): * warn on calling step after environment is terminated or truncated """ + _local_vars = ["_has_reset", "_has_rendered", "_has_updated"] + def __init__(self, env: AECEnv[AgentID, ObsType, ActionType]): assert isinstance( env, AECEnv ), "OrderEnforcingWrapper is only compatible with AEC environments" + super().__init__(env) self._has_reset = False self._has_rendered = False self._has_updated = False - super().__init__(env) def __getattr__(self, value: str) -> Any: """Raises an error message when data is gotten from the env. diff --git a/pettingzoo/utils/wrappers/terminate_illegal.py b/pettingzoo/utils/wrappers/terminate_illegal.py index a49d9a0be..281b456bb 100644 --- a/pettingzoo/utils/wrappers/terminate_illegal.py +++ b/pettingzoo/utils/wrappers/terminate_illegal.py @@ -13,6 +13,8 @@ class TerminateIllegalWrapper(BaseWrapper[AgentID, ObsType, ActionType]): illegal_reward: number that is the value of the player making an illegal move. """ + _local_vars = ["_prev_obs", "_prev_info", "_terminated", "_illegal_value"] + def __init__( self, env: AECEnv[AgentID, ObsType, ActionType], illegal_reward: float ): From b6fe82ff489e360fec51c208cc61680ea2c5cacc Mon Sep 17 00:00:00 2001 From: David Ackerman <145808634+dm-ackerman@users.noreply.github.com> Date: Sat, 10 Feb 2024 15:19:55 +0000 Subject: [PATCH 2/2] Add tests to validate fix of 1176 These check that the wrapper functions of __getattr__ and __setattr__ work as intended --- test/wrapper_test.py | 138 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/test/wrapper_test.py b/test/wrapper_test.py index 650fe328b..e5eb24cf1 100644 --- a/test/wrapper_test.py +++ b/test/wrapper_test.py @@ -4,7 +4,12 @@ from pettingzoo.butterfly import pistonball_v6 from pettingzoo.classic import texas_holdem_no_limit_v6 -from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv +from pettingzoo.utils.env import AECEnv +from pettingzoo.utils.wrappers import ( + BaseWrapper, + MultiEpisodeEnv, + MultiEpisodeParallelEnv, +) @pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6]) @@ -67,3 +72,134 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None: assert ( steps == num_episodes * 125 ), f"Expected to have 125 steps per episode, got {steps / num_episodes}." + + +class FakeEnv(AECEnv): + """Fake environment used by the getattr and setattr tests.""" + + def __init__(self): + self.public_value: int = 123 + self._private_value: int = 456 + self.agents = ["a1, a2"] + self.terminations = {agent: True for agent in self.agents} + self.agent_selection = self.agents[0] + self._name = "env" # should never be used + + def compare_private(self, value: int) -> bool: + """Return comparison of value with _private_value.""" + return self._private_value == value + + +class FakeWrapper(BaseWrapper): + """Fake wrapper used by the getattr and setattr tests.""" + + # these variables should be settable + _local_vars = ["wrapper_variable", "_private_wrapper_variable"] + + def __init__(self, env: FakeEnv): + super().__init__(env) + # bypass __setattr__ so we have a private variable that is not in + # the _local_vars list. We should be able to access this. + self.__dict__["_name"] = "wrapper" + + +def test_wrapper_getattr() -> None: + """Test that the base wrapper's __getattr__ works correctly. + + Public variables of the env can be accessed from the wrapper. + Private variables cannot and will raise an AttributeError. + """ + wrapped = FakeWrapper(FakeEnv()) + + # Public values: fall through the the base env + expected_public = wrapped.env.public_value # can access directly from env + assert ( + wrapped.public_value == expected_public + ), "Wrapper can't access public env value" + + # Private values: trying to access should trigger an AttributeError + expected_private = wrapped.env._private_value # can access directly from env + with pytest.raises(AttributeError): + result = wrapped._private_value == expected_private + + # Meanwhile, calling an env function that does the same thing + # should be fine because the the function is delegated to the env. + result = wrapped.compare_private(expected_private) + + # Wrapper should not set any default value when trying to access a variable + # that is not defined in the env or wrapper. It should trigger an AttributeError + with pytest.raises(AttributeError): + result = wrapped.nonexistant_value + + # However, should be able to intentionally assign a default value when + # using getattr, even with a private variable. + # Note: this works because the attempt to access _private_value + # raises a new AttributeError from __getattr__ that causes getattr + # to return the given default value. + default = wrapped.env._private_value + 1 # ensure default is different + result = getattr(wrapped, "_private_value", default) + assert result == default, "Default value not set correctly" + + # Should be able to get any private variables owned by the wrapper, + # even if not defined in _local_vars. + # Note: This is not a design choice, it's a consequence the implementation. + # FakeWrapper has _name defined on itself, but not listed in _local_vars. + assert wrapped._name == "wrapper" + + +def test_wrapper_setattr() -> None: + """Test that wrapper's setattr works properly. + + It should pass everything that isn't in _local_vars through to the + base environment. Everything in _local vars should be stored in the + wrapper object and not be part of the base environment. + """ + wrapped = FakeWrapper(FakeEnv()) + + # Having the wrapper directly set an env's public variable should: + # 1) change the value in the env and 2) not set it in the wrapper. + target_value = wrapped.public_value + 1 # ensure new value is different + wrapped.public_value = target_value + assert ( + wrapped.env.public_value == target_value + ), "Wrapper didn't correctly set env value" + assert "public_value" not in wrapped.__dict__, "Wrapper set value in wrong place" + + # Setting env's private value should only be allowed by the env. + # Trying to directly do so from the wrapper should raise an AttributeError + with pytest.raises(AttributeError): + wrapped._private_value = target_value + + # Should work normally when accessed from the env + wrapped.env._private_value = target_value + + # AECEnv._deads_step_first() currently sets _skip_agent_selection and + # agent_selection. These should both be dispatched to the env, not set + # on the wrapper. + wrapped._deads_step_first() + assert "_skip_agent_selection" in wrapped.env.__dict__, "Value not set on env" + assert "agent_selection" in wrapped.env.__dict__, "Value not set on env" + assert ( + "_skip_agent_selection" not in wrapped.__dict__ + ), "Wrapper set value in wrong place" + assert "agent_selection" not in wrapped.__dict__, "Wrapper set value in wrong place" + + # All values in _local_vars that are set should go to the wrapper and + # not the env, regardless of whether they are private or not + for name in wrapped._local_vars: + # should not be in either before being set + assert ( + name not in wrapped.__dict__ + ), "test logic failure: variable should not be set" + assert ( + name not in wrapped.env.__dict__ + ), "test logic failure: variable should not be set" + setattr(wrapped, name, 1) + assert name in wrapped.__dict__, "local wrapper value not set" + assert name not in wrapped.env.__dict__, "local wrapper value set on env" + + # Not able to set any private variables, even if owned by the + # wrapper, unless they are listed in _local_vars. + # FakeWrapper has _name defined on itself, but not listed in _local_vars. + with pytest.raises(AttributeError): + wrapped._name = "changed wrapper"