diff --git a/auditor/evaluation/expected_behavior.py b/auditor/evaluation/expected_behavior.py index 119dcd2..7cc4a5e 100644 --- a/auditor/evaluation/expected_behavior.py +++ b/auditor/evaluation/expected_behavior.py @@ -11,6 +11,7 @@ from auditor.utils.progress_logger import ProgressLogger from auditor.utils.similarity import compute_similarity from auditor.utils.logging import get_logger +from auditor.utils.format import construct_llm_input FAILED_TEST = 0 PASSED_TEST = 1 @@ -228,6 +229,10 @@ def check( post_context: Optional[str], ) -> List[Tuple[bool, Dict[str, float]]]: test_results = [] + progress_bar = ProgressLogger( + total_steps=len(perturbed_generations), + description=f"Grading responses with {self.grading_model}" + ) for peturbed_gen in perturbed_generations: try: rationale, test_status = self._grade( @@ -241,9 +246,13 @@ def check( self.metric_key: rationale, } test_results.append((test_status, score_dict)) + progress_bar.update() except Exception as e: # LOG.error('Unable to complete semanatic similarity checks') + progress_bar.close() raise e + + progress_bar.close() return test_results def _grade( @@ -254,7 +263,11 @@ def _grade( pre_context: Optional[str], post_context: Optional[str], ): - query = pre_context + prompt + post_context + query = construct_llm_input( + prompt=prompt, + pre_context=pre_context, + post_context=post_context, + ) grading_str = ( f'Given the following context and question are the following two answers factually same?' # noqa: E501 f'If the reponses provide different details when asked a question they must be flagged as different.\n' # noqa: E501 diff --git a/auditor/utils/format.py b/auditor/utils/format.py new file mode 100644 index 0000000..38b7b37 --- /dev/null +++ b/auditor/utils/format.py @@ -0,0 +1,15 @@ +from typing import Optional + +def construct_llm_input( + prompt: str, + pre_context: Optional[str], + post_context: Optional[str], + delimiter: str = " ", + ) -> str: + if pre_context is not None: + full_prompt = pre_context + delimiter + prompt + else: + full_prompt = prompt + if post_context is not None: + full_prompt += delimiter + post_context + return full_prompt \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index fb20bfc..1bb3865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fiddler-auditor" -version = "0.0.3" +version = "0.0.4.rc0" authors = [ { name="Fiddler Labs", email="support@fiddler.ai" }, ]