Skip to content

Commit

Permalink
add optuna
Browse files Browse the repository at this point in the history
  • Loading branch information
Anya497 committed Dec 5, 2023
1 parent 408a938 commit cc53d57
Showing 1 changed file with 46 additions and 57 deletions.
103 changes: 46 additions & 57 deletions VSharp.ML.AIAgent/run_common_model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from ml.common_model.utils import csv2best_models, get_model
from ml.common_model.wrapper import BestModelsWrapper, CommonModelWrapper
from ml.models.TAGSageSimple.model_modified import StateModelEncoderLastLayer
import optuna
from functools import partial
import joblib

LOG_PATH = Path("./ml_app.log")
TABLES_PATH = Path("./ml_tables.log")
Expand Down Expand Up @@ -82,19 +85,45 @@ class TrainConfig:
lr: float
epochs: int
batch_size: int


def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDataset):
optimizer: torch.optim.Optimizer
loss: any
random_seed: int


def train(trial: optuna.trial.Trial, dataset: FullDataset):
config = TrainConfig(
lr=trial.suggest_float("lr", 1e-7, 1e-3),
batch_size=trial.suggest_int("batch_size", 32, 1024),
epochs=10,
optimizer=trial.suggest_categorical("optimizer", [torch.optim.Adam]),
loss=trial.suggest_categorical("loss", [nn.KLDivLoss]),
random_seed=937,
)
# for name, param in model.named_parameters():
# if "lin_last" not in name:
# param.requires_grad = False

path_to_weights = os.path.join(
PRETRAINED_MODEL_PATH,
"TAGSageSimple",
"32ch",
"20e",
"GNN_state_pred_het_dict",
)
model = get_model(
Path(path_to_weights),
StateModelEncoderLastLayer(hidden_channels=32, out_channels=8),
random_seed=config.random_seed,
)

model.to(GeneralConfig.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr)
criterion = nn.KLDivLoss()
optimizer = config.optimizer(model.parameters(), lr=config.lr)
criterion = config.loss()

timestamp = datetime.now().timestamp()
run_name = f"{datetime.fromtimestamp(timestamp)}_{train_config.batch_size}_Adam_{train_config.lr}_KLDL"
run_name = (
f"{datetime.fromtimestamp(timestamp)}_{config.batch_size}_Adam_{config.lr}_KLDL"
)

print(run_name)
path_to_saved_models = os.path.join(COMMON_MODELS_PATH, run_name)
Expand All @@ -118,11 +147,9 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
# p = Pool(GeneralConfig.SERVER_COUNT)

all_average_results = []
for epoch in range(train_config.epochs):
for epoch in range(config.epochs):
data_list = dataset.get_plain_data()
data_loader = DataLoader(
data_list, batch_size=train_config.batch_size, shuffle=True
)
data_loader = DataLoader(data_list, batch_size=config.batch_size, shuffle=True)
print("DataLoader size", len(data_loader))

model.train()
Expand Down Expand Up @@ -178,8 +205,9 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
list(map(lambda x: x.game_result.actual_coverage_percent, all_results))
)
all_average_results.append(average_result)
all_results = sorted(all_results, key=lambda x: x.map.MapName)
table, _, _ = create_pivot_table({cmwrapper: all_results})
table, _, _ = create_pivot_table(
{cmwrapper: all_results.sort(key=lambda x: x.map.MapName)}
)
table = table_to_string(table)
append_to_file(
TABLES_PATH,
Expand All @@ -193,7 +221,7 @@ def train(train_config: TrainConfig, model: torch.nn.Module, dataset: FullDatase
del data_loader
# p.close()

return all_average_results
return max(all_average_results)


def get_dataset(
Expand Down Expand Up @@ -222,57 +250,18 @@ def get_dataset(

def main():
print(GeneralConfig.DEVICE)
path_to_weights = os.path.join(
PRETRAINED_MODEL_PATH,
"TAGSageSimple",
"32ch",
"20e",
"GNN_state_pred_het_dict",
)
model_initializer = lambda: StateModelEncoderLastLayer(
hidden_channels=32, out_channels=8
)

best_result = {"average_coverage": 0, "config": dict(), "epoch": 0}
generate_dataset = False
dataset = get_dataset(generate_dataset, ref_model_init=model_initializer)

while True:
config = TrainConfig(
lr=random.choice([10 ** (-i) for i in range(3, 8)]),
batch_size=random.choice([2**i for i in range(5, 10)]),
epochs=20,
)
print("Current hyperparameters")
data_frame = pd.DataFrame(
data=[asdict(config).values()],
columns=asdict(config).keys(),
index=["value"],
)
print(data_frame)

model = get_model(
Path(path_to_weights),
model_initializer,
random_seed=937,
)

results = train(train_config=config, model=model, dataset=dataset)
max_value = max(results)
max_ind = results.index(max_value)
if best_result["average_coverage"] < max_value:
best_result["average_coverage"] = max_value
best_result["config"] = asdict(config)
best_result["epoch"] = max_ind + 1
print(
f"The best result for now:\nAverage coverage: {best_result['average_coverage']}"
)
data_frame = pd.DataFrame(
data=[best_result["config"].values()],
columns=best_result["config"].keys(),
index=["value"],
)
print(data_frame)
sampler = optuna.samplers.TPESampler(n_startup_trials=10)
study = optuna.create_study(sampler=sampler, direction="maximize")
objective = partial(train, dataset=dataset)
study.optimize(objective, n_trials=100)
joblib.dump(study, f"{datetime.fromtimestamp(datetime.now().timestamp())}.pkl")


if __name__ == "__main__":
Expand Down

0 comments on commit cc53d57

Please sign in to comment.