Skip to content

Commit

Permalink
Merge branch 'main' into pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
jwohlwend committed Dec 21, 2024
2 parents d159cb9 + 2d9d45b commit ff681f2
Show file tree
Hide file tree
Showing 7 changed files with 134 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ To see all available options: `boltz predict --help` and for more information on

To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1 and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and same number of recycling and diffusion steps. More details on these evaluations can be found in our [evaluation instructions](docs/evaluation.md).

![Test set evaluations](docs/test_evals.png)
![Test set evaluations](docs/plot_test.png)
![CASP15 set evaluations](docs/plot_casp.png)


## Training
Expand Down
Binary file modified docs/boltz1_pred_figure.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion docs/evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

To encourage reproducibility and facilitate comparison with other models, we provide the evaluation scripts and predictions for Boltz-1, Chai-1, and AlphaFold3 on our test benchmark dataset as well as CASP15. These datasets are created to contain biomolecules different from the training data and to benchmark the performance of these models we run them with the same input MSAs and the same number of recycling and diffusion steps.

![Test set evaluations](../docs/test_evals.png)
![Test set evaluations](../docs/plot_test.png)
![CASP15 set evaluations](../docs/plot_casp.png)


## Evaluation files
Expand Down
Binary file added docs/plot_casp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plot_test.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed docs/test_evals.png
Binary file not shown.
147 changes: 130 additions & 17 deletions scripts/eval/aggregate_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

