diff --git a/sot/opcode_translator/executor/dispatcher.py b/sot/opcode_translator/executor/dispatcher.py index b85ff34fb..1f8fd64e8 100644 --- a/sot/opcode_translator/executor/dispatcher.py +++ b/sot/opcode_translator/executor/dispatcher.py @@ -193,6 +193,7 @@ class Dispatcher: handlers: dict[ Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]] ] = {} + graph: Any = None @classmethod def register( diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 581bd862a..d210d8bca 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -35,6 +35,7 @@ operator_in, operator_not_in, ) +from .dispatcher import Dispatcher from .function_graph import FunctionGraph from .guard import Guard from .instr_flag import FORMAT_VALUE_FLAG as FV @@ -268,7 +269,7 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None: except Exception as e: raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e finally: - simulator._graph.pycode_gen = None + simulator.cleanup() def tos_op_wrapper(fn: Callable): @@ -1511,6 +1512,11 @@ def __init__(self, frame: types.FrameType, **kwargs): self._name = "Executor" self.call_stack[:] = [] super().__init__(frame.f_code, graph) + Dispatcher.graph = graph + + def cleanup(self): + self._graph.pycode_gen = None + Dispatcher.graph = None @event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2) def _prepare_virtual_env(self): diff --git a/sot/opcode_translator/executor/variable_dispatch.py b/sot/opcode_translator/executor/variable_dispatch.py index 580d10ed3..6db9d26a8 100644 --- a/sot/opcode_translator/executor/variable_dispatch.py +++ b/sot/opcode_translator/executor/variable_dispatch.py @@ -108,6 +108,15 @@ def inner(*args, **kwargs): ) # dict +Dispatcher.register( + dict, + (), + lambda: VariableFactory.from_value( + {}, + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) Dispatcher.register( dict.get, ("DictVariable", "ConstantVariable", optional("VariableBase")), @@ -196,6 +205,16 @@ def inner(*args, **kwargs): ) # list +Dispatcher.register( + list, + (), + lambda: VariableFactory.from_value( + [], + graph=Dispatcher.graph, + tracker=DummyTracker([]), + ), +) + Dispatcher.register( list, ("ContainerVariable | EnumerateVariable",), @@ -399,16 +418,6 @@ def dispatch_reversed(var: ContainerVariable): ("ContainerVariable",), lambda var: var.bool(), ) -Dispatcher.register( - bool, - ("ConstantVariable",), - lambda var: var.bool(), -) -Dispatcher.register( - operator.truth, - ("ContainerVariable",), - lambda var: var.bool(), -) Dispatcher.register( operator.truth, ("ConstantVariable",), diff --git a/tests/test_04_list.py b/tests/test_04_list.py index 7e9690a44..89922924f 100644 --- a/tests/test_04_list.py +++ b/tests/test_04_list.py @@ -181,7 +181,15 @@ def list_tensor_min_api(x: paddle.Tensor): return x.min() -class TestExecutor(TestCaseBase): +def list_no_arguments(): + l1 = list() # noqa: C408 + l1.append(1) + l2 = list() # noqa: C408 + l2.append(2) + return l1[0] + l2[0] + + +class TestList(TestCaseBase): def test_simple(self): self.assert_results(list_getitem_int, 1, paddle.to_tensor(2)) self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2)) @@ -258,6 +266,9 @@ def test_simple(self): ) self.assert_results(list_tensor_min_api, paddle.to_tensor([1, 2, 3])) + def test_list_noargs(self): + self.assert_results(list_no_arguments) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_05_dict.py b/tests/test_05_dict.py index 622ce3ed6..742cfd822 100644 --- a/tests/test_05_dict.py +++ b/tests/test_05_dict.py @@ -140,7 +140,15 @@ def dict_construct_from_comprehension(): return d -class TestExecutor(TestCaseBase): +def dict_no_arguments(): + d1 = dict() # noqa: C408 + d1.update({1: 2}) + d2 = dict() # noqa: C408 + d2.update({3: 4}) + return d1[1] + d2[3] + + +class TestDict(TestCaseBase): def test_build_map(self): self.assert_results(build_map, 1, paddle.to_tensor(2)) @@ -199,6 +207,9 @@ def test_construct(self): self.assert_results(dict_construct_from_tuple) self.assert_results(dict_construct_from_comprehension) + def test_dict_noargs(self): + self.assert_results(dict_no_arguments) + if __name__ == "__main__": unittest.main()