From da2c480a2e3e6a938cdb05b9ba6266ec54760028 Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Mon, 4 Mar 2024 08:43:53 -0500 Subject: [PATCH 1/3] Add wandb logging to `BootstrapFewShot's` `metric_val` --- dspy/teleprompt/bootstrap.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 6fc9dd2ef..3ad014efe 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -2,6 +2,7 @@ import threading import tqdm +import wandb import dsp import dspy @@ -29,7 +30,6 @@ # TODO: Add baselines=[...] - class BootstrapFewShot(Teleprompter): def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_bootstrapped_demos=4, max_labeled_demos=16, max_rounds=1, max_errors=5): self.metric = metric @@ -43,13 +43,13 @@ def __init__(self, metric=None, metric_threshold=None, teacher_settings={}, max_ self.error_count = 0 self.error_lock = threading.Lock() - def compile(self, student, *, teacher=None, trainset, valset=None): + def compile(self, student, *, teacher=None, trainset, valset=None, wandb_enabled=False): self.trainset = trainset self.valset = valset self._prepare_student_and_teacher(student, teacher) self._prepare_predictor_mappings() - self._bootstrap() + self._bootstrap(wandb_enabled) self.student = self._train() self.student._compiled = True @@ -94,7 +94,8 @@ def _prepare_predictor_mappings(self): self.name2predictor = name2predictor self.predictor2name = predictor2name - def _bootstrap(self, *, max_bootstraps=None): + def _bootstrap(self, *, max_bootstraps=None, wandb_config=None): + wandb_enabled = True if wandb_config is not None else False max_bootstraps = max_bootstraps or self.max_bootstrapped_demos bootstrapped = {} @@ -106,7 +107,7 @@ def _bootstrap(self, *, max_bootstraps=None): break if example_idx not in bootstrapped: - success = self._bootstrap_one_example(example, round_idx) + success = self._bootstrap_one_example(example, round_idx, wandb_enabled) if success: bootstrapped[example_idx] = True @@ -124,7 +125,7 @@ def _bootstrap(self, *, max_bootstraps=None): # evaluate = Evaluate(program=self.teacher, metric=self.metric, num_threads=12) # score = evaluate(self.metric, display_table=False, display_progress=True) - def _bootstrap_one_example(self, example, round_idx=0): + def _bootstrap_one_example(self, example, round_idx=0, wandb_enabled=False): name2traces = self.name2traces teacher = self.teacher #.deepcopy() predictor_cache = {} @@ -148,6 +149,10 @@ def _bootstrap_one_example(self, example, round_idx=0): if self.metric: metric_val = self.metric(example, prediction, trace) + if wandb_enabled: + wandb.log({ + "metric_val": metric_val + }) if self.metric_threshold: success = metric_val >= self.metric_threshold else: From 67753218ecb581d21db25319a2116e13b5cd0bcc Mon Sep 17 00:00:00 2001 From: Connor Shorten Date: Mon, 4 Mar 2024 08:56:48 -0500 Subject: [PATCH 2/3] Whoops update --- dspy/teleprompt/bootstrap.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index 3ad014efe..ac687813e 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -49,7 +49,7 @@ def compile(self, student, *, teacher=None, trainset, valset=None, wandb_enabled self._prepare_student_and_teacher(student, teacher) self._prepare_predictor_mappings() - self._bootstrap(wandb_enabled) + self._bootstrap(wandb_enabled=wandb_enabled) self.student = self._train() self.student._compiled = True @@ -94,8 +94,7 @@ def _prepare_predictor_mappings(self): self.name2predictor = name2predictor self.predictor2name = predictor2name - def _bootstrap(self, *, max_bootstraps=None, wandb_config=None): - wandb_enabled = True if wandb_config is not None else False + def _bootstrap(self, *, max_bootstraps=None, wandb_enabled=False): max_bootstraps = max_bootstraps or self.max_bootstrapped_demos bootstrapped = {} From 3167807cfc75f6d7059ce57b213453337cdf7d2c Mon Sep 17 00:00:00 2001 From: CShorten Date: Mon, 4 Mar 2024 13:57:16 +0000 Subject: [PATCH 3/3] Automatic Style fixes --- dspy/teleprompt/bootstrap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/teleprompt/bootstrap.py b/dspy/teleprompt/bootstrap.py index ac687813e..76b5bf4d2 100644 --- a/dspy/teleprompt/bootstrap.py +++ b/dspy/teleprompt/bootstrap.py @@ -150,7 +150,7 @@ def _bootstrap_one_example(self, example, round_idx=0, wandb_enabled=False): metric_val = self.metric(example, prediction, trace) if wandb_enabled: wandb.log({ - "metric_val": metric_val + "metric_val": metric_val, }) if self.metric_threshold: success = metric_val >= self.metric_threshold