Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: failsafe for non-valid json and failed LLM calls #7723

Merged
merged 27 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a21b0c2
wip
davidsbatista May 21, 2024
91ad2ef
initial import
davidsbatista May 21, 2024
8746035
adding tests
davidsbatista May 21, 2024
3d16830
adding params
davidsbatista May 21, 2024
33dd22d
adding safeguards for nan in evaluators
davidsbatista May 21, 2024
7473d1f
adding docstrings
davidsbatista May 22, 2024
b2ff89a
fixing tests
davidsbatista May 22, 2024
75af5ff
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 22, 2024
860c2aa
removing unused imports
davidsbatista May 22, 2024
d502ed9
removing unused imports
davidsbatista May 22, 2024
2538ed3
removing unused imports
davidsbatista May 22, 2024
f5f3818
adding tests to context and faithfullness evaluators
davidsbatista May 22, 2024
a271db7
fixing docstrings
davidsbatista May 22, 2024
54a0146
nit
davidsbatista May 22, 2024
12164d8
removing unused imports
davidsbatista May 22, 2024
687312f
adding release notes
davidsbatista May 22, 2024
2b94818
attending PR comments
davidsbatista May 22, 2024
a2c69dd
fixing tests
davidsbatista May 22, 2024
e9497ec
fixing tests
davidsbatista May 22, 2024
f98930d
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 22, 2024
c0570ec
adding types
davidsbatista May 22, 2024
796588c
removing unused imports
davidsbatista May 22, 2024
50f6477
Update haystack/components/evaluators/context_relevance.py
davidsbatista May 23, 2024
8ce0c9d
Update haystack/components/evaluators/faithfulness.py
davidsbatista May 23, 2024
a7d7879
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 23, 2024
391e4fa
attending PR comments
davidsbatista May 23, 2024
a49fc65
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
progress_bar: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
raise_on_failure: bool = True,
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Creates an instance of ContextRelevanceEvaluator.
Expand Down Expand Up @@ -97,6 +98,9 @@ def __init__(
Supported APIs: "openai".
:param api_key:
The API key.
:param raise_on_failure:
Whether to raise an exception if the API call fails.

"""
self.instructions = (
"Your task is to judge how relevant the provided context is for answering a question. "
Expand All @@ -117,6 +121,7 @@ def __init__(
examples=self.examples,
api=self.api,
api_key=self.api_key,
raise_on_failure=raise_on_failure,
progress_bar=progress_bar,
)

Expand All @@ -138,7 +143,10 @@ def run(self, questions: List[str], contexts: List[List[str]]) -> Dict[str, Any]
result = super().run(questions=questions, contexts=contexts)

# calculate average statement relevance score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if res is None:
result["results"][idx] = {"statements": [], "statement_scores": [], "score": float("nan")}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
9 changes: 8 additions & 1 deletion haystack/components/evaluators/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
progress_bar: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
raise_on_failure: bool = True,
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Creates an instance of FaithfulnessEvaluator.
Expand Down Expand Up @@ -112,6 +113,8 @@ def __init__(
Supported APIs: "openai".
:param api_key:
The API key.
:param raise_on_failure:
Whether to raise an exception if the API call fails.

"""
self.instructions = (
Expand All @@ -134,6 +137,7 @@ def __init__(
examples=self.examples,
api=self.api,
api_key=self.api_key,
raise_on_failure=raise_on_failure,
progress_bar=progress_bar,
)

Expand All @@ -157,7 +161,10 @@ def run(self, questions: List[str], contexts: List[List[str]], predicted_answers
result = super().run(questions=questions, contexts=contexts, predicted_answers=predicted_answers)

# calculate average statement faithfulness score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if res is None:
result["results"][idx] = {"statements": [], "statement_scores": [], "score": float("nan")}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
68 changes: 54 additions & 14 deletions haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# SPDX-License-Identifier: Apache-2.0

import json
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type
from warnings import warn

from tqdm import tqdm

Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
examples: List[Dict[str, Any]],
progress_bar: bool = True,
*,
raise_on_failure: bool = True,
api: str = "openai",
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
):
Expand All @@ -73,6 +75,8 @@ def __init__(
`outputs` parameters.
Each example is a dictionary with keys "inputs" and "outputs"
They contain the input and output as dictionaries respectively.
:param raise_on_failure:
If True, the component will raise an exception on an unsuccessful API call.
:param progress_bar:
Whether to show a progress bar during the evaluation.
:param api:
Expand All @@ -83,6 +87,7 @@ def __init__(

"""
self.validate_init_parameters(inputs, outputs, examples)
self.raise_on_failure = raise_on_failure
self.instructions = instructions
self.inputs = inputs
self.outputs = outputs
Expand Down Expand Up @@ -168,7 +173,8 @@ def run(self, **inputs) -> Dict[str, Any]:
:returns:
A dictionary with a single `results` entry that contains a list of results.
Each result is a dictionary containing the keys as defined in the `outputs` parameter of the LLMEvaluator
and the evaluation results as the values.
and the evaluation results as the values. If an exception occurs for a particular input value, the result
will be `None` for that entry.
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
"""
self.validate_input_parameters(dict(self.inputs), inputs)

Expand All @@ -177,14 +183,31 @@ def run(self, **inputs) -> Dict[str, Any]:
input_names, values = inputs.keys(), list(zip(*inputs.values()))
list_of_input_names_to_values = [dict(zip(input_names, v)) for v in values]

results = []
results: List[Optional[Dict[str, Any]]] = []
errors = 0
for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
prompt = self.builder.run(**input_names_to_values)
result = self.generator.run(prompt=prompt["prompt"])

self.validate_outputs(expected=self.outputs, received=result["replies"][0])
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
try:
result = self.generator.run(prompt=prompt["prompt"])
except Exception as e:
msg = f"Error while generating response for prompt: {prompt}. Error: {e}"
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
results.append(None)
errors += 1
continue

if self.is_valid_json_and_has_expected_keys(expected=self.outputs, received=result["replies"][0]):
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
else:
results.append(None)
errors += 1

if errors > 0:
msg = f"LLM evaluator failed for {errors} out of {len(list_of_input_names_to_values)} inputs."
warn(msg)

return {"results": results}

Expand Down Expand Up @@ -299,20 +322,37 @@ def validate_input_parameters(expected: Dict[str, Any], received: Dict[str, Any]
)
raise ValueError(msg)

@staticmethod
def validate_outputs(expected: List[str], received: str) -> None:
def is_valid_json_and_has_expected_keys(self, expected: List[str], received: str) -> bool:
"""
Validate the output.
Output must be a valid JSON with the expected keys.

:param expected:
Names of expected outputs
:param received:
Names of received outputs

:raises ValueError:
If not all expected outputs are present in the received outputs
If the output is not a valid JSON with the expected keys:
- with `raise_on_failure` set to True a ValueError is raised.
- with `raise_on_failure` set to False a warning is issued and False is returned.

:returns:
True if the received output is a valid JSON with the expected keys, False otherwise.
"""
parsed_output = json.loads(received)
try:
parsed_output = json.loads(received)
except json.JSONDecodeError:
msg = "Response from LLM evaluator is not a valid JSON."
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return False

if not all(output in parsed_output for output in expected):
msg = f"Expected response from LLM evaluator to be JSON with keys {expected}, got {received}."
raise ValueError(msg)
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return False

return True
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
If an LLM-based evaluator (e.g., `Faithfulness` or `ContextRelevance`) is initialised with `raise_on_failure=False`, and if a call to an LLM fails or an LLM outputs an invalid JSON, the score of the sample is set to `NaN` instead of raising an exception.
The user is notified with a warning indicating the number of requests that failed.
41 changes: 41 additions & 0 deletions test/components/evaluators/test_context_relevance_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
from typing import List

import math

import pytest

from haystack.components.evaluators import ContextRelevanceEvaluator
Expand Down Expand Up @@ -159,6 +161,45 @@ def test_run_missing_parameters(self, monkeypatch):
with pytest.raises(TypeError, match="missing 2 required positional arguments"):
component.run()

def test_run_handles_nan(self, monkeypatch):
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = ContextRelevanceEvaluator(raise_on_failure=False)

def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
[
"The popularity of sports can be measured in various ways, including TV viewership, social media "
"presence, number of participants, and economic impact. Football is undoubtedly the world's most "
"popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and "
"Messi, drawing a followership of more than 4 billion people."
],
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
],
]
results = component.run(questions=questions, contexts=contexts)

assert math.isnan(results["score"])

assert results["individual_scores"][0] == 1.0
assert math.isnan(results["individual_scores"][1])

assert results["results"][0] == {"statements": ["c", "d"], "statement_scores": [1, 1], "score": 1.0}

assert results["results"][1]["statements"] == []
assert results["results"][1]["statement_scores"] == []
assert math.isnan(results["results"][1]["score"])

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
45 changes: 45 additions & 0 deletions test/components/evaluators/test_faithfulness_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import os
import math
from typing import List

import numpy as np
import pytest

from haystack.components.evaluators import FaithfulnessEvaluator
Expand Down Expand Up @@ -191,6 +193,49 @@ def test_run_missing_parameters(self, monkeypatch):
with pytest.raises(TypeError, match="missing 3 required positional arguments"):
component.run()

def test_run_handles_nan(self, monkeypatch):
davidsbatista marked this conversation as resolved.
Show resolved Hide resolved
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = FaithfulnessEvaluator(raise_on_failure=False)

def generator_run(self, *args, **kwargs):
if "Python" in kwargs["prompt"]:
raise Exception("OpenAI API request failed.")
else:
return {"replies": ['{"statements": ["c", "d"], "statement_scores": [1, 1]}']}

monkeypatch.setattr("haystack.components.generators.openai.OpenAIGenerator.run", generator_run)

questions = ["Which is the most popular global sport?", "Who created the Python language?"]
contexts = [
[
"The popularity of sports can be measured in various ways, including TV viewership, social media "
"presence, number of participants, and economic impact. Football is undoubtedly the world's most "
"popular sport with major events like the FIFA World Cup and sports personalities like Ronaldo and "
"Messi, drawing a followership of more than 4 billion people."
],
[
"Python, created by Guido van Rossum in the late 1980s, is a high-level general-purpose programming "
"language. Its design philosophy emphasizes code readability, and its language constructs aim to help "
"programmers write clear, logical code for both small and large-scale software projects."
],
]
predicted_answers = [
"Football is the most popular sport with around 4 billion followers worldwide.",
"Guido van Rossum.",
]
results = component.run(questions=questions, contexts=contexts, predicted_answers=predicted_answers)

assert math.isnan(results["score"])

assert results["individual_scores"][0] == 1.0
assert math.isnan(results["individual_scores"][1])

assert results["results"][0] == {"statements": ["c", "d"], "statement_scores": [1, 1], "score": 1.0}

assert results["results"][1]["statements"] == []
assert results["results"][1]["statement_scores"] == []
assert math.isnan(results["results"][1]["score"])

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
Expand Down
36 changes: 34 additions & 2 deletions test/components/evaluators/test_llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List

import numpy as np
import pytest

from haystack.components.evaluators import LLMEvaluator
Expand Down Expand Up @@ -379,10 +380,41 @@ def test_invalid_outputs(self, monkeypatch):
],
)
with pytest.raises(ValueError):
component.validate_outputs(expected=["score", "another_expected_output"], received='{"score": 1.0}')
component.is_valid_json_and_has_expected_keys(
expected=["score", "another_expected_output"], received='{"score": 1.0}'
)

with pytest.raises(ValueError):
component.is_valid_json_and_has_expected_keys(expected=["score"], received='{"wrong_name": 1.0}')

def test_output_invalid_json_raise_on_failure_false(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = LLMEvaluator(
instructions="test-instruction",
inputs=[("predicted_answers", List[str])],
outputs=["score"],
examples=[
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
raise_on_failure=False,
)
assert (
component.is_valid_json_and_has_expected_keys(expected=["score"], received="some_invalid_json_output")
is False
)

def test_output_invalid_json_raise_on_failure_true(self, monkeypatch):
monkeypatch.setenv("OPENAI_API_KEY", "test-api-key")
component = LLMEvaluator(
instructions="test-instruction",
inputs=[("predicted_answers", List[str])],
outputs=["score"],
examples=[
{"inputs": {"predicted_answers": "Football is the most popular sport."}, "outputs": {"score": 0}}
],
)
with pytest.raises(ValueError):
component.validate_outputs(expected=["score"], received='{"wrong_name": 1.0}')
component.is_valid_json_and_has_expected_keys(expected=["score"], received="some_invalid_json_output")

def test_unsupported_api(self):
with pytest.raises(ValueError):
Expand Down