Skip to content

Commit

Permalink
Add more granular relationship definition to LLM Graph Transformer (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Nov 2, 2024
1 parent 8f36c7a commit 93f2dc5
Showing 1 changed file with 149 additions and 24 deletions.
173 changes: 149 additions & 24 deletions libs/experimental/langchain_experimental/graph_transformers/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,25 +150,31 @@ def _get_additional_info(input_type: str) -> str:


def optional_enum_field(
enum_values: Optional[List[str]] = None,
enum_values: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None,
description: str = "",
input_type: str = "node",
llm_type: Optional[str] = None,
relationship_type: Optional[str] = None,
**field_kwargs: Any,
) -> Any:
"""Utility function to conditionally create a field with an enum constraint."""
parsed_enum_values = enum_values
# We have to extract enum types from tuples
if relationship_type == "tuple":
parsed_enum_values = list({el[1] for el in enum_values}) # type: ignore

# Only openai supports enum param
if enum_values and llm_type == "openai-chat":
return Field(
...,
enum=enum_values, # type: ignore[call-arg]
description=f"{description}. Available options are {enum_values}",
enum=parsed_enum_values, # type: ignore[call-arg]
description=f"{description}. Available options are {parsed_enum_values}",
**field_kwargs,
)
elif enum_values:
return Field(
...,
description=f"{description}. Available options are {enum_values}",
description=f"{description}. Available options are {parsed_enum_values}",
**field_kwargs,
)
else:
Expand Down Expand Up @@ -204,10 +210,18 @@ class UnstructuredRelation(BaseModel):


def create_unstructured_prompt(
node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None
node_labels: Optional[List[str]] = None,
rel_types: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None,
relationship_type: Optional[str] = None,
) -> ChatPromptTemplate:
node_labels_str = str(node_labels) if node_labels else ""
rel_types_str = str(rel_types) if rel_types else ""
if rel_types:
if relationship_type == "tuple":
rel_types_str = str(list({item[1] for item in rel_types}))
else:
rel_types_str = str(rel_types)
else:
rel_types_str = ""
base_string_parts = [
"You are a top-tier algorithm designed for extracting information in "
"structured formats to build a knowledge graph. Your task is to identify "
Expand All @@ -230,6 +244,13 @@ def create_unstructured_prompt(
f"of the tail entity from {node_labels_str}."
if node_labels
else "",
"Your task is to extract relationships from text strictly adhering "
"to the provided schema. The relationships can only appear "
"between specific node types are presented in the schema format "
"like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n"
f"Provided schema is {rel_types}"
if relationship_type == "tuple"
else "",
"Attempt to extract as many entities and relations as you can. Maintain "
"Entity Consistency: When extracting entities, it's vital to ensure "
'consistency. If an entity, such as "John Doe", is mentioned multiple '
Expand Down Expand Up @@ -260,6 +281,13 @@ def create_unstructured_prompt(
"{rel_types}"
if rel_types
else "",
"Your task is to extract relationships from text strictly adhering "
"to the provided schema. The relationships can only appear "
"between specific node types are presented in the schema format "
"like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n"
f"Provided schema is {rel_types}"
if relationship_type == "tuple"
else "",
"Below are a number of examples of text and their extracted "
"entities and relationships."
"{examples}\n"
Expand Down Expand Up @@ -289,10 +317,11 @@ def create_unstructured_prompt(

def create_simple_model(
node_labels: Optional[List[str]] = None,
rel_types: Optional[List[str]] = None,
rel_types: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None,
node_properties: Union[bool, List[str]] = False,
llm_type: Optional[str] = None,
relationship_properties: Union[bool, List[str]] = False,
relationship_type: Optional[str] = None,
) -> Type[_Graph]:
"""
Create a simple graph model with optional constraints on node
Expand Down Expand Up @@ -353,7 +382,13 @@ class Property(BaseModel):
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
value: str = Field(
...,
description=(
"Extracted value. Any date value "
"should be formatted as yyyy-mm-dd."
),
)

node_fields["properties"] = (
Optional[List[Property]],
Expand Down Expand Up @@ -401,6 +436,7 @@ class Property(BaseModel):
description="The type of the relationship.",
input_type="relationship",
llm_type=llm_type,
relationship_type=relationship_type,
),
),
}
Expand All @@ -426,13 +462,28 @@ class RelationshipProperty(BaseModel):
input_type="property",
llm_type=llm_type,
)
value: str = Field(..., description="value")
value: str = Field(
...,
description=(
"Extracted value. Any date value "
"should be formatted as yyyy-mm-dd."
),
)

relationship_fields["properties"] = (
Optional[List[RelationshipProperty]],
Field(None, description="List of relationship properties"),
)
SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore
# Add a docstring to the dynamically created model
if relationship_type == "tuple":
SimpleRelationship.__doc__ = (
"Your task is to extract relationships from text strictly adhering "
"to the provided schema. The relationships can only appear "
"between specific node types are presented in the schema format "
"like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n"
f"Provided schema is {rel_types}"
)

class DynamicGraph(_Graph):
"""Represents a graph document consisting of nodes and relationships."""
Expand Down Expand Up @@ -600,7 +651,6 @@ def _convert_to_graph_document(
argument_json["relationships"] = json.loads(
argument_json["relationships"]
)

nodes, relationships = _parse_and_clean_json(argument_json)
except Exception: # If we can't parse JSON
return ([], [])
Expand All @@ -625,6 +675,39 @@ def _convert_to_graph_document(
return _format_nodes(nodes), _format_relationships(relationships)


def validate_and_get_relationship_type(
allowed_relationships: Union[List[str], List[Tuple[str, str, str]]],
allowed_nodes: Optional[List[str]],
) -> Optional[str]:
if allowed_relationships and not isinstance(allowed_relationships, list):
raise ValueError("`allowed_relationships` attribute must be a list.")
# If it's an empty list
if not allowed_relationships:
return None
# Validate list of strings
if all(isinstance(item, str) for item in allowed_relationships):
# Valid: all items are strings, no further checks needed.
return "string"

# Validate list of 3-tuples and check if first/last elements are in allowed_nodes
if all(
isinstance(item, tuple)
and len(item) == 3
and all(isinstance(subitem, str) for subitem in item)
and item[0] in allowed_nodes # type: ignore
and item[2] in allowed_nodes # type: ignore
for item in allowed_relationships
):
# all items are 3-tuples, and the first/last elements are in allowed_nodes.
return "tuple"

# If the input doesn't match any of the valid cases, raise a ValueError
raise ValueError(
"`allowed_relationships` must be list of strings or a list of 3-item tuples. "
"For tuples, the first and last elements must be in the `allowed_nodes` list."
)


class LLMGraphTransformer:
"""Transform documents into graph-based documents using a LLM.
Expand Down Expand Up @@ -676,13 +759,18 @@ def __init__(
self,
llm: BaseLanguageModel,
allowed_nodes: List[str] = [],
allowed_relationships: List[str] = [],
allowed_relationships: Union[List[str], List[Tuple[str, str, str]]] = [],
prompt: Optional[ChatPromptTemplate] = None,
strict_mode: bool = True,
node_properties: Union[bool, List[str]] = False,
relationship_properties: Union[bool, List[str]] = False,
ignore_tool_usage: bool = False,
) -> None:
# Validate and check allowed relationships input
self._relationship_type = validate_and_get_relationship_type(
allowed_relationships, allowed_nodes
)

self.allowed_nodes = allowed_nodes
self.allowed_relationships = allowed_relationships
self.strict_mode = strict_mode
Expand Down Expand Up @@ -710,7 +798,7 @@ def __init__(
"Please install it with `pip install json-repair`."
)
prompt = prompt or create_unstructured_prompt(
allowed_nodes, allowed_relationships
allowed_nodes, allowed_relationships, self._relationship_type
)
self.chain = prompt | llm
else:
Expand All @@ -725,6 +813,7 @@ def __init__(
node_properties,
llm_type,
relationship_properties,
self._relationship_type,
)
structured_llm = llm.with_structured_output(schema, include_raw=True)
prompt = prompt or default_prompt
Expand Down Expand Up @@ -792,12 +881,30 @@ def process_response(
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
# Filter by type and direction
if self._relationship_type == "tuple":
relationships = [
rel
for rel in relationships
if (
(
rel.source.type.lower(),
rel.type.lower(),
rel.target.type.lower(),
)
in [ # type: ignore
(s_t.lower(), r_t.lower(), t_t.lower())
for s_t, r_t, t_t in self.allowed_relationships
]
)
]
else: # Filter by type only
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships] # type: ignore
]

return GraphDocument(nodes=nodes, relationships=relationships, source=document)

Expand Down Expand Up @@ -875,12 +982,30 @@ async def aprocess_response(
and rel.target.type.lower() in lower_allowed_nodes
]
if self.allowed_relationships:
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships]
]
# Filter by type and direction
if self._relationship_type == "tuple":
relationships = [
rel
for rel in relationships
if (
(
rel.source.type.lower(),
rel.type.lower(),
rel.target.type.lower(),
)
in [ # type: ignore
(s_t.lower(), r_t.lower(), t_t.lower())
for s_t, r_t, t_t in self.allowed_relationships
]
)
]
else: # Filter by type only
relationships = [
rel
for rel in relationships
if rel.type.lower()
in [el.lower() for el in self.allowed_relationships] # type: ignore
]

return GraphDocument(nodes=nodes, relationships=relationships, source=document)

Expand Down

0 comments on commit 93f2dc5

Please sign in to comment.