Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: failsafe for non-valid json and failed LLM calls #7723

Merged
merged 27 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a21b0c2
wip
davidsbatista May 21, 2024
91ad2ef
initial import
davidsbatista May 21, 2024
8746035
adding tests
davidsbatista May 21, 2024
3d16830
adding params
davidsbatista May 21, 2024
33dd22d
adding safeguards for nan in evaluators
davidsbatista May 21, 2024
7473d1f
adding docstrings
davidsbatista May 22, 2024
b2ff89a
fixing tests
davidsbatista May 22, 2024
75af5ff
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 22, 2024
860c2aa
removing unused imports
davidsbatista May 22, 2024
d502ed9
removing unused imports
davidsbatista May 22, 2024
2538ed3
removing unused imports
davidsbatista May 22, 2024
f5f3818
adding tests to context and faithfullness evaluators
davidsbatista May 22, 2024
a271db7
fixing docstrings
davidsbatista May 22, 2024
54a0146
nit
davidsbatista May 22, 2024
12164d8
removing unused imports
davidsbatista May 22, 2024
687312f
adding release notes
davidsbatista May 22, 2024
2b94818
attending PR comments
davidsbatista May 22, 2024
a2c69dd
fixing tests
davidsbatista May 22, 2024
e9497ec
fixing tests
davidsbatista May 22, 2024
f98930d
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 22, 2024
c0570ec
adding types
davidsbatista May 22, 2024
796588c
removing unused imports
davidsbatista May 22, 2024
50f6477
Update haystack/components/evaluators/context_relevance.py
davidsbatista May 23, 2024
8ce0c9d
Update haystack/components/evaluators/faithfulness.py
davidsbatista May 23, 2024
a7d7879
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 23, 2024
391e4fa
attending PR comments
davidsbatista May 23, 2024
a49fc65
Merge branch 'main' into failsafe-for-non-valid-JSON
davidsbatista May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adding safeguards for nan in evaluators
  • Loading branch information
davidsbatista committed May 21, 2024
commit 33dd22dbb2c902884c3045a342a990342c80e3ae
6 changes: 5 additions & 1 deletion haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, Dict, List, Optional

from numpy import isnan
from numpy import mean as np_mean

from haystack import default_from_dict
Expand Down Expand Up @@ -141,7 +142,10 @@ def run(self, questions: List[str], contexts: List[List[str]]) -> Dict[str, Any]
result = super().run(questions=questions, contexts=contexts)

# calculate average statement relevance score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if isinstance(res, float) and isnan(res):
result["results"][idx] = {"statements": [], "statement_scores": [], "score": 0}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
6 changes: 5 additions & 1 deletion haystack/components/evaluators/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any, Dict, List, Optional

from numpy import isnan
from numpy import mean as np_mean

from haystack import default_from_dict
Expand Down Expand Up @@ -159,7 +160,10 @@ def run(self, questions: List[str], contexts: List[List[str]], predicted_answers
result = super().run(questions=questions, contexts=contexts, predicted_answers=predicted_answers)

# calculate average statement faithfulness score per query
for res in result["results"]:
for idx, res in enumerate(result["results"]):
if isinstance(res, float) and isnan(res):
result["results"][idx] = {"statements": [], "statement_scores": [], "score": 0}
continue
if not res["statements"]:
res["score"] = 0
else:
Expand Down
45 changes: 30 additions & 15 deletions haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __init__(
`outputs` parameters.
Each example is a dictionary with keys "inputs" and "outputs"
They contain the input and output as dictionaries respectively.
:param raises_on_failure:
:param raise_on_failure:
If True, the component will raise an exception if the evaluation fails.
:param api:
The API to use for calling an LLM through a Generator.
Expand Down Expand Up @@ -170,6 +170,8 @@ def run(self, **inputs) -> Dict[str, Any]:
"""
Run the LLM evaluator.

# ToDo: add more details about the behavior of this method and it's exceptions

:param inputs:
The input values to evaluate. The keys are the input names and the values are lists of input values.
:returns:
Expand All @@ -187,13 +189,21 @@ def run(self, **inputs) -> Dict[str, Any]:
results = []
for input_names_to_values in tqdm(list_of_input_names_to_values, disable=not self.progress_bar):
prompt = self.builder.run(**input_names_to_values)
result = self.generator.run(prompt=prompt["prompt"])

# ToDo: how to handle too large context
try:
result = self.generator.run(prompt=prompt["prompt"])
except Exception as e:
msg = f"Error while generating response for prompt: {prompt}. Error: {e}"
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
results.append(np.nan)
continue

self.validate_outputs(expected=self.outputs, received=result["replies"][0])
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
if self.is_valid_json(expected=self.outputs, received=result["replies"][0]):
parsed_result = json.loads(result["replies"][0])
results.append(parsed_result)
else:
results.append(np.nan)

return {"results": results}

Expand Down Expand Up @@ -307,14 +317,14 @@ def validate_input_parameters(expected: Dict[str, Any], received: Dict[str, Any]
)
raise ValueError(msg)

def validate_outputs(self, expected: List[str], received: str) -> Optional[float]:
def is_valid_json(self, expected: List[str], received: str) -> bool:
"""
Validate the output.
Output must be a valid JSON with the expected keys.

If `raise_on_failure` is True, raise a ValueError if not all expected outputs are present in the received
outputs or if the received outputs are not a valid JSON.

If `raise_on_failure` is False, print a warning if the received outputs are not a valid JSON and return a `nan`.
If the output is not a valid JSON with the expected keys:
- with `raise_on_failure` set to True a ValueError is raised.
- with `raise_on_failure` set to False a warning is issued and False is returned.
If the output is a valid JSON with the expected keys, True is returned.

:param expected:
Names of expected outputs
Expand All @@ -323,6 +333,9 @@ def validate_outputs(self, expected: List[str], received: str) -> Optional[float

:raises ValueError:
If not all expected outputs are present in the received outputs

:returns:
True if the received output is a valid JSON with the expected keys, False otherwise.
"""
try:
parsed_output = json.loads(received)
Expand All @@ -332,11 +345,13 @@ def validate_outputs(self, expected: List[str], received: str) -> Optional[float
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return np.nan
return False

except json.JSONDecodeError:
msg = "Response from LLM evaluator is not a valid JSON."
if self.raise_on_failure:
raise ValueError(msg)
warn(msg)
return np.nan
return False

return True