Skip to content

Commit

Permalink
feat: adding Maximum Margin Relevance Ranker (#8554)
Browse files Browse the repository at this point in the history
* initial import

* linting

* adding MRR tests

* adding release notes

* fixing tests

* adding linting ignore to cross-encoder ranker

* update docstring

* refactoring

* making strategy Optional instead of Literal

* wip: adding unit tests

* refactoring MMR algorithm

* refactoring tests

* cleaning up and updating tests

* adding empty line between license + code

* bug in tests

* using Enum for strategy and similarity metric

* adding more tests

* adding empty line between license + code

* removing run time params

* PR comments

* PR comments

* fixing

* fixing serialisation

* fixing serialisation tests

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <[email protected]>

* fixing tests

* PR comments

* PR comments

* PR comments

* PR comments

---------

Co-authored-by: Daria Fokina <[email protected]>
  • Loading branch information
davidsbatista and dfokina authored Nov 22, 2024
1 parent a8eeb20 commit b5a2fad
Show file tree
Hide file tree
Showing 4 changed files with 341 additions and 78 deletions.
210 changes: 179 additions & 31 deletions haystack/components/rankers/sentence_transformers_diversity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, List, Literal, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Union

from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
Expand All @@ -16,47 +17,116 @@
from sentence_transformers import SentenceTransformer


class DiversityRankingStrategy(Enum):
"""
The strategy to use for diversity ranking.
"""

GREEDY_DIVERSITY_ORDER = "greedy_diversity_order"
MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance"

def __str__(self) -> str:
"""
Convert a Strategy enum to a string.
"""
return self.value

@staticmethod
def from_str(string: str) -> "DiversityRankingStrategy":
"""
Convert a string to a Strategy enum.
"""
enum_map = {e.value: e for e in DiversityRankingStrategy}
strategy = enum_map.get(string)
if strategy is None:
msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}"
raise ValueError(msg)
return strategy


class DiversityRankingSimilarity(Enum):
"""
The similarity metric to use for comparing embeddings.
"""

DOT_PRODUCT = "dot_product"
COSINE = "cosine"

def __str__(self) -> str:
"""
Convert a Similarity enum to a string.
"""
return self.value

@staticmethod
def from_str(string: str) -> "DiversityRankingSimilarity":
"""
Convert a string to a Similarity enum.
"""
enum_map = {e.value: e for e in DiversityRankingSimilarity}
similarity = enum_map.get(string)
if similarity is None:
msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}"
raise ValueError(msg)
return similarity


@component
class SentenceTransformersDiversityRanker:
"""
A Diversity Ranker based on Sentence Transformers.
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
Applies a document ranking algorithm based on one of the two strategies:
1. Greedy Diversity Order:
Implements a document ranking algorithm that orders documents in a way that maximizes the overall diversity
of the documents based on their similarity to the query.
This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.
It uses a pre-trained Sentence Transformers model to embed the query and
the documents.
Usage example:
2. Maximum Margin Relevance:
Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR)
scores.
MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between
relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the
trade-off between relevance and diversity.
### Usage example
```python
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order")
ranker.warm_up()
docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
"""
""" # noqa: E501

def __init__(
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
similarity: Literal["dot_product", "cosine"] = "cosine",
similarity: Union[str, DiversityRankingSimilarity] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
"""
Initialize a SentenceTransformersDiversityRanker.
Expand All @@ -78,6 +148,10 @@ def __init__(
:param document_suffix: A string to add to the end of each Document text before ranking.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
:param strategy: The strategy to use for diversity ranking. Can be either "greedy_diversity_order" or
"maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance".
"""
torch_and_sentence_transformers_import.check()

Expand All @@ -88,15 +162,16 @@ def __init__(
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
self.similarity = similarity
self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy)

def warm_up(self):
"""
Expand All @@ -119,16 +194,18 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model=self.model_name_or_path,
top_k=self.top_k,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
similarity=str(self.similarity),
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
query_suffix=self.query_suffix,
document_prefix=self.document_prefix,
document_suffix=self.document_suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
strategy=str(self.strategy),
lambda_threshold=self.lambda_threshold,
)

@classmethod
Expand Down Expand Up @@ -182,14 +259,7 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List
"""
texts_to_embed = self._prepare_texts_to_embed(documents)

# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)

n = len(documents)
selected: List[int] = []
Expand Down Expand Up @@ -218,14 +288,84 @@ def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List

return ranked_docs

def _embed_and_normalize(self, query, texts_to_embed):
# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]

# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == DiversityRankingSimilarity.COSINE:
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
return doc_embeddings, query_embedding

def _maximum_margin_relevance(
self, query: str, documents: List[Document], lambda_threshold: float, top_k: int
) -> List[Document]:
"""
Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores.
MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents.
The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query
and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance
and diversity.
A closer value to 0 favors diversity, while a closer value to 1 favors relevance to the query.
See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries"
https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
"""

texts_to_embed = self._prepare_texts_to_embed(documents)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
top_k = top_k if top_k else len(documents)

selected: List[int] = []
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
query_similarities = query_similarities_as_tensor.reshape(-1)
idx = int(torch.argmax(query_similarities))
selected.append(idx)
while len(selected) < top_k:
best_idx = None
best_score = -float("inf")
for idx, _ in enumerate(documents):
if idx in selected:
continue
relevance_score = query_similarities[idx]
diversity_score = max(doc_embeddings[idx] @ doc_embeddings[j].T for j in selected)
mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score
if mmr_score > best_score:
best_score = mmr_score
best_idx = idx
if best_idx is None:
raise ValueError("No best document found, check if the documents list contains any documents.")
selected.append(best_idx)

return [documents[i] for i in selected]

@staticmethod
def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy):
if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")

@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
lambda_threshold: Optional[float] = None,
) -> Dict[str, List[Document]]:
"""
Rank the documents based on their diversity.
:param query: The search query.
:param documents: List of Document objects to be ranker.
:param top_k: Optional. An integer to override the top_k set during initialization.
:param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when
strategy is "maximum_margin_relevance".
:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the diversity ranking.
Expand All @@ -245,9 +385,17 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None

if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")

diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
elif not 0 < top_k <= len(documents):
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")

if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
if lambda_threshold is None:
lambda_threshold = self.lambda_threshold
self._check_lambda_threshold(lambda_threshold, self.strategy)
re_ranked_docs = self._maximum_margin_relevance(
query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k
)
else:
re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents)

return {"documents": diversity_sorted[:top_k]}
return {"documents": re_ranked_docs[:top_k]}
4 changes: 2 additions & 2 deletions haystack/components/rankers/transformers_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TransformersSimilarityRanker:
```
"""

def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None,
Expand Down Expand Up @@ -201,7 +201,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "TransformersSimilarityRanker":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
documents: List[Document],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Added the Maximum Margin Relevance (MMR) strategy to the `SentenceTransformersDiversityRanker`. MMR scores are calculated for each document based on their relevance to the query and diversity from already selected documents.
Loading

0 comments on commit b5a2fad

Please sign in to comment.