Skip to content

Commit

Permalink
Merge pull request #68 from emnigma/rework_learning
Browse files Browse the repository at this point in the history
Rework learning
  • Loading branch information
gsvgit authored Jul 21, 2023
2 parents 6105ca4 + dc549dd commit 6b6fe96
Show file tree
Hide file tree
Showing 31 changed files with 371 additions and 337 deletions.
File renamed without changes.
20 changes: 19 additions & 1 deletion VSharp.ML.AIAgent/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
from .game import GameState
from collections import defaultdict

from common.classes import Agent2Result, AgentResultsOnGameMaps, GameMapsModelResults

from common.game import GameState


def get_states(game_state: GameState) -> set[int]:
return {s.Id for s in game_state.States}


def invert_mapping_mrgm_gmmr(
model_results_on_map: AgentResultsOnGameMaps,
) -> GameMapsModelResults:
inverse_mapping: GameMapsModelResults = defaultdict(list)

for named_agent, list_of_map_result_mappings in model_results_on_map.items():
for map_result_mapping in list_of_map_result_mappings:
map, result = (map_result_mapping.map, map_result_mapping.game_result)

inverse_mapping[map].append(Agent2Result(named_agent, result))

return inverse_mapping
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

from dataclasses_json import config, dataclass_json

from agent.unsafe_json import asdict
from common.classes import Map2Result
from config import FeatureConfig
from connection.game_server_conn.unsafe_json import asdict
from ml.model_wrappers.nnwrapper import NNWrapper, decode, encode
from selection.classes import Map2Result


def custom_encoder_if_disable_message_checks() -> Callable | None:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from contextlib import contextmanager

import websocket

from .requests import aquire_ws, return_ws


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)


