diff --git a/autointent/modules/__init__.py b/autointent/modules/__init__.py index 754fb96f9..109171295 100644 --- a/autointent/modules/__init__.py +++ b/autointent/modules/__init__.py @@ -11,7 +11,7 @@ TunablePredictor, ) from .retrieval import RetrievalModule, VectorDBModule -from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, ScoringModule +from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, ScoringModule T = TypeVar("T", bound=Module) @@ -25,7 +25,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS SCORING_MODULES_MULTICLASS: dict[str, type[ScoringModule]] = create_modules_dict( - [DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer], + [DNNCScorer, KNNScorer, LinearScorer, DescriptionScorer, RerankScorer] ) SCORING_MODULES_MULTILABEL: dict[str, type[ScoringModule]] = create_modules_dict( @@ -58,6 +58,7 @@ def create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]: "Module", "PredictionModule", "RegExp", + "RerankScorer", "RetrievalModule", "ScoringModule", "ThresholdPredictor", diff --git a/autointent/modules/scoring/__init__.py b/autointent/modules/scoring/__init__.py index 6ec2773fc..cbfabb5c8 100644 --- a/autointent/modules/scoring/__init__.py +++ b/autointent/modules/scoring/__init__.py @@ -1,8 +1,16 @@ from ._base import ScoringModule from ._description import DescriptionScorer from ._dnnc import DNNCScorer -from ._knn import KNNScorer +from ._knn import KNNScorer, RerankScorer from ._linear import LinearScorer from ._mlknn import MLKnnScorer -__all__ = ["DNNCScorer", "DescriptionScorer", "KNNScorer", "LinearScorer", "MLKnnScorer", "ScoringModule"] +__all__ = [ + "DNNCScorer", + "DescriptionScorer", + "KNNScorer", + "LinearScorer", + "MLKnnScorer", + "RerankScorer", + "ScoringModule", +] diff --git a/autointent/modules/scoring/_knn/__init__.py b/autointent/modules/scoring/_knn/__init__.py index 8db77af08..2da4dc8ee 100644 --- a/autointent/modules/scoring/_knn/__init__.py +++ b/autointent/modules/scoring/_knn/__init__.py @@ -1,3 +1,4 @@ from .knn import KNNScorer +from .rerank_scorer import RerankScorer -__all__ = ["KNNScorer"] +__all__ = ["KNNScorer", "RerankScorer"] diff --git a/autointent/modules/scoring/_knn/knn.py b/autointent/modules/scoring/_knn/knn.py index 8d1ea3fac..3e2babb26 100644 --- a/autointent/modules/scoring/_knn/knn.py +++ b/autointent/modules/scoring/_knn/knn.py @@ -51,6 +51,7 @@ class KNNScorer(ScoringModule): _vector_index: VectorIndex name = "knn" prebuilt_index: bool = False + max_length: int | None def __init__( self, @@ -193,13 +194,7 @@ def dump(self, path: str) -> None: :param path: Path to the directory where assets will be dumped. """ - self.metadata = KNNScorerDumpMetadata( - db_dir=self.db_dir, - n_classes=self.n_classes, - multilabel=self.multilabel, - batch_size=self.batch_size, - max_length=self.max_length, - ) + self.metadata = self._store_state_to_metadata() dump_dir = Path(path) @@ -208,6 +203,15 @@ def dump(self, path: str) -> None: self._vector_index.dump(dump_dir) + def _store_state_to_metadata(self) -> KNNScorerDumpMetadata: + return KNNScorerDumpMetadata( + db_dir=self.db_dir, + n_classes=self.n_classes, + multilabel=self.multilabel, + batch_size=self.batch_size, + max_length=self.max_length, + ) + def load(self, path: str) -> None: """ Load the KNNScorer's metadata and vector index from disk. @@ -219,17 +223,28 @@ def load(self, path: str) -> None: with (dump_dir / self.metadata_dict_name).open() as file: self.metadata: KNNScorerDumpMetadata = json.load(file) - self.n_classes = self.metadata["n_classes"] - self.multilabel = self.metadata["multilabel"] + self._restore_state_from_metadata(self.metadata) + + def _restore_state_from_metadata(self, metadata: KNNScorerDumpMetadata) -> None: + self.n_classes = metadata["n_classes"] + self.multilabel = metadata["multilabel"] vector_index_client = VectorIndexClient( device=self.device, - db_dir=self.metadata["db_dir"], - embedder_batch_size=self.metadata["batch_size"], - embedder_max_length=self.metadata["max_length"], + db_dir=metadata["db_dir"], + embedder_batch_size=metadata["batch_size"], + embedder_max_length=metadata["max_length"], ) self._vector_index = vector_index_client.get_index(self.embedder_name) + def _get_neighbours( + self, utterances: list[str] + ) -> tuple[list[list[LabelType]], list[list[float]], list[list[str]]]: + return self._vector_index.query(utterances, self.k) + + def _count_scores(self, labels: npt.NDArray[Any], distances: npt.NDArray[Any]) -> npt.NDArray[Any]: + return apply_weights(labels, distances, self.weights, self.n_classes, self.multilabel) + def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]: """ Predict class probabilities and retrieve neighbors for the given utterances. @@ -237,6 +252,6 @@ def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[s :param utterances: List of query utterances. :return: Tuple containing class probabilities and neighbor utterances. """ - labels, distances, neighbors = self._vector_index.query(utterances, self.k) - scores = apply_weights(np.array(labels), np.array(distances), self.weights, self.n_classes, self.multilabel) + labels, distances, neighbors = self._get_neighbours(utterances) + scores = self._count_scores(np.array(labels), np.array(distances)) return scores, neighbors diff --git a/autointent/modules/scoring/_knn/rerank_scorer.py b/autointent/modules/scoring/_knn/rerank_scorer.py new file mode 100644 index 000000000..98cdd847f --- /dev/null +++ b/autointent/modules/scoring/_knn/rerank_scorer.py @@ -0,0 +1,211 @@ +"""RerankScorer class for re-ranking based on cross-encoder scoring.""" + +import json +from pathlib import Path +from typing import Any + +import numpy as np +import numpy.typing as npt +from sentence_transformers import CrossEncoder +from torch.nn import Sigmoid +from typing_extensions import Self + +from autointent.context import Context +from autointent.custom_types import WEIGHT_TYPES, LabelType + +from .knn import KNNScorer, KNNScorerDumpMetadata + + +class RerankScorerDumpMetadata(KNNScorerDumpMetadata): + """ + Metadata for dumping the state of a RerankScorer. + + :ivar cross_encoder_name: Name of the cross-encoder model used. + :ivar m: Number of top-ranked neighbors to consider, or None to use k. + :ivar rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None. + """ + + cross_encoder_name: str + m: int | None + rank_threshold_cutoff: int | None + + +class RerankScorer(KNNScorer): + """ + Re-ranking scorer using a cross-encoder for intent classification. + + This module uses a cross-encoder to re-rank the nearest neighbors retrieved by a KNN scorer. + + :ivar name: Name of the scorer, defaults to "rerank". + :ivar _scorer: CrossEncoder instance for re-ranking. + """ + + name = "rerank" + _scorer: CrossEncoder + + def __init__( + self, + embedder_name: str, + k: int, + weights: WEIGHT_TYPES, + cross_encoder_name: str, + m: int | None = None, + rank_threshold_cutoff: int | None = None, + db_dir: str | None = None, + device: str = "cpu", + batch_size: int = 32, + max_length: int | None = None, + ) -> None: + """ + Initialize the RerankScorer. + + :param embedder_name: Name of the embedder used for vectorization. + :param k: Number of closest neighbors to consider during inference. + :param weights: Weighting strategy: + - "uniform" (or False): Equal weight for all neighbors. + - "distance" (or True): Weight inversely proportional to distance. + - "closest": Only the closest neighbor of each class is weighted. + :param cross_encoder_name: Name of the cross-encoder model used for re-ranking. + :param m: Number of top-ranked neighbors to consider, or None to use k. + :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None. + :param db_dir: Path to the database directory, or None to use default. + :param device: Device to run operations on, e.g., "cpu" or "cuda". + :param batch_size: Batch size for embedding generation, defaults to 32. + :param max_length: Maximum sequence length for embedding, or None for default. + """ + super().__init__( + embedder_name=embedder_name, + k=k, + weights=weights, + db_dir=db_dir, + device=device, + batch_size=batch_size, + max_length=max_length, + ) + + self.cross_encoder_name = cross_encoder_name + self.m = k if m is None else m + self.rank_threshold_cutoff = rank_threshold_cutoff + + @classmethod + def from_context( + cls, + context: Context, + k: int, + weights: WEIGHT_TYPES, + cross_encoder_name: str, + embedder_name: str | None = None, + m: int | None = None, + rank_threshold_cutoff: int | None = None, + ) -> Self: + """ + Create a RerankScorer instance from a given context. + + :param context: Context object containing optimization information and vector index client. + :param k: Number of closest neighbors to consider during inference. + :param weights: Weighting strategy. + :param cross_encoder_name: Name of the cross-encoder model used for re-ranking. + :param embedder_name: Name of the embedder used for vectorization, or None to use the best existing embedder. + :param m: Number of top-ranked neighbors to consider, or None to use k. + :param rank_threshold_cutoff: Rank threshold cutoff for re-ranking, or None. + :return: An instance of RerankScorer. + """ + if embedder_name is None: + embedder_name = context.optimization_info.get_best_embedder() + prebuilt_index = True + else: + prebuilt_index = context.vector_index_client.exists(embedder_name) + + instance = cls( + embedder_name=embedder_name, + k=k, + weights=weights, + cross_encoder_name=cross_encoder_name, + m=m, + rank_threshold_cutoff=rank_threshold_cutoff, + db_dir=str(context.get_db_dir()), + device=context.get_device(), + batch_size=context.get_batch_size(), + max_length=context.get_max_length(), + ) + # TODO: needs re-thinking.... + instance.prebuilt_index = prebuilt_index + return instance + + def fit(self, utterances: list[str], labels: list[LabelType]) -> None: + """ + Fit the RerankScorer with utterances and labels. + + :param utterances: List of utterances to fit the scorer. + :param labels: List of labels corresponding to the utterances. + """ + self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type] + + super().fit(utterances, labels) + + def _store_state_to_metadata(self) -> RerankScorerDumpMetadata: + """ + Store the current state of the RerankScorer to metadata. + + :return: Metadata containing the current state of the RerankScorer. + """ + return RerankScorerDumpMetadata( + **super()._store_state_to_metadata(), + m=self.m, + cross_encoder_name=self.cross_encoder_name, + rank_threshold_cutoff=self.rank_threshold_cutoff, + ) + + def load(self, path: str) -> None: + """ + Load the RerankScorer from a given path. + + :param path: Path to the directory containing the dumped metadata. + """ + dump_dir = Path(path) + + with (dump_dir / self.metadata_dict_name).open() as file: + self.metadata: RerankScorerDumpMetadata = json.load(file) + + self._restore_state_from_metadata(self.metadata) + + def _restore_state_from_metadata(self, metadata: RerankScorerDumpMetadata) -> None: + """ + Restore the state of the RerankScorer from metadata. + + :param metadata: Metadata containing the state of the RerankScorer. + """ + super()._restore_state_from_metadata(metadata) + + self.m = metadata["m"] if metadata["m"] else self.k + self.cross_encoder_name = metadata["cross_encoder_name"] + self.rank_threshold_cutoff = metadata["rank_threshold_cutoff"] + self._scorer = CrossEncoder(self.cross_encoder_name, device=self.device, max_length=self.max_length) # type: ignore[arg-type] + + def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]: + """ + Predict the scores and neighbors for given utterances. + + :param utterances: List of utterances to predict scores for. + :return: A tuple containing the scores and neighbors. + """ + knn_labels, knn_distances, knn_neighbors = self._get_neighbours(utterances) + + labels: list[list[LabelType]] = [] + distances: list[list[float]] = [] + neighbours: list[list[str]] = [] + + for query, query_labels, query_distances, query_docs in zip( + utterances, knn_labels, knn_distances, knn_neighbors, strict=True + ): + cur_ranks = self._scorer.rank( + query, query_docs, top_k=self.m, batch_size=self.batch_size, activation_fct=Sigmoid() + ) + + for dst, src in zip( + [labels, distances, neighbours], [query_labels, query_distances, query_docs], strict=True + ): + dst.append([src[rank["corpus_id"]] for rank in cur_ranks]) # type: ignore[attr-defined, index] + + scores = self._count_scores(np.array(labels), np.array(distances)) + return scores, neighbours diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index 84dc75396..f70332033 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -20,6 +20,12 @@ nodes: - avsolatorio/GIST-small-Embedding-v0 k: [1, 3] train_head: [false, true] + - module_type: rerank + k: [ 5, 10 ] + weights: [uniform, distance, closest] + m: [ 2, 3 ] + cross_encoder_name: + - cross-encoder/ms-marco-MiniLM-L-6-v2 - node_type: prediction metric: prediction_accuracy search_space: diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index 59937073c..807d1d099 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -16,6 +16,12 @@ nodes: - module_type: linear - module_type: mlknn k: [5] + - module_type: rerank + k: [ 5, 10 ] + weights: [ uniform, distance, closest ] + m: [ 2, 3 ] + cross_encoder_name: + - cross-encoder/ms-marco-MiniLM-L-6-v2 - node_type: prediction metric: prediction_accuracy search_space: diff --git a/tests/modules/scoring/test_rerank_scorer.py b/tests/modules/scoring/test_rerank_scorer.py new file mode 100644 index 000000000..8a2ff8a7c --- /dev/null +++ b/tests/modules/scoring/test_rerank_scorer.py @@ -0,0 +1,39 @@ +import numpy as np + +from autointent.context.data_handler import DataHandler +from autointent.modules import RerankScorer +from tests.conftest import setup_environment + + +def test_base_rerank_scorer(dataset): + db_dir, dump_dir, logs_dir = setup_environment() + + data_handler = DataHandler(dataset) + + scorer = RerankScorer( + k=3, + weights="distance", + embedder_name="sergeyzh/rubert-tiny-turbo", + m=2, + cross_encoder_name="cross-encoder/ms-marco-MiniLM-L-6-v2", + db_dir=db_dir, + device="cpu", + ) + + test_data = [ + "why is there a hold on my american saving bank account", + "i am nost sure why my account is blocked", + "why is there a hold on my capital one checking account", + "i think my account is blocked but i do not know the reason", + "can you tell me why is my bank account frozen", + ] + + scorer.fit(data_handler.train_utterances, data_handler.train_labels) + predictions = scorer.predict(test_data) + assert ( + predictions == np.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) + ).all() + + predictions, metadata = scorer.predict_with_metadata(test_data) + assert len(predictions) == len(test_data) + assert "neighbors" in metadata[0] diff --git a/tests/nodes/test_scoring.py b/tests/nodes/test_scoring.py index db5a1b15d..5478a81dd 100644 --- a/tests/nodes/test_scoring.py +++ b/tests/nodes/test_scoring.py @@ -42,6 +42,14 @@ def test_scoring_multiclass(retrieval_optimizer_multiclass): "temperature": [1.0, 0.5, 0.1, 0.05], "embedder_name": ["sergeyzh/rubert-tiny-turbo"], }, + { + "module_type": "rerank", + "weights": ["uniform", "distance", "closest"], + "k": [3], + "m": [2], + "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], + "embedder_name": ["sergeyzh/rubert-tiny-turbo"], + }, ], } @@ -82,6 +90,14 @@ def test_scoring_multilabel(retrieval_optimizer_multilabel): "embedder_name": ["sergeyzh/rubert-tiny-turbo"], }, {"module_type": "mlknn", "k": [5], "embedder_name": ["sergeyzh/rubert-tiny-turbo"]}, + { + "module_type": "rerank", + "weights": ["uniform", "distance", "closest"], + "k": [3], + "m": [2], + "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], + "embedder_name": ["sergeyzh/rubert-tiny-turbo"], + }, ], }