Skip to content

Commit

Permalink
Exposed step_scale parameter for diffusion temperature (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
gcorso committed Dec 11, 2024
1 parent d379ab7 commit c59aa53
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 19 deletions.
37 changes: 19 additions & 18 deletions docs/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,25 @@ As an example to predict a structure using 10 recycling steps and 25 samples (th
boltz predict input_path --recycling_steps 10 --diffusion_samples 25


| **Option** | **Type** | **Default** | **Description** |
|-------------------------------|-----------------|--------------------|----------------------------------------------------------------------------------------------------|
| `--out_dir PATH` | `PATH` | `./` | The path where to save the predictions. |
| `--cache PATH` | `PATH` | `~/.boltz` | The directory where to download the data and model. |
| `--checkpoint PATH` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. |
| `--devices INTEGER` | `INTEGER` | `1` | The number of devices to use for prediction. |
| `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. |
| `--recycling_steps INTEGER` | `INTEGER` | `3` | The number of recycling steps to use for prediction. |
| `--sampling_steps INTEGER` | `INTEGER` | `200` | The number of sampling steps to use for prediction. |
| `--diffusion_samples INTEGER` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. |
| `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. |
| `--num_workers INTEGER` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. |
| `--override` | `FLAG` | `False` | Whether to override existing predictions if found. |
| `--use_msa_server` | `FLAG` | `False` | Whether to use the msa server to generate msa's. |
| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. |
| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' |
| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. |
| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. |
| **Option** | **Type** | **Default** | **Description** |
|-------------------------|-----------------|-----------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `--out_dir` | `PATH` | `./` | The path where to save the predictions. |
| `--cache` | `PATH` | `~/.boltz` | The directory where to download the data and model. |
| `--checkpoint` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. |
| `--devices` | `INTEGER` | `1` | The number of devices to use for prediction. |
| `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. |
| `--recycling_steps` | `INTEGER` | `3` | The number of recycling steps to use for prediction. |
| `--sampling_steps` | `INTEGER` | `200` | The number of sampling steps to use for prediction. |
| `--diffusion_samples` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. |
| `--step_scale` | `FLOAT` | `1.638` | The step size is related to the temperature at which the diffusion process samples the distribution. The lower the higher the diversity among samples (recommended between 1 and 2). |
| `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. |
| `--num_workers` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. |
| `--override` | `FLAG` | `False` | Whether to override existing predictions if found. |
| `--use_msa_server` | `FLAG` | `False` | Whether to use the msa server to generate msa's. |
| `--msa_server_url` | str | `https://api.colabfold.com` | MSA server url. Used only if --use_msa_server is set. |
| `--msa_pairing_strategy` | str | `greedy` | Pairing strategy to use. Used only if --use_msa_server is set. Options are 'greedy' and 'complete' |
| `--write_full_pae` | `FLAG` | `False` | Whether to save the full PAE matrix as a file. |
| `--write_full_pde` | `FLAG` | `False` | Whether to save the full PDE matrix as a file. |

## Output

Expand Down
12 changes: 11 additions & 1 deletion src/boltz/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ def cli() -> None:
help="The number of diffusion samples to use for prediction. Default is 1.",
default=1,
)
@click.option(
"--step_scale",
type=float,
help="The step size is related to the temperature at which the diffusion process samples the distribution."
"The lower the higher the diversity among samples (recommended between 1 and 2). Default is 1.638.",
default=1.638,
)
@click.option(
"--write_full_pae",
type=bool,
Expand Down Expand Up @@ -499,6 +506,7 @@ def predict(
recycling_steps: int = 3,
sampling_steps: int = 200,
diffusion_samples: int = 1,
step_scale: float = 1.638,
write_full_pae: bool = False,
write_full_pde: bool = False,
output_format: Literal["pdb", "mmcif"] = "mmcif",
Expand Down Expand Up @@ -600,12 +608,14 @@ def predict(
"write_full_pae": write_full_pae,
"write_full_pde": write_full_pde,
}
diffusion_params = BoltzDiffusionParams()
diffusion_params.step_scale = step_scale
model_module: Boltz1 = Boltz1.load_from_checkpoint(
checkpoint,
strict=True,
predict_args=predict_args,
map_location="cpu",
diffusion_process_args=asdict(BoltzDiffusionParams()),
diffusion_process_args=asdict(diffusion_params),
ema=False,
)
model_module.eval()
Expand Down

0 comments on commit c59aa53

Please sign in to comment.