Skip to content

Commit

Permalink
textgrad vision integration init
Browse files Browse the repository at this point in the history
  • Loading branch information
mertyg committed Jul 5, 2024
1 parent 60198d3 commit e9c1240
Show file tree
Hide file tree
Showing 14 changed files with 580 additions and 85 deletions.
1 change: 1 addition & 0 deletions textgrad/autograd/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .functional import sum, aggregate
from .llm_ops import LLMCall, FormattedLLMCall, LLMCall_with_in_context_examples
from .multimodal_ops import MultimodalLLMCall, OrderedFieldsMultimodalLLMCall
from .function import Module
from .string_based_ops import StringBasedFunction
6 changes: 3 additions & 3 deletions textgrad/autograd/llm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
VARIABLE_OUTPUT_DEFAULT_ROLE)
from textgrad.variable import Variable
from textgrad.engine import EngineLM
from textgrad.config import validate_engine_or_get_default
from typing import List
from .llm_backward_prompts import (
EVALUATE_VARIABLE_INSTRUCTION,
Expand All @@ -23,13 +24,11 @@ def __init__(self, engine: EngineLM, system_prompt: Variable = None):
:param engine: engine to use for the LLM call
:type engine: EngineLM
:param input_role_description: role description for the input variable, defaults to VARIABLE_INPUT_DEFAULT_ROLE
:type input_role_description: str, optional
:param system_prompt: system prompt to use for the LLM call, default depends on the engine.
:type system_prompt: Variable, optional
"""
super().__init__()
self.engine = engine
self.engine = validate_engine_or_get_default(engine)
self.system_prompt = system_prompt
if self.system_prompt and self.system_prompt.get_role_description() is None:
self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE)
Expand Down Expand Up @@ -112,6 +111,7 @@ def _backward_through_llm_chain(variables: List[Variable],
prompt: str,
system_prompt: str,
backward_engine: EngineLM):

"""
Backward through the LLM to compute gradients for each variable, in the case where the output has gradients on them.
i.e. applying the chain rule.
Expand Down
6 changes: 6 additions & 0 deletions textgrad/autograd/multimodal_backward_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# First part of the prompt for the llm backward function
MULTIMODAL_CONVERSATION_TEMPLATE = (
"\n Above messages are the <LM_INPUT>\n\n"
"<LM_SYSTEM_PROMPT> {system_prompt} </LM_SYSTEM_PROMPT>\n\n"
"<LM_OUTPUT> {response_value} </LM_OUTPUT>\n\n"
)
226 changes: 226 additions & 0 deletions textgrad/autograd/multimodal_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
from textgrad import logger
from textgrad.defaults import (SYSTEM_PROMPT_DEFAULT_ROLE,
VARIABLE_OUTPUT_DEFAULT_ROLE)
from textgrad.variable import Variable
from textgrad.engine import EngineLM, validate_multimodal_engine
from typing import List
from .llm_backward_prompts import (
EVALUATE_VARIABLE_INSTRUCTION,
CONVERSATION_START_INSTRUCTION_BASE,
CONVERSATION_START_INSTRUCTION_CHAIN,
OBJECTIVE_INSTRUCTION_CHAIN,
OBJECTIVE_INSTRUCTION_BASE,
BACKWARD_SYSTEM_PROMPT,
)
from .multimodal_backward_prompts import MULTIMODAL_CONVERSATION_TEMPLATE
from typing import Union
from textgrad.config import validate_engine_or_get_default
from .function import Function, BackwardContext


class MultimodalLLMCall(Function):
def __init__(self,
engine: Union[str, EngineLM],
system_prompt: Variable = None):
super().__init__()
self.engine = validate_engine_or_get_default(engine)
validate_multimodal_engine(self.engine)

self.system_prompt = system_prompt
if self.system_prompt and self.system_prompt.get_role_description() is None:
self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE)


def forward(self,
inputs: List[Variable],
response_role_description: str = VARIABLE_OUTPUT_DEFAULT_ROLE) -> Variable:
# First ensure that all keys are present in the fields

# Assert that all variables are either strings or bytes
for variable in inputs:
if not isinstance(variable.get_value(), (str, bytes)):
raise ValueError(f"MultimodalLLMCall only accepts str or bytes, got {type(variable.get_value())}")

system_prompt_value = self.system_prompt.value if self.system_prompt else None
input_content = [inp.value for inp in inputs]
# Make the LLM Call
response_text = self.engine(input_content, system_prompt=system_prompt_value)

