Skip to content

Commit

Permalink
Big refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Jul 20, 2023
1 parent 007e307 commit dc549dd
Show file tree
Hide file tree
Showing 30 changed files with 127 additions and 123 deletions.
File renamed without changes.
20 changes: 19 additions & 1 deletion VSharp.ML.AIAgent/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
from .game import GameState
from collections import defaultdict

from common.classes import Agent2Result, AgentResultsOnGameMaps, GameMapsModelResults

from common.game import GameState


def get_states(game_state: GameState) -> set[int]:
return {s.Id for s in game_state.States}


def invert_mapping_mrgm_gmmr(
model_results_on_map: AgentResultsOnGameMaps,
) -> GameMapsModelResults:
inverse_mapping: GameMapsModelResults = defaultdict(list)

for named_agent, list_of_map_result_mappings in model_results_on_map.items():
for map_result_mapping in list_of_map_result_mappings:
map, result = (map_result_mapping.map, map_result_mapping.game_result)

inverse_mapping[map].append(Agent2Result(named_agent, result))

return inverse_mapping
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from dataclasses_json import config, dataclass_json

from agent.unsafe_json import asdict
from common.classes import Map2Result
from config import FeatureConfig
from connection.game_server_conn.unsafe_json import asdict
from ml.model_wrappers.nnwrapper import NNWrapper, decode, encode
from selection.classes import Map2Result


def custom_encoder_if_disable_message_checks() -> Callable | None:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import contextmanager

import websocket

from .requests import aquire_ws, return_ws


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)


