Skip to content

Commit

Permalink
kwargs for invoking LLMs
Browse files Browse the repository at this point in the history
  • Loading branch information
HuXiangkun committed Sep 2, 2024
1 parent 55c454b commit 9bafaaf
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ragchecker"
version = "0.1.2"
version = "0.1.3"
description = "RAGChecker: A Fine-grained Framework For Diagnosing Retrieval-Augmented Generation (RAG) systems."
authors = [
"Xiangkun Hu <[email protected]>",
Expand All @@ -15,7 +15,7 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = "^3.9"
refchecker = "^0.2.3"
refchecker = "^0.2.4"
loguru = "^0.7"
dataclasses-json = "^0.6"

Expand Down
5 changes: 5 additions & 0 deletions ragchecker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def get_args():
help="Disable joint checking of the claims."
)
parser.set_defaults(joint_check=True)
parser.add_argument(
"--joint_check_num", type=int, default=5
)


return parser.parse_args()

Expand All @@ -75,6 +79,7 @@ def main():
batch_size_checker=args.batch_size_checker,
openai_api_key=args.openai_api_key,
joint_check=args.joint_check,
joint_check_num=args.joint_check_num
)
with open(args.input_path, "r") as f:
rag_results = RAGResults.from_json(f.read())
Expand Down
10 changes: 7 additions & 3 deletions ragchecker/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,15 @@ def __init__(
batch_size_checker=32,
openai_api_key=None,
joint_check=True,
joint_check_num=5
joint_check_num=5,
**kwargs
):
if openai_api_key:
os.environ['OPENAI_API_KEY'] = openai_api_key
self.extractor_max_new_tokens = extractor_max_new_tokens
self.joint_check = joint_check
self.joint_check_num = joint_check_num
self.kwargs = kwargs

self.extractor = LLMExtractor(
model=extractor_name,
Expand Down Expand Up @@ -102,7 +104,8 @@ def extract_claims(self, results: List[RAGResult], extract_type="gt_answer"):
extraction_results = self.extractor.extract(
batch_responses=texts,
batch_questions=questions,
max_new_tokens=self.extractor_max_new_tokens
max_new_tokens=self.extractor_max_new_tokens,
**self.kwargs
)
claims = [[c.content for c in res.claims] for res in extraction_results]
for i, result in enumerate(results):
Expand Down Expand Up @@ -161,7 +164,8 @@ def check_claims(self, results: RAGResults, check_type="answer2response"):
max_reference_segment_length=0,
merge_psg=merge_psg,
is_joint=self.joint_check,
joint_check_num=self.joint_check_num
joint_check_num=self.joint_check_num,
**self.kwargs
)
for i, result in enumerate(results):
if check_type == "answer2response":
Expand Down

0 comments on commit 9bafaaf

Please sign in to comment.