Skip to content

Commit

Permalink
Merge pull request #64 from emnigma/unsafe_classes_speedup
Browse files Browse the repository at this point in the history
Unsafe classes speedup
  • Loading branch information
gsvgit authored Jul 20, 2023
2 parents 261a529 + 7421933 commit a4ef30f
Show file tree
Hide file tree
Showing 15 changed files with 210 additions and 126 deletions.
38 changes: 22 additions & 16 deletions VSharp.ML.AIAgent/agent/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from enum import Enum
from typing import Optional

from .unsafe_json import obj_from_dict

from common.game import GameMap, GameState, Reward
from dataclasses_json import config, dataclass_json
from config import FeatureConfig


class ClientMessageType(str, Enum):
Expand All @@ -16,28 +19,28 @@ class ClientMessageType(str, Enum):


@dataclass_json
@dataclass
@dataclass(slots=True)
class ClientMessageBody:
def type(self) -> ClientMessageType:
pass


@dataclass_json
@dataclass
@dataclass(slots=True)
class GetTrainMapsMessageBody(ClientMessageBody):
def type(self) -> ClientMessageType:
return ClientMessageType.GETTRAINMAPS


@dataclass_json
@dataclass
@dataclass(slots=True)
class GetValidationMapsMessageBody(ClientMessageBody):
def type(self) -> ClientMessageType:
return ClientMessageType.GETVALIDATIONMAPS


@dataclass_json
@dataclass
@dataclass(slots=True)
class StartMessageBody(ClientMessageBody):
MapId: int
StepsToPlay: int
Expand All @@ -47,7 +50,7 @@ def type(self) -> ClientMessageType:


@dataclass_json
@dataclass
@dataclass(slots=True)
class StepMessageBody(ClientMessageBody):
StateId: int
PredictedStateUsefulness: float
Expand All @@ -57,7 +60,7 @@ def type(self) -> ClientMessageType:


@dataclass_json
@dataclass
@dataclass(slots=True)
class ClientMessage:
MessageType: str = field(init=False)
MessageBody: ClientMessageBody = field(
Expand All @@ -73,7 +76,7 @@ def __post_init__(self):


@dataclass_json
@dataclass
@dataclass(slots=True)
class MapsMessageBody:
Maps: list[GameMap]

Expand All @@ -87,45 +90,48 @@ class ServerMessageType(str, Enum):


@dataclass_json
@dataclass
@dataclass(slots=True)
class ServerMessage:
MessageType: ServerMessageType

class DeserializationException(Exception):
pass

def from_json_handle(*args, expected, **kwargs):
def from_json_handle(data, expected):
if FeatureConfig.DISABLE_MESSAGE_CHECKS:
return obj_from_dict(json.loads(data))

try:
return expected.from_json(args[0], **kwargs)
return expected.from_json(data)
except Exception as e:
err_to_display = f"{type(e)} - {e}: tried to decode {expected}, got unmatched structure, registered to app.log under [ERROR] tag"
error = f"{type(e)} - {e}: tried to decode {expected}, got raw data: {json.dumps(json.loads(args[0]), indent=2)}"
error = f"{type(e)} - {e}: tried to decode {expected}, got raw data: {json.dumps(json.loads(data), indent=2)}"
logging.error(error)
raise ServerMessage.DeserializationException(err_to_display)


@dataclass
@dataclass(slots=True)
class GameStateServerMessage(ServerMessage):
MessageBody: GameState


@dataclass
@dataclass(slots=True)
class RewardServerMessage(ServerMessage):
MessageBody: Reward


@dataclass
@dataclass(slots=True)
class MapsServerMessage(ServerMessage):
MessageBody: MapsMessageBody


@dataclass
@dataclass(slots=True)
class GameOverServerMessageBody:
ActualCoverage: Optional[int]
TestsCount: int
ErrorsCount: int


@dataclass
@dataclass(slots=True)
class GameOverServerMessage(ServerMessage):
MessageBody: GameOverServerMessageBody
41 changes: 41 additions & 0 deletions VSharp.ML.AIAgent/agent/unsafe_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from dataclasses import asdict as dataclasses_asdict, is_dataclass
from typing import ClassVar, Protocol


class Dataclass(Protocol):
__dataclass_fields__: ClassVar[dict]


BareObject = type("BareObject", (), {})


def obj_from_dict(data: dict | list | str | float) -> BareObject:
if isinstance(data, list):
return [obj_from_dict(item) for item in data]
if isinstance(data, dict):
inner_dict = {
field_name: obj_from_dict(field_raw)
for field_name, field_raw in data.items()
}
bare_obj = BareObject()
bare_obj.__dict__ = inner_dict
return bare_obj
return data


def asdict(
data: BareObject | Dataclass | list | dict | str | float,
) -> dict | list | str | float:
if isinstance(data, BareObject):
return {asdict(k): asdict(v) for k, v in data.__dict__.items()}
elif is_dataclass(data):
try:
return data.to_json()
except TypeError:
return asdict(dataclasses_asdict(data))
elif isinstance(data, list):
return [asdict(item) for item in data]
elif isinstance(data, dict):
return {asdict(k): asdict(v) for k, v in data.items()}
else:
return data
4 changes: 3 additions & 1 deletion VSharp.ML.AIAgent/agent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def switch_maps_type(

def send_all(message_body: ClientMessage):
for ws_string in websocket_strings:
with closing(websocket.create_connection(ws_string)) as ws:
with closing(
websocket.create_connection(ws_string, skip_utf8_validation=True)
) as ws:
ws.send(message_body.to_json())
ws.recv()

Expand Down
18 changes: 9 additions & 9 deletions VSharp.ML.AIAgent/common/game.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@


@dataclass_json
@dataclass
@dataclass(slots=True)
class StateHistoryElem:
GraphVertexId: int
NumOfVisits: int


@dataclass_json
@dataclass
@dataclass(slots=True)
class State:
Id: int
Position: int
Expand All @@ -28,7 +28,7 @@ def __hash__(self) -> int:


@dataclass_json
@dataclass
@dataclass(slots=True)
class GameMapVertex:
Uid: int
Id: int
Expand All @@ -41,29 +41,29 @@ class GameMapVertex:


@dataclass_json
@dataclass
@dataclass(slots=True)
class GameEdgeLabel:
Token: int


@dataclass_json
@dataclass
@dataclass(slots=True)
class GameMapEdge:
VertexFrom: int
VertexTo: int
Label: GameEdgeLabel


@dataclass_json
@dataclass
@dataclass(slots=True)
class GameState:
GraphVertices: list[GameMapVertex]
States: list[State]
Map: list[GameMapEdge]


@dataclass_json
@dataclass
@dataclass(slots=True)
class GameMap:
Id: int
MaxSteps: int
Expand All @@ -78,7 +78,7 @@ def __hash__(self) -> int:


@dataclass_json
@dataclass
@dataclass(slots=True)
class MoveReward:
ForCoverage: int
ForVisitedInstructions: int
Expand All @@ -103,7 +103,7 @@ def printable(self, verbose=False) -> str:


@dataclass_json
@dataclass
@dataclass(slots=True)
class Reward:
ForMove: MoveReward
MaxPossibleReward: int
4 changes: 4 additions & 0 deletions VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging

import ml.models


class GeneralConfig:
SERVER_COUNT = 8
MAX_STEPS = 3000
LOGGER_LEVEL = logging.INFO
MODEL_INIT = lambda: ml.models.SAGEConvModel(16)


class BrokerConfig:
Expand All @@ -20,3 +23,4 @@ class FeatureConfig:
SHOW_SUCCESSORS = True
NAME_LEN = 7
N_BEST_SAVED_EACH_GEN = 2
DISABLE_MESSAGE_CHECKS = True
13 changes: 11 additions & 2 deletions VSharp.ML.AIAgent/conn/classes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from dataclasses import dataclass, field
from typing import Callable

from dataclasses_json import config, dataclass_json

from ml.model_wrappers.nnwrapper import NNWrapper, encode, decode
from agent.unsafe_json import asdict
from config import FeatureConfig
from ml.model_wrappers.nnwrapper import NNWrapper, decode, encode
from selection.classes import Map2Result


def custom_encoder_if_disable_message_checks() -> Callable | None:
return asdict if FeatureConfig.DISABLE_MESSAGE_CHECKS else None


@dataclass_json
@dataclass
class Agent2ResultsOnMaps:
agent: NNWrapper = field(metadata=config(encoder=encode, decoder=decode))
results: list[Map2Result]
results: list[Map2Result] = field(
metadata=config(encoder=custom_encoder_if_disable_message_checks())
)
2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/conn/socket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
@contextmanager
def game_server_socket_manager():
socket_url = aquire_ws()
socket = websocket.create_connection(socket_url)
socket = websocket.create_connection(socket_url, skip_utf8_validation=True)
try:
yield socket
finally:
Expand Down
8 changes: 4 additions & 4 deletions VSharp.ML.AIAgent/epochs_statistics/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from dataclasses import dataclass


@dataclass
@dataclass(slots=True)
class Name2ResultViewModel:
model_name: str
pretty_result: str


@dataclass
@dataclass(slots=True)
class Interval:
left: float
right: float
Expand All @@ -16,14 +16,14 @@ def pretty(self):
return f"{self.left:.2f}%-{self.right:.2f}%, diff={self.right - self.left:.2f}%"


@dataclass
@dataclass(slots=True)
class Name2Stats:
mutable_name: str
av_coverage: float
interval: Interval


@dataclass
@dataclass(slots=True)
class CoverageStats:
euc_dist2_full_cov: float
average_cov: float
Expand Down
Loading

0 comments on commit a4ef30f

Please sign in to comment.