Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RerankSorer implementation based on KNN #50

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions autointent/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from .regexp import RegExp
from .retrieval import RetrievalModule, VectorDBModule
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, ScoringModule

T = TypeVar("T", bound=Module)

Expand All @@ -25,7 +25,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS

SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict(
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer]
[DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer, RerankScorer]
)

SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict(
Expand Down Expand Up @@ -58,6 +58,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
"Module",
"PredictionModule",
"RegExp",
"RerankScorer",
"RetrievalModule",
"ScoringModule",
"ThresholdPredictor",
Expand Down
12 changes: 10 additions & 2 deletions autointent/modules/scoring/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
from .base import ScoringModule
from .description import DescriptionScorer
from .dnnc import DNNCScorer
from .knn import KNNScorer
from .knn import KNNScorer, RerankScorer
from .linear import LinearScorer
from .mlknn import MLKnnScorer

__all__ = ["DNNCScorer", "DescriptionScorer", "KNNScorer", "LinearScorer", "MLKnnScorer", "ScoringModule"]
__all__ = [
"DNNCScorer",
"DescriptionScorer",
"KNNScorer",
"LinearScorer",
"MLKnnScorer",
"RerankScorer",
"ScoringModule",
]
3 changes: 2 additions & 1 deletion autointent/modules/scoring/knn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .knn import KNNScorer
from .rerank_scorer import RerankScorer

__all__ = ["KNNScorer"]
__all__ = ["KNNScorer", "RerankScorer"]
43 changes: 29 additions & 14 deletions autointent/modules/scoring/knn/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class KNNScorer(ScoringModule):
_vector_index: VectorIndex
name = "knn"
prebuilt_index: bool = False
max_length: int | None

def __init__(
self,
Expand Down Expand Up @@ -194,13 +195,7 @@ def dump(self, path: str) -> None:

:param path: Path to the directory where assets will be dumped.
"""
self.metadata = KNNScorerDumpMetadata(
db_dir=self.db_dir,
n_classes=self.n_classes,
multilabel=self.multilabel,
batch_size=self.batch_size,
max_length=self.max_length,
)
self.metadata = self._store_state_to_metadata()

dump_dir = Path(path)

Expand All @@ -209,6 +204,15 @@ def dump(self, path: str) -> None:

self._vector_index.dump(dump_dir)

def _store_state_to_metadata(self) -> KNNScorerDumpMetadata:
return KNNScorerDumpMetadata(
db_dir=self.db_dir,
n_classes=self.n_classes,
multilabel=self.multilabel,
batch_size=self.batch_size,
max_length=self.max_length,
)

def load(self, path: str) -> None:
"""
Load the KNNScorer's metadata and vector index from disk.
Expand All @@ -220,24 +224,35 @@ def load(self, path: str) -> None:
with (dump_dir / self.metadata_dict_name).open() as file:
self.metadata: KNNScorerDumpMetadata = json.load(file)

self.n_classes = self.metadata["n_classes"]
self.multilabel = self.metadata["multilabel"]
self._restore_state_from_metadata(self.metadata)

def _restore_state_from_metadata(self, metadata: KNNScorerDumpMetadata) -> None:
self.n_classes = metadata["n_classes"]
self.multilabel = metadata["multilabel"]

vector_index_client = VectorIndexClient(
device=self.device,
db_dir=self.metadata["db_dir"],
embedder_batch_size=self.metadata["batch_size"],
embedder_max_length=self.metadata["max_length"],
db_dir=metadata["db_dir"],
embedder_batch_size=metadata["batch_size"],
embedder_max_length=metadata["max_length"],
)
self._vector_index = vector_index_client.get_index(self.embedder_name)

def _get_neighbours(
self, utterances: list[str]
) -> tuple[list[list[LabelType]], list[list[float]], list[list[str]]]:
return self._vector_index.query(utterances, self.k)

def _count_scores(self, labels: npt.NDArray[Any], distances: npt.NDArray[Any]) -> npt.NDArray[Any]:
return apply_weights(labels, distances, self.weights, self.n_classes, self.multilabel)

def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
"""
Predict class probabilities and retrieve neighbors for the given utterances.

:param utterances: List of query utterances.
:return: Tuple containing class probabilities and neighbor utterances.
"""
labels, distances, neighbors = self._vector_index.query(utterances, self.k)
scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel)
labels, distances, neighbors = self._get_neighbours(utterances)
scores = self._count_scores(np.array(labels), np.array(distances))
return scores, neighbors
137 changes: 137 additions & 0 deletions autointent/modules/scoring/knn/rerank_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import json
from pathlib import Path
from typing import Any

import numpy as np
import numpy.typing as npt
from sentence_transformers import CrossEncoder
from torch.nn import Sigmoid
from typing_extensions import Self

from autointent.context import Context
from autointent.custom_types import WEIGHT_TYPES, LabelType

from .knn import KNNScorer, KNNScorerDumpMetadata


class RerankScorerDumpMetadata(KNNScorerDumpMetadata):
scorer_name: str
m: int | None
rank_threshold_cutoff: int | None


class RerankScorer(KNNScorer):
name = "rerank_scorer"
_scorer: CrossEncoder

def __init__(
self,
embedder_name: str,
k: int,
weights: WEIGHT_TYPES,
scorer_name: str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

я бы дал название этому аргументу cross_encoder_name

m: int | None = None,
rank_threshold_cutoff: int | None = None,
db_dir: str | None = None,
device: str = "cpu",
batch_size: int = 32,
max_length: int | None = None,
) -> None:
super().__init__(
embedder_name=embedder_name,
k=k,
weights=weights,
db_dir=db_dir,
device=device,
batch_size=batch_size,
max_length=max_length,
)

self.scorer_name = scorer_name
self.m = k if m is None else m
self.rank_threshold_cutoff = rank_threshold_cutoff
self._scorer = CrossEncoder(self.scorer_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

это надо переместить в fit(), слишком тяжелая операция

у нас философия такая же как sklearn: в конструкторе мы только сохраняем и валидируем аургменты конструктора

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

в связи с этим надо будет реализовать загрузку кросс энкодера в методе load()


@classmethod
def from_context(
cls,
context: Context,
k: int,
weights: WEIGHT_TYPES,
scorer_name: str,
embedder_name: str | None = None,
m: int | None = None,
rank_threshold_cutoff: int | None = None,
) -> Self:
if embedder_name is None:
embedder_name = context.optimization_info.get_best_embedder()
prebuilt_index = True
else:
prebuilt_index = context.vector_index_client.exists(embedder_name)

instance = cls(
embedder_name=embedder_name,
k=k,
weights=weights,
scorer_name=scorer_name,
m=m,
rank_threshold_cutoff=rank_threshold_cutoff,
db_dir=str(context.get_db_dir()),
device=context.get_device(),
batch_size=context.get_batch_size(),
max_length=context.get_max_length(),
)
# TODO: needs re-thinking....
instance.prebuilt_index = prebuilt_index
return instance

def _store_state_to_metadata(self) -> RerankScorerDumpMetadata:
return RerankScorerDumpMetadata(
**super()._store_state_to_metadata(),
m=self.m,
scorer_name=self.scorer_name,
rank_threshold_cutoff=self.rank_threshold_cutoff,
)

def load(self, path: str) -> None:
dump_dir = Path(path)

with (dump_dir / self.metadata_dict_name).open() as file:
self.metadata: RerankScorerDumpMetadata = json.load(file)

self._restore_state_from_metadata(self.metadata)

def _restore_state_from_metadata(self, metadata: RerankScorerDumpMetadata) -> None:
super()._restore_state_from_metadata(metadata)

self.m = metadata["m"] if metadata["m"] else self.k
self.scorer_name = metadata["scorer_name"]
self.rank_threshold_cutoff = metadata["rank_threshold_cutoff"]

def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
knn_labels, knn_distances, knn_neighbors = self._get_neighbours(utterances)

labels: list[list[LabelType]] = []
distances: list[list[float]] = []
neighbours: list[list[str]] = []

for query, query_labels, query_distances, query_docs in zip(
utterances, knn_labels, knn_distances, knn_neighbors, strict=True
):
cur_ranks = self._scorer.rank(
query, query_docs, top_k=self.m, batch_size=self.batch_size, activation_fct=Sigmoid()
)
# if self.rank_threshold_cutoff:
# # remove neighbours where CrossEncoder is not confident enough
# while len(cur_ranks):
# if cur_ranks[-1]['score'] >= self.rank_threshold_cutoff:
# break
# cur_ranks.pop()

# keep only relevant data for the utterance
for dst, src in zip([labels, distances, neighbours], [query_labels, query_distances, query_docs],
strict=True):
dst.append([src[rank["corpus_id"]] for rank in cur_ranks]) # type: ignore[attr-defined, index]

scores = self._count_scores(np.array(labels), np.array(distances))
return scores, neighbours
39 changes: 39 additions & 0 deletions tests/modules/scoring/test_rerank_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np

from autointent.context.data_handler import DataHandler
from autointent.modules import RerankScorer
from tests.conftest import setup_environment


def test_base_rerank_scorer(dataset):
db_dir, dump_dir, logs_dir = setup_environment()

data_handler = DataHandler(dataset)

scorer = RerankScorer(
k=3,
weights="distance",
embedder_name="sergeyzh/rubert-tiny-turbo",
m=2,
scorer_name="cross-encoder/ms-marco-MiniLM-L-6-v2",
db_dir=db_dir,
device="cpu",
)

test_data = [
"why is there a hold on my american saving bank account",
"i am nost sure why my account is blocked",
"why is there a hold on my capital one checking account",
"i think my account is blocked but i do not know the reason",
"can you tell me why is my bank account frozen",
]

scorer.fit(data_handler.utterances_train, data_handler.labels_train)
predictions = scorer.predict(test_data)
assert (
predictions == np.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]])
).all()

predictions, metadata = scorer.predict_with_metadata(test_data)
assert len(predictions) == len(test_data)
assert "neighbors" in metadata[0]