Skip to content

Commit

Permalink
Update adding_custom_model.md
Browse files Browse the repository at this point in the history
  • Loading branch information
11carlesnavarro authored Apr 29, 2024
1 parent e20a7e2 commit a811794
Showing 1 changed file with 112 additions and 81 deletions.
193 changes: 112 additions & 81 deletions tutorials/adding_custom_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ TensorDict(

## Creating a custom model

Now we will define a custom model using the transformers library from HuggingFace. We will use the GPT-2 model as an example.
Now we will define a toy custom model to show how to integrate any model in Acegen.
The model will be a `torch.nn.Module`, and will provide different outputs depending on the phase (training or inference),
defined by its `train_mode` attribute.

Expand All @@ -105,39 +105,49 @@ return the prediction for all the tokens.
```python
import torch
from torch import nn
from transformers import GPT2Config, GPT2Model

class GPT2(nn.Module):
class CustomFeatureExtractor(nn.Module):
def __init__(self, config):
super().__init__()
self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)

def forward(self, batch, mask=None):
out = self.tok_embeddings(batch)
return out

class AceGenCustomModel(nn.Module):
def __init__(self, config=None):
super(GPT2, self).__init__()
self.feature_extractor = GPT2Model(config) if config is not None else None
super(AceGenCustomModel, self).__init__()
self.config = config
self.feature_extractor = CustomFeatureExtractor(config)
self._train_mode = False

@property
def train_mode(self):
return self._train_mode

def set_train_mode(self, train_mode: bool = True):
if train_mode is self._train_mode:
return self
out = GPT2()
out = AceGenCustomModel(self.config)
out.feature_extractor = self.feature_extractor
out._train_mode = train_mode
return out

def forward(self, sequence, sequence_mask):

out = self.feature_extractor(
input_ids=sequence,
attention_mask=sequence_mask.long(),
).last_hidden_state

if self.train_mode is False: # If inference, return only last token set to True by the sequence_mask
batch = sequence,
mask = sequence_mask.long())

if torch.isnan(out).any():
raise ValueError("NaN detected in model output.")

if self.train_mode is False:
obs_length = sequence_mask.sum(-1)
out = out[torch.arange(len(out)), obs_length.to(torch.int64) - 1]

return out

```

Now, we will wrap the model in a `tensordict.nn.TensordictModule` to make it Tensordict-compatible.
Expand All @@ -150,88 +160,112 @@ from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs import ExplorationType
from torchrl.modules import ProbabilisticActor

def load_model_config(config, vocabulary_size):

def create_gpt2_actor(
vocabulary_size: int,
):
# Define transformer
config = GPT2Config() # Original GPT2 configuration, can be customized
config.vocab_size = vocabulary_size
lm = GPT2(config)
model_config = ModelConfig()
model_config.model_name = config.model
model_config.vocab_size = vocabulary_size
model_config.n_embd = config.n_embd

return model_config, vocabulary_size

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
def transform_config(func):
def wrapper(config, vocabulary_size, *args, **kwargs):
transformed_config, vocabulary_size = load_model_config(config, vocabulary_size)
return func(transformed_config, vocabulary_size, *args[1:], **kwargs)
return wrapper

@transform_config
def create_custom_actor(config, vocabulary_size, return_log_prob=True):
""" Create a CLM Model actor for language modeling. """

# Define the transformer
lm = AceGenCustomModel(config)

# Wrap it in a TensorDictModule to make it TensorDictCompatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
out_keys = ["features"]
)

lm_inference = TensorDictModule(
lm,
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)

# Define final layer and also make
# Define head layer and make it a TensorDictModule
lm_head = TensorDictModule(
nn.Linear(config.n_embd, vocabulary_size, bias=False),
in_keys=["features"],
out_keys=["logits"],
in_keys = ["features"],
out_keys = ["logits"]
)

# Concatenate lm and head, similar to torch.nn.Sequential
policy_training = TensorDictSequential(lm_training, lm_head)
policy_inference = TensorDictSequential(lm_inference, lm_head)

# To make the actor probabilistic, wrap the policy in a ProbabilisticActor
# This module will take care of sampling and computing log probabilities
# To make actor probabilistic, wrap the policy in a ProbabilisticActor
# This will take care of sampling and computing log probabilities
probabilistic_policy_training = ProbabilisticActor(
module=policy_training,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
module = policy_training,
in_keys = ["logits"],
out_keys = ["action"],
distribution_class = torch.distributions.Categorical,
return_log_prob = return_log_prob,
default_interaction_type = ExplorationType.RANDOM,
)

