Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Dispatcher] Implement parameterless call support for list and dict #326

Merged
merged 2 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sot/opcode_translator/executor/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class Dispatcher:
handlers: dict[
Callable[..., Any], list[tuple[Pattern, Callable[..., Any]]]
] = {}
graph: Any = None

@classmethod
def register(
Expand Down
8 changes: 7 additions & 1 deletion sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
operator_in,
operator_not_in,
)
from .dispatcher import Dispatcher
from .function_graph import FunctionGraph
from .guard import Guard
from .instr_flag import FORMAT_VALUE_FLAG as FV
Expand Down Expand Up @@ -268,7 +269,7 @@ def start_translate(frame: types.FrameType, **kwargs) -> GuardedFunction | None:
except Exception as e:
raise InnerError(OpcodeExecutorBase.error_message_summary(e)) from e
finally:
simulator._graph.pycode_gen = None
simulator.cleanup()


def tos_op_wrapper(fn: Callable):
Expand Down Expand Up @@ -1511,6 +1512,11 @@ def __init__(self, frame: types.FrameType, **kwargs):
self._name = "Executor"
self.call_stack[:] = []
super().__init__(frame.f_code, graph)
Dispatcher.graph = graph

def cleanup(self):
self._graph.pycode_gen = None
Dispatcher.graph = None

@event_register("OpcodeExecutor: _prepare_virtual_env", event_level=2)
def _prepare_virtual_env(self):
Expand Down
29 changes: 19 additions & 10 deletions sot/opcode_translator/executor/variable_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def inner(*args, **kwargs):
)

# dict
Dispatcher.register(
dict,
(),
lambda: VariableFactory.from_value(
{},
graph=Dispatcher.graph,
tracker=DummyTracker([]),
),
)
Dispatcher.register(
dict.get,
("DictVariable", "ConstantVariable", optional("VariableBase")),
Expand Down Expand Up @@ -196,6 +205,16 @@ def inner(*args, **kwargs):
)

# list
Dispatcher.register(
list,
(),
lambda: VariableFactory.from_value(
[],
graph=Dispatcher.graph,
tracker=DummyTracker([]),
),
)

Dispatcher.register(
list,
("ContainerVariable | EnumerateVariable",),
Expand Down Expand Up @@ -399,16 +418,6 @@ def dispatch_reversed(var: ContainerVariable):
("ContainerVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
bool,
("ConstantVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ContainerVariable",),
lambda var: var.bool(),
)
Dispatcher.register(
operator.truth,
("ConstantVariable",),
Expand Down
13 changes: 12 additions & 1 deletion tests/test_04_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,15 @@ def list_tensor_min_api(x: paddle.Tensor):
return x.min()


class TestExecutor(TestCaseBase):
def list_no_arguments():
l1 = list() # noqa: C408
l1.append(1)
l2 = list() # noqa: C408
l2.append(2)
return l1[0] + l2[0]


class TestList(TestCaseBase):
def test_simple(self):
self.assert_results(list_getitem_int, 1, paddle.to_tensor(2))
self.assert_results(list_getitem_tensor, 1, paddle.to_tensor(2))
Expand Down Expand Up @@ -258,6 +266,9 @@ def test_simple(self):
)
self.assert_results(list_tensor_min_api, paddle.to_tensor([1, 2, 3]))

def test_list_noargs(self):
self.assert_results(list_no_arguments)


if __name__ == "__main__":
unittest.main()
13 changes: 12 additions & 1 deletion tests/test_05_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@ def dict_construct_from_comprehension():
return d


class TestExecutor(TestCaseBase):
def dict_no_arguments():
d1 = dict() # noqa: C408
d1.update({1: 2})
d2 = dict() # noqa: C408
d2.update({3: 4})
return d1[1] + d2[3]


class TestDict(TestCaseBase):
def test_build_map(self):
self.assert_results(build_map, 1, paddle.to_tensor(2))

Expand Down Expand Up @@ -199,6 +207,9 @@ def test_construct(self):
self.assert_results(dict_construct_from_tuple)
self.assert_results(dict_construct_from_comprehension)

def test_dict_noargs(self):
self.assert_results(dict_no_arguments)


if __name__ == "__main__":
unittest.main()