From e9c1240e354187722936af22d2d61abc58d3bbfa Mon Sep 17 00:00:00 2001 From: mertyg Date: Fri, 5 Jul 2024 00:12:53 +0000 Subject: [PATCH] textgrad vision integration init --- textgrad/autograd/__init__.py | 1 + textgrad/autograd/llm_ops.py | 6 +- .../autograd/multimodal_backward_prompts.py | 6 + textgrad/autograd/multimodal_ops.py | 226 ++++++++++++++++++ textgrad/config.py | 11 + textgrad/engine/__init__.py | 22 +- textgrad/engine/anthropic.py | 70 +++++- textgrad/engine/base.py | 6 +- textgrad/engine/openai.py | 75 ++++-- textgrad/loss.py | 37 ++- textgrad/optimizer/optimizer.py | 86 +++++-- textgrad/optimizer/optimizer_prompts.py | 46 +++- textgrad/utils/image_utils.py | 38 +++ textgrad/variable.py | 35 +-- 14 files changed, 580 insertions(+), 85 deletions(-) create mode 100644 textgrad/autograd/multimodal_backward_prompts.py create mode 100644 textgrad/autograd/multimodal_ops.py create mode 100644 textgrad/utils/image_utils.py diff --git a/textgrad/autograd/__init__.py b/textgrad/autograd/__init__.py index 05f9700..619bcf7 100644 --- a/textgrad/autograd/__init__.py +++ b/textgrad/autograd/__init__.py @@ -1,4 +1,5 @@ from .functional import sum, aggregate from .llm_ops import LLMCall, FormattedLLMCall, LLMCall_with_in_context_examples +from .multimodal_ops import MultimodalLLMCall, OrderedFieldsMultimodalLLMCall from .function import Module from .string_based_ops import StringBasedFunction \ No newline at end of file diff --git a/textgrad/autograd/llm_ops.py b/textgrad/autograd/llm_ops.py index 0236130..c09b49e 100644 --- a/textgrad/autograd/llm_ops.py +++ b/textgrad/autograd/llm_ops.py @@ -3,6 +3,7 @@ VARIABLE_OUTPUT_DEFAULT_ROLE) from textgrad.variable import Variable from textgrad.engine import EngineLM +from textgrad.config import validate_engine_or_get_default from typing import List from .llm_backward_prompts import ( EVALUATE_VARIABLE_INSTRUCTION, @@ -23,13 +24,11 @@ def __init__(self, engine: EngineLM, system_prompt: Variable = None): :param engine: engine to use for the LLM call :type engine: EngineLM - :param input_role_description: role description for the input variable, defaults to VARIABLE_INPUT_DEFAULT_ROLE - :type input_role_description: str, optional :param system_prompt: system prompt to use for the LLM call, default depends on the engine. :type system_prompt: Variable, optional """ super().__init__() - self.engine = engine + self.engine = validate_engine_or_get_default(engine) self.system_prompt = system_prompt if self.system_prompt and self.system_prompt.get_role_description() is None: self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE) @@ -112,6 +111,7 @@ def _backward_through_llm_chain(variables: List[Variable], prompt: str, system_prompt: str, backward_engine: EngineLM): + """ Backward through the LLM to compute gradients for each variable, in the case where the output has gradients on them. i.e. applying the chain rule. diff --git a/textgrad/autograd/multimodal_backward_prompts.py b/textgrad/autograd/multimodal_backward_prompts.py new file mode 100644 index 0000000..9b4ed9d --- /dev/null +++ b/textgrad/autograd/multimodal_backward_prompts.py @@ -0,0 +1,6 @@ +# First part of the prompt for the llm backward function +MULTIMODAL_CONVERSATION_TEMPLATE = ( + "\n Above messages are the \n\n" + " {system_prompt} \n\n" + " {response_value} \n\n" +) diff --git a/textgrad/autograd/multimodal_ops.py b/textgrad/autograd/multimodal_ops.py new file mode 100644 index 0000000..6cd68a6 --- /dev/null +++ b/textgrad/autograd/multimodal_ops.py @@ -0,0 +1,226 @@ +from textgrad import logger +from textgrad.defaults import (SYSTEM_PROMPT_DEFAULT_ROLE, + VARIABLE_OUTPUT_DEFAULT_ROLE) +from textgrad.variable import Variable +from textgrad.engine import EngineLM, validate_multimodal_engine +from typing import List +from .llm_backward_prompts import ( + EVALUATE_VARIABLE_INSTRUCTION, + CONVERSATION_START_INSTRUCTION_BASE, + CONVERSATION_START_INSTRUCTION_CHAIN, + OBJECTIVE_INSTRUCTION_CHAIN, + OBJECTIVE_INSTRUCTION_BASE, + BACKWARD_SYSTEM_PROMPT, +) +from .multimodal_backward_prompts import MULTIMODAL_CONVERSATION_TEMPLATE +from typing import Union +from textgrad.config import validate_engine_or_get_default +from .function import Function, BackwardContext + + +class MultimodalLLMCall(Function): + def __init__(self, + engine: Union[str, EngineLM], + system_prompt: Variable = None): + super().__init__() + self.engine = validate_engine_or_get_default(engine) + validate_multimodal_engine(self.engine) + + self.system_prompt = system_prompt + if self.system_prompt and self.system_prompt.get_role_description() is None: + self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE) + + + def forward(self, + inputs: List[Variable], + response_role_description: str = VARIABLE_OUTPUT_DEFAULT_ROLE) -> Variable: + # First ensure that all keys are present in the fields + + # Assert that all variables are either strings or bytes + for variable in inputs: + if not isinstance(variable.get_value(), (str, bytes)): + raise ValueError(f"MultimodalLLMCall only accepts str or bytes, got {type(variable.get_value())}") + + system_prompt_value = self.system_prompt.value if self.system_prompt else None + input_content = [inp.value for inp in inputs] + # Make the LLM Call + response_text = self.engine(input_content, system_prompt=system_prompt_value) + + # Create the response variable + response = Variable( + value=response_text, + predecessors=[self.system_prompt, *inputs] if self.system_prompt else [*inputs], + role_description=response_role_description + ) + + logger.info(f"MultimodalLLMCall function forward", extra={"text": f"System:{system_prompt_value}\n{inputs}"}) + + # Populate the gradient function, using a container to store the backward function and the context + response.set_grad_fn(BackwardContext(backward_fn=self.backward, + response=response, + input_content=input_content, + system_prompt=system_prompt_value)) + + return response + + + def backward(self, response: Variable, input_content: List[Union[str, bytes]], system_prompt: str, backward_engine: EngineLM): + validate_multimodal_engine(backward_engine) + + children_variables = response.predecessors + if response.get_gradient_text() == "": + self._backward_through_multimodal_llm_base(children_variables, response, input_content, system_prompt, backward_engine) + else: + self._backward_through_multimodal_llm_chain(children_variables, response, input_content, system_prompt, backward_engine) + + @staticmethod + def _construct_multimodal_llm_chain_backward_content(backward_info: dict[str, str]) -> str: + content = [c for c in backward_info["input_content"]] + conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info) + backward_prompt = CONVERSATION_START_INSTRUCTION_CHAIN.format(conversation=conversation, **backward_info) + backward_prompt += OBJECTIVE_INSTRUCTION_CHAIN.format(**backward_info) + backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info) + content.append(backward_prompt) + return content + + @staticmethod + def _backward_through_multimodal_llm_chain(variables: List[Variable], + response: Variable, + input_content: List[Union[str, bytes]], + system_prompt: str, + backward_engine: EngineLM): + for variable in variables: + if not variable.requires_grad: + continue + + backward_info = { + "response_desc": response.get_role_description(), + "response_value": response.get_value(), + "response_gradient": response.get_gradient_text(), + "input_content": input_content, + "system_prompt": system_prompt, + "variable_desc": variable.get_role_description(), + "variable_short": variable.get_short_value() + } + + backward_content = MultimodalLLMCall._construct_multimodal_llm_chain_backward_content(backward_info) + + logger.info(f"_backward_through_llm prompt", extra={"_backward_through_llm": backward_content}) + gradient_value = backward_engine(backward_content, system_prompt=BACKWARD_SYSTEM_PROMPT) + logger.info(f"_backward_through_llm gradient", extra={"_backward_through_llm": gradient_value}) + + var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}") + variable.gradients.add(var_gradients) + conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info) + variable.gradients_context[var_gradients] = { + "context": input_content + [conversation], + "response_desc": response.get_role_description(), + "variable_desc": variable.get_role_description() + } + + if response._reduce_meta: + var_gradients._reduce_meta.extend(response._reduce_meta) + variable._reduce_meta.extend(response._reduce_meta) + + @staticmethod + def _construct_multimodal_llm_base_backward_content(backward_info: dict[str, str]) -> str: + content = [c for c in backward_info["input_content"]] + conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info) + backward_prompt = CONVERSATION_START_INSTRUCTION_BASE.format(conversation=conversation, **backward_info) + backward_prompt += OBJECTIVE_INSTRUCTION_BASE.format(**backward_info) + backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info) + content.append(backward_prompt) + return content + + @staticmethod + def _backward_through_multimodal_llm_base(variables: List[Variable], + response: Variable, + input_content: List[Union[str, bytes]], + system_prompt: str, + backward_engine: EngineLM): + for variable in variables: + if not variable.requires_grad: + continue + + backward_info = { + "response_desc": response.get_role_description(), + "response_value": response.get_value(), + "input_content": input_content, + "system_prompt": system_prompt, + "variable_desc": variable.get_role_description(), + "variable_short": variable.get_short_value() + } + + backward_content = MultimodalLLMCall._construct_multimodal_llm_base_backward_content(backward_info) + + logger.info(f"_backward_through_llm prompt", extra={"_backward_through_llm": backward_content}) + gradient_value = backward_engine(backward_content, system_prompt=BACKWARD_SYSTEM_PROMPT) + logger.info(f"_backward_through_llm gradient", extra={"_backward_through_llm": gradient_value}) + + conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info) + var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}") + variable.gradients.add(var_gradients) + variable.gradients_context[var_gradients] = { + "context": input_content + [conversation], + "response_desc": response.get_role_description(), + "variable_desc": variable.get_role_description() + } + + if response._reduce_meta: + var_gradients._reduce_meta.extend(response._reduce_meta) + variable._reduce_meta.extend(response._reduce_meta) + + + +class OrderedFieldsMultimodalLLMCall(MultimodalLLMCall): + def __init__(self, + engine: Union[str, EngineLM], + fields: List[str], + system_prompt: Variable = None): + + self.engine = validate_engine_or_get_default(engine) + validate_multimodal_engine(self.engine) + + self.system_prompt = system_prompt + if self.system_prompt and self.system_prompt.get_role_description() is None: + self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE) + + self.fields = fields + + def forward(self, + inputs: dict[str, Variable], + response_role_description: str = VARIABLE_OUTPUT_DEFAULT_ROLE) -> Variable: + # Assert that all variables are either strings or bytes + for variable in inputs.values(): + if not isinstance(variable.get_value(), (str, bytes)): + raise ValueError(f"MultimodalLLMCall only accepts str or bytes, got {type(variable.get_value())}") + + assert set(inputs.keys()) == set(self.fields), f"Expected fields {self.fields.keys()} but got {inputs.keys()}" + forward_content = [] + for field in self.fields: + if type(inputs[field].value) == bytes: + forward_content.append(inputs[field].value) + else: + forward_content.append(f"{field}: {inputs[field].value}") + + system_prompt_value = self.system_prompt.value if self.system_prompt else None + + # Make the LLM Call + response_text = self.engine(forward_content, system_prompt=system_prompt_value) + + # Create the response variable + response = Variable( + value=response_text, + predecessors=[self.system_prompt, *(list(inputs.values()))] if self.system_prompt else [*(list(inputs.values()))], + role_description=response_role_description + ) + + logger.info(f"MultimodalLLMCall function forward", extra={"text": f"System:{system_prompt_value}\n{forward_content}"}) + + # Populate the gradient function, using a container to store the backward function and the context + response.set_grad_fn(BackwardContext(backward_fn=self.backward, + response=response, + input_content=forward_content, + system_prompt=system_prompt_value)) + + return response diff --git a/textgrad/config.py b/textgrad/config.py index eeca5ca..1b01388 100644 --- a/textgrad/config.py +++ b/textgrad/config.py @@ -47,3 +47,14 @@ def set_backward_engine(engine: Union[EngineLM, str], override: bool = False): if isinstance(engine, str): engine = get_engine(engine) singleton_backward_engine.set_engine(engine, override=override) + + +def validate_engine_or_get_default(engine): + if (engine is None) and (SingletonBackwardEngine().get_engine() is None): + raise Exception( + "No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.") + elif engine is None: + engine = SingletonBackwardEngine().get_engine() + if isinstance(engine, str): + engine = get_engine(engine) + return engine \ No newline at end of file diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index c098afb..2eebcaf 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -7,6 +7,24 @@ "together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf", } +# Any better way to do this? +__MULTIMODAL_ENGINES__ = ["gpt-4-turbo", + "gpt-4o", + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-sonnet-20240229", + "claude-3-haiku-20240307", + "gpt-4-turbo-2024-04-09", + ] + +def _check_if_multimodal(engine_name: str): + return any([name == engine_name for name in __MULTIMODAL_ENGINES__]) + +def validate_multimodal_engine(engine): + if not _check_if_multimodal(engine.model_string): + raise ValueError( + f"The engine provided is not multimodal. Please provide a multimodal engine, one of the following: {__MULTIMODAL_ENGINES__}") + def get_engine(engine_name: str, **kwargs) -> EngineLM: if engine_name in __ENGINE_NAME_SHORTCUTS__: engine_name = __ENGINE_NAME_SHORTCUTS__[engine_name] @@ -16,10 +34,10 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM: if (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)): from .openai import ChatOpenAI - return ChatOpenAI(model_string=engine_name, **kwargs) + return ChatOpenAI(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs) elif "claude" in engine_name: from .anthropic import ChatAnthropic - return ChatAnthropic(model_string=engine_name, **kwargs) + return ChatAnthropic(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs) elif "gemini" in engine_name: from .gemini import ChatGemini return ChatGemini(model_string=engine_name, **kwargs) diff --git a/textgrad/engine/anthropic.py b/textgrad/engine/anthropic.py index 56ab676..9fde954 100644 --- a/textgrad/engine/anthropic.py +++ b/textgrad/engine/anthropic.py @@ -10,7 +10,9 @@ stop_after_attempt, wait_random_exponential, ) - +import base64 +import json +from typing import List, Union from .base import EngineLM, CachedEngine class ChatAnthropic(EngineLM, CachedEngine): @@ -20,6 +22,7 @@ def __init__( self, model_string="claude-3-opus-20240229", system_prompt=SYSTEM_PROMPT, + is_multimodal=False, ): root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_anthropic_{model_string}.db") @@ -33,13 +36,23 @@ def __init__( self.model_string = model_string self.system_prompt = system_prompt assert isinstance(self.system_prompt, str) + self.is_multimodal = is_multimodal - @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) def __call__(self, prompt, **kwargs): return self.generate(prompt, **kwargs) - + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) - def generate( + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): + if isinstance(content, str): + return self._generate_text(content, system_prompt=system_prompt, **kwargs) + + elif isinstance(content, list): + if (not self.is_multimodal): + raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") + + return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs) + + def _generate_text( self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 ): @@ -65,3 +78,52 @@ def generate( response = response.content[0].text self._save_cache(sys_prompt_arg + prompt, response) return response + + def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: + formatted_content = [] + for item in content: + if isinstance(item, bytes): + image_media_type = "image/jpeg" + base64_image = base64.b64encode(item).decode('utf-8') + formatted_content.append( { + "type": "image", + "source": { + "type": "base64", + "media_type": image_media_type, + "data": base64_image, + }, + }) + elif isinstance(item, str): + formatted_content.append({ + "type": "text", + "text": item + }) + else: + raise ValueError(f"Unsupported input type: {type(item)}") + return formatted_content + + def _generate_multimodal( + self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=300, top_p=0.99 + ): + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + formatted_content = self._format_content(content) + + cache_key = sys_prompt_arg + json.dumps(formatted_content) + cache_or_none = self._check_cache(cache_key) + if cache_or_none is not None: + return cache_or_none + + response = self.client.messages.create( + model=self.model_string, + messages=[ + {"role": "user", "content": formatted_content}, + ], + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + system=sys_prompt_arg + ) + + response_text = response.content[0].text + self._save_cache(cache_key, response_text) + return response_text diff --git a/textgrad/engine/base.py b/textgrad/engine/base.py index 1d73cd5..b57a719 100644 --- a/textgrad/engine/base.py +++ b/textgrad/engine/base.py @@ -8,7 +8,11 @@ class EngineLM(ABC): @abstractmethod def generate(self, prompt, system_prompt=None, **kwargs): pass - + + def __call__(self, *args, **kwargs): + pass + + class CachedEngine: def __init__(self, cache_path): super().__init__() diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index e39e540..c59f489 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -4,15 +4,17 @@ raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.") import os +import json +import base64 import platformdirs from tenacity import ( retry, stop_after_attempt, wait_random_exponential, ) -import json +from typing import List, Union + from .base import EngineLM, CachedEngine -from openai._types import NotGiven class ChatOpenAI(EngineLM, CachedEngine): DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." @@ -21,6 +23,7 @@ def __init__( self, model_string="gpt-3.5-turbo-0613", system_prompt=DEFAULT_SYSTEM_PROMPT, + is_multimodal: bool=False, **kwargs): """ :param model_string: @@ -28,6 +31,8 @@ def __init__( """ root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_openai_{model_string}.db") + self.image_cache_dir = os.path.join(root, "image_cache") + os.makedirs(self.image_cache_dir, exist_ok=True) super().__init__(cache_path=cache_path) @@ -39,10 +44,20 @@ def __init__( api_key=os.getenv("OPENAI_API_KEY"), ) self.model_string = model_string + self.is_multimodal = is_multimodal + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): + if isinstance(content, str): + return self._generate_text(content, system_prompt=system_prompt, **kwargs) + + elif isinstance(content, list): + if (not self.is_multimodal): + raise NotImplementedError("Multimodal generation is only supported for GPT-4 models.") + + return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs) - - def generate( + def _generate_text( self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 ): @@ -70,29 +85,51 @@ def generate( self._save_cache(sys_prompt_arg + prompt, response) return response - def generate_with_messages(self, messages, temperature=0, max_tokens=2000, top_p=0.99): - prompt = json.dumps(messages) + def __call__(self, prompt, **kwargs): + return self.generate(prompt, **kwargs) + + def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: + formatted_content = [] + for item in content: + if isinstance(item, bytes): + base64_image = base64.b64encode(item).decode('utf-8') + formatted_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{base64_image}" + } + }) + elif isinstance(item, str): + formatted_content.append({ + "type": "text", + "text": item + }) + else: + raise ValueError(f"Unsupported input type: {type(item)}") + return formatted_content + + def _generate_multimodal( + self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=300, top_p=0.99 + ): + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + formatted_content = self._format_content(content) - cache_or_none = self._check_cache(prompt) + cache_key = sys_prompt_arg + json.dumps(formatted_content) + cache_or_none = self._check_cache(cache_key) if cache_or_none is not None: return cache_or_none response = self.client.chat.completions.create( model=self.model_string, - messages=messages, - frequency_penalty=0, - presence_penalty=0, - stop=None, + messages=[ + {"role": "system", "content": sys_prompt_arg}, + {"role": "user", "content": formatted_content}, + ], temperature=temperature, max_tokens=max_tokens, top_p=top_p, ) - response = response.choices[0].message.content - self._save_cache(prompt, response) - return response - - @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) - def __call__(self, prompt, **kwargs): - return self.generate(prompt, **kwargs) - + response_text = response.choices[0].message.content + self._save_cache(cache_key, response_text) + return response_text diff --git a/textgrad/loss.py b/textgrad/loss.py index 9c50c46..50542b8 100644 --- a/textgrad/loss.py +++ b/textgrad/loss.py @@ -1,7 +1,7 @@ from textgrad.engine import EngineLM, get_engine from textgrad.variable import Variable from typing import List, Union -from textgrad.autograd import LLMCall, FormattedLLMCall +from textgrad.autograd import LLMCall, FormattedLLMCall, OrderedFieldsMultimodalLLMCall from textgrad.autograd import Module from .config import SingletonBackwardEngine @@ -192,3 +192,38 @@ def forward(self, question: str, prediction: Variable) -> Variable: return self.formatted_llm_call(inputs=inputs, response_role_description=f"evaluation of the {prediction.get_role_description()}") +class ImageQALoss(Module): + def __init__(self, + evaluation_instruction: str, + engine: Union[EngineLM, str] = None, + system_prompt: Variable = None): + super().__init__() + self.evaluation_instruction = Variable(evaluation_instruction, role_description="evaluation instruction", requires_grad=False) + if ((engine is None) and (SingletonBackwardEngine().get_engine() is None)): + raise Exception("No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.") + elif engine is None: + engine = SingletonBackwardEngine().get_engine() + if isinstance(engine, str): + engine = get_engine(engine) + self.engine = engine + if system_prompt: + self.system_prompt = system_prompt + else: + self.system_prompt = Variable("You are an evaluation system that evaluates image-related questions.", + requires_grad=False, + role_description="system prompt for the evaluation") + + self.multimodal_llm_call = OrderedFieldsMultimodalLLMCall(engine=self.engine, + system_prompt=self.system_prompt, + fields=["Evaluation Instruction", "Question", "Image", "Answer"]) + + def forward(self, image: Variable, question: Variable, response: Variable) -> Variable: + + inputs = { + "Evaluation Instruction": self.evaluation_instruction, + "Question": question, + "Image": image, + "Answer": response + } + return self.multimodal_llm_call(inputs=inputs, + response_role_description=f"evaluation of the {response.get_role_description()}") \ No newline at end of file diff --git a/textgrad/optimizer/optimizer.py b/textgrad/optimizer/optimizer.py index ee28c78..f3e3d8e 100644 --- a/textgrad/optimizer/optimizer.py +++ b/textgrad/optimizer/optimizer.py @@ -3,9 +3,47 @@ from collections import defaultdict from textgrad.variable import Variable from textgrad import logger -from textgrad.engine import EngineLM, get_engine -from .optimizer_prompts import construct_tgd_prompt, OPTIMIZER_SYSTEM_PROMPT -from textgrad.config import SingletonBackwardEngine +from textgrad.engine import EngineLM +from textgrad.config import validate_engine_or_get_default +from .optimizer_prompts import construct_tgd_prompt, OPTIMIZER_SYSTEM_PROMPT, GRADIENT_TEMPLATE, GRADIENT_MULTIPART_TEMPLATE + + +def get_gradient_and_context_text(variable) -> Union[str, List[Union[str, bytes]]]: + """For the variable, aggregates and returns + i. the gradients + ii. the context for which the gradients are computed. + + This is used by the optimizer. + :return: A string containing the aggregated gradients and their corresponding context. + :rtype: str + """ + + gradient_content = [] + for g in variable.gradients: + if variable.gradients_context[g] is None: + gradient_content.append(g.value) + else: + # If context is a list, we handle it differently. + context = variable.gradients_context[g] + if isinstance(context["context"], str): + # The context could be all string. + criticism_and_context = GRADIENT_TEMPLATE.format( + feedback=g.value, **context) + gradient_content.append(criticism_and_context) + elif isinstance(context["context"], list): + # The context may have a list of images / strings. In this case, we need to handle it differently. + context_prompt = GRADIENT_MULTIPART_TEMPLATE.format(**context, feedback=g.value) + criticism_and_context = context["context"] + [context_prompt] + gradient_content.extend(criticism_and_context) + else: + raise ValueError("Context must be either a string or a list.") + + # Check if all instances are string + if all(isinstance(i, str) for i in gradient_content): + return "\n".join(gradient_content) + else: + return gradient_content + class Optimizer(ABC): """ @@ -20,8 +58,11 @@ class Optimizer(ABC): """ def __init__(self, parameters: List[Variable]): + for parameter in parameters: + if type(parameter.value) != str: + raise NotImplementedError(f"We cannot yet update multimodal content and this data type: {type(parameter.value)}. We can only evaluate gradients using multimodal models. This may change soon (looking at you, GPT-5).") self.parameters = parameters - + def zero_grad(self): """ Clears the gradients of all parameters. @@ -43,7 +84,7 @@ def __init__(self, verbose: int=0, engine: Union[EngineLM, str]=None, constraints: List[str]=None, - new_variable_tags: List[str]=["", ""], + new_variable_tags: List[str]=None, optimizer_system_prompt: str=OPTIMIZER_SYSTEM_PROMPT, in_context_examples: List[str]=None, gradient_memory: int=0): @@ -65,13 +106,11 @@ def __init__(self, :type gradient_memory: int, optional """ super().__init__(parameters) - if ((engine is None) and (SingletonBackwardEngine().get_engine() is None)): - raise Exception("No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.") - elif engine is None: - engine = SingletonBackwardEngine().get_engine() - if isinstance(engine, str): - engine = get_engine(engine) - self.engine = engine + + if new_variable_tags is None: + new_variable_tags = ["", ""] + + self.engine = validate_engine_or_get_default(engine) self.verbose = verbose self.constraints = constraints if constraints is not None else [] self.optimizer_system_prompt = optimizer_system_prompt.format(new_variable_start_tag=new_variable_tags[0], new_variable_end_tag=new_variable_tags[1]) @@ -104,12 +143,12 @@ def get_gradient_memory_text(self, variable: Variable): def update_gradient_memory(self, variable: Variable): self.gradient_memory_dict[variable].append({"value": variable.get_gradient_text()}) - def _update_prompt(self, variable: Variable): + def _update_prompt(self, variable: Variable) -> Union[str, List[Union[str, bytes]]]: grad_memory = self.get_gradient_memory_text(variable) optimizer_information = { "variable_desc": variable.get_role_description(), "variable_value": variable.value, - "variable_grad": variable.get_gradient_and_context_text(), + "variable_grad": get_gradient_and_context_text(variable), "variable_short": variable.get_short_value(), "constraint_text": self.constraint_text, "new_variable_start_tag": self.new_variable_tags[0], @@ -154,21 +193,20 @@ def __init__(self, parameters: List[Variable], momentum_window: int = 0, constraints: List[str]=None, - new_variable_tags: List[str]=["", ""], + new_variable_tags: List[str]=None, in_context_examples: List[str]=None, optimizer_system_prompt: str=OPTIMIZER_SYSTEM_PROMPT): super().__init__(parameters) - if ((engine is None) and (SingletonBackwardEngine().get_engine() is None)): - raise Exception("No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.") - elif engine is None: - engine = SingletonBackwardEngine().get_engine() - if isinstance(engine, str): - engine = get_engine(engine) - self.engine = engine + + if new_variable_tags is None: + new_variable_tags = ["", ""] + + self.engine = validate_engine_or_get_default(engine) if momentum_window == 0: return TextualGradientDescent(engine=engine, parameters=parameters, constraints=constraints) - # Each item in the momentum storage will include past value and the criticsm + + # Each item in the momentum storage will include past value and the criticism self.momentum_storage = [[] for _ in range(len(parameters))] self.momentum_window = momentum_window self.do_momentum = True @@ -217,7 +255,7 @@ def _update_momentum_storage(self, variable: Variable, momentum_storage_idx: int if len(self.momentum_storage[momentum_storage_idx]) >= self.momentum_window: self.momentum_storage[momentum_storage_idx].pop(0) - self.momentum_storage[momentum_storage_idx].append({"value": variable.value, "gradients": variable.get_gradient_and_context_text()}) + self.momentum_storage[momentum_storage_idx].append({"value": variable.value, "gradients": get_gradient_and_context_text(variable)}) def step(self): for idx, parameter in enumerate(self.parameters): diff --git a/textgrad/optimizer/optimizer_prompts.py b/textgrad/optimizer/optimizer_prompts.py index 6c6021c..4cbe2f3 100644 --- a/textgrad/optimizer/optimizer_prompts.py +++ b/textgrad/optimizer/optimizer_prompts.py @@ -31,6 +31,17 @@ "Improve the variable ({variable_desc}) using the feedback provided in tags.\n" ) +# If the gradients are in a multi-part container +TGD_MULTIPART_PROMPT_INIT = ( + "Here is the role of the variable you will improve: {variable_desc}.\n\n" + "The variable is the text within the following span: {variable_short} \n\n" + "Here is the context and feedback we got for the variable:\n\n" +) + +TGD_MULTIPART_PROMPT_PREFIX = ( + "Improve the variable ({variable_desc}) using the feedback provided in tags.\n" +) + TGD_PROMPT_SUFFIX = ( "Send the improved variable " "in the following format:\n\n{new_variable_start_tag}{{the improved variable}}{new_variable_end_tag}\n\n" @@ -72,18 +83,41 @@ def construct_tgd_prompt(do_momentum: bool = False, :rtype: str """ - prompt = TGD_PROMPT_PREFIX.format(**optimizer_kwargs) - + if isinstance(optimizer_kwargs["variable_grad"], str): + multipart=False + prompt = TGD_PROMPT_PREFIX.format(**optimizer_kwargs) + + else: + gradient_context = optimizer_kwargs["variable_grad"] + gradient_context = [TGD_MULTIPART_PROMPT_INIT.format(**optimizer_kwargs)] + gradient_context + multipart=True + prompt = TGD_MULTIPART_PROMPT_PREFIX.format(**optimizer_kwargs) + if do_momentum: prompt += MOMENTUM_PROMPT_ADDITION.format(**optimizer_kwargs) - + if do_constrained: prompt += CONSTRAINT_PROMPT_ADDITION.format(**optimizer_kwargs) - + if do_in_context_examples: prompt += IN_CONTEXT_EXAMPLE_PROMPT_ADDITION.format(**optimizer_kwargs) - + prompt += TGD_PROMPT_SUFFIX.format(**optimizer_kwargs) - return prompt + if not multipart: + return prompt + + else: + return gradient_context + [prompt] +# This is how we save gradients to the variable. +GRADIENT_TEMPLATE = ( + "Here is a conversation:\n\n{context}\n\n" + "This conversation is potentially part of a larger system. The output is used as {response_desc}\n\n" + "Here is the feedback we got for {variable_desc} in the conversation:\n\n{feedback}\n\n" +) +GRADIENT_MULTIPART_TEMPLATE = ( + "Above is a conversation with a language model.\n" + "This conversation is potentially part of a larger system. The output is used as {response_desc}\n\n" + "Here is the feedback we got for {variable_desc} in the conversation:\n\n{feedback}\n\n" +) diff --git a/textgrad/utils/image_utils.py b/textgrad/utils/image_utils.py new file mode 100644 index 0000000..706f529 --- /dev/null +++ b/textgrad/utils/image_utils.py @@ -0,0 +1,38 @@ +import os +import requests +import hashlib +from urllib.parse import urlparse +from typing import Union +import platformdirs +import base64 + +def download_and_cache_image(image_url: str) -> str: + # Set up cache directory + root = platformdirs.user_cache_dir("textgrad") + image_cache_dir = os.path.join(root, "image_cache") + os.makedirs(image_cache_dir, exist_ok=True) + + # Generate a unique filename + file_name = hashlib.md5(image_url.encode()).hexdigest() + ".jpg" + cache_path = os.path.join(image_cache_dir, file_name) + + # Check if the image is already cached + if os.path.exists(cache_path): + print(f"Image already cached at: {cache_path}") + with open(cache_path, "rb") as f: + image_data = f.read() + else: + # Download the image + print(f"Downloading image from: {image_url}") + response = requests.get(image_url) + response.raise_for_status() + image_data = response.content + + # Save to cache + with open(cache_path, "wb") as f: + f.write(image_data) + print(f"Image cached at: {cache_path}") + + with open(cache_path, "rb") as image_file: + return image_file.read() + diff --git a/textgrad/variable.py b/textgrad/variable.py index ab3d045..23fb1a5 100644 --- a/textgrad/variable.py +++ b/textgrad/variable.py @@ -5,12 +5,12 @@ from collections import defaultdict from functools import partial from .config import SingletonBackwardEngine -from .prompts import GRADIENT_TEMPLATE +from typing import Union class Variable: def __init__( self, - value: str = "", + value: Union[str, bytes] = "", predecessors: List['Variable']=None, requires_grad: bool=True, *, @@ -20,7 +20,7 @@ def __init__( :param role_description: The role of this variable. We find that this has a huge impact on the optimization performance, and being specific often helps quite a bit! :type role_description: str :param value: The string value of this variable, defaults to "". In the future, we'll go multimodal, for sure! - :type value: str, optional + :type value: str or bytes, optional :param predecessors: predecessors of this variable in the computation graph, defaults to None. Here, for instance, if we have a prompt -> response through an LLM call, we'd call the prompt the predecessor, and the response the successor. :type predecessors: List[Variable], optional :param requires_grad: Whether this variable requires a gradient, defaults to True. If False, we'll not compute the gradients on this variable. @@ -36,7 +36,8 @@ def __init__( raise Exception("If the variable does not require grad, none of its predecessors should require grad." f"In this case, following predecessors require grad: {_predecessor_requires_grad}") - self.value = str(value) + assert type(value) in [str, bytes], "Value must be a string or image (bytes)." + self.value = value self.gradients: Set[Variable] = set() self.gradients_context: Dict[Variable, str] = defaultdict(lambda: None) self.grad_fn = None @@ -44,9 +45,12 @@ def __init__( self.predecessors = set(predecessors) self.requires_grad = requires_grad self._reduce_meta = [] + + if requires_grad and (type(value) == bytes): + raise ValueError("Gradients are not yet supported for image inputs. Please provide a string input instead.") def __repr__(self): - return f"Variable(value={self.value}, role={self.get_role_description()}, grads={self.get_gradient_and_context_text()})" + return f"Variable(value={self.value}, role={self.get_role_description()}, grads={self.gradients})" def __str__(self): return str(self.value) @@ -114,26 +118,6 @@ def get_gradient_text(self) -> str: return "\n".join([g.value for g in self.gradients]) - def get_gradient_and_context_text(self) -> str: - """For the variable, aggregates and returns - i. the gradients - ii. the context for which the gradients are computed. - - :return: A string containing the aggregated gradients and their corresponding context. - :rtype: str - """ - - gradients = [] - for g in self.gradients: - if self.gradients_context[g] is None: - gradients.append(g.value) - else: - criticism_and_context = GRADIENT_TEMPLATE.format( - feedback=g.value, **(self.gradients_context[g])) - gradients.append(criticism_and_context) - gradient_text = "\n".join(gradients) - return gradient_text - def backward(self, engine: EngineLM = None): """ Backpropagate gradients through the computation graph starting from this variable. @@ -264,6 +248,7 @@ def get_grad_fn_name(name): return graph + def _check_and_reduce_gradients(variable: Variable, backward_engine: EngineLM) -> Set[Variable]: """ Check and reduce gradients for a given variable.