# Create the response variable
response = Variable(
value=response_text,
predecessors=[self.system_prompt, *inputs] if self.system_prompt else [*inputs],
role_description=response_role_description
)

logger.info(f"MultimodalLLMCall function forward", extra={"text": f"System:{system_prompt_value}\n{inputs}"})

# Populate the gradient function, using a container to store the backward function and the context
response.set_grad_fn(BackwardContext(backward_fn=self.backward,
response=response,
input_content=input_content,
system_prompt=system_prompt_value))

return response


def backward(self, response: Variable, input_content: List[Union[str, bytes]], system_prompt: str, backward_engine: EngineLM):
validate_multimodal_engine(backward_engine)

children_variables = response.predecessors
if response.get_gradient_text() == "":
self._backward_through_multimodal_llm_base(children_variables, response, input_content, system_prompt, backward_engine)
else:
self._backward_through_multimodal_llm_chain(children_variables, response, input_content, system_prompt, backward_engine)

@staticmethod
def _construct_multimodal_llm_chain_backward_content(backward_info: dict[str, str]) -> str:
content = [c for c in backward_info["input_content"]]
conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info)
backward_prompt = CONVERSATION_START_INSTRUCTION_CHAIN.format(conversation=conversation, **backward_info)
backward_prompt += OBJECTIVE_INSTRUCTION_CHAIN.format(**backward_info)
backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
content.append(backward_prompt)
return content

@staticmethod
def _backward_through_multimodal_llm_chain(variables: List[Variable],
response: Variable,
input_content: List[Union[str, bytes]],
system_prompt: str,
backward_engine: EngineLM):
for variable in variables:
if not variable.requires_grad:
continue

backward_info = {
"response_desc": response.get_role_description(),
"response_value": response.get_value(),
"response_gradient": response.get_gradient_text(),
"input_content": input_content,
"system_prompt": system_prompt,
"variable_desc": variable.get_role_description(),
"variable_short": variable.get_short_value()
}

backward_content = MultimodalLLMCall._construct_multimodal_llm_chain_backward_content(backward_info)

logger.info(f"_backward_through_llm prompt", extra={"_backward_through_llm": backward_content})
gradient_value = backward_engine(backward_content, system_prompt=BACKWARD_SYSTEM_PROMPT)
logger.info(f"_backward_through_llm gradient", extra={"_backward_through_llm": gradient_value})

var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}")
variable.gradients.add(var_gradients)
conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info)
variable.gradients_context[var_gradients] = {
"context": input_content + [conversation],
"response_desc": response.get_role_description(),
"variable_desc": variable.get_role_description()
}

if response._reduce_meta:
var_gradients._reduce_meta.extend(response._reduce_meta)
variable._reduce_meta.extend(response._reduce_meta)

@staticmethod
def _construct_multimodal_llm_base_backward_content(backward_info: dict[str, str]) -> str:
content = [c for c in backward_info["input_content"]]
conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info)
backward_prompt = CONVERSATION_START_INSTRUCTION_BASE.format(conversation=conversation, **backward_info)
backward_prompt += OBJECTIVE_INSTRUCTION_BASE.format(**backward_info)
backward_prompt += EVALUATE_VARIABLE_INSTRUCTION.format(**backward_info)
content.append(backward_prompt)
return content

@staticmethod
def _backward_through_multimodal_llm_base(variables: List[Variable],
response: Variable,
input_content: List[Union[str, bytes]],
system_prompt: str,
backward_engine: EngineLM):
for variable in variables:
if not variable.requires_grad:
continue

backward_info = {
"response_desc": response.get_role_description(),
"response_value": response.get_value(),
"input_content": input_content,
"system_prompt": system_prompt,
"variable_desc": variable.get_role_description(),
"variable_short": variable.get_short_value()
}

backward_content = MultimodalLLMCall._construct_multimodal_llm_base_backward_content(backward_info)

logger.info(f"_backward_through_llm prompt", extra={"_backward_through_llm": backward_content})
gradient_value = backward_engine(backward_content, system_prompt=BACKWARD_SYSTEM_PROMPT)
logger.info(f"_backward_through_llm gradient", extra={"_backward_through_llm": gradient_value})

conversation = MULTIMODAL_CONVERSATION_TEMPLATE.format(**backward_info)
var_gradients = Variable(value=gradient_value, role_description=f"feedback to {variable.get_role_description()}")
variable.gradients.add(var_gradients)
variable.gradients_context[var_gradients] = {
"context": input_content + [conversation],
"response_desc": response.get_role_description(),
"variable_desc": variable.get_role_description()
}

