Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fraction best ligands metric #122

Open
wants to merge 1 commit into
base: classification_metric
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 120 additions & 5 deletions cinnabar/classification_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from scipy import stats
import numpy as np
import ast
import math


def _experiment_prediction_binning(experiment_dG: Iterable[float], predict_dG: Iterable[float], n_classes:int =2, best_class_fraction:float=None):
Expand Down Expand Up @@ -111,12 +112,126 @@ def acc_boots_tfunc(data):
return acc, s.standard_error


def FBVL(experiment_dG:Iterable[float], perdict_dG: Iterable[float], max_best_molecules_ratio:float=0.5)->float:
def _create_2d_histogram(y_true, y_pred):
"""
Metric inspired by the talk of Chris Bailey on alchemistry 2024
Create a 2D histogram from two arrays of data.

Parameters
----------
y_true : array-like
The true values.
y_pred : array-like
The predicted values.

Returns
-------
histogram : ndarray
The 2D histogram of the input data.
bins_true : ndarray
The bin edges along the y_true axis.
bins_pred : ndarray
The bin edges along the y_pred axis.

Raises
------
ValueError
If `y_true` and `y_pred` have different lengths.
TypeError
If `y_true` or `y_pred` cannot be converted to numpy arrays.
"""
raise NotImplementedError()

fbvl_score = 0
try:
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
except Exception as e:
raise TypeError("Input data cannot be converted to numpy arrays.") from e

if y_true.shape != y_pred.shape:
raise ValueError("y_true and y_pred must have the same length.")

y_true_sorted = np.sort(y_true)
y_pred_sorted = np.sort(y_pred)

bins_true = np.concatenate(([y_true.min()], (y_true_sorted[:-1] + y_true_sorted[1:]) / 2, [y_true.max()]))
bins_pred = np.concatenate(([y_pred.min()], (y_pred_sorted[:-1] + y_pred_sorted[1:]) / 2, [y_pred.max()]))

histogram, bins_true, bins_pred = np.histogram2d(y_true, y_pred, bins=[bins_true, bins_pred])

return histogram, bins_true, bins_pred


def _compute_overlap_coefficient(histogram, ranking):
"""
Compute the overlap coefficient from a 2D histogram.

The overlap coefficient is calculated based on the counts in the histogram
for the top N ranked ligands (most active).

Parameters
----------
histogram : ndarray
A 2D histogram array where the counts are stored.
ranking : int
The number of rankings to consider when computing overlap.

Returns
-------
float
The overlap coefficient.

Raises
------
ValueError
If `top_n_ligands` is greater than the number of ligands in the histogram.
"""
if ranking < 1:
raise ValueError("Ranking must be greater than 0.")

if histogram.shape[0] < ranking:
raise ValueError("Ranking must be less than the number of ligands.")

overlap = np.sum(histogram[:ranking, :ranking])

return overlap / ranking


def compute_fraction_best_ligands(y_true, y_pred, fraction=0.5):
"""
Compute the fraction of the best ligands metric introduced by Chris Bayly.

This function calculates the fraction of the best ligands by computing overlap
coefficients for each ranking up to the number of ligands and then averaging up to the specified fraction.

Parameters
----------
y_true : array-like
The true values.
y_pred : array-like
The predicted values.
fraction : float, optional
The fraction of ligands to consider as the best (default is 0.5).

Returns
-------
float
The computed fraction of the best ligands.

Raises
------
ValueError
If `fraction` is not between 0 and 1.
"""

if not (0 <= fraction <= 1):
raise ValueError("Fraction must be between 0 and 1.")

histogram = _create_2d_histogram(y_true, y_pred)[0]
num_ligands = histogram.shape[0]
num_best_ligands = math.floor(num_ligands * fraction)

overlap_coefficients = [_compute_overlap_coefficient(histogram, i + 1) for i in range(num_ligands)]
best_coefficients = overlap_coefficients[:num_best_ligands]

fraction_best_ligands = sum(best_coefficients) / num_best_ligands

return fbvl_score
return fraction_best_ligands