Skip to content

Commit

Permalink
Save models as a state_dict, add model load function
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Jul 20, 2023
1 parent fcd3fa8 commit 43a6c91
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 14 deletions.
15 changes: 12 additions & 3 deletions VSharp.ML.AIAgent/learning/r_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
create_epoch_subdir,
rewrite_best_tables_file,
)
from ml.fileop import save_model
from ml.model_wrappers.nnwrapper import NNWrapper
from selection.classes import AgentResultsOnGameMaps, GameResult, Map2Result
from selection.scorer import straight_scorer
Expand All @@ -37,7 +38,6 @@
get_map_inference_times,
load_times_array,
)
from weights_dump.weights_dump import save_weights

TimeDuration: TypeAlias = float

Expand Down Expand Up @@ -93,7 +93,12 @@ def play_map(
FeatureConfig.DUMP_BY_TIMEOUT.enabled
and end_time - start_time > FeatureConfig.DUMP_BY_TIMEOUT.timeout_seconds
):
save_weights(with_model.weights, to=FeatureConfig.DUMP_BY_TIMEOUT.save_path)
save_model(
GeneralConfig.MODEL_INIT(),
to=FeatureConfig.DUMP_BY_TIMEOUT.save_path
/ f"{sum(with_model.weights)}.pth",
weights=with_model.weights,
)

if actual_coverage != 100 and steps_count != steps:
logging.error(
Expand Down Expand Up @@ -158,7 +163,11 @@ def on_generation(ga_instance):
for weights in get_n_best_weights_in_last_generation(
ga_instance, FeatureConfig.N_BEST_SAVED_EACH_GEN
):
save_weights(weights, to=epoch_subdir)
save_model(
GeneralConfig.MODEL_INIT(),
to=epoch_subdir / f"{sum(weights)}.pth",
weights=weights,
)

ga_pop_inner_hashes = [
tuple(individual).__hash__() for individual in ga_instance.population
Expand Down
21 changes: 21 additions & 0 deletions VSharp.ML.AIAgent/ml/fileop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pathlib import Path

import pygad
import torch

import ml


def save_model(model: torch.nn.Module, to: Path, weights=None):
if weights is None:
torch.save(model.state_dict(), to)
else:
model.forward(*ml.onnx.onnx_import.create_torch_dummy_input())
state_dict = pygad.torchga.model_weights_as_dict(model, weights)
torch.save(state_dict, to)


def load_model_from_file(model: torch.nn.Module, file: Path):
model.load_state_dict(torch.load(file))
model.eval()
return model
11 changes: 0 additions & 11 deletions VSharp.ML.AIAgent/weights_dump/weights_dump.py

This file was deleted.

0 comments on commit 43a6c91

Please sign in to comment.