Skip to content

Commit

Permalink
Merge pull request #38 from jimilp7/feature/progress-bar
Browse files Browse the repository at this point in the history
Issue#17 - Add Progress Bar to show progress on perturbations and scores
  • Loading branch information
iterix authored Oct 6, 2023
2 parents 3167a96 + b592ea4 commit 033718f
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 0 deletions.
9 changes: 9 additions & 0 deletions auditor/evaluation/discriminative.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from auditor.reporting import generate_robustness_report
from auditor.utils.logging import get_logger
from auditor.utils.progress_logger import ProgressLogger

LOG = get_logger(__name__)

Expand Down Expand Up @@ -90,6 +91,12 @@ def evaluate(
f'Started model evaluation with perturbation type '
f'{self.perturbed_dataset.perturbation_type}'
)
progress_bar = ProgressLogger(
total_steps=min(len(self.perturbed_dataset.data),
len(self.perturbed_dataset.metadata)),
description="Starting Model Evaluation"
)

for perturbed_samples, metadata_samples in zip(
self.perturbed_dataset.data, self.perturbed_dataset.metadata
):
Expand Down Expand Up @@ -126,12 +133,14 @@ def evaluate(
metadata=mdata,
)
)
progress_bar.update()
robust_accuracy = self.compute_accuracy(test_results)
LOG.info(f'Robust Accuracy: {robust_accuracy*100.}')
LOG.info(
'Completed model evaluation with perturbation type '
f'{self.perturbed_dataset.perturbation_type}'
)
progress_bar.close()
self.test_results = TestSummary(
results=test_results,
robust_accuracy=robust_accuracy,
Expand Down
8 changes: 8 additions & 0 deletions auditor/evaluation/expected_behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from sentence_transformers.SentenceTransformer import SentenceTransformer

from auditor.utils.progress_logger import ProgressLogger
from auditor.utils.similarity import compute_similarity
from auditor.utils.logging import get_logger

Expand Down Expand Up @@ -163,6 +164,9 @@ def check(
reference_generation: str,
) -> List[Tuple[bool, Dict[str, float]]]:
test_results = []
progress_bar = ProgressLogger(total_steps=len(perturbed_generations),
description="Fetching Scores")

for peturbed_gen in perturbed_generations:
try:
score = compute_similarity(
Expand All @@ -178,9 +182,13 @@ def check(
self.similarity_metric_key: round(score, ndigits=2)
}
test_results.append((test_status, score_dict))
progress_bar.update()
except Exception as e:
LOG.error('Unable to complete semanatic similarity checks')
progress_bar.close()
raise e

progress_bar.close()
return test_results

def behavior_description(self):
Expand Down
7 changes: 7 additions & 0 deletions auditor/evaluation/generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from auditor.utils.logging import get_logger
from auditor.perturbations import Paraphrase
from auditor.perturbations import TransformBase
from auditor.utils.progress_logger import ProgressLogger

LOG = get_logger(__name__)

Expand Down Expand Up @@ -98,6 +99,9 @@ def _evaluate_generations(
else:
evaluate_prompts = prompt_perturbations

progress_bar = ProgressLogger(total_steps=len(evaluate_prompts),
description="Applying Perturbations")

# generations for each of the perturbed prompts
alternative_generations = []
for alt_prompt in evaluate_prompts:
Expand All @@ -107,6 +111,9 @@ def _evaluate_generations(
post_context,
)
alternative_generations.append(resp)
progress_bar.update()

progress_bar.close()

# create test result
metric = self.expected_behavior.check(
Expand Down
17 changes: 17 additions & 0 deletions auditor/utils/progress_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import tqdm


class ProgressLogger:
"""class to show progress bar"""

def __init__(self, total_steps, description="Logging..."):
self.total_steps = total_steps
self.description = description

self.pbar = tqdm.tqdm(total=total_steps, desc=description)

def update(self, incremental=1):
self.pbar.update(incremental)

def close(self):
self.pbar.close()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"langchain >=0.0.158",
"openai >=0.27.0",
"sentence-transformers>=2.2.2",
"tqdm>=4.66.1"
]

[project.license]
Expand Down

0 comments on commit 033718f

Please sign in to comment.