diff --git a/tutorials/adding_custom_model.md b/tutorials/adding_custom_model.md index dedfbf27..8a9857e1 100644 --- a/tutorials/adding_custom_model.md +++ b/tutorials/adding_custom_model.md @@ -90,7 +90,7 @@ TensorDict( ## Creating a custom model -Now we will define a custom model using the transformers library from HuggingFace. We will use the GPT-2 model as an example. +Now we will define a toy custom model to show how to integrate any model in Acegen. The model will be a `torch.nn.Module`, and will provide different outputs depending on the phase (training or inference), defined by its `train_mode` attribute. @@ -105,15 +105,23 @@ return the prediction for all the tokens. ```python import torch from torch import nn -from transformers import GPT2Config, GPT2Model -class GPT2(nn.Module): +class CustomFeatureExtractor(nn.Module): + def __init__(self, config): + super().__init__() + self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd) + def forward(self, batch, mask=None): + out = self.tok_embeddings(batch) + return out + +class AceGenCustomModel(nn.Module): def __init__(self, config=None): - super(GPT2, self).__init__() - self.feature_extractor = GPT2Model(config) if config is not None else None + super(AceGenCustomModel, self).__init__() + self.config = config + self.feature_extractor = CustomFeatureExtractor(config) self._train_mode = False - + @property def train_mode(self): return self._train_mode @@ -121,23 +129,25 @@ class GPT2(nn.Module): def set_train_mode(self, train_mode: bool = True): if train_mode is self._train_mode: return self - out = GPT2() + out = AceGenCustomModel(self.config) out.feature_extractor = self.feature_extractor out._train_mode = train_mode return out - + def forward(self, sequence, sequence_mask): - out = self.feature_extractor( - input_ids=sequence, - attention_mask=sequence_mask.long(), - ).last_hidden_state - - if self.train_mode is False: # If inference, return only last token set to True by the sequence_mask + batch = sequence, + mask = sequence_mask.long()) + + if torch.isnan(out).any(): + raise ValueError("NaN detected in model output.") + + if self.train_mode is False: obs_length = sequence_mask.sum(-1) out = out[torch.arange(len(out)), obs_length.to(torch.int64) - 1] - + return out + ``` Now, we will wrap the model in a `tensordict.nn.TensordictModule` to make it Tensordict-compatible. @@ -150,54 +160,69 @@ from tensordict.nn import TensorDictModule, TensorDictSequential from torchrl.envs import ExplorationType from torchrl.modules import ProbabilisticActor +def load_model_config(config, vocabulary_size): -def create_gpt2_actor( - vocabulary_size: int, -): - # Define transformer - config = GPT2Config() # Original GPT2 configuration, can be customized - config.vocab_size = vocabulary_size - lm = GPT2(config) + model_config = ModelConfig() + model_config.model_name = config.model + model_config.vocab_size = vocabulary_size + model_config.n_embd = config.n_embd + + return model_config, vocabulary_size - # Wrap the transformer in a TensorDictModule to make TensorDict compatible +def transform_config(func): + def wrapper(config, vocabulary_size, *args, **kwargs): + transformed_config, vocabulary_size = load_model_config(config, vocabulary_size) + return func(transformed_config, vocabulary_size, *args[1:], **kwargs) + return wrapper + +@transform_config +def create_custom_actor(config, vocabulary_size, return_log_prob=True): + """ Create a CLM Model actor for language modeling. """ + + # Define the transformer + lm = AceGenCustomModel(config) + + # Wrap it in a TensorDictModule to make it TensorDictCompatible lm_training = TensorDictModule( lm.set_train_mode(True), in_keys=["sequence", "sequence_mask"], - out_keys=["features"], + out_keys = ["features"] ) + lm_inference = TensorDictModule( lm, in_keys=["sequence", "sequence_mask"], out_keys=["features"], ) - # Define final layer and also make + # Define head layer and make it a TensorDictModule lm_head = TensorDictModule( nn.Linear(config.n_embd, vocabulary_size, bias=False), - in_keys=["features"], - out_keys=["logits"], + in_keys = ["features"], + out_keys = ["logits"] ) - + # Concatenate lm and head, similar to torch.nn.Sequential policy_training = TensorDictSequential(lm_training, lm_head) policy_inference = TensorDictSequential(lm_inference, lm_head) - - # To make the actor probabilistic, wrap the policy in a ProbabilisticActor - # This module will take care of sampling and computing log probabilities + + # To make actor probabilistic, wrap the policy in a ProbabilisticActor + # This will take care of sampling and computing log probabilities probabilistic_policy_training = ProbabilisticActor( - module=policy_training, - in_keys=["logits"], - out_keys=["action"], - distribution_class=torch.distributions.Categorical, - return_log_prob=return_log_prob, - default_interaction_type=ExplorationType.RANDOM, + module = policy_training, + in_keys = ["logits"], + out_keys = ["action"], + distribution_class = torch.distributions.Categorical, + return_log_prob = return_log_prob, + default_interaction_type = ExplorationType.RANDOM, ) + probabilistic_policy_inference = ProbabilisticActor( module=policy_inference, in_keys=["logits"], out_keys=["action"], distribution_class=torch.distributions.Categorical, - return_log_prob=True, + return_log_prob=return_log_prob, default_interaction_type=ExplorationType.RANDOM, ) return probabilistic_policy_training, probabilistic_policy_inference @@ -205,10 +230,10 @@ def create_gpt2_actor( ## How to make the custom model available in the training scripts -Available models to the training scripts are defined in `/acegen/__init__.py` as a mapping to tuples with the following format: +Available models to the training scripts are defined in `/acegen/__init__.py` as a mapping to factory classes with the following format: - model_mapping = { - "example_model": ( + def default_example_model_factory(*args, **kwargs): + return ( create_actor_method: Callable # A method to create the actor model create_critic_method: Callable # A method to create the critic model (Optional) create_actor_critic_method: Callable # A method to create the actor-critic model (Optional) @@ -216,22 +241,31 @@ Available models to the training scripts are defined in `/acegen/__init__.py` as weights_file_path: Path # The path to the weights file (Optional) tokenizer: Tokenizer # The tokenizer to use for the model (Optional) ) - } - -New models can be added by creating a new tuple and appending it to the model_mapping dictionary. Then the model can be -selected in any configuration file by setting the `model` parameter to the name of the model. In the case of our example, -adding the models would look like this: - - model_mapping = { - "gpt2": ( - create_gpt2_actor, - None, - None, - Path(__file__).resolve().parent.parent.parent / "priors" / "enamine_real_vocabulary.txt", - Path(__file__).resolve().parent.parent.parent / "priors" / "gpt2_enamine_real.ckpt", - None, + +Currently GRU, LSTM and GPT2 default models are implemented, and can be loaded setting `model: model_name` in the configuration file: +``` +# Default models +models = { + "gru": default_gru_model_factory, + "lstm": default_lstm_model_factory, + "gpt2": default_gpt2_model_factory, +} +``` +New models can be registered by creating a model factory, then in any configuration the path to the factory function can be included, +so it will be directly imported in the script. For our custom model we can create a custom model factory and add the function module path +to the config file `model_factory: acegen_custom_model.custom_model.custom_model_factory`: +``` +def custom_model_factory(cfg, *args, **kwargs): + return ( + partial(create_custom_actor, cfg), + None, + None, + resources.files("acegen.priors") / "custom_ascii.pt", + resources.files("acegen.priors") / "custom_model.ckpt", + None, ) - } + +``` Here we have assigned vocabulary and weights files from out set of priors to the model. We could, however, use others. Now, we can already use the model in the Reinvent and AHC training scripts for de novo molecule generation. @@ -242,15 +276,13 @@ ProbabilisticActor wrapper. Let's see how to define it: ```python -def create_gpt2_critic( - vocabulary_size: int, -): - """Create a GPT2 critic for language modeling.""" +@transform_config +def create_custom_critic(config, vocabulary_size, critic_value_per_action=False): + + """ Create CLM critic for language modeling. """ # Define transformer - config = GPT2Config() # Original GPT2 configuration, can be customized - config.vocab_size = vocabulary_size - lm = GPT2(config) - + lm = AceGenCustomModel(config) + # Wrap the transformer in a TensorDictModule to make TensorDict compatible lm_training = TensorDictModule( lm.set_train_mode(True), @@ -263,33 +295,32 @@ def create_gpt2_critic( out_keys=["features"], ) - # Define final layer and also make it a TensorDictModule + # Define last layer and also make it TensorDictModule lm_head = TensorDictModule( nn.Linear( config.n_embd, - 1, - bias=False, + vocabulary_size if critic_value_per_action else 1, + bias = False ), - in_keys=["features"], - out_keys=["state_value"], + in_keys = ["features"], + out_keys = ["action_value"] if critic_value_per_action else ["state_value"], ) - - # Concatenate lm and head, similar to torch.nn.Sequential + + # Concatenate lm and head, similar to troch.nn.Sequential # Critic does not need to be probabilistic, so we can return directly critic_training = TensorDictSequential(lm_training, lm_head) critic_inference = TensorDictSequential(lm_inference, lm_head) return critic_training, critic_inference ``` -and then add it to the model_mapping dictionary: +and then add it to the model factory: - model_mapping = { - "gpt2": ( - create_gpt2_actor, - create_gpt2_critic, - None, - Path(__file__).resolve().parent.parent.parent / "priors" / "enamine_real_vocabulary.txt", - Path(__file__).resolve().parent.parent.parent / "priors" / "gpt2_enamine_real.ckpt", - SMILEStokenizer2(), # Constratined generation tasks require a tokenizer +def custom_model_factory(cfg, *args, **kwargs): + return ( + partial(create_custom_actor, cfg), + partial(create_custom_critic, cfg), + None, + resources.files("acegen.priors") / "custom_ascii.pt", + resources.files("acegen.priors") / "custom_model.ckpt", + None, ) - }