Skip to content

Commit

Permalink
fix(testset generation) : Improvements (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Oct 18, 2023
1 parent 462247b commit 2d42d69
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 16 deletions.
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pytest
pytest-xdist[psutil]
llama_index
5 changes: 2 additions & 3 deletions src/ragas/testset/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@

FILTER_QUESTION = HumanMessagePromptTemplate.from_template(
"""\
Determine if the given question can be clearly understood even when presented without any additional context? Reason before arriving at the answer.
Determine if the given question can be clearly understood even when presented without any additional context. Specify reason and verdict is a valid json format.
question: What is the keyword that best describes the paper's focus in natural language understanding tasks?
answer: The specific paper being referred to is not mentioned in the question. Hence, No.
{{"reason":"The specific paper being referred to is not mentioned in the question.", "verdict": "No"}}
question:{question}
answer:
""" # noqa: E501
)

Expand Down
34 changes: 21 additions & 13 deletions src/ragas/testset/testset_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from __future__ import annotations

import typing as t
import warnings
from collections import defaultdict, namedtuple
Expand Down Expand Up @@ -33,6 +34,7 @@
SCORE_CONTEXT,
SEED_QUESTION,
)
from ragas.testset.utils import load_as_json, load_as_score

DEFAULT_TEST_DISTRIBUTION = {
"simple": 0.4,
Expand Down Expand Up @@ -138,6 +140,7 @@ def from_default(
openai_filter_llm: str = "gpt-4",
chat_qa: float = 0.3,
chunk_size: int = 512,
testset_distribution: dict = DEFAULT_TEST_DISTRIBUTION,
):
generator_llm = ChatOpenAI(model=openai_generator_llm)
critic_llm = ChatOpenAI(model=openai_filter_llm)
Expand All @@ -148,6 +151,7 @@ def from_default(
embeddings_model=embeddings_model,
chat_qa=chat_qa,
chunk_size=chunk_size,
testset_distribution=testset_distribution,
)

def _get_evolve_type(self) -> str:
Expand Down Expand Up @@ -175,12 +179,7 @@ def _filter_context(self, context: str) -> bool:
prompt = ChatPromptTemplate.from_messages([human_prompt])
results = generate(prompts=[prompt], llm=self.critic_llm)
output = results.generations[0][0].text.strip()
pattern = r"^[\d.]+$"
if not re.match(pattern, output):
score = 0.0
else:
score = eval(output)

score = load_as_score(output)
return score >= self.threshold

def _seed_question(self, context: str) -> str:
Expand All @@ -193,7 +192,9 @@ def _filter_question(self, question: str) -> bool:
human_prompt = FILTER_QUESTION.format(question=question)
prompt = ChatPromptTemplate.from_messages([human_prompt])
results = generate(prompts=[prompt], llm=self.critic_llm)
return bool(results.generations[0][0].text.strip().endswith("Yes."))
results = results.generations[0][0].text.strip()
json_results = load_as_json(results)
return json_results.get("verdict") != "No"

def _reasoning_question(self, question: str, context: str) -> str:
return self._qc_template(REASONING_QUESTION, question, context)
Expand Down Expand Up @@ -320,6 +321,9 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
if not score:
continue
seed_question = self._seed_question(text_chunk)
is_valid_question = self._filter_question(seed_question)
if not is_valid_question:
continue

if evolve_type == "multi_context":
# Find most similar chunk in same document
Expand Down Expand Up @@ -361,10 +365,14 @@ def generate(self, documents: t.List[Document], test_size: int) -> TestDataset:
else:
question = self._compress_question(question=question)

context = self._generate_context(question, text_chunk)
answer = self._generate_answer(question, context)
samples.append(DataRow(question.split("\n"), context, answer, evolve_type))
count += 1
pbar.update(count)
is_valid_question = self._filter_question(question)
if is_valid_question:
context = self._generate_context(question, text_chunk)
answer = self._generate_answer(question, context)
samples.append(
DataRow(question.split("\n"), context, answer, evolve_type)
)
count += 1
pbar.update(count)

return TestDataset(test_data=samples)
31 changes: 31 additions & 0 deletions src/ragas/testset/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json
import re
import warnings


def load_as_json(text):
"""
validate and return given text as json
"""

try:
return json.loads(text)
except ValueError:
warnings.warn("Invalid json")

return {}


def load_as_score(text):
"""
validate and returns given text as score
"""

pattern = r"^[\d.]+$"
if not re.match(pattern, text):
warnings.warn("Invalid score")
score = 0.0
else:
score = eval(text)

return score
2 changes: 2 additions & 0 deletions tests/unit/test_simple.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
def test_import():
import ragas
from ragas.testset.testset_generator import TestsetGenerator

assert TestsetGenerator is not None
assert ragas is not None

0 comments on commit 2d42d69

Please sign in to comment.