class NAgent:
class WrongAgentStateError(Exception):
class Connector:
class WrongConnectorStateError(Exception):
def __init__(
self, source: str, received: str, expected: str, at_step: int
) -> None:
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(

def _raise_if_gameover(self, msg) -> GameOverServerMessage | str:
if self.game_is_over:
raise NAgent.GameOver
raise Connector.GameOver

matching_message_type = ServerMessage.from_json_handle(
msg, expected=ServerMessage
Expand All @@ -75,7 +75,7 @@ def _raise_if_gameover(self, msg) -> GameOverServerMessage | str:
)
self.game_is_over = True
logging.debug(f"--> {matching_message_type}")
raise NAgent.GameOver(
raise Connector.GameOver(
actual_coverage=deser_msg.MessageBody.ActualCoverage,
tests_count=deser_msg.MessageBody.TestsCount,
errors_count=deser_msg.MessageBody.ErrorsCount,
Expand Down Expand Up @@ -114,7 +114,7 @@ def recv_reward_or_throw_gameover(self) -> Reward:
def _process_reward_server_message(self, msg):
match msg.MessageType:
case ServerMessageType.INCORRECT_PREDICTED_STATEID:
raise NAgent.IncorrectSentStateError(
raise Connector.IncorrectSentStateError(
f"Sending state_id={self._sent_state_id} \
at step #{self._current_step} resulted in {msg.MessageType}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from enum import Enum
from typing import Optional

from .unsafe_json import obj_from_dict
from dataclasses_json import config, dataclass_json

from common.game import GameMap, GameState, Reward
from dataclasses_json import config, dataclass_json
from config import FeatureConfig

from .unsafe_json import obj_from_dict


class ClientMessageType(str, Enum):
START = "start"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import asdict as dataclasses_asdict, is_dataclass
from dataclasses import asdict as dataclasses_asdict
from dataclasses import is_dataclass
from typing import ClassVar, Protocol


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/epochs_statistics/gen_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from statistics import mean, median

from epochs_statistics.common import CoverageStats, Interval
from selection.classes import Map2Result
from common.classes import Map2Result


def euc_dist2full_coverage(data: list[float]) -> float:
Expand Down
4 changes: 2 additions & 2 deletions VSharp.ML.AIAgent/epochs_statistics/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

import pandas as pd

from common.classes import Agent2Result, AgentResultsOnGameMaps
from common.strings import (
AV_COVERAGE_COL_NAME,
COV_DEVIATION_COL_NAME,
EUC_DIST2FULL_COV_COL_NAME,
MEDIAN_COVERAGE_COL_NAME,
)
from common.utils import invert_mapping_mrgm_gmmr
from config import FeatureConfig
from epochs_statistics.common import Interval, Name2ResultViewModel
from epochs_statistics.gen_stats import compute_euc_dist_to_full_coverage
from selection.classes import Agent2Result, AgentResultsOnGameMaps
from selection.utils import invert_mapping_mrgm_gmmr


def get_sample_val(d: dict):
Expand Down
6 changes: 3 additions & 3 deletions VSharp.ML.AIAgent/epochs_statistics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from shutil import rmtree

from common.constants import (
LEADERS_TABLES_LOG_FILE,
TABLES_LOG_FILE,
APP_LOG_FILE,
EPOCH_BEST_DIR,
BASE_REPORT_DIR,
EPOCH_BEST_DIR,
LEADERS_TABLES_LOG_FILE,
TABLES_LOG_FILE,
)


Expand Down
8 changes: 4 additions & 4 deletions VSharp.ML.AIAgent/learning/entry_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
init_log_file,
init_tables_file,
)
from selection.crossover_type import CrossoverType
from selection.mutation_type import MutationType
from selection.parent_selection_type import ParentSelectionType
from timer.resources_manager import manage_inference_stats
from learning.selection.crossover_type import CrossoverType
from learning.selection.mutation_type import MutationType
from learning.selection.parent_selection_type import ParentSelectionType
from learning.timer.resources_manager import manage_inference_stats

logging.basicConfig(
level=GeneralConfig.LOGGER_LEVEL,
Expand Down
20 changes: 10 additions & 10 deletions VSharp.ML.AIAgent/learning/genetic_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@
import pygad.torchga

import ml
from agent.utils import MapsType
from common.classes import AgentResultsOnGameMaps
from common.constants import DEVICE
from config import FeatureConfig, GeneralConfig
from conn.classes import Agent2ResultsOnMaps
from conn.requests import recv_game_result_list, send_game_results
from connection.broker_conn.classes import Agent2ResultsOnMaps
from connection.broker_conn.requests import recv_game_result_list, send_game_results
from connection.game_server_conn.utils import MapsType
from epochs_statistics.tables import create_pivot_table, table_to_csv, table_to_string
from epochs_statistics.utils import (
append_to_tables_file,
create_epoch_subdir,
rewrite_best_tables_file,
)
from ml.fileop import save_model
from ml.model_wrappers.nnwrapper import NNWrapper
from selection.classes import AgentResultsOnGameMaps
from selection.scorer import straight_scorer
from timer.stats import compute_statistics
from timer.utils import (
from learning.selection.scorer import straight_scorer
from learning.timer.stats import compute_statistics
from learning.timer.utils import (
create_temp_epoch_inference_dir,
dump_and_reset_epoch_times,
load_times_array,
)
from ml.fileop import save_model
from ml.model_wrappers.nnwrapper import NNWrapper

from .play_game import play_game

Expand Down Expand Up @@ -133,7 +133,7 @@ def fitness_function(ga_inst, solution, solution_idx) -> float:
nnwrapper = NNWrapper(model, weights_flat=solution)

list_of_map2result = play_game(
weighted_predictor=nnwrapper,
with_predictor=nnwrapper,
max_steps=GeneralConfig.MAX_STEPS,
maps_type=MapsType.TRAIN,
)
Expand Down
104 changes: 54 additions & 50 deletions VSharp.ML.AIAgent/learning/play_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,54 @@

import tqdm

from agent.n_agent import NAgent
from agent.utils import MapsType, get_maps
from common.classes import GameResult, Map2Result
from common.constants import TQDM_FORMAT_DICT
from common.utils import get_states
from config import FeatureConfig, GeneralConfig
from conn.socket_manager import game_server_socket_manager
from config import FeatureConfig
from connection.broker_conn.socket_manager import game_server_socket_manager
from connection.game_server_conn.connector import Connector
from connection.game_server_conn.utils import MapsType, get_maps
from learning.timer.resources_manager import manage_map_inference_times_array
from learning.timer.stats import compute_statistics
from learning.timer.utils import get_map_inference_times
from ml.fileop import save_model
from ml.model_wrappers.protocols import WeightedPredictor
from selection.classes import GameResult, Map2Result
from timer.resources_manager import manage_map_inference_times_array
from timer.stats import compute_statistics
from timer.utils import get_map_inference_times
from ml.model_wrappers.protocols import Predictor

TimeDuration: TypeAlias = float


def play_map(
with_agent: NAgent, with_weighted_model: WeightedPredictor
with_connector: Connector, with_predictor: Predictor
) -> tuple[GameResult, TimeDuration]:
steps_count = 0
game_state = None
actual_coverage = None
steps = with_agent.steps
steps = with_connector.steps

start_time = perf_counter()

try:
for _ in range(steps):
game_state = with_agent.recv_state_or_throw_gameover()
predicted_state_id = with_weighted_model.predict(game_state)
game_state = with_connector.recv_state_or_throw_gameover()
predicted_state_id = with_predictor.predict(game_state)
logging.debug(
f"<{with_weighted_model.name()}> step: {steps_count}, available states: {get_states(game_state)}, predicted: {predicted_state_id}"
f"<{with_predictor.name()}> step: {steps_count}, available states: {get_states(game_state)}, predicted: {predicted_state_id}"
)

with_agent.send_step(
with_connector.send_step(
next_state_id=predicted_state_id,
predicted_usefullness=42.0, # left it a constant for now
)

_ = with_agent.recv_reward_or_throw_gameover()
_ = with_connector.recv_reward_or_throw_gameover()
steps_count += 1

_ = with_agent.recv_state_or_throw_gameover() # wait for gameover
_ = with_connector.recv_state_or_throw_gameover() # wait for gameover
steps_count += 1
except NAgent.GameOver as gameover:
except Connector.GameOver as gameover:
if game_state is None:
logging.error(
f"<{with_weighted_model.name()}>: immediate GameOver on {with_agent.map.MapName}"
logging.warning(
f"<{with_predictor.name()}>: immediate GameOver on {with_connector.map.MapName}"
)
return GameResult(
steps_count=steps,
Expand All @@ -68,20 +68,9 @@ def play_map(

end_time = perf_counter()

if (
FeatureConfig.DUMP_BY_TIMEOUT.enabled
and end_time - start_time > FeatureConfig.DUMP_BY_TIMEOUT.timeout_seconds
):
save_model(
GeneralConfig.MODEL_INIT(),
to=FeatureConfig.DUMP_BY_TIMEOUT.save_path
/ f"{with_weighted_model.name()}.pth",
weights=with_weighted_model.weights(),
)

if actual_coverage != 100 and steps_count != steps:
logging.error(
f"<{with_weighted_model.name()}>: not all steps exshausted on {with_agent.map.MapName} with non-100% coverage"
logging.warning(
f"<{with_predictor.name()}>: not all steps exshausted on {with_connector.map.MapName} with non-100% coverage"
f"steps taken: {steps_count}, actual coverage: {actual_coverage:.2f}"
)
steps_count = steps
Expand All @@ -93,49 +82,64 @@ def play_map(
actual_coverage_percent=actual_coverage,
)

return model_result, end_time - start_time


def play_map_with_stats(
with_connector: Connector, with_predictor: Predictor
) -> tuple[GameResult, TimeDuration]:
model_result, time_duration = play_map(with_connector, with_predictor)

with manage_map_inference_times_array():
try:
map_inference_times = get_map_inference_times()
mean, std = compute_statistics(map_inference_times)
logging.info(
f"Inference stats for <{with_weighted_model.name()}> on {with_agent.map.MapName}: {mean=}ms, {std=}ms"
f"Inference stats for <{with_predictor.name()}> on {with_connector.map.MapName}: {mean=}ms, {std=}ms"
)
except StatisticsError:
logging.info(
f"<{with_weighted_model.name()}> on {with_agent.map.MapName}: too few samples for stats count"
f"<{with_predictor.name()}> on {with_connector.map.MapName}: too few samples for stats count"
)

return model_result, end_time - start_time
return model_result, time_duration


def play_game(
weighted_predictor: WeightedPredictor, max_steps: int, maps_type: MapsType
):
def play_game(with_predictor: Predictor, max_steps: int, maps_type: MapsType):
with game_server_socket_manager() as ws:
maps = get_maps(websocket=ws, type=maps_type)
with tqdm.tqdm(
total=len(maps),
desc=f"{weighted_predictor.name():20}: {maps_type.value}",
desc=f"{with_predictor.name():20}: {maps_type.value}",
**TQDM_FORMAT_DICT,
) as pbar:
rst: list[GameResult] = []
list_of_map2result: list[Map2Result] = []
for game_map in maps:
logging.info(
f"<{weighted_predictor.name()}> is playing {game_map.MapName}"
)
logging.info(f"<{with_predictor.name()}> is playing {game_map.MapName}")

game_result, time = play_map(
with_agent=NAgent(ws, game_map, max_steps),
with_weighted_model=weighted_predictor,
game_result, time = play_map_with_stats(
with_connector=Connector(ws, game_map, max_steps),
with_predictor=with_predictor,
)
rst.append(game_result)
list_of_map2result.append(Map2Result(game_map, game_result))

logging.info(
f"<{weighted_predictor.name()}> finished map {game_map.MapName} "
message = (
f"<{with_predictor.name()}> finished map {game_map.MapName} "
f"in {game_result.steps_count} steps, {time} seconds, "
f"actual coverage: {game_result.actual_coverage_percent:.2f}"
)
logging_func = logging.info
if (
FeatureConfig.DUMP_BY_TIMEOUT.enabled
and time > FeatureConfig.DUMP_BY_TIMEOUT.timeout_seconds
):
logging_func = logging.warning
message = "OVERTIME: " + message
save_model(
with_predictor.model(),
to=FeatureConfig.DUMP_BY_TIMEOUT.save_path
/ f"{with_predictor.name()}.pth",
)
logging_func(message)
pbar.update(1)
return list_of_map2result
Loading

0 comments on commit dc549dd

Please sign in to comment.