From 066bb89989eda07c214f094e1cd18a4b23139639 Mon Sep 17 00:00:00 2001 From: mertyg Date: Sun, 7 Jul 2024 10:57:37 -0700 Subject: [PATCH] better error handling in the optimizer --- textgrad/model.py | 4 +--- textgrad/optimizer/optimizer.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/textgrad/model.py b/textgrad/model.py index 882cf19..4e71d78 100644 --- a/textgrad/model.py +++ b/textgrad/model.py @@ -1,10 +1,8 @@ -from typing import List, Union -from functools import partial +from typing import Union from textgrad.variable import Variable from textgrad.autograd import LLMCall from textgrad.autograd.function import Module from textgrad.engine import EngineLM, get_engine -from textgrad.autograd.llm_ops import FormattedLLMCall from .config import SingletonBackwardEngine class BlackboxLLM(Module): diff --git a/textgrad/optimizer/optimizer.py b/textgrad/optimizer/optimizer.py index f3e3d8e..5ebec04 100644 --- a/textgrad/optimizer/optimizer.py +++ b/textgrad/optimizer/optimizer.py @@ -177,7 +177,13 @@ def step(self): prompt_update_parameter = self._update_prompt(parameter) new_text = self.engine(prompt_update_parameter, system_prompt=self.optimizer_system_prompt) logger.info(f"TextualGradientDescent optimizer response", extra={"optimizer.response": new_text}) - parameter.set_value(new_text.split(self.new_variable_tags[0])[1].split(self.new_variable_tags[1])[0].strip()) + try: + new_value = new_text.split(self.new_variable_tags[0])[1].split(self.new_variable_tags[1])[0].strip() + # Check if we got a cannot be indexed error + except IndexError: + logger.error(f"TextualGradientDescent optimizer response could not be indexed", extra={"optimizer.response": new_text}) + raise IndexError(f"TextualGradientDescent optimizer response could not be indexed. This can happen if the optimizer model cannot follow the instructions. You can try using a stronger model, or somehow reducing the context of the optimization. Response: {new_text}") + parameter.set_value(new_value) logger.info(f"TextualGradientDescent updated text", extra={"parameter.value": parameter.value}) if self.verbose: print("-----------------------TextualGradientDescent------------------------") @@ -263,5 +269,12 @@ def step(self): prompt_update_parameter = self._update_prompt(parameter, momentum_storage_idx=idx) new_text = self.engine(prompt_update_parameter, system_prompt=self.optimizer_system_prompt) logger.info(f"TextualGradientDescentwithMomentum optimizer response", extra={"optimizer.response": new_text}) - parameter.set_value(new_text.split(self.new_variable_tags[0])[1].split(self.new_variable_tags[1])[0].strip()) + try: + new_value = new_text.split(self.new_variable_tags[0])[1].split(self.new_variable_tags[1])[0].strip() + # Check if we got a cannot be indexed error + except IndexError: + logger.error(f"TextualGradientDescent optimizer response could not be indexed", extra={"optimizer.response": new_text}) + raise IndexError(f"TextualGradientDescent optimizer response could not be indexed. This can happen if the optimizer model cannot follow the instructions. You can try using a stronger model, or somehow reducing the context of the optimization. Response: {new_text}") + parameter.set_value(new_value) + logger.info(f"TextualGradientDescent updated text", extra={"parameter.value": parameter.value}) logger.info(f"TextualGradientDescentwithMomentum updated text", extra={"parameter.value": parameter.value})