if response._reduce_meta:
var_gradients._reduce_meta.extend(response._reduce_meta)
variable._reduce_meta.extend(response._reduce_meta)



class OrderedFieldsMultimodalLLMCall(MultimodalLLMCall):
def __init__(self,
engine: Union[str, EngineLM],
fields: List[str],
system_prompt: Variable = None):

self.engine = validate_engine_or_get_default(engine)
validate_multimodal_engine(self.engine)

self.system_prompt = system_prompt
if self.system_prompt and self.system_prompt.get_role_description() is None:
self.system_prompt.set_role_description(SYSTEM_PROMPT_DEFAULT_ROLE)

self.fields = fields

def forward(self,
inputs: dict[str, Variable],
response_role_description: str = VARIABLE_OUTPUT_DEFAULT_ROLE) -> Variable:
# Assert that all variables are either strings or bytes
for variable in inputs.values():
if not isinstance(variable.get_value(), (str, bytes)):
raise ValueError(f"MultimodalLLMCall only accepts str or bytes, got {type(variable.get_value())}")

assert set(inputs.keys()) == set(self.fields), f"Expected fields {self.fields.keys()} but got {inputs.keys()}"
forward_content = []
for field in self.fields:
if type(inputs[field].value) == bytes:
forward_content.append(inputs[field].value)
else:
forward_content.append(f"{field}: {inputs[field].value}")

system_prompt_value = self.system_prompt.value if self.system_prompt else None

# Make the LLM Call
response_text = self.engine(forward_content, system_prompt=system_prompt_value)

# Create the response variable
response = Variable(
value=response_text,
predecessors=[self.system_prompt, *(list(inputs.values()))] if self.system_prompt else [*(list(inputs.values()))],
role_description=response_role_description
)

logger.info(f"MultimodalLLMCall function forward", extra={"text": f"System:{system_prompt_value}\n{forward_content}"})

# Populate the gradient function, using a container to store the backward function and the context
response.set_grad_fn(BackwardContext(backward_fn=self.backward,
response=response,
input_content=forward_content,
system_prompt=system_prompt_value))

return response
11 changes: 11 additions & 0 deletions textgrad/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ def set_backward_engine(engine: Union[EngineLM, str], override: bool = False):
if isinstance(engine, str):
engine = get_engine(engine)
singleton_backward_engine.set_engine(engine, override=override)


def validate_engine_or_get_default(engine):
if (engine is None) and (SingletonBackwardEngine().get_engine() is None):
raise Exception(
"No engine provided. Either provide an engine as the argument to this call, or use `textgrad.set_backward_engine(engine)` to set the backward engine.")
elif engine is None:
engine = SingletonBackwardEngine().get_engine()
if isinstance(engine, str):
engine = get_engine(engine)
return engine
22 changes: 20 additions & 2 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@
"together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf",
}

# Any better way to do this?
__MULTIMODAL_ENGINES__ = ["gpt-4-turbo",
"gpt-4o",
"claude-3-5-sonnet-20240620",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229",
"claude-3-haiku-20240307",
"gpt-4-turbo-2024-04-09",
]

def _check_if_multimodal(engine_name: str):
return any([name == engine_name for name in __MULTIMODAL_ENGINES__])

def validate_multimodal_engine(engine):
if not _check_if_multimodal(engine.model_string):
raise ValueError(
f"The engine provided is not multimodal. Please provide a multimodal engine, one of the following: {__MULTIMODAL_ENGINES__}")

def get_engine(engine_name: str, **kwargs) -> EngineLM:
if engine_name in __ENGINE_NAME_SHORTCUTS__:
engine_name = __ENGINE_NAME_SHORTCUTS__[engine_name]
Expand All @@ -16,10 +34,10 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:

if (("gpt-4" in engine_name) or ("gpt-3.5" in engine_name)):
from .openai import ChatOpenAI
return ChatOpenAI(model_string=engine_name, **kwargs)
return ChatOpenAI(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs)
elif "claude" in engine_name:
from .anthropic import ChatAnthropic
return ChatAnthropic(model_string=engine_name, **kwargs)
return ChatAnthropic(model_string=engine_name, is_multimodal=_check_if_multimodal(engine_name), **kwargs)
elif "gemini" in engine_name:
from .gemini import ChatGemini
return ChatGemini(model_string=engine_name, **kwargs)
Expand Down
Loading

0 comments on commit e9c1240

Please sign in to comment.