Skip to content

Commit

Permalink
dev(narugo): add algo
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 4, 2024
1 parent d3036ad commit 4b70ef7
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion sdeval/fidelity/ccip.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(self, images: ImagesTyping, feats: Optional[np.ndarray] = None, mod
self._features = list(feats)

def score(self, images: ImagesTyping, silent: bool = None,
algo: Literal['same', 'diff'] = 'same',
mode: Literal['mean', 'seq'] = 'mean') -> Union[float, np.ndarray]:
"""
Calculate the similarity score between the reference dataset and a set of input images.
Expand All @@ -70,6 +71,9 @@ def score(self, images: ImagesTyping, silent: bool = None,
:type images: ImagesTyping
:param silent: If True, suppresses progress bars and additional output during calculation.
:type silent: bool
:param algo: Algorithm of the return value. Return float value represent same-or-not ratio
when using ``same``, return mean difference when using ``diff``. Default is ``same``.
:type algo: Literal['same', 'diff']
:param mode: Mode of the return value. Return a float value when ``mean`` is assigned,
return a numpy array when ``seq`` is assigned. Default is ``mean``.
:type mode: Literal['mean', 'seq']
Expand All @@ -89,7 +93,10 @@ def score(self, images: ImagesTyping, silent: bool = None,

diffs = ccip_batch_differences([*self._features, *_features])
matrix = diffs[:len(self._features), len(self._features):]
seq = (matrix < self._threshold).mean(axis=0)
if algo == 'same':
seq = (matrix < self._threshold).mean(axis=0)
else:
seq = matrix.mean(axis=0)
assert seq.shape == (len(_features),)

if mode == 'seq':
Expand Down

0 comments on commit 4b70ef7

Please sign in to comment.