From c59aa53815c02f0028c812732a174b8422df6fa6 Mon Sep 17 00:00:00 2001 From: Gabriele Corso Date: Tue, 10 Dec 2024 18:55:22 -0800 Subject: [PATCH] Exposed step_scale parameter for diffusion temperature (#85) --- docs/prediction.md | 37 +++++++++++++++++++------------------ src/boltz/main.py | 12 +++++++++++- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/docs/prediction.md b/docs/prediction.md index c9c2760..a167699 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -117,24 +117,25 @@ As an example to predict a structure using 10 recycling steps and 25 samples (th boltz predict input_path --recycling_steps 10 --diffusion_samples 25 -| **Option** | **Type** | **Default** | **Description** | -|-------------------------------|-----------------|--------------------|----------------------------------------------------------------------------------------------------| -| `--out_dir PATH` | `PATH` | `./` | The path where to save the predictions. | -| `--cache PATH` | `PATH` | `~/.boltz` | The directory where to download the data and model. | -| `--checkpoint PATH` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. | -| `--devices INTEGER` | `INTEGER` | `1` | The number of devices to use for prediction. | -| `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. | -| `--recycling_steps INTEGER` | `INTEGER` | `3` | The number of recycling steps to use for prediction. | -| `--sampling_steps INTEGER` | `INTEGER` | `200` | The number of sampling steps to use for prediction. | -| `--diffusion_samples INTEGER` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. | -| `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. | -| `--num_workers INTEGER` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. | -| `--override` | `FLAG` | `False` | Whether to override existing predictions if found. | -| `--use_msa_server` | `FLAG` | `False` | Whether to use the msa server to generate msa's. | -| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. | -| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' | -| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. | -| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. | +| **Option** | **Type** | **Default** | **Description** | +|-------------------------|-----------------|-----------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `--out_dir` | `PATH` | `./` | The path where to save the predictions. | +| `--cache` | `PATH` | `~/.boltz` | The directory where to download the data and model. | +| `--checkpoint` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. | +| `--devices` | `INTEGER` | `1` | The number of devices to use for prediction. | +| `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. | +| `--recycling_steps` | `INTEGER` | `3` | The number of recycling steps to use for prediction. | +| `--sampling_steps` | `INTEGER` | `200` | The number of sampling steps to use for prediction. | +| `--diffusion_samples` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. | +| `--step_scale` | `FLOAT` | `1.638` | The step size is related to the temperature at which the diffusion process samples the distribution. The lower the higher the diversity among samples (recommended between 1 and 2). | +| `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. | +| `--num_workers` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. | +| `--override` | `FLAG` | `False` | Whether to override existing predictions if found. | +| `--use_msa_server` | `FLAG` | `False` | Whether to use the msa server to generate msa's. | +| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. | +| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' | +| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. | +| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. | ## Output diff --git a/src/boltz/main.py b/src/boltz/main.py index d6685db..71428e3 100644 --- a/src/boltz/main.py +++ b/src/boltz/main.py @@ -437,6 +437,13 @@ def cli() -> None: help="The number of diffusion samples to use for prediction. Default is 1.", default=1, ) +@click.option( + "--step_scale", + type=float, + help="The step size is related to the temperature at which the diffusion process samples the distribution." + "The lower the higher the diversity among samples (recommended between 1 and 2). Default is 1.638.", + default=1.638, +) @click.option( "--write_full_pae", type=bool, @@ -499,6 +506,7 @@ def predict( recycling_steps: int = 3, sampling_steps: int = 200, diffusion_samples: int = 1, + step_scale: float = 1.638, write_full_pae: bool = False, write_full_pde: bool = False, output_format: Literal["pdb", "mmcif"] = "mmcif", @@ -600,12 +608,14 @@ def predict( "write_full_pae": write_full_pae, "write_full_pde": write_full_pde, } + diffusion_params = BoltzDiffusionParams() + diffusion_params.step_scale = step_scale model_module: Boltz1 = Boltz1.load_from_checkpoint( checkpoint, strict=True, predict_args=predict_args, map_location="cpu", - diffusion_process_args=asdict(BoltzDiffusionParams()), + diffusion_process_args=asdict(diffusion_params), ema=False, ) model_module.eval()