Skip to content

Commit

Permalink
model registry
Browse files Browse the repository at this point in the history
  • Loading branch information
11carlesnavarro committed Apr 26, 2024
1 parent 6271348 commit af617be
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 39 deletions.
Binary file modified acegen/priors/custom_model.ckpt
Binary file not shown.
28 changes: 16 additions & 12 deletions scripts/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tqdm
import yaml

from acegen.models import adapt_state_dict, models
from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
from acegen.scoring_functions import custom_scoring_functions, Task
from acegen.vocabulary import SMILESVocabulary
Expand Down Expand Up @@ -125,17 +125,19 @@ def run_a2c(cfg, task):
)

# Get model and vocabulary checkpoints
if cfg.model in models:
(
create_actor,
create_critic,
create_shared,
voc_path,
ckpt_path,
tokenizer,
) = models[cfg.model]
if cfg.model not in models and cfg.model_factory is not None:
register_model(cfg.model, cfg.model_factory)
else:
raise ValueError(f"Unknown model type: {cfg.model}")
raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.")

(
create_actor,
create_critic,
create_shared,
voc_path,
ckpt_path,
tokenizer
) = models[cfg.model](cfg)

# Create vocabulary
####################################################################################################################
Expand All @@ -158,7 +160,9 @@ def run_a2c(cfg, task):
critic_training, critic_inference = create_critic(len(vocabulary))

# Load pretrained weights
ckpt = torch.load(ckpt_path)
ckpt_path = cfg.get("model_weights", ckpt_path)
ckpt = torch.load(ckpt_path, map_location=device)

actor_inference.load_state_dict(
adapt_state_dict(ckpt, actor_inference.state_dict())
)
Expand Down
11 changes: 7 additions & 4 deletions scripts/ahc/ahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tqdm
import yaml

from acegen.models import adapt_state_dict, models
from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
from acegen.scoring_functions import custom_scoring_functions, Task
from acegen.vocabulary import SMILESVocabulary
Expand Down Expand Up @@ -127,10 +127,12 @@ def run_ahc(cfg, task):
)

# Get model and vocabulary checkpoints
if cfg.model in models:
create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model]
if cfg.model not in models and cfg.model_factory is not None:
register_model(cfg.model, cfg.model_factory)
else:
raise ValueError(f"Unknown model type: {cfg.model}")
raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.")

create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg)

# Create vocabulary
####################################################################################################################
Expand All @@ -140,6 +142,7 @@ def run_ahc(cfg, task):
# Create models
####################################################################################################################

ckpt_path = cfg.get("model_weights", ckpt_path)
ckpt = torch.load(ckpt_path)

actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary))
Expand Down
25 changes: 14 additions & 11 deletions scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import tqdm
import yaml

from acegen.models import adapt_state_dict, models
from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
from acegen.scoring_functions import custom_scoring_functions, Task
from acegen.vocabulary import SMILESVocabulary
Expand Down Expand Up @@ -128,17 +128,20 @@ def run_ppo(cfg, task):
)

# Get model and vocabulary checkpoints
if cfg.model in models:
(
create_actor,
create_critic,
create_shared,
voc_path,
ckpt_path,
tokenizer,
) = models[cfg.model](cfg)
if cfg.model not in models and cfg.model_factory is not None:
register_model(cfg.model, cfg.model_factory)
else:
raise ValueError(f"Unknown model type: {cfg.model}")
raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.")

(
create_actor,
create_critic,
create_shared,
voc_path,
ckpt_path,
tokenizer
) = models[cfg.model](cfg)


# Create vocabulary
####################################################################################################################
Expand Down
11 changes: 7 additions & 4 deletions scripts/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import tqdm
import yaml

from acegen.models import adapt_state_dict, models
from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
from acegen.scoring_functions import custom_scoring_functions, Task
from acegen.vocabulary import SMILESVocabulary
Expand Down Expand Up @@ -127,10 +127,12 @@ def run_reinforce(cfg, task):
)

# Get model and vocabulary checkpoints
if cfg.model in models:
create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model]
if cfg.model not in models and cfg.model_factory is not None:
register_model(cfg.model, cfg.model_factory)
else:
raise ValueError(f"Unknown model type: {cfg.model}")
raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.")

create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg)

# Create vocabulary
####################################################################################################################
Expand All @@ -140,6 +142,7 @@ def run_reinforce(cfg, task):
# Create models
####################################################################################################################

ckpt_path = cfg.get("model_weights", ckpt_path)
ckpt = torch.load(ckpt_path)

actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary))
Expand Down
13 changes: 5 additions & 8 deletions scripts/reinvent/reinvent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import torch
import tqdm
import yaml
from importlib import import_module

from acegen.models import adapt_state_dict, models, register_model
from acegen.rl_env import generate_complete_smiles, SMILESEnv
Expand Down Expand Up @@ -128,13 +127,12 @@ def run_reinvent(cfg, task):
)

# Get model and vocabulary checkpoints
if cfg.model in models:
create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg)
if cfg.model not in models and cfg.model_factory is not None:
register_model(cfg.model, cfg.model_factory)
else:
m, f = cfg.model_factory.rsplit('.', 1)
model_factory = getattr(import_module(m), f)
register_model(cfg.model, model_factory)
create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg)
raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.")

create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg)

# Create vocabulary
####################################################################################################################
Expand All @@ -144,7 +142,6 @@ def run_reinvent(cfg, task):
# Create models
####################################################################################################################

#ckpt = torch.load(ckpt_path)
ckpt_path = cfg.get("model_weights", ckpt_path)
ckpt = torch.load(ckpt_path)

Expand Down

0 comments on commit af617be

Please sign in to comment.