From d06294ebb38524c45219102f0f54ecfa50fe7ed9 Mon Sep 17 00:00:00 2001 From: Jonathan Romano Date: Fri, 31 Jan 2025 18:40:55 -0500 Subject: [PATCH] Support non-normalized probability matrices in hungarian --- src/arnie/pk_predictors.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/arnie/pk_predictors.py b/src/arnie/pk_predictors.py index ed1a9db..3b8f52b 100644 --- a/src/arnie/pk_predictors.py +++ b/src/arnie/pk_predictors.py @@ -112,7 +112,7 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior= bpp_orig = bpp.copy() if add_p_unpaired: - p_unpaired = 1 - np.sum(bpp, axis=0) + p_unpaired = 1 - np.clip(np.sum(bpp, axis=0), 0, 1) for i, punp in enumerate(p_unpaired): bpp[i, i] = punp @@ -137,15 +137,26 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior= bpp = _sigmoid(bpp, slope_factor=sigmoid_slope_factor) # should think about order of above functions and possibly normalize again here + # (normalize again we shall...) + if add_p_unpaired: + row_sums = bpp.sum(axis=1) + bpp = bpp / row_sums[:, np.newaxis] # run hungarian algorithm to find base pairs _, row_pairs = linear_sum_assignment(-bpp) bp_list = [] + conf = {} for col, row in enumerate(row_pairs): # if bpp_orig[col, row] != bpp[col, row]: # print(col, row, bpp_orig[col, row], bpp[col, row]) if bpp_orig[col, row] > theta and col < row: - bp_list.append([col, row]) + p = max(conf.get(col,0), conf.get(row,0)) + if p != 0: + raise ValueError('conflicting pairs') + else: + conf[col] = 1 + conf[row] = 1 + bp_list.append([col, row]) structure = convert_bp_list_to_dotbracket(bp_list, bpp.shape[0]) structure = post_process_struct(structure, allowed_buldge_len, min_len_helix)