Skip to content

Commit

Permalink
Support non-normalized probability matrices in hungarian
Browse files Browse the repository at this point in the history
  • Loading branch information
luxaritas committed Jan 31, 2025
1 parent 7d0d9e7 commit d06294e
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions src/arnie/pk_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior=
bpp_orig = bpp.copy()

if add_p_unpaired:
p_unpaired = 1 - np.sum(bpp, axis=0)
p_unpaired = 1 - np.clip(np.sum(bpp, axis=0), 0, 1)
for i, punp in enumerate(p_unpaired):
bpp[i, i] = punp

Expand All @@ -137,15 +137,26 @@ def _hungarian(bpp, exp=1, sigmoid_slope_factor=None, prob_to_0_threshold_prior=
bpp = _sigmoid(bpp, slope_factor=sigmoid_slope_factor)

# should think about order of above functions and possibly normalize again here
# (normalize again we shall...)
if add_p_unpaired:
row_sums = bpp.sum(axis=1)
bpp = bpp / row_sums[:, np.newaxis]

# run hungarian algorithm to find base pairs
_, row_pairs = linear_sum_assignment(-bpp)
bp_list = []
conf = {}
for col, row in enumerate(row_pairs):
# if bpp_orig[col, row] != bpp[col, row]:
# print(col, row, bpp_orig[col, row], bpp[col, row])
if bpp_orig[col, row] > theta and col < row:
bp_list.append([col, row])
p = max(conf.get(col,0), conf.get(row,0))
if p != 0:
raise ValueError('conflicting pairs')
else:
conf[col] = 1
conf[row] = 1
bp_list.append([col, row])

structure = convert_bp_list_to_dotbracket(bp_list, bpp.shape[0])
structure = post_process_struct(structure, allowed_buldge_len, min_len_helix)
Expand Down

0 comments on commit d06294e

Please sign in to comment.