From 61f5817f497c24746b685d20f822f76de34ea561 Mon Sep 17 00:00:00 2001 From: Anya497 Date: Tue, 5 Dec 2023 14:33:40 +0300 Subject: [PATCH] change model in best_models dict creator, add files with last layers to models directories, move random seed setting to train function, replace update condition with '<=' --- VSharp.ML.AIAgent/ml/common_model/dataset.py | 2 +- VSharp.ML.AIAgent/ml/common_model/utils.py | 5 +-- .../model_modified.py | 36 +++++++++++++++++++ .../model_modified.py | 36 +++++++++++++++++++ .../run_common_model_training.py | 23 +++++++----- 5 files changed, 88 insertions(+), 14 deletions(-) create mode 100644 VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG2VerticesDouble/model_modified.py create mode 100644 VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG3SageDouble/model_modified.py diff --git a/VSharp.ML.AIAgent/ml/common_model/dataset.py b/VSharp.ML.AIAgent/ml/common_model/dataset.py index 51b57fd0d..6ef9e736c 100644 --- a/VSharp.ML.AIAgent/ml/common_model/dataset.py +++ b/VSharp.ML.AIAgent/ml/common_model/dataset.py @@ -124,7 +124,7 @@ def update( x.to("cpu") filtered_map_steps = self.filter_map_steps(map_steps) if map_name in self.maps_data.keys(): - if self.maps_data[map_name][0] < map_result: + if self.maps_data[map_name][0] <= map_result: logging.info( f"The model with result = {self.maps_data[map_name][0]} was replaced with the model with " f"result = {map_result} on the map {map_name}" diff --git a/VSharp.ML.AIAgent/ml/common_model/utils.py b/VSharp.ML.AIAgent/ml/common_model/utils.py index 3fe27a59d..4afe5f281 100644 --- a/VSharp.ML.AIAgent/ml/common_model/utils.py +++ b/VSharp.ML.AIAgent/ml/common_model/utils.py @@ -130,10 +130,7 @@ def load_dataset_state_dict(path): return dataset_state_dict -def get_model( - path_to_weights: Path, model_init: t.Callable[[], torch.nn.Module], random_seed: int -): - np.random.seed(random_seed) +def get_model(path_to_weights: Path, model_init: t.Callable[[], torch.nn.Module]): model = model_init() weights = torch.load(path_to_weights) weights["lin_last.weight"] = torch.tensor(np.random.random([1, 8])) diff --git a/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG2VerticesDouble/model_modified.py b/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG2VerticesDouble/model_modified.py new file mode 100644 index 000000000..92f7aea95 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG2VerticesDouble/model_modified.py @@ -0,0 +1,36 @@ +from torch_geometric.nn import Linear +from torch.nn.functional import softmax +from .model import StateModelEncoder + + +class StateModelEncoderLastLayer(StateModelEncoder): + def __init__(self, hidden_channels, out_channels): + super().__init__(hidden_channels, out_channels) + self.lin_last = Linear(out_channels, 1) + + def forward( + self, + game_x, + state_x, + edge_index_v_v, + edge_type_v_v, + edge_index_history_v_s, + edge_attr_history_v_s, + edge_index_in_v_s, + edge_index_s_s, + ): + return softmax( + self.lin_last( + super().forward( + game_x=game_x, + state_x=state_x, + edge_index_v_v=edge_index_v_v, + edge_type_v_v=edge_type_v_v, + edge_index_history_v_s=edge_index_history_v_s, + edge_attr_history_v_s=edge_attr_history_v_s, + edge_index_in_v_s=edge_index_in_v_s, + edge_index_s_s=edge_index_s_s, + ) + ), + dim=0, + ) diff --git a/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG3SageDouble/model_modified.py b/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG3SageDouble/model_modified.py new file mode 100644 index 000000000..92f7aea95 --- /dev/null +++ b/VSharp.ML.AIAgent/ml/models/RGCNEdgeTypeTAG3SageDouble/model_modified.py @@ -0,0 +1,36 @@ +from torch_geometric.nn import Linear +from torch.nn.functional import softmax +from .model import StateModelEncoder + + +class StateModelEncoderLastLayer(StateModelEncoder): + def __init__(self, hidden_channels, out_channels): + super().__init__(hidden_channels, out_channels) + self.lin_last = Linear(out_channels, 1) + + def forward( + self, + game_x, + state_x, + edge_index_v_v, + edge_type_v_v, + edge_index_history_v_s, + edge_attr_history_v_s, + edge_index_in_v_s, + edge_index_s_s, + ): + return softmax( + self.lin_last( + super().forward( + game_x=game_x, + state_x=state_x, + edge_index_v_v=edge_index_v_v, + edge_type_v_v=edge_type_v_v, + edge_index_history_v_s=edge_index_history_v_s, + edge_attr_history_v_s=edge_attr_history_v_s, + edge_index_in_v_s=edge_index_in_v_s, + edge_index_s_s=edge_index_s_s, + ) + ), + dim=0, + ) diff --git a/VSharp.ML.AIAgent/run_common_model_training.py b/VSharp.ML.AIAgent/run_common_model_training.py index d493992b2..8545807b7 100644 --- a/VSharp.ML.AIAgent/run_common_model_training.py +++ b/VSharp.ML.AIAgent/run_common_model_training.py @@ -30,7 +30,12 @@ ) from ml.common_model.utils import csv2best_models, get_model from ml.common_model.wrapper import BestModelsWrapper, CommonModelWrapper -from ml.models.TAGSageSimple.model_modified import StateModelEncoderLastLayer +from ml.models.RGCNEdgeTypeTAG2VerticesDouble.model_modified import ( + StateModelEncoderLastLayer, +) +from ml.models.StateGNNEncoderConvEdgeAttr.model_modified import ( + StateModelEncoderLastLayer as RefStateModelEncoderLastLayer, +) import optuna from functools import partial import joblib @@ -99,21 +104,21 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset): loss=trial.suggest_categorical("loss", [nn.KLDivLoss]), random_seed=937, ) + np.random.seed(config.random_seed) # for name, param in model.named_parameters(): # if "lin_last" not in name: # param.requires_grad = False path_to_weights = os.path.join( PRETRAINED_MODEL_PATH, - "TAGSageSimple", - "32ch", - "20e", + "RGCNEdgeTypeTAG2VerticesDouble", + "64ch", + "100e", "GNN_state_pred_het_dict", ) model = get_model( Path(path_to_weights), - StateModelEncoderLastLayer(hidden_channels=32, out_channels=8), - random_seed=config.random_seed, + lambda: StateModelEncoderLastLayer(hidden_channels=64, out_channels=8), ) model.to(GeneralConfig.DEVICE) @@ -206,7 +211,7 @@ def train(trial: optuna.trial.Trial, dataset: FullDataset): ) all_average_results.append(average_result) table, _, _ = create_pivot_table( - {cmwrapper: all_results.sort(key=lambda x: x.map.MapName)} + {cmwrapper: sorted(all_results, key=lambda x: x.map.MapName)} ) table = table_to_string(table) append_to_file( @@ -250,12 +255,12 @@ def get_dataset( def main(): print(GeneralConfig.DEVICE) - model_initializer = lambda: StateModelEncoderLastLayer( + ref_model_initializer = lambda: RefStateModelEncoderLastLayer( hidden_channels=32, out_channels=8 ) generate_dataset = False - dataset = get_dataset(generate_dataset, ref_model_init=model_initializer) + dataset = get_dataset(generate_dataset, ref_model_init=ref_model_initializer) sampler = optuna.samplers.TPESampler(n_startup_trials=10) study = optuna.create_study(sampler=sampler, direction="maximize")