From 9941e25caaaef59139b85a3b17f5da867c2b4366 Mon Sep 17 00:00:00 2001 From: Holger Roth Date: Wed, 5 Feb 2025 02:59:03 +0000 Subject: [PATCH] test sequence level classification --- .../downstream/downstream_nvflare.ipynb | 72 +++++++-- .../downstream/sabdab/finetune_esm2.py | 149 ++++++++++++++---- .../downstream/sabdab/prepare_sabdab_data.py | 7 +- .../downstream/sabdab/run_sim_sabdab.py | 8 +- examples/advanced/bionemo/start_bionemo.sh | 2 +- 5 files changed, 188 insertions(+), 50 deletions(-) diff --git a/examples/advanced/bionemo/downstream/downstream_nvflare.ipynb b/examples/advanced/bionemo/downstream/downstream_nvflare.ipynb index 7e3bcd1af5..afe643b0f2 100644 --- a/examples/advanced/bionemo/downstream/downstream_nvflare.ipynb +++ b/examples/advanced/bionemo/downstream/downstream_nvflare.ipynb @@ -47,17 +47,30 @@ "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/nvfuser-0.2.13a0+0d33366-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", "\u001b[0m\u001b[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_thunder-0.2.0.dev0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n", "\u001b[0mLooking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com\n", - "Requirement already satisfied: fuzzywuzzy in /usr/local/lib/python3.12/dist-packages (0.18.0)\n", - "Requirement already satisfied: PyTDC in /usr/local/lib/python3.12/dist-packages (1.1.12)\n", + "Collecting fuzzywuzzy\n", + " Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl.metadata (4.9 kB)\n", + "Collecting PyTDC\n", + " Downloading pytdc-1.1.12.tar.gz (151 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?2done\n", + "\u001b[?25hDownloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)\n", + "Building wheels for collected packages: PyTDC\n", + " Building wheel for PyTDC (setup.py) ... \u001b[done\n", + "\u001b[?25h Created wheel for PyTDC: filename=PyTDC-1.1.12-py3-none-any.whl size=189419 sha256=9048cd857b3364c6a849e786b96fc3be325e6570054366190aebadf9283d8376\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-z8ep7o__/wheels/de/2e/da/e46690d98c1256fee1741d03030623bc926cebc13e3e359145\n", + "Successfully built PyTDC\n", + "Installing collected packages: PyTDC, fuzzywuzzy\n", + "Successfully installed PyTDC-1.1.12 fuzzywuzzy-0.18.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable.It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", - "\u001b[0m" + "\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m24.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m25.0\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpython -m pip install --upgrade pip\u001b[0m\n" ] } ], "source": [ "# %%capture --no-display --no-stderr cell_output\n", "! pip install fuzzywuzzy PyTDC --no-dependencies # install tdc without dependencies to avoid version conflicts in the BioNeMo container\n", - "#! pip install nvflare~=2.5\n", + "! pip install nvflare~=2.5\n", "#! pip install biopython\n", "#! pip install scikit-learn\n", "#! pip install matplotlib\n", @@ -82,9 +95,38 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Downloading data from 'nvidia/clara/esm2nv8m:2.0' to file '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz'.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{\n", + " \"download_end\": \"2025-02-04 15:59:35\",\n", + " \"download_start\": \"2025-02-04 15:59:33\",\n", + " \"download_time\": \"1s\",\n", + " \"files_downloaded\": 1,\n", + " \"local_path\": \"/root/.cache/bionemo/tmperx2hsc3/esm2nv8m_v2.0\",\n", + " \"size_downloaded\": \"16.97 MB\",\n", + " \"status\": \"COMPLETED\"\n", + "}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Untarring contents of '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz' to '/root/.cache/bionemo/2957b2c36d5978d0f595d6f1b72104b312621cf0329209086537b613c1c96d16-esm2_hf_converted_8m_checkpoint.tar.gz.untar'\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -177,15 +219,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Downloading...\n", - "100%|███████████████████████████████████████| 601k/601k [00:00<00:00, 1.96MiB/s]\n", + "Found local copy...\n", "Loading...\n", "Done!\n", "Sampling with alpha=1.0\n", @@ -240,9 +281,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Traceback (most recent call last):\n", + " File \"/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py\", line 15, in \n", + " from nvflare.job_config.script_runner import BaseScriptRunner\n", + "ModuleNotFoundError: No module named 'nvflare'\n" + ] + } + ], "source": [ "! cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py" ] diff --git a/examples/advanced/bionemo/downstream/sabdab/finetune_esm2.py b/examples/advanced/bionemo/downstream/sabdab/finetune_esm2.py index 5980deacbf..516a677847 100644 --- a/examples/advanced/bionemo/downstream/sabdab/finetune_esm2.py +++ b/examples/advanced/bionemo/downstream/sabdab/finetune_esm2.py @@ -35,8 +35,8 @@ InMemoryProteinDataset, InMemorySingleValueDataset, ) -from bionemo.esm2.model.finetune.finetune_regressor import ESM2FineTuneSeqConfig -from bionemo.esm2.model.finetune.finetune_token_classifier import ESM2FineTuneTokenConfig +from bionemo.esm2.model.finetune.sequence_model import ESM2FineTuneSeqConfig +from bionemo.esm2.model.finetune.token_model import ESM2FineTuneTokenConfig from bionemo.llm.model.biobert.lightning import biobert_lightning_module from bionemo.llm.model.biobert.model import BioBertConfig from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size @@ -80,9 +80,18 @@ def train_model( experiment_name: str, resume_if_exists: bool, precision: PrecisionTypes, + task_type: str = "regression", encoder_frozen: bool = False, scale_lr_layer: Optional[str] = None, lr_multiplier: float = 1.0, + # single value classification / regression mlp + mlp_ft_dropout: float = 0.25, + mlp_hidden_size: int = 256, + mlp_target_size: int = 1, + # token-level classification cnn + cnn_dropout: float = 0.25, + cnn_hidden_size: int = 32, + cnn_num_classes: int = 3, wandb_entity: Optional[str] = None, wandb_project: Optional[str] = None, wandb_offline: bool = False, @@ -132,9 +141,16 @@ def train_model( result_dir that stores the logs and checkpoints. resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] precision (PrecisionTypes): Precision type for training (e.g., float16, float32) + task_type (str): Fine-tuning task type. Default is regression. encoder_frozen (bool): Freeze the encoder parameters. Default is False. scale_lr_layer (Optional[str]): layer names for which the lr is scaled by lr_multiplier lr_multiplier (float): lr multiplier for parameters in scale_lr_layer + mlp_ft_dropout (float): dropout for single value classification / regression mlp + mlp_hidden_size (int): dimension of hidden layer in mlp task head + mlp_target_size: (int): output dimension of the mlp task head (number of classes in classification tasks) + cnn_dropout (float): dropout for token-level classification cnn + cnn_hidden_size (int): hidden dimension of cnn head + cnn_num_classes (int): number of classes in token-level classification wandb_entity (Optional[str]): The team posting this run (default: your username or your default team) wandb_project (Optional[str]): The name of the project to which this run will belong wandb_offline (bool): Run offline (data can be streamed later to wandb servers). @@ -163,7 +179,6 @@ def train_model( average_in_collective (bool): average in collective grad_reduce_in_fp32 (bool): gradient reduction in fp32 """ - print("XXXXXX starting train_model") # Create the result directory if it does not exist. result_dir.mkdir(parents=True, exist_ok=True) @@ -177,7 +192,6 @@ def train_model( pipeline_model_parallel_size=pipeline_model_parallel_size, ) - print("XXXXXX starting MegatronStrategy") strategy = nl.MegatronStrategy( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, @@ -229,8 +243,6 @@ def train_model( ) ) - print("XXXXXX starting Trainer") - trainer = nl.Trainer( devices=devices, max_steps=num_steps, @@ -249,14 +261,12 @@ def train_model( autocast_enabled=False, ), ) - # (2) patch the lightning trainer - flare.patch(trainer, restore_state=False, load_state_dict_strict=False) tokenizer = get_tokenizer() # Initialize the data module. - train_dataset = dataset_class.from_csv(train_data_path) - valid_dataset = dataset_class.from_csv(valid_data_path) + train_dataset = dataset_class.from_csv(train_data_path, task_type=task_type) + valid_dataset = dataset_class.from_csv(valid_data_path, task_type=task_type) data_module = ESM2FineTuneDataModule( train_dataset=train_dataset, @@ -270,6 +280,7 @@ def train_model( ) # Configure the model config = config_class( + task_type=task_type, encoder_frozen=encoder_frozen, params_dtype=get_autocast_dtype(precision), pipeline_dtype=get_autocast_dtype(precision), @@ -277,8 +288,21 @@ def train_model( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, initial_ckpt_path=str(restore_from_checkpoint_path), - # initial_ckpt_skip_keys_with_these_prefixes=[], # load everything from the checkpoint. - ) + initial_ckpt_skip_keys_with_these_prefixes=[f"{task_type}_head"], + ) + # Mapping of task-dependent config attributes to their new values + task_dependent_attr = { + "mlp_ft_dropout": mlp_ft_dropout, + "mlp_hidden_size": mlp_hidden_size, + "mlp_target_size": mlp_target_size, + "cnn_dropout": cnn_dropout, + "cnn_hidden_size": cnn_hidden_size, + "cnn_num_classes": cnn_num_classes, + } + # Update attributes only if they exist in the config + for attr, value in task_dependent_attr.items(): + if hasattr(config, attr): + setattr(config, attr, value) optimizer = MegatronOptimizerModule( config=OptimizerConfig( @@ -317,28 +341,24 @@ def train_model( ckpt_callback=checkpoint_callback, ) - # (3) receives FLModel from NVFlare - # Note that we don't need to pass this input_model to trainer - # because after flare.patch the trainer.fit/validate will get the - # global model internally - input_model = flare.receive() - print(f"\n[Current Round={input_model.current_round}, Site = {flare.get_site_name()}]\n") - - llm.train( - model=module, - data=data_module, - trainer=trainer, - log=None, #nemo_logger, - resume=None #resume.AutoResume( - #resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training. - #resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. - #), - ) + #llm.train( + # model=module, + # data=data_module, + # trainer=trainer, + # log=nemo_logger, + # resume=resume.AutoResume( + # resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training. + # resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + # ), + #) - flare.shutdown() + # (2) patch the lightning trainer + flare.patch(trainer, restore_state=False, load_state_dict_strict=False) + while flare.is_running(): + trainer.fit(module, data_module) + ckpt_path = Path(checkpoint_callback.last_model_path.replace(".ckpt", "")) - return ckpt_path, metric_tracker, trainer @@ -347,6 +367,11 @@ def finetune_esm2_entrypoint(): # 1. get arguments parser = get_parser() args = parser.parse_args() + + # to avoid padding for single value labels: + if args.min_seq_length is not None and args.datset_class is InMemorySingleValueDataset: + parser.error("Arguments --min-seq-length cannot be set when using InMemorySingleValueDataset.") + # 2. Call pretrain with args train_model( train_data_path=args.train_data_path, @@ -375,9 +400,18 @@ def finetune_esm2_entrypoint(): tensor_model_parallel_size=args.tensor_model_parallel_size, accumulate_grad_batches=args.accumulate_grad_batches, precision=args.precision, + task_type=args.task_type, encoder_frozen=args.encoder_frozen, scale_lr_layer=args.scale_lr_layer, lr_multiplier=args.lr_multiplier, + # single value classification / regression mlp + mlp_ft_dropout=args.mlp_ft_dropout, + mlp_hidden_size=args.mlp_hidden_size, + mlp_target_size=args.mlp_target_size, + # token-level classification cnn + cnn_dropout=args.cnn_dropout, + cnn_hidden_size=args.cnn_hidden_size, + cnn_num_classes=args.cnn_num_classes, experiment_name=args.experiment_name, resume_if_exists=args.resume_if_exists, restore_from_checkpoint_path=args.restore_from_checkpoint_path, @@ -423,6 +457,14 @@ def get_parser(): default="bf16-mixed", help="Precision type to use for training.", ) + parser.add_argument( + "--task-type", + type=str, + choices=["regression", "classification"], + required=True, + default="regression", + help="Fine-tuning task type.", + ) parser.add_argument( "--encoder-frozen", action="store_true", @@ -450,6 +492,48 @@ def get_parser(): default=1.0, help="Learning rate multiplier for layers with scale-lr-layer in their name", ) + parser.add_argument( + "--mlp-ft-dropout", + type=float, + required=False, + default=0.25, + help="Dropout for single value classification / regression mlp. Default is 0.25", + ) + parser.add_argument( + "--mlp-hidden-size", + type=int, + required=False, + default=256, + help="Dimension of hidden layer in mlp task head. Default is 256", + ) + parser.add_argument( + "--mlp-target-size", + type=int, + required=False, + default=1, + help="Output dimension of the mlp task head. Set to 1 for regression and number of classes for classification tasks. Default is 1", + ) + parser.add_argument( + "--cnn-dropout", + type=float, + required=False, + default=0.25, + help="Dropout for token-level classification cnn. Default is 0.25", + ) + parser.add_argument( + "--cnn-hidden-size", + type=int, + required=False, + default=32, + help="Hidden dimension of cnn head. Default is 32", + ) + parser.add_argument( + "--cnn-num-classes", + type=int, + required=False, + default=3, + help="Number of classes for token-level classification cnn. Default is 3", + ) parser.add_argument( "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." ) @@ -523,7 +607,7 @@ def get_parser(): "--min-seq-length", type=float_or_int_or_none, required=False, - default=1024, + default=None, help="Minimum sequence length. Sampled will be padded if less than this value. Set 'None' to unset minimum.", ) parser.add_argument( @@ -689,3 +773,4 @@ def dataset_class_type(desc: str) -> Type[InMemoryProteinDataset]: if __name__ == "__main__": finetune_esm2_entrypoint() + diff --git a/examples/advanced/bionemo/downstream/sabdab/prepare_sabdab_data.py b/examples/advanced/bionemo/downstream/sabdab/prepare_sabdab_data.py index 5faa66633c..f6793cd00f 100644 --- a/examples/advanced/bionemo/downstream/sabdab/prepare_sabdab_data.py +++ b/examples/advanced/bionemo/downstream/sabdab/prepare_sabdab_data.py @@ -81,6 +81,7 @@ def main(): for s in ["train", "valid", "test"]: split[s] = split[s].rename(columns={"Antibody": "sequences"}) split[s] = split[s].rename(columns={"Y": "labels"}) + split[s]["labels"] = split[s]["labels"].map({0: "neg", 1: "pos"}) train_df = pd.concat([split["train"], split["valid"]]) test_df = split["test"] @@ -117,7 +118,7 @@ def main(): if do_clean_chains: train_df = clean_chains(train_df) test_df = clean_chains(test_df) - + _split_dir = os.path.join(split_dir, "train") if not os.path.isdir(_split_dir): os.makedirs(_split_dir) @@ -134,8 +135,8 @@ def main(): print(f"Saved {len(train_df)} training and {len(test_df)} testing proteins.") for _set, _df in zip(["TRAIN", "TEST"], [train_df, test_df]): - n_pos = np.sum(_df["labels"] == 0) - n_neg = np.sum(_df["labels"] == 1) + n_pos = np.sum(_df["labels"] == "pos") + n_neg = np.sum(_df["labels"] == "neg") n = len(_df) print(f" {_set} Pos/Neg ratio: neg={n_neg}, pos={n_pos}: {n_pos / n_neg:0.3f}") print(f" {_set} Trivial accuracy: {n_pos / n:0.3f}") diff --git a/examples/advanced/bionemo/downstream/sabdab/run_sim_sabdab.py b/examples/advanced/bionemo/downstream/sabdab/run_sim_sabdab.py index 645ad735dd..5b19cc0737 100644 --- a/examples/advanced/bionemo/downstream/sabdab/run_sim_sabdab.py +++ b/examples/advanced/bionemo/downstream/sabdab/run_sim_sabdab.py @@ -48,7 +48,7 @@ def __init__(self): super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) def process_dxo(self, dxo, shareable, fl_ctx): - #self.logger.info(f"#######RECEIVING DXO: {dxo}") + self.logger.info(f"#######RECEIVING DXO: {dxo}") print(f"#######RECEIVING DXO: {dxo}") return dxo @@ -61,7 +61,7 @@ def __init__(self): super().__init__(supported_data_kinds=data_kinds, data_kinds_to_filter=data_kinds) def process_dxo(self, dxo, shareable, fl_ctx): - #self.logger.info(f"#######SENDING DXO: {dxo}") + self.logger.info(f"#######SENDING DXO: {dxo}") print(f"#######SENDING DXO: {dxo}") return dxo @@ -89,10 +89,10 @@ def main(n_clients, num_rounds, train_script): for i in range(n_clients): client_name = f"site-{i+1}" runner = BaseScriptRunner(script=train_script, + script_args=f"--restore-from-checkpoint-path {checkpoint_path} --train-data-path {train_data_path} --valid-data-path {val_data_path} --config-class ESM2FineTuneSeqConfig --dataset-class InMemorySingleValueDataset --task-type classification --mlp-ft-dropout 0.25 --mlp-hidden-size 256 --mlp-target-size 3 --experiment-name {job.name} --num-steps 10 --num-gpus 1 --val-check-interval 10 --log-every-n-steps 10 --encoder-frozen --lr 5e-3 --lr-multiplier 1e2 --scale-lr-layer classification_head --result-dir . --micro-batch-size 2 --precision bf16-mixed", launch_external_process=True, framework="pytorch", - params_exchange_format="pytorch", - launcher=SubprocessLauncher(script=f"python custom/{train_script} --restore-from-checkpoint-path {checkpoint_path} --train-data-path {train_data_path} --valid-data-path {val_data_path} --config-class ESM2FineTuneSeqConfig --dataset-class InMemorySingleValueDataset --experiment-name {job.name} --num-steps 10 --num-gpus 1 --val-check-interval 10 --log-every-n-steps 10 --lr 5e-3 --lr-multiplier 1e2 --scale-lr-layer regression_head --result-dir . --micro-batch-size 2 --num-gpus 1 --precision bf16-mixed", launch_once=False)) + params_exchange_format="pytorch") job.to(runner, client_name) job.to(ReceiveFilter(), client_name, tasks=["*"], filter_type=FilterType.TASK_DATA) job.to(SendFilter(), client_name, tasks=["*"], filter_type=FilterType.TASK_RESULT) diff --git a/examples/advanced/bionemo/start_bionemo.sh b/examples/advanced/bionemo/start_bionemo.sh index bc0808483a..b6c5a77f9b 100755 --- a/examples/advanced/bionemo/start_bionemo.sh +++ b/examples/advanced/bionemo/start_bionemo.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash -#DOCKER_IMAGE="nvcr.io/nvidia/clara/bionemo-framework:2.2" +#DOCKER_IMAGE="nvcr.io/nvidia/clara/bionemo-framework:2.3@sha256:715388e62a55ee55f3f0796174576299b121ca9f95d3b83d2b397f7501b21061" DOCKER_IMAGE="nvcr.io/nvidia/clara/bionemo-framework:nightly" GPU="all"