Skip to content

Commit

Permalink
Merge pull request #36 from Acellera/mamba_latest
Browse files Browse the repository at this point in the history
Mamba latest
  • Loading branch information
albertbou92 authored Jun 5, 2024
2 parents a49efbb + eca156f commit 5e2de20
Show file tree
Hide file tree
Showing 15 changed files with 478 additions and 25 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,19 @@ We provide a variety of default priors that can be selected in the configuration
- A Long Short-Term Memory (LSTM) model
- pre-training dataset: [ChEMBL](https://www.ebi.ac.uk/chembl/)
- number of parameters: 5,807,909
- to select set the field `model` to `lstm` in any configuration file
- to select set the field `model` to `lstm` in any configuration file


- A GPT-2 model (requires installation of HuggingFace's `transformers` library)
- pre-training dataset: [REAL 350/3 lead-like, 613.86M cpds, CXSMILES](https://enamine.net/compound-collections/real-compounds/real-database-subsets)
- number of parameters: 5,030,400
- to select set the field `model` to `gpt2` in any configuration file
- to select set the field `model` to `gpt2` in any configuration file


- A Mamba model (requires installation of mamba-ssm library)
- pre-training dataset: [ChEMBL](https://www.ebi.ac.uk/chembl/)
- number of parameters: 2,809,216
- to select set the field `model` to `mamba` in any configuration file

### Integration of custom models

Expand Down
13 changes: 13 additions & 0 deletions acegen/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
create_lstm_actor_critic,
create_lstm_critic,
)
from acegen.models.mamba import (
create_mamba_actor,
create_mamba_actor_critic,
create_mamba_critic,
)
from acegen.models.utils import adapt_state_dict
from acegen.vocabulary.tokenizers import (
SMILESTokenizerChEMBL,
Expand Down Expand Up @@ -71,6 +76,14 @@ def extract(path):
resources.files("acegen.priors") / "gpt2_enamine_real.ckpt",
SMILESTokenizerEnamine(),
),
"mamba": (
create_mamba_actor,
create_mamba_critic,
create_mamba_actor_critic,
resources.files("acegen.priors") / "mamba_chembl_filtered_vocabulary.ckpt",
resources.files("acegen.priors") / "mamba_chembl_filtered.ckpt",
SMILESTokenizerChEMBL(),
),
}


Expand Down
254 changes: 254 additions & 0 deletions acegen/models/mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
from dataclasses import dataclass

import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs import ExplorationType
from torchrl.modules import ActorValueOperator, ProbabilisticActor

try:
from mamba_ssm.models.mixer_seq_simple import MixerModel

_has_mamba = True
except ImportError as err:
_has_mamba = False
MAMBA_ERR = err


class Mamba(nn.Module):
"""Mamba model for language modelling. This model is a simple wrapper around the Mamba MixerModel."""

def __init__(self, config=None):
if not _has_mamba:
raise RuntimeError(
"mamba-ssm library not found, please install with pip install mamba-ssm==1.2.2 "
"(dependencies include: cuda, causal-conv1d>=1.2.0)."
) from MAMBA_ERR

super(Mamba, self).__init__()

# Define model
self.config = config
if config:
self.feature_extractor = MixerModel(
d_model=config.n_embd,
n_layer=config.n_layer,
vocab_size=config.vocab_size,
ssm_cfg={"use_fast_path": True},
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
)
else:
self.feature_extractor = None

# Start in evaluation mode
self._train_mode = False

@property
def train_mode(self):
return self._train_mode

def set_train_mode(self, train_mode: bool = True):
if train_mode is self._train_mode:
return self
out = Mamba()
out.feature_extractor = self.feature_extractor
out._train_mode = True
return out

def forward(self, sequence, sequence_mask):
out = self.feature_extractor(input_ids=sequence)

if self.train_mode is False: # Data collection, return only last token
obs_length = sequence_mask.sum(-1)
out = out[torch.arange(len(out)), obs_length.to(torch.int64) - 1]

return out


@dataclass
class MambaConfig:
"""Simple data class for mamba configuration."""

vocab_size: int
n_embd: int = 128
n_layer: int = 24


def create_mamba_actor(
vocabulary_size: int,
n_embd: int = 128,
n_layer: int = 24,
return_log_prob: bool = True,
**kwargs,
):
"""Create a Mamba actor for language modelling."""
# Define mode
config = MambaConfig(
vocab_size=vocabulary_size,
n_embd=n_embd,
n_layer=n_layer,
)
lm = Mamba(config)

# Transform into TensorDict modules
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)

lm_inference = TensorDictModule(
lm, in_keys=["sequence", "sequence_mask"], out_keys=["features"]
)

# Define final layer and also make it a TensorDictModule
lm_head = TensorDictModule(
nn.Linear(config.n_embd, vocabulary_size, bias=False),
in_keys=["features"],
out_keys=["logits"],
)

# Concatenate lm and head, similar to torch.nn.Sequential
policy_training = TensorDictSequential(lm_training, lm_head)
policy_inference = TensorDictSequential(lm_inference, lm_head)

# To make the actor probabilities, wrap the policy in a ProbabilisticActor
# This module will take care of sampling and computing log_probabilities
probabilistic_policy_training = ProbabilisticActor(
module=policy_training,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)
probabilistic_policy_inference = ProbabilisticActor(
module=policy_inference,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)
return probabilistic_policy_training, probabilistic_policy_inference


def create_mamba_critic(
vocabulary_size: int,
n_positions: int = 2048,
n_head: int = 16,
n_layer: int = 24,
n_embd: int = 128,
critic_value_per_action=False,
**kwargs,
):
"""Create a Mamba critic for language modelling."""
# Define model
config = MambaConfig(
vocab_size=vocabulary_size,
n_embd=n_embd,
n_layer=n_layer,
)
lm = Mamba(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)
lm_inference = TensorDictModule(
lm, in_keys=["sequence", "sequence_mask"], out_keys=["features"]
)

# Define final layer and also make it a TensorDictModule
lm_head = TensorDictModule(
nn.Linear(
config.n_embd,
vocabulary_size if critic_value_per_action else 1,
bias=False,
),
in_keys=["features"],
out_keys=["action_value"] if critic_value_per_action else ["state_value"],
)

# Concatenate lm and head, similar to torch.nn.Sequential
critic_training = TensorDictSequential(lm_training, lm_head)
critic_inference = TensorDictSequential(lm_inference, lm_head)
return critic_training, critic_inference


def create_mamba_actor_critic(
vocabulary_size: int,
n_embd: int = 128,
n_layer: int = 24,
return_log_prob: bool = True,
critic_value_per_action=False,
**kwargs,
):
"""Create a Mamba actor-critic for language modelling."""
# Define model
config = MambaConfig(
vocab_size=vocabulary_size,
n_embd=n_embd,
n_layer=n_layer,
)
lm = Mamba(config)

# Wrap the transformer in a TensorDictModule to make TensorDict compatible
lm_training = TensorDictModule(
lm.set_train_mode(True),
in_keys=["sequence", "sequence_mask"],
out_keys=["features"],
)
lm_inference = TensorDictModule(
lm, in_keys=["sequence", "sequence_mask"], out_keys=["features"]
)

# Define actor head and also make it a TensorDictModule and Probabilistic
actor_head = TensorDictModule(
nn.Linear(config.n_embd, vocabulary_size, bias=False),
in_keys=["features"],
out_keys=["logits"],
)
actor_head = ProbabilisticActor(
module=actor_head,
in_keys=["logits"],
out_keys=["action"],
distribution_class=torch.distributions.Categorical,
return_log_prob=return_log_prob,
default_interaction_type=ExplorationType.RANDOM,
)

# Define critic head and also make it a TensorDictModule
critic_head = TensorDictModule(
nn.Linear(
config.n_embd,
vocabulary_size if critic_value_per_action else 1,
bias=False,
),
in_keys=["features"],
out_keys=["action_value"] if critic_value_per_action else ["state_value"],
)

# Create shared actor-critic TensorDictModule
actor_critic_train = ActorValueOperator(
common_operator=lm_training,
policy_operator=actor_head,
value_operator=critic_head,
)
actor_critic_inference = ActorValueOperator(
common_operator=lm_inference,
policy_operator=actor_head,
value_operator=critic_head,
)

# Get individual operators
actor_training = actor_critic_train.get_policy_operator()
critic_training = actor_critic_train.get_value_operator()
actor_inference = actor_critic_inference.get_policy_operator()
critic_inference = actor_critic_inference.get_value_operator()

return actor_training, actor_inference, critic_training, critic_inference
Binary file added acegen/priors/mamba_chembl_filtered.ckpt
Binary file not shown.
Binary file not shown.
6 changes: 4 additions & 2 deletions scripts/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(get_primers_from_module(actor_training))
env.append_transform(get_primers_from_module(critic_training))
if primers := get_primers_from_module(actor_inference):
env.append_transform(primers)
if primers := get_primers_from_module(critic_inference):
env.append_transform(primers)
return env

env = create_env_fn()
Expand Down
3 changes: 2 additions & 1 deletion scripts/ahc/ahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(get_primers_from_module(actor_training))
if primers := get_primers_from_module(actor_inference):
env.append_transform(primers)
return env

env = create_env_fn()
Expand Down
6 changes: 4 additions & 2 deletions scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,10 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(get_primers_from_module(actor_training))
env.append_transform(get_primers_from_module(critic_training))
if primers := get_primers_from_module(actor_inference):
env.append_transform(primers)
if primers := get_primers_from_module(critic_inference):
env.append_transform(primers)
return env

env = create_env_fn()
Expand Down
28 changes: 28 additions & 0 deletions scripts/pretrain/config_mamba.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Logging configuration
experiment_name: mamba
agent_name: chembl28p
logger_backend: wandb # csv, wandb, tensorboard, or null
seed: 101
log_frequency: 500

# Dataset configuration
train_dataset_path: /workspace1/Priors/ChEMBL_potent/processed_data/ChEMBL28p_all_undersample-8.smi.gz
tokenizer: SMILESTokenizerChEMBL # SMILESTokenizer, SMILESTokenizer2, SMILESTokenizer3, DeepSMILESTokenizer, SELFIESTokenizer, AISTokenizer, SAFETokenizer, SmiZipTokenizer
#special_tokens: [".", "/", "\\", "@", "%", "*", "=", ":", "#", ">", "+", "-", "<UNK>", "<SEP>"]
recompute_dataset: True # if False and dataset_log_dir contains a dataset, it will be used
dataset_log_dir: /tmp/pretrain

# Model configuration
model: mamba # gru, lstm, gpt2 or mamba
model_log_dir: mamba/chembl28p

# Training configuration
lr: 0.001
lr_scheduler: StepLR
lr_scheduler_kwargs:
step_size: 500 # Change every n epochs
gamma: 0.95 # 1.0 = no decay
epochs: 25
batch_size: 512 # 512
randomize_smiles: True
num_test_smiles: 100
3 changes: 2 additions & 1 deletion scripts/pretrain/pretrain_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def main(cfg: "DictConfig"):
)
test_env = TransformedEnv(test_env)
test_env.append_transform(InitTracker())
test_env.append_transform(get_primers_from_module(actor_inference))
if primers := get_primers_from_module(actor_inference):
test_env.append_transform(primers)

logging.info("\nCreating test scoring function...")

Expand Down
3 changes: 2 additions & 1 deletion scripts/pretrain/pretrain_single_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def main(cfg: "DictConfig"):
)
test_env = TransformedEnv(test_env)
test_env.append_transform(InitTracker())
test_env.append_transform(get_primers_from_module(actor_inference))
if primers := get_primers_from_module(actor_inference):
test_env.append_transform(primers)

logging.info("\nCreating optimizer...")
actor_optimizer = torch.optim.Adam(actor_training.parameters(), lr=cfg.lr)
Expand Down
3 changes: 2 additions & 1 deletion scripts/reinforce/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def create_env_fn():
env = SMILESEnv(**env_kwargs)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(get_primers_from_module(actor_training))
if primers := get_primers_from_module(actor_inference):
env.append_transform(primers)
return env

env = create_env_fn()
Expand Down
Loading

0 comments on commit 5e2de20

Please sign in to comment.