From 2aac4ffbec9f5c6e3f51324a35e96865a6ff2188 Mon Sep 17 00:00:00 2001 From: Li Yin Date: Mon, 16 Dec 2024 23:47:47 -0800 Subject: [PATCH] build cycle multi-hop rag, trainable, added component graph with cycle, output graph as simplified version of the dynamic computation graph --- adalflow/adalflow/core/generator.py | 7 +- adalflow/adalflow/core/retriever.py | 8 +- adalflow/adalflow/optim/grad_component.py | 12 +- adalflow/adalflow/optim/parameter.py | 356 +++++++++++++++- .../hotpot_qa/adal_exp/build_multi_hop_rag.py | 381 +++++++++++++----- .../hotpot_qa/adal_exp/build_vanilla_rag.py | 2 +- .../adal_exp/train_multi_hop_rag_cycle.py | 163 ++++++++ 7 files changed, 815 insertions(+), 114 deletions(-) create mode 100644 benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index baedd8fb..63a85ba6 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -506,6 +506,7 @@ def forward( self.model_kwargs, model_kwargs ), } + output = self.call(**input_args, id=id) # 2. Generate a Parameter object from the output combined_prompt_kwargs = compose_model_kwargs(self.prompt_kwargs, prompt_kwargs) @@ -527,9 +528,12 @@ def forward( name=self.name + "_output", role_desc=f"Output from (llm) {self.name}", param_type=ParameterType.GENERATOR_OUTPUT, + data_id=id, ) response.set_predecessors(predecessors) - response.trace_forward_pass(input_args=input_args, full_response=output) + response.trace_forward_pass( + input_args=input_args, full_response=output, id=self.id, name=self.name + ) # *** special to the generator *** response.trace_api_kwargs(api_kwargs=self._trace_api_kwargs) # attach the demo to the demo parameter @@ -755,6 +759,7 @@ def _backward_through_one_predecessor( score=response._score, # add score to gradient param_type=ParameterType.GRADIENT, from_response_id=response.id, + data_id=response.data_id, ) pred.add_gradient(var_gradient) pred.set_score(response._score) diff --git a/adalflow/adalflow/core/retriever.py b/adalflow/adalflow/core/retriever.py index fb65a298..1e7e40d3 100644 --- a/adalflow/adalflow/core/retriever.py +++ b/adalflow/adalflow/core/retriever.py @@ -128,7 +128,13 @@ def forward( ) if input is None: raise ValueError("Input cannot be empty") - response = super().forward(input, top_k=top_k, **kwargs) + response = super().forward(input, top_k=top_k, id=id, **kwargs) + response.trace_forward_pass( + input_args={"input": input, "top_k": top_k}, + full_response=response, + id=self.id, + name=self.name, + ) response.param_type = ( ParameterType.RETRIEVER_OUTPUT ) # be more specific about the type diff --git a/adalflow/adalflow/optim/grad_component.py b/adalflow/adalflow/optim/grad_component.py index b73e536e..b9964aba 100644 --- a/adalflow/adalflow/optim/grad_component.py +++ b/adalflow/adalflow/optim/grad_component.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from collections import OrderedDict +import uuid import logging if TYPE_CHECKING: @@ -13,6 +14,7 @@ from adalflow.core.component import Component from adalflow.optim.function import BackwardContext + __all__ = ["GradComponent"] log = logging.getLogger(__name__) @@ -31,10 +33,12 @@ class GradComponent(Component): """ backward_engine: "BackwardEngine" _component_type = "grad" + id = None def __init__(self, *args, **kwargs): super().__init__() super().__setattr__("backward_engine", None) + super().__setattr__("id", str(uuid.uuid4())) def __call__(self, *args, **kwargs): if self.training: @@ -122,9 +126,15 @@ def forward(self, *args, **kwargs) -> "Parameter": name=self.name + "_output", role_desc=self.name + " response", param_type=ParameterType.OUTPUT, + data_id=kwargs.get("id", None), ) response.set_predecessors(predecessors) - response.trace_forward_pass(input_args=input_args, full_response=call_response) + response.trace_forward_pass( + input_args=input_args, + full_response=call_response, + id=self.id, + name=self.name, + ) response.set_grad_fn( BackwardContext( backward_fn=self.backward, diff --git a/adalflow/adalflow/optim/parameter.py b/adalflow/adalflow/optim/parameter.py index 5b60995c..1e553302 100644 --- a/adalflow/adalflow/optim/parameter.py +++ b/adalflow/adalflow/optim/parameter.py @@ -46,8 +46,18 @@ class GradientContext: ) +@dataclass(frozen=True) +class ComponentNode: + """Used to represent a node in the component graph.""" + + id: str = field(metadata={"desc": "The unique id of the component"}) + name: str = field(metadata={"desc": "The name of the component"}) + + @dataclass class ComponentTrace: + name: str = field(metadata={"desc": "The name of the component"}, default=None) + id: str = field(metadata={"desc": "The unique id of the component"}, default=None) input_args: Dict[str, Any] = field( metadata={"desc": "The input arguments of the GradComponent forward"}, default=None, @@ -159,7 +169,7 @@ class Parameter(Generic[T]): def __init__( self, *, - id: Optional[str] = None, + id: Optional[str] = None, # unique id of the parameter data: T = None, # for generator output, the data will be set up as raw_response data_id: str = None, # for tracing the data item in the training/val/test set requires_opt: bool = True, @@ -300,13 +310,21 @@ def trace_optimizer(self, api_kwargs: Dict[str, Any], response: "TGDData"): ############################################################################################################ # Trace component, include trace_forward_pass & trace_api_kwargs for now ############################################################################################################ - def trace_forward_pass(self, input_args: Dict[str, Any], full_response: object): - r"""Trace the forward pass of the parameter.""" + def trace_forward_pass( + self, + input_args: Dict[str, Any], + full_response: object, + id: str = None, + name: str = None, + ): + r"""Trace the forward pass of the parameter. Adding the component information to the trace""" self.input_args = input_args self.full_response = full_response # TODO: remove the input_args and full_response to use component_trace self.component_trace.input_args = input_args self.component_trace.full_response = full_response + self.component_trace.id = id + self.component_trace.name = name def trace_api_kwargs(self, api_kwargs: Dict[str, Any]): r"""Trace the api_kwargs for components like Generator and Retriever that pass to the model client.""" @@ -515,20 +533,20 @@ def build_graph(node: "Parameter"): build_graph(root) return nodes, edges - def report_cycle(cycle_nodes: List["Parameter"]): - """ - Report the detected cycle and provide guidance to the user on how to avoid it. - """ - cycle_names = [node.name for node in cycle_nodes] - log.warning(f"Cycle detected: {' -> '.join(cycle_names)}") - print(f"Cycle detected in the graph: {' -> '.join(cycle_names)}") - - # Provide guidance on how to avoid the cycle - print("To avoid the cycle, consider the following strategies:") - print("- Modify the graph structure to remove cyclic dependencies.") - print( - "- Check the relationships between these nodes to ensure no feedback loops." - ) + # def report_cycle(cycle_nodes: List["Parameter"]): + # """ + # Report the detected cycle and provide guidance to the user on how to avoid it. + # """ + # cycle_names = [node.name for node in cycle_nodes] + # log.warning(f"Cycle detected: {' -> '.join(cycle_names)}") + # print(f"Cycle detected in the graph: {' -> '.join(cycle_names)}") + + # # Provide guidance on how to avoid the cycle + # print("To avoid the cycle, consider the following strategies:") + # print("- Modify the graph structure to remove cyclic dependencies.") + # print( + # "- Check the relationships between these nodes to ensure no feedback loops." + # ) def backward( self, @@ -799,6 +817,7 @@ def wrap_and_escape(text, width=40): node_label = ( f"" + f"" f"" f"" f"" @@ -838,6 +857,12 @@ def wrap_and_escape(text, width=40): if n.tgd_optimizer_trace is not None: node_label += f"" + # show component trace, id and name + if n.component_trace.id is not None: + node_label += f"" + if n.component_trace.name is not None: + node_label += f"" + node_label += "
Name: {wrap_and_escape(n.id)}
Name: {wrap_and_escape(n.name)}
Role: {wrap_and_escape(n.role_desc.capitalize())}
Value: {wrap_and_escape(n.data)}
TGD Optimizer Trace: {wrap_and_escape(str(n.tgd_optimizer_trace))}
Component Trace ID: {wrap_and_escape(str(n.component_trace.id))}
Component Trace Name: {wrap_and_escape(str(n.component_trace.name))}
" # check if the name exists in dot if n.name in node_names: @@ -906,6 +931,303 @@ def wrap_and_escape(text, width=40): ) return {"graph_path": filepath, "root_path": f"{filepath}_root.json"} + def draw_output_subgraph( + self, + add_grads: bool = True, + format: str = "png", + rankdir: str = "TB", + filepath: str = None, + ): + """ + Build and visualize a subgraph containing only OUTPUT parameters. + + Args: + add_grads (bool): Whether to include gradient edges. + format (str): Format for output (e.g., png, svg). + rankdir (str): Graph layout direction ("LR" or "TB"). + filepath (str): Path to save the graph. + """ + assert rankdir in ["LR", "TB"] + from adalflow.utils.global_config import get_adalflow_default_root_path + + try: + from graphviz import Digraph + + except ImportError as e: + raise ImportError( + "Please install graphviz using 'pip install graphviz' to use this feature" + ) from e + + root_path = get_adalflow_default_root_path() + # # prepare the log directory + # log_dir = os.path.join(root_path, "logs") + + # # Set up TensorBoard logging + # writer = SummaryWriter(log_dir) + + filename = f"trace_component_output_graph_{self.name}_id_{self.id}" + filepath = ( + os.path.join(filepath, filename) + if filepath + else os.path.join(root_path, "graphs", filename) + ) + print(f"Saving graph to {filepath}.{format}") + + # Step 1: Collect OUTPUT nodes and edges + nodes, edges = self._collect_output_subgraph() + + # Step 2: Render using Graphviz + filename = f"output_subgraph_{self.name}_{self.id}" + filepath = filepath or f"./{filename}" + print(f"Saving OUTPUT subgraph to {filepath}.{format}") + + dot = Digraph(format=format, graph_attr={"rankdir": rankdir}) + node_ids = set() + + for node in nodes: + node_label = f""" + + + + """ + # add the component trace id and name + if node.component_trace.id is not None: + node_label += f"" + if node.component_trace.name is not None: + node_label += f"" + + node_label += "
Name:{node.name}
Type:{node.param_type}
Value:{node.get_short_value()}
Component Trace ID:{node.component_trace.id}
Component Trace Name:{node.component_trace.name}
" + dot.node( + name=node.id, + label=f"<{node_label}>", + shape="plaintext", + color="lightblue" if node.requires_opt else "gray", + ) + node_ids.add(node.id) + + for source, target in edges: + if source.id in node_ids and target.id in node_ids: + dot.edge(source.id, target.id) + + # Step 3: Save and render + dot.render(filepath, cleanup=True) + print(f"Graph saved as {filepath}.{format}") + + def draw_component_subgraph( + self, + format: str = "png", + rankdir: str = "TB", + filepath: str = None, + ): + """ + Build and visualize a subgraph containing only OUTPUT parameters. + + Args: + format (str): Format for output (e.g., png, svg). + rankdir (str): Graph layout direction ("LR" or "TB"). + filepath (str): Path to save the graph. + """ + assert rankdir in ["LR", "TB"] + + try: + from graphviz import Digraph + except ImportError as e: + raise ImportError( + "Please install graphviz using 'pip install graphviz' to use this feature" + ) from e + + # Step 1: Collect OUTPUT nodes and edges + component_nodes, edges, component_nodes_orders = ( + self._collect_component_subgraph() + ) + + # Step 2: Setup graph rendering + filename = f"output_component_{self.name}_{self.id}" + filepath = filepath or f"./{filename}" + print(f"Saving OUTPUT subgraph to {filepath}.{format}") + + dot = Digraph(format=format, graph_attr={"rankdir": rankdir}) + + # Add nodes + for node in component_nodes: + node_label = f""" + + + """ + + # add the list of orders + if node.id in component_nodes_orders: + node_label += f"" + node_label += "
ID:{node.id}
Name:{node.name}
Order:{component_nodes_orders[node.id]}
" + dot.node( + name=node.id, + label=f"<{node_label}>", + shape="plaintext", + color="lightblue", + ) + + # Add edges with order labels + for source_id, target_id, edge_order in edges: + dot.edge(source_id, target_id, label=str(edge_order), color="black") + + # Step 3: Save and render + dot.render(filepath, cleanup=True) + print(f"Graph saved as {filepath}.{format}") + + def _collect_output_subgraph( + self, + ) -> Tuple[Set["Parameter"], List[Tuple["Parameter", "Parameter"]]]: + """ + Collect nodes of type OUTPUT and their relationships. + + Returns: + nodes (Set[Parameter]): Set of OUTPUT nodes. + edges (List[Tuple[Parameter, Parameter]]): Edges between OUTPUT nodes. + """ + output_nodes = set() + edges = [] + + visited = set() # check component_trace.id and name + + def traverse(node: "Parameter"): + if node in visited: + return + visited.add(node) + + # Add OUTPUT nodes to the set + if ( + node.param_type == ParameterType.OUTPUT + or "OUTPUT" in node.param_type.name + ): + output_nodes.add(node) + + # Traverse predecessors and add edges + for pred in node.predecessors: + if ( + pred.param_type == ParameterType.OUTPUT + or "OUTPUT" in pred.param_type.name + ): + edges.append((pred, node)) + traverse(pred) + + traverse(self) + return output_nodes, edges + + # def _collect_output_subgraph( + # self, + # ) -> Tuple[Set[Tuple[str, str]], List[Tuple[str, str]]]: + # """ + # Collect OUTPUT nodes and their relationships using component_trace information. + + # Returns: + # nodes (Set[Tuple[str, str]]): Set of component nodes (component_id, label). + # edges (List[Tuple[str, str]]): Edges between component IDs. + # """ + # component_nodes = set() # To store component nodes as (component_id, label) + # edges = [] # To store edges between components + + # visited = set() # Track visited parameters to avoid cycles + + # def traverse(node: "Parameter"): + # if node in visited: + # return + # visited.add(node) + + # # Only consider OUTPUT-type parameters + # if ( + # node.param_type == ParameterType.OUTPUT + # or "OUTPUT" in node.param_type.name + # ): + # component_id = node.component_trace.id + # component_name = node.component_trace.name or "Unknown Component" + # label = f"{component_name}\nID: {component_id}" + + # # Add the component as a node + # component_nodes.add((component_id, label)) + + # # Traverse predecessors and add edges + # for pred in node.predecessors: + # if pred.param_type == ParameterType.OUTPUT: + # pred_id = pred.component_trace.id + # pred_name = pred.component_trace.name or "Unknown Component" + + # # Add predecessor as a node + # pred_label = f"{pred_name}\nID: {pred_id}" + # component_nodes.add((pred_id, pred_label)) + + # # Add edge between components + # edges.append((pred_id, component_id)) + + # # Recursive traversal + # traverse(pred) + + # # Start traversal from the current parameter + # traverse(self) + # return component_nodes, edges + + def _collect_component_subgraph( + self, + ) -> Tuple[Set[ComponentNode], List[Tuple[str, str]]]: + """ + Collect OUTPUT nodes and their relationships as ComponentNodes. + + Returns: + component_nodes (Set[ComponentNode]): Set of component nodes (id and name only). + edges (List[Tuple[str, str]]): Edges between component IDs. + """ + component_nodes = set() # To store component nodes as ComponentNode + component_nodes_orders: Dict[str, List[int]] = ( + {} + ) # To store component nodes order + edges = [] # To store edges between component IDs + + visited = set() # Track visited parameters to avoid cycles + edge_counter = [0] # Mutable counter for edge order tracking + + def traverse(node: "Parameter", depth: int): + if node in visited: + return + visited.add(node) + + # Check if node is of OUTPUT type + if ( + node.param_type == ParameterType.OUTPUT + or "OUTPUT" in node.param_type.name + ): + component_id = node.component_trace.id or f"unknown_id_{uuid.uuid4()}" + component_name = node.component_trace.name or "Unknown Component" + + # Create a ComponentNode and add to the set + component_node = ComponentNode(id=component_id, name=component_name) + component_nodes.add(component_node) + + # Traverse predecessors and add edges + for pred in node.predecessors: + pred_id = pred.component_trace.id or f"unknown_id_{uuid.uuid4()}" + pred_name = pred.component_trace.name or "Unknown Component" + + # Add edge if predecessor is also of OUTPUT type + if ( + pred.param_type == ParameterType.OUTPUT + or "OUTPUT" in pred.param_type.name + ): + edges.append((pred_id, component_id, depth)) + component_nodes.add(ComponentNode(id=pred_id, name=pred_name)) + edge_counter[0] += 1 + + traverse(pred, depth + 1) + + # Start traversal from the current parameter + traverse(self, depth=0) + # Reverse the edge order + # total_edges = len(edges) + # edges = [ + # (source, target, (total_edges - 1) - edge_number) + # for idx, (source, target, edge_number) in enumerate(edges) + # ] + + return component_nodes, edges, component_nodes_orders + def to_dict(self): return { "name": self.name, diff --git a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py index cebcfdf2..ba3f8bfc 100644 --- a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py @@ -62,7 +62,9 @@ class DeduplicateList(adal.GradComponent): def __init__(self): super().__init__() - def call(self, exisiting_list: List[str], new_list: List[str]) -> List[str]: + def call( + self, exisiting_list: List[str], new_list: List[str], id: str = None + ) -> List[str]: seen = set() return [x for x in exisiting_list + new_list if not (x in seen or seen.add(x))] @@ -78,7 +80,162 @@ def backward(self, *args, **kwargs): # NOTE: deprecated -class MultiHopRetriever(adal.Retriever): +# class MultiHopRetriever(adal.Retriever): +# def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): +# super().__init__() + +# self.passages_per_hop = passages_per_hop +# self.max_hops = max_hops + +# self.data_parser = adal.DataClassParser( +# data_class=QueryRewritterData, return_data_class=True, format_type="json" +# ) + +# # Grad Component +# self.query_generators: List[adal.Generator] = [] +# for i in range(self.max_hops): +# self.query_generators.append( +# adal.Generator( +# name=f"query_generator_{i}", +# model_client=model_client, +# model_kwargs=model_kwargs, +# prompt_kwargs={ +# "few_shot_demos": Parameter( +# name="few_shot_demos_1", +# data=None, +# role_desc="To provide few shot demos to the language model", +# requires_opt=True, +# param_type=ParameterType.DEMOS, +# ), +# "task_desc_str": Parameter( +# name="task_desc_str", +# data="""Write a simple search query that will help answer a complex question. + +# You will receive a context(may contain relevant facts) and a question. +# Think step by step.""", +# role_desc="Task description for the language model", +# requires_opt=True, +# param_type=ParameterType.PROMPT, +# ), +# "output_format_str": self.data_parser.get_output_format_str(), +# }, +# template=query_template, +# output_processors=self.data_parser, +# use_cache=True, +# ) +# ) +# self.retriever = DspyRetriever(top_k=passages_per_hop) +# self.deduplicater = DeduplicateList() + +# @staticmethod +# def context_to_str(context: List[str]) -> str: +# return "\n".join(context) + +# @staticmethod +# def deduplicate(seq: list[str]) -> list[str]: +# """ +# Source: https://stackoverflow.com/a/480227/1493011 +# """ + +# seen = set() +# return [x for x in seq if not (x in seen or seen.add(x))] + +# def call(self, *, question: str, id: str = None) -> adal.RetrieverOutput: +# context = [] +# print(f"question: {question}") +# for i in range(self.max_hops): +# gen_out = self.query_generators[i]( +# prompt_kwargs={ +# "context": self.context_to_str(context), +# "question": question, +# }, +# id=id, +# ) + +# query = gen_out.data.query if gen_out.data and gen_out.data.query else None + +# print(f"query {i}: {query}") + +# retrieve_out = self.retriever.call(input=query) +# passages = retrieve_out[0].documents +# context = self.deduplicate(context + passages) +# out = [adal.RetrieverOutput(documents=context, query=query, doc_indices=[])] +# return out + +# def forward(self, *, question: str, id: str = None) -> adal.Parameter: +# # assemble the foundamental building blocks +# context = [] +# print(f"question: {question}") +# # 1. make question a parameter as generator does not have it yet +# # can create the parameter at the leaf, but not the intermediate nodes +# question_param = adal.Parameter( +# name="question", +# data=question, +# role_desc="The question to be answered", +# requires_opt=True, +# param_type=ParameterType.INPUT, +# ) +# context_param = adal.Parameter( +# name="context", +# data=context, +# role_desc="The context to be used for the query", +# requires_opt=True, +# param_type=ParameterType.INPUT, +# ) +# context_param.add_successor_map_fn( +# successor=self.query_generators[0], +# map_fn=lambda x: self.context_to_str(x.data), +# ) + +# for i in range(self.max_hops): + +# gen_out = self.query_generators[i].forward( +# prompt_kwargs={ +# "context": context_param, +# "question": question_param, +# }, +# id=id, +# ) + +# success_map_fn = lambda x: ( # noqa E731 +# x.full_response.data.query +# if x.full_response +# and x.full_response.data +# and x.full_response.data.query +# else None +# ) +# print(f"query {i}: {success_map_fn(gen_out)}") + +# gen_out.add_successor_map_fn( +# successor=self.retriever, map_fn=success_map_fn +# ) + +# retrieve_out = self.retriever.forward(input=gen_out) + +# def retrieve_out_map_fn(x: adal.Parameter): +# return x.data[0].documents if x.data and x.data[0].documents else [] + +# print(f"retrieve_out: {retrieve_out}") + +# retrieve_out.add_successor_map_fn( +# successor=self.deduplicater, map_fn=retrieve_out_map_fn +# ) + +# context_param = self.deduplicater.forward( +# exisiting_list=context_param, new_list=retrieve_out +# ) + +# context_param.param_type = ParameterType.RETRIEVER_OUTPUT + +# return context_param + +query_generator_task_desc = """Write a simple search query that will help answer a complex question. + +You will receive a context(may contain relevant facts) and a question. +Think step by step.""" + + +class MultiHopRetrieverCycle(adal.Retriever): def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): super().__init__() @@ -89,39 +246,33 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): data_class=QueryRewritterData, return_data_class=True, format_type="json" ) - # Grad Component - self.query_generators: List[adal.Generator] = [] - for i in range(self.max_hops): - self.query_generators.append( - adal.Generator( - name=f"query_generator_{i}", - model_client=model_client, - model_kwargs=model_kwargs, - prompt_kwargs={ - "few_shot_demos": Parameter( - name="few_shot_demos_1", - data=None, - role_desc="To provide few shot demos to the language model", - requires_opt=True, - param_type=ParameterType.DEMOS, - ), - "task_desc_str": Parameter( - name="task_desc_str", - data="""Write a simple search query that will help answer a complex question. + # only one generator which will be used in a loop, called max_hops times + self.query_generator: adal.Generator = adal.Generator( + name="query_generator", + model_client=model_client, + model_kwargs=model_kwargs, + prompt_kwargs={ + "few_shot_demos": Parameter( + name="few_shot_demos", + data=None, + role_desc="To provide few shot demos to the language model", + requires_opt=False, + param_type=ParameterType.DEMOS, + ), + "task_desc_str": Parameter( + name="task_desc_str", + data=query_generator_task_desc, + role_desc="Task description for the language model", + requires_opt=True, + param_type=ParameterType.PROMPT, + ), + "output_format_str": self.data_parser.get_output_format_str(), + }, + template=query_template, + output_processors=self.data_parser, + use_cache=True, + ) -You will receive a context(may contain relevant facts) and a question. -Think step by step.""", - role_desc="Task description for the language model", - requires_opt=True, - param_type=ParameterType.PROMPT, - ), - "output_format_str": self.data_parser.get_output_format_str(), - }, - template=query_template, - output_processors=self.data_parser, - use_cache=True, - ) - ) self.retriever = DspyRetriever(top_k=passages_per_hop) self.deduplicater = DeduplicateList() @@ -129,78 +280,62 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): def context_to_str(context: List[str]) -> str: return "\n".join(context) - @staticmethod - def deduplicate(seq: list[str]) -> list[str]: - """ - Source: https://stackoverflow.com/a/480227/1493011 - """ - - seen = set() - return [x for x in seq if not (x in seen or seen.add(x))] - - def call(self, *, question: str, id: str = None) -> adal.RetrieverOutput: - context = [] - print(f"question: {question}") - for i in range(self.max_hops): - gen_out = self.query_generators[i]( - prompt_kwargs={ - "context": self.context_to_str(context), - "question": question, - }, - id=id, - ) - - query = gen_out.data.query if gen_out.data and gen_out.data.query else None + def call(self, *, input: str, id: str = None) -> List[adal.RetrieverOutput]: + # assemble the foundamental building blocks + printc(f"question: {input}", "yellow") + out = self.forward(input=input, id=id) - print(f"query {i}: {query}") + if not isinstance(out, adal.Parameter): + raise ValueError("The output should be a parameter") - retrieve_out = self.retriever.call(input=query) - passages = retrieve_out[0].documents - context = self.deduplicate(context + passages) - out = [adal.RetrieverOutput(documents=context, query=query, doc_indices=[])] - return out + return out.data # or full response its up to users - def forward(self, *, question: str, id: str = None) -> adal.Parameter: + def forward(self, *, input: str, id: str = None) -> adal.Parameter: # assemble the foundamental building blocks context = [] - print(f"question: {question}") + # queries: List[str] = [] + print(f"question: {input}") # 1. make question a parameter as generator does not have it yet # can create the parameter at the leaf, but not the intermediate nodes question_param = adal.Parameter( name="question", - data=question, + data=input, role_desc="The question to be answered", requires_opt=True, param_type=ParameterType.INPUT, ) - context_param = adal.Parameter( - name="context", - data=context, - role_desc="The context to be used for the query", - requires_opt=True, - param_type=ParameterType.INPUT, - ) - context_param.add_successor_map_fn( - successor=self.query_generators[0], - map_fn=lambda x: self.context_to_str(x.data), - ) + # context_param = adal.Parameter( + # name="context", + # data=context, + # role_desc="The context to be used for the query", + # requires_opt=True, + # param_type=ParameterType.INPUT, + # ) + # context_param.add_successor_map_fn( + # successor=self.query_generator, + # map_fn=lambda x: self.context_to_str(x.data), + # ) for i in range(self.max_hops): - gen_out = self.query_generators[i].forward( + gen_out = self.query_generator.forward( prompt_kwargs={ - "context": context_param, + "context": context, "question": question_param, }, id=id, ) - + # extract the query from the generator output success_map_fn = lambda x: ( # noqa E731 x.full_response.data.query if x.full_response and x.full_response.data and x.full_response.data.query - else None + else ( + x.full_response.raw_response + if x.full_response and x.full_response.raw_response + else None + ) ) print(f"query {i}: {success_map_fn(gen_out)}") @@ -208,27 +343,41 @@ def forward(self, *, question: str, id: str = None) -> adal.Parameter: successor=self.retriever, map_fn=success_map_fn ) - retrieve_out = self.retriever.forward(input=gen_out) + # retrieve the passages + retrieve_out: adal.Parameter = self.retriever.forward(input=gen_out, id=id) def retrieve_out_map_fn(x: adal.Parameter): return x.data[0].documents if x.data and x.data[0].documents else [] - print(f"retrieve_out: {retrieve_out}") - + # add the map function to the retrieve_out retrieve_out.add_successor_map_fn( successor=self.deduplicater, map_fn=retrieve_out_map_fn ) - context_param = self.deduplicater.forward( - exisiting_list=context_param, new_list=retrieve_out + # combine the context + deduplicated passages + context = self.deduplicater.forward( + exisiting_list=context, new_list=retrieve_out, id=id ) - context_param.param_type = ParameterType.RETRIEVER_OUTPUT + context.param_type = ParameterType.RETRIEVER_OUTPUT + # used as the final outptu - return context_param + # convert the context to the retriever output + def context_to_retrover_output(x): + return [ + adal.RetrieverOutput( + documents=x.data, + query=[input] + [success_map_fn(gen_out)], + doc_indices=[], + ) + ] + + context.data = context_to_retrover_output(context) + + return context -class MultiHopRetriever2(adal.Retriever): +class MultiHopRetriever(adal.Retriever): def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2): super().__init__() @@ -395,13 +544,13 @@ def context_to_retrover_output(x): context.data = context_to_retrover_output(context) - printc(f"MultiHopRetriever2 grad fn: {context.grad_fn}", "yellow") + printc(f"MultiHopRetriever grad fn: {context.grad_fn}", "yellow") return context def backward(self, *args, **kwargs): - printc(f"MultiHopRetriever2 backward: {args}", "yellow") + printc(f"MultiHopRetriever backward: {args}", "yellow") super().backward(*args, **kwargs) return @@ -418,7 +567,24 @@ def __init__( model_client=model_client, model_kwargs=model_kwargs, ) - self.retriever = MultiHopRetriever2( + self.retriever = MultiHopRetriever( + model_client=model_client, + model_kwargs=model_kwargs, + passages_per_hop=passages_per_hop, + max_hops=max_hops, + ) + + +class MultiHopRAGCycle(VanillaRAG): + def __init__( + self, passages_per_hop=3, max_hops=2, model_client=None, model_kwargs=None + ): + super().__init__( + passages_per_hop=passages_per_hop, + model_client=model_client, + model_kwargs=model_kwargs, + ) + self.retriever = MultiHopRetrieverCycle( model_client=model_client, model_kwargs=model_kwargs, passages_per_hop=passages_per_hop, @@ -451,13 +617,40 @@ def test_multi_hop_retriever(): output.draw_graph() +def test_multi_hop_retriever_cycle(): + + from use_cases.config import ( + gpt_3_model, + ) + + multi_hop_retriever = MultiHopRetrieverCycle( + **gpt_3_model, + passages_per_hop=3, + max_hops=2, + ) + + question = "How many storeys are in the castle that David Gregory inherited?" + + # eval mode + output = multi_hop_retriever.call(input=question, id="1") + print(output) + + # train mode + multi_hop_retriever.train() + output = multi_hop_retriever.forward(input=question, id="1") + print(output) + output.draw_graph() + output.draw_output_subgraph() + output.draw_component_subgraph() + + def test_multi_hop_retriever2(): from use_cases.config import ( gpt_3_model, ) - multi_hop_retriever = MultiHopRetriever2( + multi_hop_retriever = MultiHopRetriever( **gpt_3_model, passages_per_hop=3, max_hops=2, @@ -531,4 +724,6 @@ def test_multi_hop_rag(): # get_logger(level="DEBUG") # test_multi_hop_retriever() # test_multi_hop_retriever2() - test_multi_hop_rag() + + test_multi_hop_retriever_cycle() + # test_multi_hop_rag() diff --git a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py index 3eae0598..fd4a1dcd 100644 --- a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py +++ b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py @@ -161,7 +161,7 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None): ), "few_shot_demos": adal.Parameter( data=None, - requires_opt=True, + requires_opt=None, role_desc="To provide few shot demos to the language model", param_type=adal.ParameterType.DEMOS, ), diff --git a/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py new file mode 100644 index 00000000..376ef664 --- /dev/null +++ b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py @@ -0,0 +1,163 @@ +from typing import Any, Callable, Dict, Tuple + +import adalflow as adal +from adalflow.eval.answer_match_acc import AnswerMatchAcc +from adalflow.datasets.types import HotPotQAData + +from benchmarks.hotpot_qa._adal_train import load_datasets +from benchmarks.hotpot_qa.adal_exp.build_multi_hop_rag import MultiHopRAGCycle +from use_cases.config import gpt_3_model, gpt_4o_model + + +# TODO: look more into the loss function +# TODO: test LLM judge too. +class MultiHopRAGAdal(adal.AdalComponent): + def __init__( + self, + model_client: adal.ModelClient, + model_kwargs: Dict, + backward_engine_model_config: Dict | None = None, + teacher_model_config: Dict | None = None, + text_optimizer_model_config: Dict | None = None, + ): + task = MultiHopRAGCycle( + model_client=model_client, + model_kwargs=model_kwargs, + passages_per_hop=3, + max_hops=2, + ) + eval_fn = AnswerMatchAcc(type="fuzzy_match").compute_single_item + loss_fn = adal.EvalFnToTextLoss( + eval_fn=eval_fn, eval_fn_desc="fuzzy_match: 1 if str(y) in str(y_gt) else 0" + ) + super().__init__( + task=task, + eval_fn=eval_fn, + loss_fn=loss_fn, + backward_engine_model_config=backward_engine_model_config, + teacher_model_config=teacher_model_config, + text_optimizer_model_config=text_optimizer_model_config, + ) + + # tell the trainer how to call the task + def prepare_task(self, sample: HotPotQAData) -> Tuple[Callable[..., Any], Dict]: + if self.task.training: + return self.task.forward, {"question": sample.question, "id": sample.id} + else: + return self.task.call, {"question": sample.question, "id": sample.id} + + # TODO: use two map fn to make the cde even simpler + + # eval mode: get the generator output, directly engage with the eval_fn + def prepare_eval(self, sample: HotPotQAData, y_pred: adal.GeneratorOutput) -> float: + y_label = "" + if y_pred and y_pred.data and y_pred.data.answer: + y_label = y_pred.data.answer + return self.eval_fn, {"y": y_label, "y_gt": sample.answer} + + # train mode: get the loss and get the data from the full_response + def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter): + # prepare gt parameter + y_gt = adal.Parameter( + name="y_gt", + data=sample.answer, + eval_input=sample.answer, + requires_opt=False, + ) + + # pred's full_response is the output of the task pipeline which is GeneratorOutput + pred.eval_input = ( + pred.full_response.data.answer + if pred.full_response + and pred.full_response.data + and pred.full_response.data.answer + else "" + ) + return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}} + + +# Note: diagnose is quite helpful, it helps you to quickly check if the evalfunction is the right metrics +# i checked the eval which does fuzzy match, and found some yes and Yes are not matched, then converted both strings to lower and +# the performances have gone up from 0.15 to 0.4 +def train_diagnose( + model_client: adal.ModelClient, + model_kwargs: Dict, +) -> Dict: + + trainset, valset, testset = load_datasets() + + adal_component = MultiHopRAGAdal( + model_client, + model_kwargs, + backward_engine_model_config=gpt_4o_model, + teacher_model_config=gpt_3_model, + text_optimizer_model_config=gpt_3_model, + ) + trainer = adal.Trainer(adaltask=adal_component) + trainer.diagnose(dataset=trainset, split="train") + # trainer.diagnose(dataset=valset, split="val") + # trainer.diagnose(dataset=testset, split="test") + + +def train( + train_batch_size=4, # larger batch size is not that effective, probably because of llm's lost in the middle + raw_shots: int = 0, + bootstrap_shots: int = 4, + max_steps=1, + num_workers=4, + strategy="constrained", + optimization_order="sequential", + debug=False, + resume_from_ckpt=None, + exclude_input_fields_from_bootstrap_demos=True, +): + adal_component = MultiHopRAGAdal( + **gpt_3_model, + teacher_model_config=gpt_3_model, + text_optimizer_model_config=gpt_4o_model, # gpt3.5 is not enough to be used as a good optimizer, it struggles for long contenxt + backward_engine_model_config=gpt_4o_model, + ) + print(adal_component) + trainer = adal.Trainer( + train_batch_size=train_batch_size, + adaltask=adal_component, + strategy=strategy, + max_steps=max_steps, + num_workers=num_workers, + raw_shots=raw_shots, + bootstrap_shots=bootstrap_shots, + debug=debug, + weighted_sampling=True, + optimization_order=optimization_order, + exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos, + sequential_order=["text", "demo"], + ) + print(trainer) + + train_dataset, val_dataset, test_dataset = load_datasets() + trainer.fit( + train_dataset=train_dataset, + val_dataset=val_dataset, + test_dataset=test_dataset, + resume_from_ckpt=resume_from_ckpt, + ) + + +if __name__ == "__main__": + from use_cases.config import gpt_3_model + + log = adal.get_logger(level="DEBUG", enable_console=False) + + adal.setup_env() + + # task = MultiHopRAGAdal(**gpt_3_model) + # print(task) + + # train_diagnose(**gpt_3_model) + + # train: 0.15 before the evaluator converted to lower and 0.4 after the conversion + train( + debug=False, + max_steps=12, + # resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json", + )