From 2aac4ffbec9f5c6e3f51324a35e96865a6ff2188 Mon Sep 17 00:00:00 2001
From: Li Yin
Date: Mon, 16 Dec 2024 23:47:47 -0800
Subject: [PATCH] build cycle multi-hop rag, trainable, added component graph
with cycle, output graph as simplified version of the dynamic computation
graph
---
adalflow/adalflow/core/generator.py | 7 +-
adalflow/adalflow/core/retriever.py | 8 +-
adalflow/adalflow/optim/grad_component.py | 12 +-
adalflow/adalflow/optim/parameter.py | 356 +++++++++++++++-
.../hotpot_qa/adal_exp/build_multi_hop_rag.py | 381 +++++++++++++-----
.../hotpot_qa/adal_exp/build_vanilla_rag.py | 2 +-
.../adal_exp/train_multi_hop_rag_cycle.py | 163 ++++++++
7 files changed, 815 insertions(+), 114 deletions(-)
create mode 100644 benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py
diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py
index baedd8fb..63a85ba6 100644
--- a/adalflow/adalflow/core/generator.py
+++ b/adalflow/adalflow/core/generator.py
@@ -506,6 +506,7 @@ def forward(
self.model_kwargs, model_kwargs
),
}
+
output = self.call(**input_args, id=id)
# 2. Generate a Parameter object from the output
combined_prompt_kwargs = compose_model_kwargs(self.prompt_kwargs, prompt_kwargs)
@@ -527,9 +528,12 @@ def forward(
name=self.name + "_output",
role_desc=f"Output from (llm) {self.name}",
param_type=ParameterType.GENERATOR_OUTPUT,
+ data_id=id,
)
response.set_predecessors(predecessors)
- response.trace_forward_pass(input_args=input_args, full_response=output)
+ response.trace_forward_pass(
+ input_args=input_args, full_response=output, id=self.id, name=self.name
+ )
# *** special to the generator ***
response.trace_api_kwargs(api_kwargs=self._trace_api_kwargs)
# attach the demo to the demo parameter
@@ -755,6 +759,7 @@ def _backward_through_one_predecessor(
score=response._score, # add score to gradient
param_type=ParameterType.GRADIENT,
from_response_id=response.id,
+ data_id=response.data_id,
)
pred.add_gradient(var_gradient)
pred.set_score(response._score)
diff --git a/adalflow/adalflow/core/retriever.py b/adalflow/adalflow/core/retriever.py
index fb65a298..1e7e40d3 100644
--- a/adalflow/adalflow/core/retriever.py
+++ b/adalflow/adalflow/core/retriever.py
@@ -128,7 +128,13 @@ def forward(
)
if input is None:
raise ValueError("Input cannot be empty")
- response = super().forward(input, top_k=top_k, **kwargs)
+ response = super().forward(input, top_k=top_k, id=id, **kwargs)
+ response.trace_forward_pass(
+ input_args={"input": input, "top_k": top_k},
+ full_response=response,
+ id=self.id,
+ name=self.name,
+ )
response.param_type = (
ParameterType.RETRIEVER_OUTPUT
) # be more specific about the type
diff --git a/adalflow/adalflow/optim/grad_component.py b/adalflow/adalflow/optim/grad_component.py
index b73e536e..b9964aba 100644
--- a/adalflow/adalflow/optim/grad_component.py
+++ b/adalflow/adalflow/optim/grad_component.py
@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
from collections import OrderedDict
+import uuid
import logging
if TYPE_CHECKING:
@@ -13,6 +14,7 @@
from adalflow.core.component import Component
from adalflow.optim.function import BackwardContext
+
__all__ = ["GradComponent"]
log = logging.getLogger(__name__)
@@ -31,10 +33,12 @@ class GradComponent(Component):
"""
backward_engine: "BackwardEngine"
_component_type = "grad"
+ id = None
def __init__(self, *args, **kwargs):
super().__init__()
super().__setattr__("backward_engine", None)
+ super().__setattr__("id", str(uuid.uuid4()))
def __call__(self, *args, **kwargs):
if self.training:
@@ -122,9 +126,15 @@ def forward(self, *args, **kwargs) -> "Parameter":
name=self.name + "_output",
role_desc=self.name + " response",
param_type=ParameterType.OUTPUT,
+ data_id=kwargs.get("id", None),
)
response.set_predecessors(predecessors)
- response.trace_forward_pass(input_args=input_args, full_response=call_response)
+ response.trace_forward_pass(
+ input_args=input_args,
+ full_response=call_response,
+ id=self.id,
+ name=self.name,
+ )
response.set_grad_fn(
BackwardContext(
backward_fn=self.backward,
diff --git a/adalflow/adalflow/optim/parameter.py b/adalflow/adalflow/optim/parameter.py
index 5b60995c..1e553302 100644
--- a/adalflow/adalflow/optim/parameter.py
+++ b/adalflow/adalflow/optim/parameter.py
@@ -46,8 +46,18 @@ class GradientContext:
)
+@dataclass(frozen=True)
+class ComponentNode:
+ """Used to represent a node in the component graph."""
+
+ id: str = field(metadata={"desc": "The unique id of the component"})
+ name: str = field(metadata={"desc": "The name of the component"})
+
+
@dataclass
class ComponentTrace:
+ name: str = field(metadata={"desc": "The name of the component"}, default=None)
+ id: str = field(metadata={"desc": "The unique id of the component"}, default=None)
input_args: Dict[str, Any] = field(
metadata={"desc": "The input arguments of the GradComponent forward"},
default=None,
@@ -159,7 +169,7 @@ class Parameter(Generic[T]):
def __init__(
self,
*,
- id: Optional[str] = None,
+ id: Optional[str] = None, # unique id of the parameter
data: T = None, # for generator output, the data will be set up as raw_response
data_id: str = None, # for tracing the data item in the training/val/test set
requires_opt: bool = True,
@@ -300,13 +310,21 @@ def trace_optimizer(self, api_kwargs: Dict[str, Any], response: "TGDData"):
############################################################################################################
# Trace component, include trace_forward_pass & trace_api_kwargs for now
############################################################################################################
- def trace_forward_pass(self, input_args: Dict[str, Any], full_response: object):
- r"""Trace the forward pass of the parameter."""
+ def trace_forward_pass(
+ self,
+ input_args: Dict[str, Any],
+ full_response: object,
+ id: str = None,
+ name: str = None,
+ ):
+ r"""Trace the forward pass of the parameter. Adding the component information to the trace"""
self.input_args = input_args
self.full_response = full_response
# TODO: remove the input_args and full_response to use component_trace
self.component_trace.input_args = input_args
self.component_trace.full_response = full_response
+ self.component_trace.id = id
+ self.component_trace.name = name
def trace_api_kwargs(self, api_kwargs: Dict[str, Any]):
r"""Trace the api_kwargs for components like Generator and Retriever that pass to the model client."""
@@ -515,20 +533,20 @@ def build_graph(node: "Parameter"):
build_graph(root)
return nodes, edges
- def report_cycle(cycle_nodes: List["Parameter"]):
- """
- Report the detected cycle and provide guidance to the user on how to avoid it.
- """
- cycle_names = [node.name for node in cycle_nodes]
- log.warning(f"Cycle detected: {' -> '.join(cycle_names)}")
- print(f"Cycle detected in the graph: {' -> '.join(cycle_names)}")
-
- # Provide guidance on how to avoid the cycle
- print("To avoid the cycle, consider the following strategies:")
- print("- Modify the graph structure to remove cyclic dependencies.")
- print(
- "- Check the relationships between these nodes to ensure no feedback loops."
- )
+ # def report_cycle(cycle_nodes: List["Parameter"]):
+ # """
+ # Report the detected cycle and provide guidance to the user on how to avoid it.
+ # """
+ # cycle_names = [node.name for node in cycle_nodes]
+ # log.warning(f"Cycle detected: {' -> '.join(cycle_names)}")
+ # print(f"Cycle detected in the graph: {' -> '.join(cycle_names)}")
+
+ # # Provide guidance on how to avoid the cycle
+ # print("To avoid the cycle, consider the following strategies:")
+ # print("- Modify the graph structure to remove cyclic dependencies.")
+ # print(
+ # "- Check the relationships between these nodes to ensure no feedback loops."
+ # )
def backward(
self,
@@ -799,6 +817,7 @@ def wrap_and_escape(text, width=40):
node_label = (
f""
+ f"Name: | {wrap_and_escape(n.id)} |
"
f"Name: | {wrap_and_escape(n.name)} |
"
f"Role: | {wrap_and_escape(n.role_desc.capitalize())} |
"
f"Value: | {wrap_and_escape(n.data)} |
"
@@ -838,6 +857,12 @@ def wrap_and_escape(text, width=40):
if n.tgd_optimizer_trace is not None:
node_label += f"TGD Optimizer Trace: | {wrap_and_escape(str(n.tgd_optimizer_trace))} |
"
+ # show component trace, id and name
+ if n.component_trace.id is not None:
+ node_label += f"Component Trace ID: | {wrap_and_escape(str(n.component_trace.id))} |
"
+ if n.component_trace.name is not None:
+ node_label += f"Component Trace Name: | {wrap_and_escape(str(n.component_trace.name))} |
"
+
node_label += "
"
# check if the name exists in dot
if n.name in node_names:
@@ -906,6 +931,303 @@ def wrap_and_escape(text, width=40):
)
return {"graph_path": filepath, "root_path": f"{filepath}_root.json"}
+ def draw_output_subgraph(
+ self,
+ add_grads: bool = True,
+ format: str = "png",
+ rankdir: str = "TB",
+ filepath: str = None,
+ ):
+ """
+ Build and visualize a subgraph containing only OUTPUT parameters.
+
+ Args:
+ add_grads (bool): Whether to include gradient edges.
+ format (str): Format for output (e.g., png, svg).
+ rankdir (str): Graph layout direction ("LR" or "TB").
+ filepath (str): Path to save the graph.
+ """
+ assert rankdir in ["LR", "TB"]
+ from adalflow.utils.global_config import get_adalflow_default_root_path
+
+ try:
+ from graphviz import Digraph
+
+ except ImportError as e:
+ raise ImportError(
+ "Please install graphviz using 'pip install graphviz' to use this feature"
+ ) from e
+
+ root_path = get_adalflow_default_root_path()
+ # # prepare the log directory
+ # log_dir = os.path.join(root_path, "logs")
+
+ # # Set up TensorBoard logging
+ # writer = SummaryWriter(log_dir)
+
+ filename = f"trace_component_output_graph_{self.name}_id_{self.id}"
+ filepath = (
+ os.path.join(filepath, filename)
+ if filepath
+ else os.path.join(root_path, "graphs", filename)
+ )
+ print(f"Saving graph to {filepath}.{format}")
+
+ # Step 1: Collect OUTPUT nodes and edges
+ nodes, edges = self._collect_output_subgraph()
+
+ # Step 2: Render using Graphviz
+ filename = f"output_subgraph_{self.name}_{self.id}"
+ filepath = filepath or f"./{filename}"
+ print(f"Saving OUTPUT subgraph to {filepath}.{format}")
+
+ dot = Digraph(format=format, graph_attr={"rankdir": rankdir})
+ node_ids = set()
+
+ for node in nodes:
+ node_label = f"""
+
+ Name: | {node.name} |
+ Type: | {node.param_type} |
+ Value: | {node.get_short_value()} |
"""
+ # add the component trace id and name
+ if node.component_trace.id is not None:
+ node_label += f"Component Trace ID: | {node.component_trace.id} |
"
+ if node.component_trace.name is not None:
+ node_label += f"Component Trace Name: | {node.component_trace.name} |
"
+
+ node_label += "
"
+ dot.node(
+ name=node.id,
+ label=f"<{node_label}>",
+ shape="plaintext",
+ color="lightblue" if node.requires_opt else "gray",
+ )
+ node_ids.add(node.id)
+
+ for source, target in edges:
+ if source.id in node_ids and target.id in node_ids:
+ dot.edge(source.id, target.id)
+
+ # Step 3: Save and render
+ dot.render(filepath, cleanup=True)
+ print(f"Graph saved as {filepath}.{format}")
+
+ def draw_component_subgraph(
+ self,
+ format: str = "png",
+ rankdir: str = "TB",
+ filepath: str = None,
+ ):
+ """
+ Build and visualize a subgraph containing only OUTPUT parameters.
+
+ Args:
+ format (str): Format for output (e.g., png, svg).
+ rankdir (str): Graph layout direction ("LR" or "TB").
+ filepath (str): Path to save the graph.
+ """
+ assert rankdir in ["LR", "TB"]
+
+ try:
+ from graphviz import Digraph
+ except ImportError as e:
+ raise ImportError(
+ "Please install graphviz using 'pip install graphviz' to use this feature"
+ ) from e
+
+ # Step 1: Collect OUTPUT nodes and edges
+ component_nodes, edges, component_nodes_orders = (
+ self._collect_component_subgraph()
+ )
+
+ # Step 2: Setup graph rendering
+ filename = f"output_component_{self.name}_{self.id}"
+ filepath = filepath or f"./{filename}"
+ print(f"Saving OUTPUT subgraph to {filepath}.{format}")
+
+ dot = Digraph(format=format, graph_attr={"rankdir": rankdir})
+
+ # Add nodes
+ for node in component_nodes:
+ node_label = f"""
+
+ ID: | {node.id} |
+ Name: | {node.name} |
"""
+
+ # add the list of orders
+ if node.id in component_nodes_orders:
+ node_label += f"Order: | {component_nodes_orders[node.id]} |
"
+ node_label += "
"
+ dot.node(
+ name=node.id,
+ label=f"<{node_label}>",
+ shape="plaintext",
+ color="lightblue",
+ )
+
+ # Add edges with order labels
+ for source_id, target_id, edge_order in edges:
+ dot.edge(source_id, target_id, label=str(edge_order), color="black")
+
+ # Step 3: Save and render
+ dot.render(filepath, cleanup=True)
+ print(f"Graph saved as {filepath}.{format}")
+
+ def _collect_output_subgraph(
+ self,
+ ) -> Tuple[Set["Parameter"], List[Tuple["Parameter", "Parameter"]]]:
+ """
+ Collect nodes of type OUTPUT and their relationships.
+
+ Returns:
+ nodes (Set[Parameter]): Set of OUTPUT nodes.
+ edges (List[Tuple[Parameter, Parameter]]): Edges between OUTPUT nodes.
+ """
+ output_nodes = set()
+ edges = []
+
+ visited = set() # check component_trace.id and name
+
+ def traverse(node: "Parameter"):
+ if node in visited:
+ return
+ visited.add(node)
+
+ # Add OUTPUT nodes to the set
+ if (
+ node.param_type == ParameterType.OUTPUT
+ or "OUTPUT" in node.param_type.name
+ ):
+ output_nodes.add(node)
+
+ # Traverse predecessors and add edges
+ for pred in node.predecessors:
+ if (
+ pred.param_type == ParameterType.OUTPUT
+ or "OUTPUT" in pred.param_type.name
+ ):
+ edges.append((pred, node))
+ traverse(pred)
+
+ traverse(self)
+ return output_nodes, edges
+
+ # def _collect_output_subgraph(
+ # self,
+ # ) -> Tuple[Set[Tuple[str, str]], List[Tuple[str, str]]]:
+ # """
+ # Collect OUTPUT nodes and their relationships using component_trace information.
+
+ # Returns:
+ # nodes (Set[Tuple[str, str]]): Set of component nodes (component_id, label).
+ # edges (List[Tuple[str, str]]): Edges between component IDs.
+ # """
+ # component_nodes = set() # To store component nodes as (component_id, label)
+ # edges = [] # To store edges between components
+
+ # visited = set() # Track visited parameters to avoid cycles
+
+ # def traverse(node: "Parameter"):
+ # if node in visited:
+ # return
+ # visited.add(node)
+
+ # # Only consider OUTPUT-type parameters
+ # if (
+ # node.param_type == ParameterType.OUTPUT
+ # or "OUTPUT" in node.param_type.name
+ # ):
+ # component_id = node.component_trace.id
+ # component_name = node.component_trace.name or "Unknown Component"
+ # label = f"{component_name}\nID: {component_id}"
+
+ # # Add the component as a node
+ # component_nodes.add((component_id, label))
+
+ # # Traverse predecessors and add edges
+ # for pred in node.predecessors:
+ # if pred.param_type == ParameterType.OUTPUT:
+ # pred_id = pred.component_trace.id
+ # pred_name = pred.component_trace.name or "Unknown Component"
+
+ # # Add predecessor as a node
+ # pred_label = f"{pred_name}\nID: {pred_id}"
+ # component_nodes.add((pred_id, pred_label))
+
+ # # Add edge between components
+ # edges.append((pred_id, component_id))
+
+ # # Recursive traversal
+ # traverse(pred)
+
+ # # Start traversal from the current parameter
+ # traverse(self)
+ # return component_nodes, edges
+
+ def _collect_component_subgraph(
+ self,
+ ) -> Tuple[Set[ComponentNode], List[Tuple[str, str]]]:
+ """
+ Collect OUTPUT nodes and their relationships as ComponentNodes.
+
+ Returns:
+ component_nodes (Set[ComponentNode]): Set of component nodes (id and name only).
+ edges (List[Tuple[str, str]]): Edges between component IDs.
+ """
+ component_nodes = set() # To store component nodes as ComponentNode
+ component_nodes_orders: Dict[str, List[int]] = (
+ {}
+ ) # To store component nodes order
+ edges = [] # To store edges between component IDs
+
+ visited = set() # Track visited parameters to avoid cycles
+ edge_counter = [0] # Mutable counter for edge order tracking
+
+ def traverse(node: "Parameter", depth: int):
+ if node in visited:
+ return
+ visited.add(node)
+
+ # Check if node is of OUTPUT type
+ if (
+ node.param_type == ParameterType.OUTPUT
+ or "OUTPUT" in node.param_type.name
+ ):
+ component_id = node.component_trace.id or f"unknown_id_{uuid.uuid4()}"
+ component_name = node.component_trace.name or "Unknown Component"
+
+ # Create a ComponentNode and add to the set
+ component_node = ComponentNode(id=component_id, name=component_name)
+ component_nodes.add(component_node)
+
+ # Traverse predecessors and add edges
+ for pred in node.predecessors:
+ pred_id = pred.component_trace.id or f"unknown_id_{uuid.uuid4()}"
+ pred_name = pred.component_trace.name or "Unknown Component"
+
+ # Add edge if predecessor is also of OUTPUT type
+ if (
+ pred.param_type == ParameterType.OUTPUT
+ or "OUTPUT" in pred.param_type.name
+ ):
+ edges.append((pred_id, component_id, depth))
+ component_nodes.add(ComponentNode(id=pred_id, name=pred_name))
+ edge_counter[0] += 1
+
+ traverse(pred, depth + 1)
+
+ # Start traversal from the current parameter
+ traverse(self, depth=0)
+ # Reverse the edge order
+ # total_edges = len(edges)
+ # edges = [
+ # (source, target, (total_edges - 1) - edge_number)
+ # for idx, (source, target, edge_number) in enumerate(edges)
+ # ]
+
+ return component_nodes, edges, component_nodes_orders
+
def to_dict(self):
return {
"name": self.name,
diff --git a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
index cebcfdf2..ba3f8bfc 100644
--- a/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
+++ b/benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py
@@ -62,7 +62,9 @@ class DeduplicateList(adal.GradComponent):
def __init__(self):
super().__init__()
- def call(self, exisiting_list: List[str], new_list: List[str]) -> List[str]:
+ def call(
+ self, exisiting_list: List[str], new_list: List[str], id: str = None
+ ) -> List[str]:
seen = set()
return [x for x in exisiting_list + new_list if not (x in seen or seen.add(x))]
@@ -78,7 +80,162 @@ def backward(self, *args, **kwargs):
# NOTE: deprecated
-class MultiHopRetriever(adal.Retriever):
+# class MultiHopRetriever(adal.Retriever):
+# def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
+# super().__init__()
+
+# self.passages_per_hop = passages_per_hop
+# self.max_hops = max_hops
+
+# self.data_parser = adal.DataClassParser(
+# data_class=QueryRewritterData, return_data_class=True, format_type="json"
+# )
+
+# # Grad Component
+# self.query_generators: List[adal.Generator] = []
+# for i in range(self.max_hops):
+# self.query_generators.append(
+# adal.Generator(
+# name=f"query_generator_{i}",
+# model_client=model_client,
+# model_kwargs=model_kwargs,
+# prompt_kwargs={
+# "few_shot_demos": Parameter(
+# name="few_shot_demos_1",
+# data=None,
+# role_desc="To provide few shot demos to the language model",
+# requires_opt=True,
+# param_type=ParameterType.DEMOS,
+# ),
+# "task_desc_str": Parameter(
+# name="task_desc_str",
+# data="""Write a simple search query that will help answer a complex question.
+
+# You will receive a context(may contain relevant facts) and a question.
+# Think step by step.""",
+# role_desc="Task description for the language model",
+# requires_opt=True,
+# param_type=ParameterType.PROMPT,
+# ),
+# "output_format_str": self.data_parser.get_output_format_str(),
+# },
+# template=query_template,
+# output_processors=self.data_parser,
+# use_cache=True,
+# )
+# )
+# self.retriever = DspyRetriever(top_k=passages_per_hop)
+# self.deduplicater = DeduplicateList()
+
+# @staticmethod
+# def context_to_str(context: List[str]) -> str:
+# return "\n".join(context)
+
+# @staticmethod
+# def deduplicate(seq: list[str]) -> list[str]:
+# """
+# Source: https://stackoverflow.com/a/480227/1493011
+# """
+
+# seen = set()
+# return [x for x in seq if not (x in seen or seen.add(x))]
+
+# def call(self, *, question: str, id: str = None) -> adal.RetrieverOutput:
+# context = []
+# print(f"question: {question}")
+# for i in range(self.max_hops):
+# gen_out = self.query_generators[i](
+# prompt_kwargs={
+# "context": self.context_to_str(context),
+# "question": question,
+# },
+# id=id,
+# )
+
+# query = gen_out.data.query if gen_out.data and gen_out.data.query else None
+
+# print(f"query {i}: {query}")
+
+# retrieve_out = self.retriever.call(input=query)
+# passages = retrieve_out[0].documents
+# context = self.deduplicate(context + passages)
+# out = [adal.RetrieverOutput(documents=context, query=query, doc_indices=[])]
+# return out
+
+# def forward(self, *, question: str, id: str = None) -> adal.Parameter:
+# # assemble the foundamental building blocks
+# context = []
+# print(f"question: {question}")
+# # 1. make question a parameter as generator does not have it yet
+# # can create the parameter at the leaf, but not the intermediate nodes
+# question_param = adal.Parameter(
+# name="question",
+# data=question,
+# role_desc="The question to be answered",
+# requires_opt=True,
+# param_type=ParameterType.INPUT,
+# )
+# context_param = adal.Parameter(
+# name="context",
+# data=context,
+# role_desc="The context to be used for the query",
+# requires_opt=True,
+# param_type=ParameterType.INPUT,
+# )
+# context_param.add_successor_map_fn(
+# successor=self.query_generators[0],
+# map_fn=lambda x: self.context_to_str(x.data),
+# )
+
+# for i in range(self.max_hops):
+
+# gen_out = self.query_generators[i].forward(
+# prompt_kwargs={
+# "context": context_param,
+# "question": question_param,
+# },
+# id=id,
+# )
+
+# success_map_fn = lambda x: ( # noqa E731
+# x.full_response.data.query
+# if x.full_response
+# and x.full_response.data
+# and x.full_response.data.query
+# else None
+# )
+# print(f"query {i}: {success_map_fn(gen_out)}")
+
+# gen_out.add_successor_map_fn(
+# successor=self.retriever, map_fn=success_map_fn
+# )
+
+# retrieve_out = self.retriever.forward(input=gen_out)
+
+# def retrieve_out_map_fn(x: adal.Parameter):
+# return x.data[0].documents if x.data and x.data[0].documents else []
+
+# print(f"retrieve_out: {retrieve_out}")
+
+# retrieve_out.add_successor_map_fn(
+# successor=self.deduplicater, map_fn=retrieve_out_map_fn
+# )
+
+# context_param = self.deduplicater.forward(
+# exisiting_list=context_param, new_list=retrieve_out
+# )
+
+# context_param.param_type = ParameterType.RETRIEVER_OUTPUT
+
+# return context_param
+
+query_generator_task_desc = """Write a simple search query that will help answer a complex question.
+
+You will receive a context(may contain relevant facts) and a question.
+Think step by step."""
+
+
+class MultiHopRetrieverCycle(adal.Retriever):
def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
super().__init__()
@@ -89,39 +246,33 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
data_class=QueryRewritterData, return_data_class=True, format_type="json"
)
- # Grad Component
- self.query_generators: List[adal.Generator] = []
- for i in range(self.max_hops):
- self.query_generators.append(
- adal.Generator(
- name=f"query_generator_{i}",
- model_client=model_client,
- model_kwargs=model_kwargs,
- prompt_kwargs={
- "few_shot_demos": Parameter(
- name="few_shot_demos_1",
- data=None,
- role_desc="To provide few shot demos to the language model",
- requires_opt=True,
- param_type=ParameterType.DEMOS,
- ),
- "task_desc_str": Parameter(
- name="task_desc_str",
- data="""Write a simple search query that will help answer a complex question.
+ # only one generator which will be used in a loop, called max_hops times
+ self.query_generator: adal.Generator = adal.Generator(
+ name="query_generator",
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ prompt_kwargs={
+ "few_shot_demos": Parameter(
+ name="few_shot_demos",
+ data=None,
+ role_desc="To provide few shot demos to the language model",
+ requires_opt=False,
+ param_type=ParameterType.DEMOS,
+ ),
+ "task_desc_str": Parameter(
+ name="task_desc_str",
+ data=query_generator_task_desc,
+ role_desc="Task description for the language model",
+ requires_opt=True,
+ param_type=ParameterType.PROMPT,
+ ),
+ "output_format_str": self.data_parser.get_output_format_str(),
+ },
+ template=query_template,
+ output_processors=self.data_parser,
+ use_cache=True,
+ )
-You will receive a context(may contain relevant facts) and a question.
-Think step by step.""",
- role_desc="Task description for the language model",
- requires_opt=True,
- param_type=ParameterType.PROMPT,
- ),
- "output_format_str": self.data_parser.get_output_format_str(),
- },
- template=query_template,
- output_processors=self.data_parser,
- use_cache=True,
- )
- )
self.retriever = DspyRetriever(top_k=passages_per_hop)
self.deduplicater = DeduplicateList()
@@ -129,78 +280,62 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
def context_to_str(context: List[str]) -> str:
return "\n".join(context)
- @staticmethod
- def deduplicate(seq: list[str]) -> list[str]:
- """
- Source: https://stackoverflow.com/a/480227/1493011
- """
-
- seen = set()
- return [x for x in seq if not (x in seen or seen.add(x))]
-
- def call(self, *, question: str, id: str = None) -> adal.RetrieverOutput:
- context = []
- print(f"question: {question}")
- for i in range(self.max_hops):
- gen_out = self.query_generators[i](
- prompt_kwargs={
- "context": self.context_to_str(context),
- "question": question,
- },
- id=id,
- )
-
- query = gen_out.data.query if gen_out.data and gen_out.data.query else None
+ def call(self, *, input: str, id: str = None) -> List[adal.RetrieverOutput]:
+ # assemble the foundamental building blocks
+ printc(f"question: {input}", "yellow")
+ out = self.forward(input=input, id=id)
- print(f"query {i}: {query}")
+ if not isinstance(out, adal.Parameter):
+ raise ValueError("The output should be a parameter")
- retrieve_out = self.retriever.call(input=query)
- passages = retrieve_out[0].documents
- context = self.deduplicate(context + passages)
- out = [adal.RetrieverOutput(documents=context, query=query, doc_indices=[])]
- return out
+ return out.data # or full response its up to users
- def forward(self, *, question: str, id: str = None) -> adal.Parameter:
+ def forward(self, *, input: str, id: str = None) -> adal.Parameter:
# assemble the foundamental building blocks
context = []
- print(f"question: {question}")
+ # queries: List[str] = []
+ print(f"question: {input}")
# 1. make question a parameter as generator does not have it yet
# can create the parameter at the leaf, but not the intermediate nodes
question_param = adal.Parameter(
name="question",
- data=question,
+ data=input,
role_desc="The question to be answered",
requires_opt=True,
param_type=ParameterType.INPUT,
)
- context_param = adal.Parameter(
- name="context",
- data=context,
- role_desc="The context to be used for the query",
- requires_opt=True,
- param_type=ParameterType.INPUT,
- )
- context_param.add_successor_map_fn(
- successor=self.query_generators[0],
- map_fn=lambda x: self.context_to_str(x.data),
- )
+ # context_param = adal.Parameter(
+ # name="context",
+ # data=context,
+ # role_desc="The context to be used for the query",
+ # requires_opt=True,
+ # param_type=ParameterType.INPUT,
+ # )
+ # context_param.add_successor_map_fn(
+ # successor=self.query_generator,
+ # map_fn=lambda x: self.context_to_str(x.data),
+ # )
for i in range(self.max_hops):
- gen_out = self.query_generators[i].forward(
+ gen_out = self.query_generator.forward(
prompt_kwargs={
- "context": context_param,
+ "context": context,
"question": question_param,
},
id=id,
)
-
+ # extract the query from the generator output
success_map_fn = lambda x: ( # noqa E731
x.full_response.data.query
if x.full_response
and x.full_response.data
and x.full_response.data.query
- else None
+ else (
+ x.full_response.raw_response
+ if x.full_response and x.full_response.raw_response
+ else None
+ )
)
print(f"query {i}: {success_map_fn(gen_out)}")
@@ -208,27 +343,41 @@ def forward(self, *, question: str, id: str = None) -> adal.Parameter:
successor=self.retriever, map_fn=success_map_fn
)
- retrieve_out = self.retriever.forward(input=gen_out)
+ # retrieve the passages
+ retrieve_out: adal.Parameter = self.retriever.forward(input=gen_out, id=id)
def retrieve_out_map_fn(x: adal.Parameter):
return x.data[0].documents if x.data and x.data[0].documents else []
- print(f"retrieve_out: {retrieve_out}")
-
+ # add the map function to the retrieve_out
retrieve_out.add_successor_map_fn(
successor=self.deduplicater, map_fn=retrieve_out_map_fn
)
- context_param = self.deduplicater.forward(
- exisiting_list=context_param, new_list=retrieve_out
+ # combine the context + deduplicated passages
+ context = self.deduplicater.forward(
+ exisiting_list=context, new_list=retrieve_out, id=id
)
- context_param.param_type = ParameterType.RETRIEVER_OUTPUT
+ context.param_type = ParameterType.RETRIEVER_OUTPUT
+ # used as the final outptu
- return context_param
+ # convert the context to the retriever output
+ def context_to_retrover_output(x):
+ return [
+ adal.RetrieverOutput(
+ documents=x.data,
+ query=[input] + [success_map_fn(gen_out)],
+ doc_indices=[],
+ )
+ ]
+
+ context.data = context_to_retrover_output(context)
+
+ return context
-class MultiHopRetriever2(adal.Retriever):
+class MultiHopRetriever(adal.Retriever):
def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
super().__init__()
@@ -395,13 +544,13 @@ def context_to_retrover_output(x):
context.data = context_to_retrover_output(context)
- printc(f"MultiHopRetriever2 grad fn: {context.grad_fn}", "yellow")
+ printc(f"MultiHopRetriever grad fn: {context.grad_fn}", "yellow")
return context
def backward(self, *args, **kwargs):
- printc(f"MultiHopRetriever2 backward: {args}", "yellow")
+ printc(f"MultiHopRetriever backward: {args}", "yellow")
super().backward(*args, **kwargs)
return
@@ -418,7 +567,24 @@ def __init__(
model_client=model_client,
model_kwargs=model_kwargs,
)
- self.retriever = MultiHopRetriever2(
+ self.retriever = MultiHopRetriever(
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ passages_per_hop=passages_per_hop,
+ max_hops=max_hops,
+ )
+
+
+class MultiHopRAGCycle(VanillaRAG):
+ def __init__(
+ self, passages_per_hop=3, max_hops=2, model_client=None, model_kwargs=None
+ ):
+ super().__init__(
+ passages_per_hop=passages_per_hop,
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ )
+ self.retriever = MultiHopRetrieverCycle(
model_client=model_client,
model_kwargs=model_kwargs,
passages_per_hop=passages_per_hop,
@@ -451,13 +617,40 @@ def test_multi_hop_retriever():
output.draw_graph()
+def test_multi_hop_retriever_cycle():
+
+ from use_cases.config import (
+ gpt_3_model,
+ )
+
+ multi_hop_retriever = MultiHopRetrieverCycle(
+ **gpt_3_model,
+ passages_per_hop=3,
+ max_hops=2,
+ )
+
+ question = "How many storeys are in the castle that David Gregory inherited?"
+
+ # eval mode
+ output = multi_hop_retriever.call(input=question, id="1")
+ print(output)
+
+ # train mode
+ multi_hop_retriever.train()
+ output = multi_hop_retriever.forward(input=question, id="1")
+ print(output)
+ output.draw_graph()
+ output.draw_output_subgraph()
+ output.draw_component_subgraph()
+
+
def test_multi_hop_retriever2():
from use_cases.config import (
gpt_3_model,
)
- multi_hop_retriever = MultiHopRetriever2(
+ multi_hop_retriever = MultiHopRetriever(
**gpt_3_model,
passages_per_hop=3,
max_hops=2,
@@ -531,4 +724,6 @@ def test_multi_hop_rag():
# get_logger(level="DEBUG")
# test_multi_hop_retriever()
# test_multi_hop_retriever2()
- test_multi_hop_rag()
+
+ test_multi_hop_retriever_cycle()
+ # test_multi_hop_rag()
diff --git a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
index 3eae0598..fd4a1dcd 100644
--- a/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
+++ b/benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py
@@ -161,7 +161,7 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):
),
"few_shot_demos": adal.Parameter(
data=None,
- requires_opt=True,
+ requires_opt=None,
role_desc="To provide few shot demos to the language model",
param_type=adal.ParameterType.DEMOS,
),
diff --git a/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py
new file mode 100644
index 00000000..376ef664
--- /dev/null
+++ b/benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag_cycle.py
@@ -0,0 +1,163 @@
+from typing import Any, Callable, Dict, Tuple
+
+import adalflow as adal
+from adalflow.eval.answer_match_acc import AnswerMatchAcc
+from adalflow.datasets.types import HotPotQAData
+
+from benchmarks.hotpot_qa._adal_train import load_datasets
+from benchmarks.hotpot_qa.adal_exp.build_multi_hop_rag import MultiHopRAGCycle
+from use_cases.config import gpt_3_model, gpt_4o_model
+
+
+# TODO: look more into the loss function
+# TODO: test LLM judge too.
+class MultiHopRAGAdal(adal.AdalComponent):
+ def __init__(
+ self,
+ model_client: adal.ModelClient,
+ model_kwargs: Dict,
+ backward_engine_model_config: Dict | None = None,
+ teacher_model_config: Dict | None = None,
+ text_optimizer_model_config: Dict | None = None,
+ ):
+ task = MultiHopRAGCycle(
+ model_client=model_client,
+ model_kwargs=model_kwargs,
+ passages_per_hop=3,
+ max_hops=2,
+ )
+ eval_fn = AnswerMatchAcc(type="fuzzy_match").compute_single_item
+ loss_fn = adal.EvalFnToTextLoss(
+ eval_fn=eval_fn, eval_fn_desc="fuzzy_match: 1 if str(y) in str(y_gt) else 0"
+ )
+ super().__init__(
+ task=task,
+ eval_fn=eval_fn,
+ loss_fn=loss_fn,
+ backward_engine_model_config=backward_engine_model_config,
+ teacher_model_config=teacher_model_config,
+ text_optimizer_model_config=text_optimizer_model_config,
+ )
+
+ # tell the trainer how to call the task
+ def prepare_task(self, sample: HotPotQAData) -> Tuple[Callable[..., Any], Dict]:
+ if self.task.training:
+ return self.task.forward, {"question": sample.question, "id": sample.id}
+ else:
+ return self.task.call, {"question": sample.question, "id": sample.id}
+
+ # TODO: use two map fn to make the cde even simpler
+
+ # eval mode: get the generator output, directly engage with the eval_fn
+ def prepare_eval(self, sample: HotPotQAData, y_pred: adal.GeneratorOutput) -> float:
+ y_label = ""
+ if y_pred and y_pred.data and y_pred.data.answer:
+ y_label = y_pred.data.answer
+ return self.eval_fn, {"y": y_label, "y_gt": sample.answer}
+
+ # train mode: get the loss and get the data from the full_response
+ def prepare_loss(self, sample: HotPotQAData, pred: adal.Parameter):
+ # prepare gt parameter
+ y_gt = adal.Parameter(
+ name="y_gt",
+ data=sample.answer,
+ eval_input=sample.answer,
+ requires_opt=False,
+ )
+
+ # pred's full_response is the output of the task pipeline which is GeneratorOutput
+ pred.eval_input = (
+ pred.full_response.data.answer
+ if pred.full_response
+ and pred.full_response.data
+ and pred.full_response.data.answer
+ else ""
+ )
+ return self.loss_fn, {"kwargs": {"y": pred, "y_gt": y_gt}}
+
+
+# Note: diagnose is quite helpful, it helps you to quickly check if the evalfunction is the right metrics
+# i checked the eval which does fuzzy match, and found some yes and Yes are not matched, then converted both strings to lower and
+# the performances have gone up from 0.15 to 0.4
+def train_diagnose(
+ model_client: adal.ModelClient,
+ model_kwargs: Dict,
+) -> Dict:
+
+ trainset, valset, testset = load_datasets()
+
+ adal_component = MultiHopRAGAdal(
+ model_client,
+ model_kwargs,
+ backward_engine_model_config=gpt_4o_model,
+ teacher_model_config=gpt_3_model,
+ text_optimizer_model_config=gpt_3_model,
+ )
+ trainer = adal.Trainer(adaltask=adal_component)
+ trainer.diagnose(dataset=trainset, split="train")
+ # trainer.diagnose(dataset=valset, split="val")
+ # trainer.diagnose(dataset=testset, split="test")
+
+
+def train(
+ train_batch_size=4, # larger batch size is not that effective, probably because of llm's lost in the middle
+ raw_shots: int = 0,
+ bootstrap_shots: int = 4,
+ max_steps=1,
+ num_workers=4,
+ strategy="constrained",
+ optimization_order="sequential",
+ debug=False,
+ resume_from_ckpt=None,
+ exclude_input_fields_from_bootstrap_demos=True,
+):
+ adal_component = MultiHopRAGAdal(
+ **gpt_3_model,
+ teacher_model_config=gpt_3_model,
+ text_optimizer_model_config=gpt_4o_model, # gpt3.5 is not enough to be used as a good optimizer, it struggles for long contenxt
+ backward_engine_model_config=gpt_4o_model,
+ )
+ print(adal_component)
+ trainer = adal.Trainer(
+ train_batch_size=train_batch_size,
+ adaltask=adal_component,
+ strategy=strategy,
+ max_steps=max_steps,
+ num_workers=num_workers,
+ raw_shots=raw_shots,
+ bootstrap_shots=bootstrap_shots,
+ debug=debug,
+ weighted_sampling=True,
+ optimization_order=optimization_order,
+ exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
+ sequential_order=["text", "demo"],
+ )
+ print(trainer)
+
+ train_dataset, val_dataset, test_dataset = load_datasets()
+ trainer.fit(
+ train_dataset=train_dataset,
+ val_dataset=val_dataset,
+ test_dataset=test_dataset,
+ resume_from_ckpt=resume_from_ckpt,
+ )
+
+
+if __name__ == "__main__":
+ from use_cases.config import gpt_3_model
+
+ log = adal.get_logger(level="DEBUG", enable_console=False)
+
+ adal.setup_env()
+
+ # task = MultiHopRAGAdal(**gpt_3_model)
+ # print(task)
+
+ # train_diagnose(**gpt_3_model)
+
+ # train: 0.15 before the evaluator converted to lower and 0.4 after the conversion
+ train(
+ debug=False,
+ max_steps=12,
+ # resume_from_ckpt="/Users/liyin/.adalflow/ckpt/ValinaRAGAdal/random_max_steps_12_7c091_run_1.json",
+ )