Skip to content

Commit

Permalink
test sequence level classification
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Feb 5, 2025
1 parent 5f1baa0 commit 9941e25
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 50 deletions.
72 changes: 62 additions & 10 deletions examples/advanced/bionemo/downstream/downstream_nvflare.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,30 @@
"\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.13a0+0d33366-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_thunder-0.2.0.dev0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n",
"Requirement already satisfied: fuzzywuzzy in /usr/local/lib/python3.12/dist-packages (0.18.0)\n",
"Requirement already satisfied: PyTDC in /usr/local/lib/python3.12/dist-packages (1.1.12)\n",
"Collecting fuzzywuzzy\n",
" Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl.metadata (4.9 kB)\n",
"Collecting PyTDC\n",
" Downloading pytdc-1.1.12.tar.gz (151 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?2done\n",
"\u001b[?25hDownloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)\n",
"Building wheels for collected packages: PyTDC\n",
" Building wheel for PyTDC (setup.py) ... \u001b[done\n",
"\u001b[?25h Created wheel for PyTDC: filename=PyTDC-1.1.12-py3-none-any.whl size=189419 sha256=9048cd857b3364c6a849e786b96fc3be325e6570054366190aebadf9283d8376\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-z8ep7o__/wheels/de/2e/da/e46690d98c1256fee1741d03030623bc926cebc13e3e359145\n",
"Successfully built PyTDC\n",
"Installing collected packages: PyTDC, fuzzywuzzy\n",
"Successfully installed PyTDC-1.1.12 fuzzywuzzy-0.18.0\n",
"\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n",
"\u001b[0m"
"\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n"
]
}
],
"source": [
"# %%capture --no-display --no-stderr cell_output\n",
"! pip install fuzzywuzzy PyTDC --no-dependencies # install tdc without dependencies to avoid version conflicts in the BioNeMo container\n",
"#! pip install nvflare~=2.5\n",
"! pip install nvflare~=2.5\n",
"#! pip install biopython\n",
"#! pip install scikit-learn\n",
"#! pip install matplotlib\n",
Expand All @@ -82,9 +95,38 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading data from 'nvidia/clara/esm2nv8m:2.0' to file '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz'.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"download_end\": \"2025-02-04 15:59:35\",\n",
" \"download_start\": \"2025-02-04 15:59:33\",\n",
" \"download_time\": \"1s\",\n",
" \"files_downloaded\": 1,\n",
" \"local_path\": \"/root/.cache/bionemo/tmperx2hsc3/esm2nv8m_v2.0\",\n",
" \"size_downloaded\": \"16.97 MB\",\n",
" \"status\": \"COMPLETED\"\n",
"}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Untarring contents of '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz' to '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar'\n"
]
},
{
"name": "stdout",
"output_type": "stream",
Expand Down Expand Up @@ -177,15 +219,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading...\n",
"100%|███████████████████████████████████████| 601k/601k [00:00<00:00, 1.96MiB/s]\n",
"Found local copy...\n",
"Loading...\n",
"Done!\n",
"Sampling with alpha=1.0\n",
Expand Down Expand Up @@ -240,9 +281,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Traceback (most recent call last):\n",
" File \"/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py\", line 15, in <module>\n",
" from nvflare.job_config.script_runner import BaseScriptRunner\n",
"ModuleNotFoundError: No module named 'nvflare'\n"
]
}
],
"source": [
"! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py"
]
Expand Down
149 changes: 117 additions & 32 deletions examples/advanced/bionemo/downstream/sabdab/finetune_esm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
InMemoryProteinDataset,
InMemorySingleValueDataset,
)
from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig
from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig
from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig
from bionemo.llm.model.biobert.lightning import biobert_lightning_module
from bionemo.llm.model.biobert.model import BioBertConfig
from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size
Expand Down Expand Up @@ -80,9 +80,18 @@ def train_model(
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
task_type: str = "regression",
encoder_frozen: bool = False,
scale_lr_layer: Optional[str] = None,
lr_multiplier: float = 1.0,
# single value classification / regression mlp
mlp_ft_dropout: float = 0.25,
mlp_hidden_size: int = 256,
mlp_target_size: int = 1,
# token-level classification cnn
cnn_dropout: float = 0.25,
cnn_hidden_size: int = 32,
cnn_num_classes: int = 3,
wandb_entity: Optional[str] = None,
wandb_project: Optional[str] = None,
wandb_offline: bool = False,
Expand Down Expand Up @@ -132,9 +141,16 @@ def train_model(
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
precision (PrecisionTypes): Precision type for training (e.g., float16, float32)
task_type (str): Fine-tuning task type. Default is regression.
encoder_frozen (bool): Freeze the encoder parameters. Default is False.
scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier
lr_multiplier (float): lr multiplier for parameters in scale_lr_layer
mlp_ft_dropout (float): dropout for single value classification / regression mlp
mlp_hidden_size (int): dimension of hidden layer in mlp task head
mlp_target_size: (int): output dimension of the mlp task head (number of classes in classification tasks)
cnn_dropout (float): dropout for token-level classification cnn
cnn_hidden_size (int): hidden dimension of cnn head
cnn_num_classes (int): number of classes in token-level classification
wandb_entity (Optional[str]): The team posting this run (default: your username or your default team)
wandb_project (Optional[str]): The name of the project to which this run will belong
wandb_offline (bool): Run offline (data can be streamed later to wandb servers).
Expand Down Expand Up @@ -163,7 +179,6 @@ def train_model(
average_in_collective (bool): average in collective
grad_reduce_in_fp32 (bool): gradient reduction in fp32
"""
print("XXXXXX starting train_model")
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)

Expand All @@ -177,7 +192,6 @@ def train_model(
pipeline_model_parallel_size=pipeline_model_parallel_size,
)

print("XXXXXX starting MegatronStrategy")
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
Expand Down Expand Up @@ -229,8 +243,6 @@ def train_model(
)
)

print("XXXXXX starting Trainer")

trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
Expand All @@ -249,14 +261,12 @@ def train_model(
autocast_enabled=False,
),
)
# (2) patch the lightning trainer
flare.patch(trainer, restore_state=False, load_state_dict_strict=False)

tokenizer = get_tokenizer()

# Initialize the data module.
train_dataset = dataset_class.from_csv(train_data_path)
valid_dataset = dataset_class.from_csv(valid_data_path)
train_dataset = dataset_class.from_csv(train_data_path, task_type=task_type)
valid_dataset = dataset_class.from_csv(valid_data_path, task_type=task_type)

data_module = ESM2FineTuneDataModule(
train_dataset=train_dataset,
Expand All @@ -270,15 +280,29 @@ def train_model(
)
# Configure the model
config = config_class(
task_type=task_type,
encoder_frozen=encoder_frozen,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
initial_ckpt_path=str(restore_from_checkpoint_path),
# initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint.
)
initial_ckpt_skip_keys_with_these_prefixes=[f"{task_type}_head"],
)
# Mapping of task-dependent config attributes to their new values
task_dependent_attr = {
"mlp_ft_dropout": mlp_ft_dropout,
"mlp_hidden_size": mlp_hidden_size,
"mlp_target_size": mlp_target_size,
"cnn_dropout": cnn_dropout,
"cnn_hidden_size": cnn_hidden_size,
"cnn_num_classes": cnn_num_classes,
}
# Update attributes only if they exist in the config
for attr, value in task_dependent_attr.items():
if hasattr(config, attr):
setattr(config, attr, value)

optimizer = MegatronOptimizerModule(
config=OptimizerConfig(
Expand Down Expand Up @@ -317,28 +341,24 @@ def train_model(
ckpt_callback=checkpoint_callback,
)

# (3) receives FLModel from NVFlare
# Note that we don't need to pass this input_model to trainer
# because after flare.patch the trainer.fit/validate will get the
# global model internally
input_model = flare.receive()
print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n")

llm.train(
model=module,
data=data_module,
trainer=trainer,
log=None, #nemo_logger,
resume=None #resume.AutoResume(
#resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
#resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
#),
)
#llm.train(
# model=module,
# data=data_module,
# trainer=trainer,
# log=nemo_logger,
# resume=resume.AutoResume(
# resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
# resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
# ),
#)

flare.shutdown()
# (2) patch the lightning trainer
flare.patch(trainer, restore_state=False, load_state_dict_strict=False)

while flare.is_running():
trainer.fit(module, data_module)

ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", ""))

return ckpt_path, metric_tracker, trainer


Expand All @@ -347,6 +367,11 @@ def finetune_esm2_entrypoint():
# 1. get arguments
parser = get_parser()
args = parser.parse_args()

# to avoid padding for single value labels:
if args.min_seq_length is not None and args.datset_class is InMemorySingleValueDataset:
parser.error("Arguments --min-seq-length cannot be set when using InMemorySingleValueDataset.")

# 2. Call pretrain with args
train_model(
train_data_path=args.train_data_path,
Expand Down Expand Up @@ -375,9 +400,18 @@ def finetune_esm2_entrypoint():
tensor_model_parallel_size=args.tensor_model_parallel_size,
accumulate_grad_batches=args.accumulate_grad_batches,
precision=args.precision,
task_type=args.task_type,
encoder_frozen=args.encoder_frozen,
scale_lr_layer=args.scale_lr_layer,
lr_multiplier=args.lr_multiplier,
# single value classification / regression mlp
mlp_ft_dropout=args.mlp_ft_dropout,
mlp_hidden_size=args.mlp_hidden_size,
mlp_target_size=args.mlp_target_size,
# token-level classification cnn
cnn_dropout=args.cnn_dropout,
cnn_hidden_size=args.cnn_hidden_size,
cnn_num_classes=args.cnn_num_classes,
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
Expand Down Expand Up @@ -423,6 +457,14 @@ def get_parser():
default="bf16-mixed",
help="Precision type to use for training.",
)
parser.add_argument(
"--task-type",
type=str,
choices=["regression", "classification"],
required=True,
default="regression",
help="Fine-tuning task type.",
)
parser.add_argument(
"--encoder-frozen",
action="store_true",
Expand Down Expand Up @@ -450,6 +492,48 @@ def get_parser():
default=1.0,
help="Learning rate multiplier for layers with scale-lr-layer in their name",
)
parser.add_argument(
"--mlp-ft-dropout",
type=float,
required=False,
default=0.25,
help="Dropout for single value classification / regression mlp. Default is 0.25",
)
parser.add_argument(
"--mlp-hidden-size",
type=int,
required=False,
default=256,
help="Dimension of hidden layer in mlp task head. Default is 256",
)
parser.add_argument(
"--mlp-target-size",
type=int,
required=False,
default=1,
help="Output dimension of the mlp task head. Set to 1 for regression and number of classes for classification tasks. Default is 1",
)
parser.add_argument(
"--cnn-dropout",
type=float,
required=False,
default=0.25,
help="Dropout for token-level classification cnn. Default is 0.25",
)
parser.add_argument(
"--cnn-hidden-size",
type=int,
required=False,
default=32,
help="Hidden dimension of cnn head. Default is 32",
)
parser.add_argument(
"--cnn-num-classes",
type=int,
required=False,
default=3,
help="Number of classes for token-level classification cnn. Default is 3",
)
parser.add_argument(
"--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
)
Expand Down Expand Up @@ -523,7 +607,7 @@ def get_parser():
"--min-seq-length",
type=float_or_int_or_none,
required=False,
default=1024,
default=None,
help="Minimum sequence length. Sampled will be padded if less than this value. Set 'None' to unset minimum.",
)
parser.add_argument(
Expand Down Expand Up @@ -689,3 +773,4 @@ def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]:

if __name__ == "__main__":
finetune_esm2_entrypoint()

Loading

0 comments on commit 9941e25

Please sign in to comment.