METRICS = ["lddt", "bb_lddt", "tm_score", "rmsd"]
Expand Down Expand Up @@ -329,28 +330,40 @@ def eval_models(chai_preds, chai_evals, af3_preds, af3_evals, boltz_preds, boltz
if metric_name in chai_results and metric_name in boltz_results:
if af3_results[metric_name]["len"] == chai_results[metric_name]["len"] and af3_results[metric_name]["len"] == boltz_results[metric_name]["len"]:
results.append({
"tool": "af3",
"tool": "AF3 oracle",
"target": name,
"metric": metric_name,
"oracle": af3_results[metric_name]["oracle"],
"average": af3_results[metric_name]["average"],
"top1": af3_results[metric_name]["top1"],
"value": af3_results[metric_name]["oracle"],
})
results.append({
"tool": "chai",
"tool": "AF3 top-1",
"target": name,
"metric": metric_name,
"oracle": chai_results[metric_name]["oracle"],
"average": chai_results[metric_name]["average"],
"top1": chai_results[metric_name]["top1"],
"value": af3_results[metric_name]["top1"],
})
results.append({
"tool": "boltz",
"tool": "Chai-1 oracle",
"target": name,
"metric": metric_name,
"oracle": boltz_results[metric_name]["oracle"],
"average": boltz_results[metric_name]["average"],
"top1": boltz_results[metric_name]["top1"],
"value": chai_results[metric_name]["oracle"],
})
results.append({
"tool": "Chai-1 top-1",
"target": name,
"metric": metric_name,
"value": chai_results[metric_name]["top1"],
})
results.append({
"tool": "Boltz-1 oracle",
"target": name,
"metric": metric_name,
"value": boltz_results[metric_name]["oracle"],
})
results.append({
"tool": "Boltz-1 top-1",
"target": name,
"metric": metric_name,
"value": boltz_results[metric_name]["top1"],
})
else:
print("Different lengths", name, metric_name, af3_results[metric_name]["len"], chai_results[metric_name]["len"], boltz_results[metric_name]["len"])
Expand All @@ -359,10 +372,110 @@ def eval_models(chai_preds, chai_evals, af3_preds, af3_evals, boltz_preds, boltz

# Write the results to a file, ensure we only keep the target & metrics where we have all tools
df = pd.DataFrame(results)
df = df.groupby(["target", "metric"]).filter(lambda x: len(x["tool"]) == 3)
return df


def bootstrap_ci(series, n_boot=1000, alpha=0.05):
"""
Compute 95% bootstrap confidence intervals for the mean of 'series'.
"""
n = len(series)
boot_means = []
# Perform bootstrap resampling
for _ in range(n_boot):
sample = series.sample(n, replace=True)
boot_means.append(sample.mean())

boot_means = np.array(boot_means)
mean_val = np.mean(boot_means)
lower = np.percentile(boot_means, 100 * alpha / 2)
upper = np.percentile(boot_means, 100 * (1 - alpha / 2))
return mean_val, lower, upper


def plot_data(desired_tools, desired_metrics, df, dataset, filename):

filtered_df = df[df['tool'].isin(desired_tools) & df['metric'].isin(desired_metrics)]

# Apply bootstrap to each (tool, metric) group
boot_stats = filtered_df.groupby(["tool", "metric"])["value"].apply(bootstrap_ci)

# boot_stats is a Series of tuples (mean, lower, upper). Convert to DataFrame:
boot_stats = boot_stats.apply(pd.Series)
boot_stats.columns = ["mean", "lower", "upper"]

# Unstack to get a DataFrame suitable for plotting
plot_data = boot_stats['mean'].unstack('tool')
plot_data = plot_data.reindex(desired_metrics)

lower_data = boot_stats['lower'].unstack('tool')
lower_data = lower_data.reindex(desired_metrics)

upper_data = boot_stats['upper'].unstack('tool')
upper_data = upper_data.reindex(desired_metrics)

# If you need a specific order of tools:
tool_order = ["AF3 oracle", "AF3 top-1",
"Chai-1 oracle", "Chai-1 top-1",
"Boltz-1 oracle", "Boltz-1 top-1"]
plot_data = plot_data[tool_order]
lower_data = lower_data[tool_order]
upper_data = upper_data[tool_order]

# Rename metrics
renaming = {
"lddt_pli": "Mean LDDT-PLI",
"rmsd<2": "L-RMSD < 2A",
"lddt": "Mean LDDT",
"dockq_>0.23": "DockQ > 0.23",
}
plot_data = plot_data.rename(index=renaming)
lower_data = lower_data.rename(index=renaming)
upper_data = upper_data.rename(index=renaming)
mean_vals = plot_data.values

# Colors
tool_colors = [
"#004D80", # AF3 oracle
"#55C2FF", # AF3 top-1
"#931652", # Chai-1 oracle
"#FC8AD9", # Chai-1 top-1
"#188F52", # Boltz-1 oracle
"#86E935" # Boltz-1 top-1
]

fig, ax = plt.subplots(figsize=(10, 5))

x = np.arange(len(plot_data.index))
bar_spacing = 0.015
total_width = 0.7
# Adjust width to account for the spacing
width = (total_width - (len(tool_order) - 1) * bar_spacing) / len(tool_order)

for i, tool in enumerate(tool_order):
# Each subsequent bar moves over by width + bar_spacing
offsets = x - (total_width - width) / 2 + i * (width + bar_spacing)
# Extract the means and errors for this tool
tool_means = plot_data[tool].values
tool_yerr_lower = mean_vals[:, i] - lower_data.values[:, i]
tool_yerr_upper = upper_data.values[:, i] - mean_vals[:, i]
# Construct yerr array specifically for this tool
tool_yerr = np.vstack([tool_yerr_lower, tool_yerr_upper])

ax.bar(offsets, tool_means, width=width, color=tool_colors[i], label=tool, yerr=tool_yerr, capsize=2, error_kw={'elinewidth': 0.75})

ax.set_xticks(x)
ax.set_xticklabels(plot_data.index, rotation=0)
ax.set_ylabel("Value")
ax.set_title(f"Performances on {dataset} with 95% CI (Bootstrap)")

plt.tight_layout()
ax.legend(loc='lower center', bbox_to_anchor=(0.5, 0.85), ncol=3, frameon=False)

plt.savefig(filename)
plt.show()


def main():
eval_folder = "../../boltz_results_final/"

Expand All @@ -379,8 +492,9 @@ def main():
df = eval_models(chai_preds, chai_evals, af3_preds, af3_evals, boltz_preds, boltz_evals)
df.to_csv(eval_folder + "results_test.csv", index=False)

print("Test results: mean")
print(df[["tool", "metric", "oracle", "average", "top1"]].groupby(["tool", "metric"]).mean())
desired_tools = ["AF3 oracle", "AF3 top-1", "Chai-1 oracle", "Chai-1 top-1", "Boltz-1 oracle", "Boltz-1 top-1"]
desired_metrics = ["lddt", "dockq_>0.23", "lddt_pli", "rmsd<2"]
plot_data(desired_tools, desired_metrics, df, "PDB Test", eval_folder + "plot_test.png")

# Eval CASP
chai_preds = eval_folder + "outputs/casp15/chai"
Expand All @@ -395,8 +509,7 @@ def main():
df = eval_models(chai_preds, chai_evals, af3_preds, af3_evals, boltz_preds, boltz_evals)
df.to_csv(eval_folder + "results_casp.csv", index=False)

print("CASP15 results: mean")
print(df[["tool", "metric", "oracle", "average", "top1"]].groupby(["tool", "metric"]).mean())
plot_data(desired_tools, desired_metrics, df, "CASP15", eval_folder + "plot_casp.png")


if __name__ == "__main__":
Expand Down

0 comments on commit ff681f2

Please sign in to comment.