Skip to content

Commit

Permalink
Added model loading in train_cli.py
Browse files Browse the repository at this point in the history
Fixed model export
Set kernels for alpha-vile-large
  • Loading branch information
QueensGambit committed Aug 5, 2024
1 parent 7cbd4f8 commit aafca5c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,29 @@ def get_alpha_vile_model(args, model_size='normal'):

kernels = [3] * depth
end_idx = int(len(kernels) * kernel_5_ratio + 0.5)
for idx in range(end_idx):
kernels[idx] = 5
random.shuffle(kernels)

if model_size == 'large':
kernels[1] = 5
kernels[6] = 5
kernels[7] = 5
kernels[9] = 5
kernels[10] = 5
kernels[14] = 5
kernels[18] = 5
kernels[19] = 5
kernels[23] = 5
kernels[25] = 5
kernels[26] = 5
kernels[27] = 5
kernels[28] = 5
kernels[29] = 5
kernels[33] = 5
kernels[34] = 5
kernels[35] = 5
else:
for idx in range(end_idx):
kernels[idx] = 5
random.shuffle(kernels)

use_transformers = [False] * len(kernels)
if nb_transformers > 0:
Expand Down
7 changes: 6 additions & 1 deletion DeepCrazyhouse/src/training/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import sys
import torch
import logging
from pathlib import Path

sys.path.insert(0, '../../../')

from DeepCrazyhouse.src.runtime.color_logger import enable_color_logging
from DeepCrazyhouse.configs.train_config import TrainConfig, TrainObjects
from DeepCrazyhouse.src.training.train_cli_util import create_pytorch_model, get_validation_data, fill_train_objects,\
print_model_summary, export_best_model_state, fill_train_config, export_configs, create_export_dirs, export_cmd_args
from DeepCrazyhouse.src.training.trainer_agent_pytorch import TrainerAgentPytorch
from DeepCrazyhouse.src.training.trainer_agent_pytorch import TrainerAgentPytorch, load_torch_state


def parse_args(train_config: TrainConfig):
Expand Down Expand Up @@ -91,6 +92,10 @@ def main():

train_objects = TrainObjects()
fill_train_objects(train_config, train_objects)
if train_config.tar_file != "":
print("load model weights")
load_torch_state(model, torch.optim.SGD(model.parameters(), lr=train_config.max_lr), Path(train_config.tar_file),
train_config.device_id)

create_export_dirs(train_config)
export_configs(args, train_config)
Expand Down
8 changes: 7 additions & 1 deletion DeepCrazyhouse/src/training/train_cli_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
LinearWarmUp, MomentumSchedule
from DeepCrazyhouse.src.training.train_util import get_metrics
from DeepCrazyhouse.src.training.trainer_agent_pytorch import save_torch_state, export_to_onnx, get_context,\
get_data_loader
get_data_loader, load_torch_state


class Args:
Expand Down Expand Up @@ -232,6 +232,12 @@ def export_best_model_state(k_steps_best: int, k_steps_final: int, model, policy
shutil.copy(model_tar_path, best_model_tar_path)

# ## Convert to onnx
print("load current best model")
load_torch_state(model, torch.optim.SGD(model.parameters(), lr=train_config.max_lr), Path(model_tar_path),
train_config.device_id)

if hasattr(model, "merge_bn"):
model.merge_bn()
convert_model_to_onnx(input_shape, k_steps_best, model, model_name, train_config)

print("Saved weight & onnx files of the best model to %s" % (train_config.export_dir + "best-model"))
Expand Down

0 comments on commit aafca5c

Please sign in to comment.