class NAgent:
class WrongAgentStateError(Exception):
class Connector:
class WrongConnectorStateError(Exception):
def __init__(
self, source: str, received: str, expected: str, at_step: int
) -> None:
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(

def _raise_if_gameover(self, msg) -> GameOverServerMessage | str:
if self.game_is_over:
raise NAgent.GameOver
raise Connector.GameOver

matching_message_type = ServerMessage.from_json_handle(
msg, expected=ServerMessage
Expand All @@ -75,7 +75,7 @@ def _raise_if_gameover(self, msg) -> GameOverServerMessage | str:
)
self.game_is_over = True
logging.debug(f"--> {matching_message_type}")
raise NAgent.GameOver(
raise Connector.GameOver(
actual_coverage=deser_msg.MessageBody.ActualCoverage,
tests_count=deser_msg.MessageBody.TestsCount,
errors_count=deser_msg.MessageBody.ErrorsCount,
Expand Down Expand Up @@ -114,7 +114,7 @@ def recv_reward_or_throw_gameover(self) -> Reward:
def _process_reward_server_message(self, msg):
match msg.MessageType:
case ServerMessageType.INCORRECT_PREDICTED_STATEID:
raise NAgent.IncorrectSentStateError(
raise Connector.IncorrectSentStateError(
f"Sending state_id={self._sent_state_id} \
at step #{self._current_step} resulted in {msg.MessageType}"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
from enum import Enum
from typing import Optional

from .unsafe_json import obj_from_dict
from dataclasses_json import config, dataclass_json

from common.game import GameMap, GameState, Reward
from dataclasses_json import config, dataclass_json
from config import FeatureConfig

from .unsafe_json import obj_from_dict


class ClientMessageType(str, Enum):
START = "start"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import asdict as dataclasses_asdict, is_dataclass
from dataclasses import asdict as dataclasses_asdict
from dataclasses import is_dataclass
from typing import ClassVar, Protocol


Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion VSharp.ML.AIAgent/epochs_statistics/gen_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from statistics import mean, median

from epochs_statistics.common import CoverageStats, Interval
from selection.classes import Map2Result
from common.classes import Map2Result


def euc_dist2full_coverage(data: list[float]) -> float:
Expand Down
4 changes: 2 additions & 2 deletions VSharp.ML.AIAgent/epochs_statistics/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@

import pandas as pd

from common.classes import Agent2Result, AgentResultsOnGameMaps
from common.strings import (
AV_COVERAGE_COL_NAME,
COV_DEVIATION_COL_NAME,
EUC_DIST2FULL_COV_COL_NAME,
MEDIAN_COVERAGE_COL_NAME,
)
from common.utils import invert_mapping_mrgm_gmmr
from config import FeatureConfig
from epochs_statistics.common import Interval, Name2ResultViewModel
from epochs_statistics.gen_stats import compute_euc_dist_to_full_coverage
from selection.classes import Agent2Result, AgentResultsOnGameMaps
from selection.utils import invert_mapping_mrgm_gmmr


def get_sample_val(d: dict):
Expand Down
6 changes: 3 additions & 3 deletions VSharp.ML.AIAgent/epochs_statistics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from shutil import rmtree

from common.constants import (
LEADERS_TABLES_LOG_FILE,
TABLES_LOG_FILE,
APP_LOG_FILE,
EPOCH_BEST_DIR,
BASE_REPORT_DIR,
EPOCH_BEST_DIR,
LEADERS_TABLES_LOG_FILE,
TABLES_LOG_FILE,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
init_log_file,
init_tables_file,
)
from selection.crossover_type import CrossoverType
from selection.mutation_type import MutationType
from selection.parent_selection_type import ParentSelectionType
from timer.resources_manager import manage_inference_stats
from learning.selection.crossover_type import CrossoverType
from learning.selection.mutation_type import MutationType
from learning.selection.parent_selection_type import ParentSelectionType
from learning.timer.resources_manager import manage_inference_stats

logging.basicConfig(
level=GeneralConfig.LOGGER_LEVEL,
Expand All @@ -29,7 +29,7 @@

os.environ["NUMEXPR_NUM_THREADS"] = "1"

from .r_learn import fitness_function, on_generation
from .genetic_learning import fitness_function, on_generation


@contextmanager
Expand Down
146 changes: 146 additions & 0 deletions VSharp.ML.AIAgent/learning/genetic_learning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import copy
import json
import random
from collections import defaultdict
from os import getpid

import pygad.torchga

import ml
from common.classes import AgentResultsOnGameMaps
from common.constants import DEVICE
from config import FeatureConfig, GeneralConfig
from connection.broker_conn.classes import Agent2ResultsOnMaps
from connection.broker_conn.requests import recv_game_result_list, send_game_results
from connection.game_server_conn.utils import MapsType
from epochs_statistics.tables import create_pivot_table, table_to_csv, table_to_string
from epochs_statistics.utils import (
append_to_tables_file,
create_epoch_subdir,
rewrite_best_tables_file,
)
from learning.selection.scorer import straight_scorer
from learning.timer.stats import compute_statistics
from learning.timer.utils import (
create_temp_epoch_inference_dir,
dump_and_reset_epoch_times,
load_times_array,
)
from ml.fileop import save_model
from ml.model_wrappers.nnwrapper import NNWrapper

from .play_game import play_game

info_for_tables: AgentResultsOnGameMaps = defaultdict(list)
leader_table: AgentResultsOnGameMaps = defaultdict(list)


def get_n_best_weights_in_last_generation(ga_instance, n: int):
population = ga_instance.population
population_fitnesses = ga_instance.last_generation_fitness

assert n <= len(
population
), f"asked for {n} best when population size is {len(population)}"

sorted_population = sorted(
zip(population, population_fitnesses), key=lambda x: x[1], reverse=True
)

return list(map(lambda x: x[0], sorted_population))[:n]


def on_generation(ga_instance):
game_results_raw = json.loads(recv_game_result_list())
game_results_decoded = [
Agent2ResultsOnMaps.from_json(item) for item in game_results_raw
]

for full_game_result in game_results_decoded:
info_for_tables[full_game_result.agent] = full_game_result.results

print(f"Generation = {ga_instance.generations_completed};")
epoch_subdir = create_epoch_subdir(ga_instance.generations_completed)

for weights in get_n_best_weights_in_last_generation(
ga_instance, FeatureConfig.N_BEST_SAVED_EACH_GEN
):
save_model(
GeneralConfig.MODEL_INIT(),
to=epoch_subdir / f"{sum(weights)}.pth",
weights=weights,
)

ga_pop_inner_hashes = [
tuple(individual).__hash__() for individual in ga_instance.population
]
info_for_tables_filtered = {
nnwrapper: res
for nnwrapper, res in info_for_tables.items()
if nnwrapper.weights_hash in ga_pop_inner_hashes
}

best_solution_hash = tuple(
ga_instance.best_solution(pop_fitness=ga_instance.last_generation_fitness)[0]
).__hash__()
best_solution_nnwrapper, best_solution_results = next(
filter(
lambda item: item[0].weights_hash == best_solution_hash,
info_for_tables_filtered.items(),
)
)

append_to_tables_file(
f"Generations completed: {ga_instance.generations_completed}" + "\n"
)

if best_solution_nnwrapper in leader_table.keys():
best_wrapper_copy = copy.copy(best_solution_nnwrapper)
best_wrapper_copy.weights_hash += random.randint(0, 10**3)
leader_table[best_wrapper_copy] = best_solution_results
else:
leader_table[best_solution_nnwrapper] = best_solution_results

_, stats, _ = create_pivot_table(leader_table, sort=False)
rewrite_best_tables_file(table_to_string(stats) + "\n")

pivot, stats, epoch_table = create_pivot_table(info_for_tables_filtered)
if FeatureConfig.SAVE_EPOCHS_COVERAGES.enabled:
path_to_save_to = (
FeatureConfig.SAVE_EPOCHS_COVERAGES.save_path
/ f"{ga_instance.generations_completed}.csv"
)
table_to_csv(epoch_table, path=path_to_save_to)
append_to_tables_file(table_to_string(pivot) + "\n")
append_to_tables_file(table_to_string(stats) + "\n")
mean, std = compute_statistics(load_times_array())
print(f"Gen#{ga_instance.generations_completed} inference statistics:")
print(f"{mean=}ms")
print(f"{std=}ms")
create_temp_epoch_inference_dir()


def fitness_function(ga_inst, solution, solution_idx) -> float:
model = GeneralConfig.MODEL_INIT()
model.forward(*ml.onnx.onnx_import.create_torch_dummy_input())
model_weights_dict = pygad.torchga.model_weights_as_dict(
model=model, weights_vector=solution
)

model.load_state_dict(model_weights_dict)
model.to(DEVICE)
model.eval()
nnwrapper = NNWrapper(model, weights_flat=solution)

list_of_map2result = play_game(
with_predictor=nnwrapper,
max_steps=GeneralConfig.MAX_STEPS,
maps_type=MapsType.TRAIN,
)
send_game_results(Agent2ResultsOnMaps(nnwrapper, list_of_map2result))

dump_and_reset_epoch_times(
f"{nnwrapper.name()}_epoch{ga_inst.generations_completed}_pid{getpid()}"
)
rst = map(lambda map2res: map2res.game_result, list_of_map2result)
return straight_scorer(rst)
Loading

0 comments on commit 6b6fe96

Please sign in to comment.