Skip to content

Commit

Permalink
Save all NNs now, fix onnx code
Browse files Browse the repository at this point in the history
  • Loading branch information
emnigma committed Aug 1, 2023
1 parent 1344a44 commit eb0ff5e
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 18 deletions.
1 change: 0 additions & 1 deletion VSharp.ML.AIAgent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class FeatureConfig:
VERBOSE_TABLES = True
SHOW_SUCCESSORS = True
NAME_LEN = 7
N_BEST_SAVED_EACH_GEN = 2
DISABLE_MESSAGE_CHECKS = True
DUMP_BY_TIMEOUT = DumpByTimeoutFeature(
enabled=True, timeout_sec=600, save_path=Path("./report/timeouted_agents/")
Expand Down
4 changes: 1 addition & 3 deletions VSharp.ML.AIAgent/learning/genetic_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def on_generation(ga_instance):
print(f"Generation = {ga_instance.generations_completed};")
epoch_subdir = create_epoch_subdir(ga_instance.generations_completed)

for weights in get_n_best_weights_in_last_generation(
ga_instance, FeatureConfig.N_BEST_SAVED_EACH_GEN
):
for weights in ga_instance.population:
save_model(
GeneralConfig.MODEL_INIT(),
to=epoch_subdir / f"{sum(weights)}.pth",
Expand Down
4 changes: 2 additions & 2 deletions VSharp.ML.AIAgent/ml/fileop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
import pygad
import torch

import ml
from ml.onnx.onnx_import import create_torch_dummy_input
from common.constants import DEVICE


def save_model(model: torch.nn.Module, to: Path, weights=None):
if weights is None:
torch.save(model.state_dict(), to)
else:
model.forward(*ml.onnx.onnx_import.create_torch_dummy_input())
model.forward(*create_torch_dummy_input())
state_dict = pygad.torchga.model_weights_as_dict(model, weights)
torch.save(state_dict, to)

Expand Down
15 changes: 3 additions & 12 deletions VSharp.ML.AIAgent/ml/onnx/onnx_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,12 @@ def create_torch_dummy_input():
def export_onnx_model(model: torch.nn.Module, save_path: str):
torch.onnx.export(
model=model,
args=create_onnx_dummy_input(),
args=(*create_torch_dummy_input(), {}),
f=save_path,
verbose=False,
export_params=True,
input_names=["x_dict", "edge_index_dict"],
# input_names=[
# "game_vertex",
# "state_vertex",
# "gv2gv",
# "sv_in_gv",
# "gv_in_sv",
# "sv_his_gv",
# "gv_his_sv",
# "sv_parentof_sv",
# ],
input_names=["x_dict", "edge_index_dict", "edge_attr_dict"],
opset_version=16,
)

torch_model_out = model(*create_torch_dummy_input())
Expand Down

0 comments on commit eb0ff5e

Please sign in to comment.