From 8d5685636ff84e29f0c1d3d6273f404b7206d0ec Mon Sep 17 00:00:00 2001 From: Fanyi Pu Date: Sun, 31 Mar 2024 00:15:47 +0800 Subject: [PATCH] Refactor get_question method in Context class --- lmms_eval/api/samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lmms_eval/api/samplers.py b/lmms_eval/api/samplers.py index 9278f9f1..f1c4fcf0 100644 --- a/lmms_eval/api/samplers.py +++ b/lmms_eval/api/samplers.py @@ -27,7 +27,8 @@ def __init__(self, task, few_shot_delimiter: str = "\n\n", target_delimiter: str self.contexts = [] def get_question(self, doc, model_specific_prompt_kwargs=None): - return self.doc_to_text(doc, model_specific_prompt_kwargs) if (self.doc_to_choice is None or type(self.doc_to_text(doc)) is str) else self.doc_to_choice(doc)[self.doc_to_text(doc)] + text = self.doc_to_text(doc, model_specific_prompt_kwargs) + return text if (self.doc_to_choice is None or isinstance(text, str)) else self.doc_to_choice(doc)[text] def get_target(self, doc): return (