probabilistic_policy_inference = ProbabilisticActor(
module=policy_inference,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=True,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)
return probabilistic_policy_training, probabilistic_policy_inference
```

## How to make the custom model available in the training scripts

Available models to the training scripts are defined in `/acegen/__init__.py` as a mapping to tuples with the following format:
Available models to the training scripts are defined in `/acegen/__init__.py` as a mapping to factory classes with the following format:

model_mapping = {
"example_model": (
def default_example_model_factory(*args, **kwargs):
return (
create_actor_method: Callable # A method to create the actor model
create_critic_method: Callable # A method to create the critic model (Optional)
create_actor_critic_method: Callable # A method to create the actor-critic model (Optional)
vocabulary_file_path: Path # The path to the vocabulary file
weights_file_path: Path # The path to the weights file (Optional)
tokenizer: Tokenizer # The tokenizer to use for the model (Optional)
)
}

New models can be added by creating a new tuple and appending it to the model_mapping dictionary. Then the model can be
selected in any configuration file by setting the `model` parameter to the name of the model. In the case of our example,
adding the models would look like this:

model_mapping = {
"gpt2": (
create_gpt2_actor,
None,
None,
Path(__file__).resolve().parent.parent.parent / "priors" / "enamine_real_vocabulary.txt",
Path(__file__).resolve().parent.parent.parent / "priors" / "gpt2_enamine_real.ckpt",
None,

Currently GRU, LSTM and GPT2 default models are implemented, and can be loaded setting `model: model_name` in the configuration file:
```
# Default models
models = {
"gru": default_gru_model_factory,
"lstm": default_lstm_model_factory,
"gpt2": default_gpt2_model_factory,
}
```
New models can be registered by creating a model factory, then in any configuration the path to the factory function can be included,
so it will be directly imported in the script. For our custom model we can create a custom model factory and add the function module path
to the config file `model_factory: acegen_custom_model.custom_model.custom_model_factory`:
```
def custom_model_factory(cfg, *args, **kwargs):
return (
partial(create_custom_actor, cfg),
None,
None,
resources.files("acegen.priors") / "custom_ascii.pt",
resources.files("acegen.priors") / "custom_model.ckpt",
None,
)
}
```

Here we have assigned vocabulary and weights files from out set of priors to the model. We could, however, use others.
Now, we can already use the model in the Reinvent and AHC training scripts for de novo molecule generation.
Expand All @@ -242,15 +276,13 @@ ProbabilisticActor wrapper. Let's see how to define it:

```python

def create_gpt2_critic(
vocabulary_size: int,
):
"""Create a GPT2 critic for language modeling."""
@transform_config
def create_custom_critic(config, vocabulary_size, critic_value_per_action=False):
""" Create CLM critic for language modeling. """
# Define transformer
config = GPT2Config() # Original GPT2 configuration, can be customized
config.vocab_size = vocabulary_size
lm = GPT2(config)

lm = AceGenCustomModel(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
Expand All @@ -263,33 +295,32 @@ def create_gpt2_critic(
out_keys=["features"],
)

# Define final layer and also make it a TensorDictModule
# Define last layer and also make it TensorDictModule
lm_head = TensorDictModule(
nn.Linear(
config.n_embd,
1,
bias=False,
vocabulary_size if critic_value_per_action else 1,
bias = False
),
in_keys=["features"],
out_keys=["state_value"],
in_keys = ["features"],
out_keys = ["action_value"] if critic_value_per_action else ["state_value"],
)

# Concatenate lm and head, similar to torch.nn.Sequential
# Concatenate lm and head, similar to troch.nn.Sequential
# Critic does not need to be probabilistic, so we can return directly
critic_training = TensorDictSequential(lm_training, lm_head)
critic_inference = TensorDictSequential(lm_inference, lm_head)
return critic_training, critic_inference
```

and then add it to the model_mapping dictionary:
and then add it to the model factory:

model_mapping = {
"gpt2": (
create_gpt2_actor,
create_gpt2_critic,
None,
Path(__file__).resolve().parent.parent.parent / "priors" / "enamine_real_vocabulary.txt",
Path(__file__).resolve().parent.parent.parent / "priors" / "gpt2_enamine_real.ckpt",
SMILEStokenizer2(), # Constratined generation tasks require a tokenizer
def custom_model_factory(cfg, *args, **kwargs):
return (
partial(create_custom_actor, cfg),
partial(create_custom_critic, cfg),
None,
resources.files("acegen.priors") / "custom_ascii.pt",
resources.files("acegen.priors") / "custom_model.ckpt",
None,
)
}

0 comments on commit a811794

Please sign in to comment.