diff --git a/acegen/priors/custom_model.ckpt b/acegen/priors/custom_model.ckpt index 5d60c9fd..35a00b00 100644 Binary files a/acegen/priors/custom_model.ckpt and b/acegen/priors/custom_model.ckpt differ diff --git a/scripts/a2c/a2c.py b/scripts/a2c/a2c.py index 1223d465..024186d4 100644 --- a/scripts/a2c/a2c.py +++ b/scripts/a2c/a2c.py @@ -13,7 +13,7 @@ import tqdm import yaml -from acegen.models import adapt_state_dict, models +from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, SMILESEnv from acegen.scoring_functions import custom_scoring_functions, Task from acegen.vocabulary import SMILESVocabulary @@ -125,17 +125,19 @@ def run_a2c(cfg, task): ) # Get model and vocabulary checkpoints - if cfg.model in models: - ( - create_actor, - create_critic, - create_shared, - voc_path, - ckpt_path, - tokenizer, - ) = models[cfg.model] + if cfg.model not in models and cfg.model_factory is not None: + register_model(cfg.model, cfg.model_factory) else: - raise ValueError(f"Unknown model type: {cfg.model}") + raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.") + + ( + create_actor, + create_critic, + create_shared, + voc_path, + ckpt_path, + tokenizer + ) = models[cfg.model](cfg) # Create vocabulary #################################################################################################################### @@ -158,7 +160,9 @@ def run_a2c(cfg, task): critic_training, critic_inference = create_critic(len(vocabulary)) # Load pretrained weights - ckpt = torch.load(ckpt_path) + ckpt_path = cfg.get("model_weights", ckpt_path) + ckpt = torch.load(ckpt_path, map_location=device) + actor_inference.load_state_dict( adapt_state_dict(ckpt, actor_inference.state_dict()) ) diff --git a/scripts/ahc/ahc.py b/scripts/ahc/ahc.py index d2b3b0b3..ee15f378 100644 --- a/scripts/ahc/ahc.py +++ b/scripts/ahc/ahc.py @@ -14,7 +14,7 @@ import tqdm import yaml -from acegen.models import adapt_state_dict, models +from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, SMILESEnv from acegen.scoring_functions import custom_scoring_functions, Task from acegen.vocabulary import SMILESVocabulary @@ -127,10 +127,12 @@ def run_ahc(cfg, task): ) # Get model and vocabulary checkpoints - if cfg.model in models: - create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model] + if cfg.model not in models and cfg.model_factory is not None: + register_model(cfg.model, cfg.model_factory) else: - raise ValueError(f"Unknown model type: {cfg.model}") + raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.") + + create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg) # Create vocabulary #################################################################################################################### @@ -140,6 +142,7 @@ def run_ahc(cfg, task): # Create models #################################################################################################################### + ckpt_path = cfg.get("model_weights", ckpt_path) ckpt = torch.load(ckpt_path) actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary)) diff --git a/scripts/ppo/ppo.py b/scripts/ppo/ppo.py index 18c67ec8..9a96dd7c 100644 --- a/scripts/ppo/ppo.py +++ b/scripts/ppo/ppo.py @@ -13,7 +13,7 @@ import tqdm import yaml -from acegen.models import adapt_state_dict, models +from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, SMILESEnv from acegen.scoring_functions import custom_scoring_functions, Task from acegen.vocabulary import SMILESVocabulary @@ -128,17 +128,20 @@ def run_ppo(cfg, task): ) # Get model and vocabulary checkpoints - if cfg.model in models: - ( - create_actor, - create_critic, - create_shared, - voc_path, - ckpt_path, - tokenizer, - ) = models[cfg.model](cfg) + if cfg.model not in models and cfg.model_factory is not None: + register_model(cfg.model, cfg.model_factory) else: - raise ValueError(f"Unknown model type: {cfg.model}") + raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.") + + ( + create_actor, + create_critic, + create_shared, + voc_path, + ckpt_path, + tokenizer + ) = models[cfg.model](cfg) + # Create vocabulary #################################################################################################################### diff --git a/scripts/reinforce/reinforce.py b/scripts/reinforce/reinforce.py index 8ad5b8ee..cee13f26 100644 --- a/scripts/reinforce/reinforce.py +++ b/scripts/reinforce/reinforce.py @@ -14,7 +14,7 @@ import tqdm import yaml -from acegen.models import adapt_state_dict, models +from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, SMILESEnv from acegen.scoring_functions import custom_scoring_functions, Task from acegen.vocabulary import SMILESVocabulary @@ -127,10 +127,12 @@ def run_reinforce(cfg, task): ) # Get model and vocabulary checkpoints - if cfg.model in models: - create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model] + if cfg.model not in models and cfg.model_factory is not None: + register_model(cfg.model, cfg.model_factory) else: - raise ValueError(f"Unknown model type: {cfg.model}") + raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.") + + create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg) # Create vocabulary #################################################################################################################### @@ -140,6 +142,7 @@ def run_reinforce(cfg, task): # Create models #################################################################################################################### + ckpt_path = cfg.get("model_weights", ckpt_path) ckpt = torch.load(ckpt_path) actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary)) diff --git a/scripts/reinvent/reinvent.py b/scripts/reinvent/reinvent.py index f477a7d6..f189001d 100644 --- a/scripts/reinvent/reinvent.py +++ b/scripts/reinvent/reinvent.py @@ -13,7 +13,6 @@ import torch import tqdm import yaml -from importlib import import_module from acegen.models import adapt_state_dict, models, register_model from acegen.rl_env import generate_complete_smiles, SMILESEnv @@ -128,13 +127,12 @@ def run_reinvent(cfg, task): ) # Get model and vocabulary checkpoints - if cfg.model in models: - create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg) + if cfg.model not in models and cfg.model_factory is not None: + register_model(cfg.model, cfg.model_factory) else: - m, f = cfg.model_factory.rsplit('.', 1) - model_factory = getattr(import_module(m), f) - register_model(cfg.model, model_factory) - create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg) + raise ValueError(f"Model {cfg.model} not found. For custom models, create and register a model factory.") + + create_actor, _, _, voc_path, ckpt_path, tokenizer = models[cfg.model](cfg) # Create vocabulary #################################################################################################################### @@ -144,7 +142,6 @@ def run_reinvent(cfg, task): # Create models #################################################################################################################### - #ckpt = torch.load(ckpt_path) ckpt_path = cfg.get("model_weights", ckpt_path) ckpt = torch.load(ckpt_path)