diff --git a/tests/unit/test_evaluator.py b/tests/unit/test_evaluator.py index 198935d4..e1e90960 100644 --- a/tests/unit/test_evaluator.py +++ b/tests/unit/test_evaluator.py @@ -10,14 +10,33 @@ class MockVespaResponse: hits: List[Dict[str, Any]] + def add_namespace_to_hit_ids(self, hits) -> str: + new_hits = [] + for hit in hits: + hit["id"] = f"id:mynamespace:mydoctype::{hit['id']}" + new_hits.append(hit) + return new_hits + def get_json(self): - return {"root": {"children": self.hits}} + return {"root": {"children": self.add_namespace_to_hit_ids(self.hits)}} @property def status_code(self): return 200 +class QueryBodyCapturingApp: + """Mock Vespa app that captures query bodies passed to query_many.""" + + def __init__(self, responses): + self.responses = responses + self.captured_query_bodies = None + + def query_many(self, query_bodies): + self.captured_query_bodies = query_bodies + return self.responses + + class TestVespaEvaluator(unittest.TestCase): def setUp(self): # Sample queries @@ -40,6 +59,12 @@ def setUp(self): "q3": "doc6", } + self.relevant_docs_relevance = { + "q1": {"doc1": 1.0, "doc2": 0.5, "doc3": 0.2}, + "q2": {"doc4": 0.8, "doc5": 0.6}, + "q3": {"doc6": 1.0}, + } + # Mock Vespa responses # For q1: doc1 at rank 1, doc2 at rank 3, doc3 at rank 5 q1_response = MockVespaResponse( @@ -121,6 +146,16 @@ def test_init_single_relevant_docs(self): } self.assertEqual(evaluator.relevant_docs, relevant_docs_to_set) + def test_init_relevant_docs_with_relevance(self): + """Test initialization with relevant docs having relevance scores""" + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs_relevance, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertEqual(evaluator.relevant_docs, self.relevant_docs_relevance) + def test_custom_k_values(self): """Test initialization with custom k values""" evaluator = VespaEvaluator( @@ -248,6 +283,96 @@ def test_map_metric(self): expected_map = 0.7519 # Approximate value self.assertAlmostEqual(results["map@5"], expected_map, places=4) + def test_graded_ndcg_metric(self): + """Test graded NDCG@k calculations""" + queries = {"535": "06 bmw 325i radio oem not navigation system"} + relevant_docs = { + "535": { + "B08VSJGP1N": 0.01, + "B08VJ66CNL": 0.01, + "B08SHMLP5S": 0.0, + "B08QGZMCYQ": 0.0, + "B08PB9TTKT": 1.0, + "B08NVQ8MZX": 0.01, + "B084TV3C1B": 0.01, + "B0742BZXC2": 1.0, + "B00DHUA9VA": 0.0, + "B00B4PJC9K": 0.0, + "B0072LFB68": 0.01, + "B0051GN8JI": 0.01, + "B000J1HDWI": 0.0, + "B0007KPS3C": 0.0, + "B01M0SFMIH": 1.0, + "B0007KPRIS": 0.0, + } + } + # B08PB9TTKT 1 0.463 + # B00B4PJC9K 2 0.431 + # B0051GN8JI 3 0.419 + # B084TV3C1B 4 0.417 + # B08NVQ8MZX 5 0.41 + # B00DHUA9VA 6 0.415 + # B08SHMLP5S 7 0.415 + # B08VSJGP1N 8 0.41 + # B08QGZMCYQ 9 0.411 + # B0007KPRIS 10 0.40 + # B08VJ66CNL 11 0.40 + # B000J1HDWI 12 0.40 + # B0007KPS3C 13 0.39 + # B0072LFB68 14 0.39 + # B01M0SFMIH 15 0.39 + # B0742BZXC2 16 0.37 + + # Mock Vespa responses - must match doc_ids in relevant_docs + q1_response = MockVespaResponse( + [ + {"id": "B08PB9TTKT", "relevance": 0.463}, + {"id": "B00B4PJC9K", "relevance": 0.431}, + {"id": "B0051GN8JI", "relevance": 0.419}, + {"id": "B084TV3C1B", "relevance": 0.417}, + {"id": "B08NVQ8MZX", "relevance": 0.41}, + {"id": "B00DHUA9VA", "relevance": 0.415}, + {"id": "B08SHMLP5S", "relevance": 0.415}, + {"id": "B08VSJGP1N", "relevance": 0.41}, + {"id": "B08QGZMCYQ", "relevance": 0.411}, + {"id": "B0007KPRIS", "relevance": 0.40}, + {"id": "B08VJ66CNL", "relevance": 0.40}, + {"id": "B000J1HDWI", "relevance": 0.40}, + {"id": "B0007KPS3C", "relevance": 0.39}, + {"id": "B0072LFB68", "relevance": 0.39}, + {"id": "B01M0SFMIH", "relevance": 0.39}, + {"id": "B0742BZXC2", "relevance": 0.37}, + ] + ) + + class MockVespaApp: + def __init__(self, mock_responses): + self.mock_responses = mock_responses + self.current_query = 0 + + def query_many(self, queries): + return self.mock_responses + + mock_app = MockVespaApp([q1_response]) + + def mock_vespa_query_fn(query_text: str, top_k: int) -> dict: + return { + "yql": f'select * from sources * where text contains "{query_text}";', + "hits": top_k, + } + + evaluator = VespaEvaluator( + queries=queries, + relevant_docs=relevant_docs, + vespa_query_fn=mock_vespa_query_fn, + app=mock_app, + ndcg_at_k=[16], + ) + + results = evaluator.run() + print(results) + self.assertAlmostEqual(results["ndcg@16"], 0.7046, places=4) + def test_vespa_query_fn_validation(self): """Test validation of vespa_query_fn with valid functions""" @@ -277,7 +402,7 @@ def test_vespa_query_fn_validation_errors(self): """Test validation of vespa_query_fn with invalid functions""" # Not a callable - with self.assertRaisesRegex(ValueError, "must be a callable"): + with self.assertRaisesRegex(ValueError, "must be callable"): VespaEvaluator( queries=self.queries, relevant_docs=self.relevant_docs, @@ -289,7 +414,7 @@ def test_vespa_query_fn_validation_errors(self): def fn1(query: str) -> dict: return {"yql": query} - with self.assertRaisesRegex(TypeError, "must take exactly 2 parameters"): + with self.assertRaisesRegex(TypeError, "must take 2 or 3 parameters"): VespaEvaluator( queries=self.queries, relevant_docs=self.relevant_docs, @@ -309,42 +434,290 @@ def fn2(query: int, k: str) -> dict: app=self.mock_app, ) - # Wrong return type annotation - def fn3(query: str, k: int) -> list: - return [query, k] + # No type hints + def fn3(query, k): + return {"yql": query, "hits": k} + + def test_validate_qrels(self): + """Test validation of qrels with valid qrels""" + # Valid qrels + qrels1 = { + "q1": {"doc1", "doc2", "doc3"}, + "q2": {"doc4", "doc5"}, + "q3": {"doc6"}, + } + qrels2 = { + "q1": "doc1", + "q2": "doc4", + "q3": "doc6", + } + qrels3 = { + "q1": {"doc1": 1.0, "doc2": 0.5, "doc3": 0.2}, + "q2": {"doc4": 0.8, "doc5": 0.6}, + "q3": {"doc6": 1.0}, + } - with self.assertRaisesRegex(ValueError, "must return a dict"): + # All should work without raising exceptions + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=qrels1, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertIsInstance(evaluator, VespaEvaluator) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=qrels2, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertIsInstance(evaluator, VespaEvaluator) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=qrels3, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertIsInstance(evaluator, VespaEvaluator) + + def test_validate_qrels_errors(self): + """Test validation of qrels with invalid qrels""" + + # Not a dict + with self.assertRaisesRegex(ValueError, "qrels must be a dict"): VespaEvaluator( queries=self.queries, - relevant_docs=self.relevant_docs, - vespa_query_fn=fn3, + relevant_docs="not_a_dict", + vespa_query_fn=self.vespa_query_fn, app=self.mock_app, ) - # Function that raises error - def fn4(query: str, k: int) -> dict: - raise ValueError("Something went wrong") + # Relevant docs not a set, string, or dict + with self.assertRaisesRegex(ValueError, "must be a set, string, or dict"): + VespaEvaluator( + queries=self.queries, + relevant_docs={"q1": 1}, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) - with self.assertRaisesRegex(ValueError, "Error calling vespa_query_fn"): + # Relevance scores not numeric + with self.assertRaisesRegex( + ValueError, "must be a dict of string doc_id => numeric relevance" + ): VespaEvaluator( queries=self.queries, - relevant_docs=self.relevant_docs, - vespa_query_fn=fn4, + relevant_docs={"q1": {"doc1": "not_numeric"}}, + vespa_query_fn=self.vespa_query_fn, app=self.mock_app, ) - # Function that returns wrong type at runtime - def fn5(query: str, k: int) -> dict: - return [query, k] # Actually returns a list + # Relevance scores not between 0 and 1 + with self.assertRaisesRegex(ValueError, "must be between 0 and 1"): + VespaEvaluator( + queries=self.queries, + relevant_docs={"q1": {"doc1": 1.1}}, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) - with self.assertRaisesRegex(ValueError, "must return a dict"): + with self.assertRaisesRegex(ValueError, "must be between 0 and 1"): VespaEvaluator( queries=self.queries, - relevant_docs=self.relevant_docs, - vespa_query_fn=fn5, + relevant_docs={"q1": {"doc1": -0.1}}, + vespa_query_fn=self.vespa_query_fn, app=self.mock_app, ) + def test_filter_queries(self): + """Test filter_queries method""" + queries = { + "q1": "what is machine learning", + "q2": "how to code python", + "q3": "what is the capital of France", + "q4": "irrelevant query", + } + + relevant_docs = { + "q1": {"doc1", "doc2", "doc3"}, + "q2": {"doc4", "doc5"}, + "q3": {"doc6"}, + } + + evaluator = VespaEvaluator( + queries=queries, + relevant_docs=relevant_docs, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + + # Test that queries with no relevant docs are filtered out + self.assertEqual(len(evaluator.queries_ids), 3) + self.assertNotIn("q4", evaluator.queries_ids) + + # Test that queries with empty relevant docs are filtered out + relevant_docs["q4"] = set() + evaluator = VespaEvaluator( + queries=queries, + relevant_docs=relevant_docs, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertEqual(len(evaluator.queries_ids), 3) + self.assertNotIn("q4", evaluator.queries_ids) + + # Test that queries with relevant docs are not filtered out + relevant_docs["q4"] = {"doc7"} + evaluator = VespaEvaluator( + queries=queries, + relevant_docs=relevant_docs, + vespa_query_fn=self.vespa_query_fn, + app=self.mock_app, + ) + self.assertEqual(len(evaluator.queries_ids), 4) + self.assertIn("q4", evaluator.queries_ids) + + def test_vespa_query_fn_with_query_id(self): + """Test that vespa_query_fn accepting query_id receives it as the third argument.""" + + def fn(query_text: str, top_k: int, query_id: str) -> dict: + return { + "yql": f'select * from sources * where text contains "{query_text}" and id="{query_id}";', + "hits": top_k, + "query_id": query_id, # Not for passing to Vespa, but for testing + } + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs, + vespa_query_fn=fn, + app=self.mock_app, + ) + self.assertTrue(evaluator._vespa_query_fn_takes_query_id) + # Build query bodies and check that query_id is passed correctly. + query_bodies = [] + max_k = evaluator._find_max_k() + for qid, query_text in zip(evaluator.queries_ids, evaluator.queries): + query_body = evaluator.vespa_query_fn(query_text, max_k, qid) + query_bodies.append(query_body) + + for qid, qb in zip(evaluator.queries_ids, query_bodies): + self.assertIn("query_id", qb) + self.assertEqual(qb["query_id"], qid) + + def test_vespa_query_fn_without_query_id(self): + """Test that a vespa_query_fn accepting only 2 parameters does not receive a query_id.""" + + def fn(query_text: str, top_k: int) -> dict: + # Return a basic query body. + return {"yql": query_text, "hits": top_k} + + # Create a dummy response (the content is not used for these tests). + dummy_response = MockVespaResponse([{"id": "doc1", "relevance": 1.0}]) + capturing_app = QueryBodyCapturingApp([dummy_response] * len(self.queries)) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs, + vespa_query_fn=fn, + app=capturing_app, + ) + # Since fn accepts only 2 params, the evaluator should mark it as NOT taking a query_id. + self.assertFalse(evaluator._vespa_query_fn_takes_query_id) + + # Run the evaluator to trigger query body generation. + evaluator.run() + + # Verify that none of the query bodies include a "query_id" key and that default_body keys were added. + for qb in capturing_app.captured_query_bodies: + self.assertNotIn("query_id", qb) + self.assertIn("timeout", qb) + self.assertEqual(qb["timeout"], "5s") + self.assertIn("presentation.timing", qb) + self.assertEqual(qb["presentation.timing"], True) + + def test_vespa_query_fn_no_type_hints(self): + """Test that a vespa_query_fn without type hints is handled correctly.""" + + def fn(query_text, top_k): + # Return a basic query body. + return {"yql": query_text, "hits": top_k} + + # Create a dummy response (the content is not used for these tests). + dummy_response = MockVespaResponse([{"id": "doc1", "relevance": 1.0}]) + capturing_app = QueryBodyCapturingApp([dummy_response] * len(self.queries)) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs, + vespa_query_fn=fn, + app=capturing_app, + ) + + # Run the evaluator to trigger query body generation. + evaluator.run() + + # Verify that none of the query bodies include a "query_id" key and that default_body keys were added. + for qb in capturing_app.captured_query_bodies: + self.assertNotIn("query_id", qb) + self.assertIn("timeout", qb) + self.assertEqual(qb["timeout"], "5s") + self.assertIn("presentation.timing", qb) + self.assertEqual(qb["presentation.timing"], True) + + def test_vespa_query_fn_default_body_override(self): + """Test that keys from default_body override any conflicting keys returned by vespa_query_fn.""" + + def fn_override(query_text: str, top_k: int) -> dict: + # Return a query body that has conflicting values for default keys. + return { + "yql": query_text, + "hits": top_k, + "timeout": "10s", + "presentation.timing": False, + } + + dummy_response = MockVespaResponse([{"id": "doc1", "relevance": 1.0}]) + capturing_app = QueryBodyCapturingApp([dummy_response] * len(self.queries)) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs, + vespa_query_fn=fn_override, + app=capturing_app, + ) + evaluator.run() + + # After evaluator.run(), the default body should override the keys from fn_override. + for qb in capturing_app.captured_query_bodies: + self.assertEqual(qb["timeout"], "5s") + self.assertEqual(qb["presentation.timing"], True) + + def test_vespa_query_fn_preserves_extra_keys(self): + """Test that extra keys returned by vespa_query_fn are preserved after merging with default_body.""" + + def fn_extra(query_text: str, top_k: int) -> dict: + # Return a query body that includes an extra key. + return {"yql": query_text, "hits": top_k, "extra": "value"} + + dummy_response = MockVespaResponse([{"id": "doc1", "relevance": 1.0}]) + capturing_app = QueryBodyCapturingApp([dummy_response] * len(self.queries)) + + evaluator = VespaEvaluator( + queries=self.queries, + relevant_docs=self.relevant_docs, + vespa_query_fn=fn_extra, + app=capturing_app, + ) + evaluator.run() + + # Verify that the extra key is still present in each query body. + for qb in capturing_app.captured_query_bodies: + self.assertIn("extra", qb) + self.assertEqual(qb["extra"], "value") + if __name__ == "__main__": unittest.main() diff --git a/vespa/evaluation.py b/vespa/evaluation.py index 503559e7..7b382d4e 100644 --- a/vespa/evaluation.py +++ b/vespa/evaluation.py @@ -78,6 +78,13 @@ class VespaEvaluator: # "q2": "d101", # # ... # } + # Or, relevant_docs can be a dict of query_id => map of doc_id => relevance + # relevant_docs = { + # "q1": {"d12": 1, "d99": 0.1}, + # "q2": {"d101": 0.01}, + # # ... + # Note that for non-binary relevance, the relevance values should be in [0, 1], and that + # only the nDCG metric will be computed. def my_vespa_query_fn(query_text: str, top_k: int) -> dict: return { @@ -110,10 +117,13 @@ def my_vespa_query_fn(query_text: str, top_k: int) -> dict: def __init__( self, queries: Dict[str, str], - relevant_docs: Union[Dict[str, Set[str]], Dict[str, str]], - vespa_query_fn: Callable[[str, int], dict], + relevant_docs: Union[ + Dict[str, Union[Set[str], Dict[str, float]]], Dict[str, str] + ], + vespa_query_fn: Callable[[str, int, Optional[str]], dict], app: Vespa, name: str = "", + id_field: str = "", accuracy_at_k: List[int] = [1, 3, 5, 10], precision_recall_at_k: List[int] = [1, 3, 5, 10], mrr_at_k: List[int] = [10], @@ -124,10 +134,11 @@ def __init__( ): """ :param queries: Dict of query_id => query text - :param relevant_docs: Dict of query_id => set of relevant doc_ids (the user-specified part of `id::::` in Vespa, see https://docs.vespa.ai/en/documents.html#document-ids) + :param relevant_docs: Dict of query_id => set of relevant doc_ids or query_id => dict of doc_id => relevance. See example usage. :param vespa_query_fn: Callable, with signature: my_func(query:str, top_k: int)-> dict: Given a query string and top_k, returns a Vespa query body (dict). :param app: A `vespa.application.Vespa` instance. :param name: A name or tag for this evaluation run. + :param id_field: Specify the field name in Vespa that contains the document ID (If unset, will try to use vespa internal document id, but this may fail in some cases, see https://docs.vespa.ai/en/documents.html#docid-in-results). :param accuracy_at_k: list of k-values for Accuracy@k :param precision_recall_at_k: list of k-values for Precision@k and Recall@k :param mrr_at_k: list of k-values for MRR@k @@ -136,10 +147,9 @@ def __init__( :param write_csv: If True, writes results to CSV :param csv_dir: Path in which to write the CSV file (default: current working dir). """ + self.id_field = id_field self._validate_queries(queries) - self._validate_vespa_query_fn( - vespa_query_fn - ) # Add this line before _validate_qrels + self._validate_vespa_query_fn(vespa_query_fn) relevant_docs = self._validate_qrels(relevant_docs) # Filter out any queries that have no relevant docs @@ -193,81 +203,114 @@ def filter_queries( filtered.append(qid) return filtered - def _validate_queries(self, queries: Dict[str, str]): + def _validate_queries(self, queries: Dict[Union[str, int], str]) -> Dict[str, str]: + """ + Validate and normalize queries. + Converts query IDs to strings if they are ints. + """ if not isinstance(queries, dict): raise ValueError("queries must be a dict of query_id => query_text") + normalized_queries = {} for qid, query_text in queries.items(): - if not isinstance(qid, str) or not isinstance(query_text, str): - raise ValueError("Each query must be a string.", qid, query_text) + if not isinstance(qid, (str, int)): + raise ValueError("Query ID must be a string or an int.", qid) + if not isinstance(query_text, str): + raise ValueError("Query text must be a string.", query_text) + normalized_queries[str(qid)] = query_text + return normalized_queries def _validate_qrels( - self, qrels: Union[Dict[str, Set[str]], Dict[str, str]] - ) -> Dict[str, Set[str]]: + self, + qrels: Union[ + Dict[Union[str, int], Union[Set[str], Dict[str, float]]], + Dict[Union[str, int], str], + ], + ) -> Dict[str, Union[Set[str], Dict[str, float]]]: + """ + Validate and normalize qrels. + Converts query IDs to strings if they are ints. + """ if not isinstance(qrels, dict): raise ValueError( - "qrels must be a dict of query_id => set of relevant doc_ids" + "qrels must be a dict of query_id => set/dict of relevant doc_ids or a single doc_id string" ) - new_qrels: Dict[str, Set[str]] = {} + new_qrels: Dict[str, Union[Set[str], Dict[str, float]]] = {} for qid, relevant_docs in qrels.items(): - if not isinstance(qid, str): + if not isinstance(qid, (str, int)): raise ValueError( - "Each qrel must be a string query_id and a set of doc_ids.", - qid, - relevant_docs, + "Query ID in qrels must be a string or an int.", qid, relevant_docs ) + normalized_qid = str(qid) if isinstance(relevant_docs, str): - new_qrels[qid] = {relevant_docs} + new_qrels[normalized_qid] = {relevant_docs} elif isinstance(relevant_docs, set): - new_qrels[qid] = relevant_docs + new_qrels[normalized_qid] = relevant_docs + elif isinstance(relevant_docs, dict): + for doc_id, relevance in relevant_docs.items(): + if not isinstance(doc_id, str) or not isinstance( + relevance, (int, float) + ): + raise ValueError( + f"Relevance scores for query {normalized_qid} must be a dict of string doc_id => numeric relevance." + ) + if not 0 <= relevance <= 1: + raise ValueError( + f"Relevance scores for query {normalized_qid} must be between 0 and 1." + ) + new_qrels[normalized_qid] = relevant_docs else: raise ValueError( - f"Relevant docs for query {qid} must be a set or string." + f"Relevant docs for query {normalized_qid} must be a set, string, or dict." ) return new_qrels - def _validate_vespa_query_fn(self, fn: Callable[[str, int], dict]) -> None: + def _validate_vespa_query_fn(self, fn: Callable) -> None: """ - Validate that vespa_query_fn is callable and has correct signature. + Simplified validation of vespa_query_fn. - :param fn: Function to validate - :raises ValueError: If function doesn't meet requirements - :raises TypeError: If function signature is incorrect + The function must be callable and take either 2 or 3 parameters: + - (query_text: str, top_k: int) or + - (query_text: str, top_k: int, query_id: str) where query_id can also be Optional[str]. + It must return a dict when called with test inputs. """ if not callable(fn): - raise ValueError("vespa_query_fn must be a callable") + raise ValueError("vespa_query_fn must be callable") import inspect sig = inspect.signature(fn) - params = list(sig.parameters.items()) - - # Check number of parameters - if len(params) != 2: - raise TypeError( - f"vespa_query_fn must take exactly 2 parameters (query_text, top_k), got {len(params)}" - ) - - # Check parameter types from type hints - param_types = {name: param.annotation for name, param in params} - - expected_types = {params[0][0]: str, params[1][0]: int} - - for param_name, expected_type in expected_types.items(): - if param_types.get(param_name) not in ( - expected_type, - inspect.Parameter.empty, + params = list(sig.parameters.values()) + + if len(params) not in (2, 3): + raise TypeError("vespa_query_fn must take 2 or 3 parameters") + + # Validate first parameter: query_text + if ( + params[0].annotation is not inspect.Parameter.empty + and params[0].annotation is not str + ): + raise TypeError("Parameter 'query_text' must be of type str") + + # Validate second parameter: top_k + if ( + params[1].annotation is not inspect.Parameter.empty + and params[1].annotation is not int + ): + raise TypeError("Parameter 'top_k' must be of type int") + + # If there's a third parameter, validate query_id + if len(params) == 3: + third = params[2] + if ( + third.annotation is not inspect.Parameter.empty + and third.annotation not in (str, Optional[str]) ): raise TypeError( - f"Parameter '{param_name}' must be of type {expected_type.__name__}" + "Parameter 'query_id' must be of type str or Optional[str]" ) - - # Validate the function can actually be called with test inputs - try: - result = fn("test query", 10) - if not isinstance(result, dict): - raise TypeError("vespa_query_fn must return a dict") - except Exception as e: - raise ValueError(f"Error calling vespa_query_fn with test inputs: {str(e)}") + self._vespa_query_fn_takes_query_id = True + else: + self._vespa_query_fn_takes_query_id = False def _find_max_k(self): """ @@ -314,8 +357,15 @@ def run(self) -> Dict[str, float]: # Build query bodies using the provided vespa_query_fn query_bodies = [] - for query_text in self.queries: - query_body: dict = self.vespa_query_fn(query_text, max_k) + for qid, query_text in zip(self.queries_ids, self.queries): + if getattr(self, "_vespa_query_fn_takes_query_id", False): + query_body: dict = self.vespa_query_fn(query_text, max_k, qid) + else: + query_body: dict = self.vespa_query_fn(query_text, max_k) + if not isinstance(query_body, dict): + raise ValueError( + f"vespa_query_fn must return a dict, got: {type(query_body)}" + ) # Add default body parameters query_body.update(self.default_body) query_bodies.append(query_body) @@ -337,8 +387,19 @@ def run(self) -> Dict[str, float]: hits = resp.hits or [] top_hit_list = [] for hit in hits[:max_k]: - # doc_id extraction logic - doc_id = str(hit.get("id", "").split("::")[-1]) + # May be a Vespa internal id. + if self.id_field == "": + full_id = hit.get("id", "") + if ( + "::" not in full_id + ): # vespa internal id - eg. index:content/0/35c332d6bc52ae1f8378f7b3 + # Trying 'id' field as a fallback + doc_id = str(hit.get("fields", {}).get("id", "")) + else: + doc_id = full_id.split("::")[-1] + else: + # doc_id extraction logic + doc_id = str(hit.get("fields", {}).get(self.id_field, "")) if not doc_id: raise ValueError(f"Could not extract doc_id from hit: {hit}") score = float(hit.get("relevance", float("nan"))) @@ -347,7 +408,6 @@ def run(self) -> Dict[str, float]: top_hit_list.append((doc_id, score)) queries_result_list.append(top_hit_list) - metrics = self._compute_metrics(queries_result_list) searchtime_stats = self._calculate_searchtime_stats() metrics.update(searchtime_stats) @@ -379,6 +439,30 @@ def _calculate_searchtime_stats(self) -> Dict[str, float]: def _compute_metrics(self, queries_result_list): num_queries = len(queries_result_list) + # Infer graded relevance on the fly instead of storing as a class variable. + graded = bool(self.relevant_docs) and isinstance( + next(iter(self.relevant_docs.values())), dict + ) + + if graded: + ndcg_at_k_list = {k: [] for k in self.ndcg_at_k} + for query_idx, top_hits in enumerate(queries_result_list): + qid = self.queries_ids[query_idx] + # For graded relevance, 'relevant_docs' is a list of dicts: doc_id -> grade + relevant: Dict[str, float] = self.relevant_docs[qid] + for k_val in self.ndcg_at_k: + predicted_relevance = [ + relevant.get(doc_id, 0.0) for doc_id, _ in top_hits[:k_val] + ] + dcg_pred = self._dcg_at_k(predicted_relevance, k_val) + # Ideal ranking is the sorted graded scores in descending order + ideal_relevances = sorted(relevant.values(), reverse=True)[:k_val] + dcg_true = self._dcg_at_k(ideal_relevances, k_val) + ndcg_val = dcg_pred / dcg_true if dcg_true > 0 else 0.0 + ndcg_at_k_list[k_val].append(ndcg_val) + metrics = {f"ndcg@{k}": mean(ndcg_at_k_list[k]) for k in self.ndcg_at_k} + return metrics + num_hits_at_k = {k: 0 for k in self.accuracy_at_k} precision_at_k_list = {k: [] for k in self.precision_recall_at_k} recall_at_k_list = {k: [] for k in self.precision_recall_at_k} @@ -436,10 +520,21 @@ def _compute_metrics(self, queries_result_list): sum_precisions = 0.0 top_k_hits = top_hits[:k_val] for rank, (doc_id, _) in enumerate(top_k_hits, start=1): - if doc_id in relevant: + if isinstance(relevant, dict): + if doc_id in relevant: + num_correct += 1 + sum_precisions += ( + relevant[doc_id] / rank + ) # Use relevance score + elif doc_id in relevant: num_correct += 1 sum_precisions += num_correct / rank - denom = min(k_val, len(relevant)) + denom = min( + k_val, + len(relevant) + if isinstance(relevant, set) + else len(relevant.keys()), + ) avg_precision = sum_precisions / denom if denom > 0 else 0.0 map_at_k_list[k_val].append(avg_precision)