diff --git a/src/ragas/llms/json_load.py b/src/ragas/llms/json_load.py index 29197cbbb..509b31217 100644 --- a/src/ragas/llms/json_load.py +++ b/src/ragas/llms/json_load.py @@ -83,8 +83,8 @@ def _safe_load(self, text: str, llm: BaseRagasLLM, callbacks: Callbacks = None): retry = 0 while retry <= self.max_retries: try: - start, end = self._find_outermost_json(text) - return json.loads(text[start:end]) + _json = self._load_all_jsons(text) + return _json[0] if len(_json) == 1 else _json except ValueError: from ragas.llms.prompt import PromptValue @@ -104,8 +104,8 @@ async def _asafe_load( retry = 0 while retry <= self.max_retries: try: - start, end = self._find_outermost_json(text) - return json.loads(text[start:end]) + _json = self._load_all_jsons(text) + return _json[0] if len(_json) == 1 else _json except ValueError: from ragas.llms.prompt import PromptValue @@ -126,7 +126,7 @@ async def safe_load( callbacks: Callbacks = None, is_async: bool = True, run_config: RunConfig = RunConfig(), - ): + ) -> t.Union[t.Dict, t.List]: if is_async: _asafe_load_with_retry = add_async_retry(self._asafe_load, run_config) return await _asafe_load_with_retry(text=text, llm=llm, callbacks=callbacks) @@ -141,6 +141,16 @@ async def safe_load( safe_load, ) + def _load_all_jsons(self, text): + start, end = self._find_outermost_json(text) + _json = json.loads(text[start:end]) + text = text.replace(text[start:end], "", 1) + start, end = self._find_outermost_json(text) + if (start, end) == (-1, -1): + return [_json] + else: + return [_json] + self._load_all_jsons(text) + def _find_outermost_json(self, text): stack = [] start_index = -1 diff --git a/src/ragas/metrics/_context_precision.py b/src/ragas/metrics/_context_precision.py index d3e6c0d57..221c5d8ea 100644 --- a/src/ragas/metrics/_context_precision.py +++ b/src/ragas/metrics/_context_precision.py @@ -138,6 +138,7 @@ async def _ascore( await json_loader.safe_load(item, self.llm, is_async=is_async) for item in responses ] + json_responses = t.cast(t.List[t.Dict], json_responses) score = self._calculate_average_precision(json_responses) return score diff --git a/src/ragas/metrics/_faithfulness.py b/src/ragas/metrics/_faithfulness.py index 0a38f10ad..9e10c54af 100644 --- a/src/ragas/metrics/_faithfulness.py +++ b/src/ragas/metrics/_faithfulness.py @@ -187,6 +187,7 @@ async def _ascore( is_async=is_async, ) + assert isinstance(statements, dict), "Invalid JSON response" p = self._create_nli_prompt(row, statements.get("statements", [])) nli_result = await self.llm.generate(p, callbacks=callbacks, is_async=is_async) json_output = await json_loader.safe_load(