From 5ebc0e218528918d71d7effcf8bdd6450b1bd947 Mon Sep 17 00:00:00 2001 From: Shahules786 Date: Fri, 13 Dec 2024 17:18:27 +0530 Subject: [PATCH] fix: add reference to simple scoring --- src/ragas/metrics/_simple_criteria.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/ragas/metrics/_simple_criteria.py b/src/ragas/metrics/_simple_criteria.py index 141415228..c9977dbbc 100644 --- a/src/ragas/metrics/_simple_criteria.py +++ b/src/ragas/metrics/_simple_criteria.py @@ -49,9 +49,7 @@ class SingleTurnSimpleCriteriaInput(BaseModel): class MultiTurnSimpleCriteriaInput(BaseModel): - user_input: t.Optional[str] = Field( - description="The input to the model", default=None - ) + user_input: str = Field(description="The input to the model") reference: t.Optional[str] = Field( description="The reference response", default=None ) @@ -172,20 +170,18 @@ async def _single_turn_ascore( async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: assert self.llm is not None, "set LLM before use" - user_input, context, response = ( - row["user_input"], + user_input, response, retrieved_contexts, reference = ( + row.get("user_input"), + row.get("response"), row.get("retrieved_contexts"), - row["response"], + row.get("reference"), ) - if context is not None: - if isinstance(context, list): - context = "\n".join(context) - user_input = f"Question: {user_input} Answer using context: {context}" - prompt_input = SingleTurnSimpleCriteriaInput( user_input=user_input, response=response, + retrieved_contexts=retrieved_contexts, + reference=reference, ) response = await self.single_turn_prompt.generate( @@ -200,11 +196,11 @@ async def _multi_turn_ascore( self, sample: MultiTurnSample, callbacks: Callbacks ) -> float: assert self.llm is not None, "LLM is not set" - assert sample.reference is not None, "Reference is not set" interaction = sample.pretty_repr() prompt_input = MultiTurnSimpleCriteriaInput( user_input=interaction, + reference=sample.reference, ) response = await self.multi_turn_prompt.generate( data=prompt_input,