diff --git a/libs/experimental/langchain_experimental/graph_transformers/llm.py b/libs/experimental/langchain_experimental/graph_transformers/llm.py index d4624c1..262bcbf 100644 --- a/libs/experimental/langchain_experimental/graph_transformers/llm.py +++ b/libs/experimental/langchain_experimental/graph_transformers/llm.py @@ -150,25 +150,31 @@ def _get_additional_info(input_type: str) -> str: def optional_enum_field( - enum_values: Optional[List[str]] = None, + enum_values: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None, description: str = "", input_type: str = "node", llm_type: Optional[str] = None, + relationship_type: Optional[str] = None, **field_kwargs: Any, ) -> Any: """Utility function to conditionally create a field with an enum constraint.""" + parsed_enum_values = enum_values + # We have to extract enum types from tuples + if relationship_type == "tuple": + parsed_enum_values = list({el[1] for el in enum_values}) # type: ignore + # Only openai supports enum param if enum_values and llm_type == "openai-chat": return Field( ..., - enum=enum_values, # type: ignore[call-arg] - description=f"{description}. Available options are {enum_values}", + enum=parsed_enum_values, # type: ignore[call-arg] + description=f"{description}. Available options are {parsed_enum_values}", **field_kwargs, ) elif enum_values: return Field( ..., - description=f"{description}. Available options are {enum_values}", + description=f"{description}. Available options are {parsed_enum_values}", **field_kwargs, ) else: @@ -204,10 +210,18 @@ class UnstructuredRelation(BaseModel): def create_unstructured_prompt( - node_labels: Optional[List[str]] = None, rel_types: Optional[List[str]] = None + node_labels: Optional[List[str]] = None, + rel_types: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None, + relationship_type: Optional[str] = None, ) -> ChatPromptTemplate: node_labels_str = str(node_labels) if node_labels else "" - rel_types_str = str(rel_types) if rel_types else "" + if rel_types: + if relationship_type == "tuple": + rel_types_str = str(list({item[1] for item in rel_types})) + else: + rel_types_str = str(rel_types) + else: + rel_types_str = "" base_string_parts = [ "You are a top-tier algorithm designed for extracting information in " "structured formats to build a knowledge graph. Your task is to identify " @@ -230,6 +244,13 @@ def create_unstructured_prompt( f"of the tail entity from {node_labels_str}." if node_labels else "", + "Your task is to extract relationships from text strictly adhering " + "to the provided schema. The relationships can only appear " + "between specific node types are presented in the schema format " + "like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n" + f"Provided schema is {rel_types}" + if relationship_type == "tuple" + else "", "Attempt to extract as many entities and relations as you can. Maintain " "Entity Consistency: When extracting entities, it's vital to ensure " 'consistency. If an entity, such as "John Doe", is mentioned multiple ' @@ -260,6 +281,13 @@ def create_unstructured_prompt( "{rel_types}" if rel_types else "", + "Your task is to extract relationships from text strictly adhering " + "to the provided schema. The relationships can only appear " + "between specific node types are presented in the schema format " + "like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n" + f"Provided schema is {rel_types}" + if relationship_type == "tuple" + else "", "Below are a number of examples of text and their extracted " "entities and relationships." "{examples}\n" @@ -289,10 +317,11 @@ def create_unstructured_prompt( def create_simple_model( node_labels: Optional[List[str]] = None, - rel_types: Optional[List[str]] = None, + rel_types: Optional[Union[List[str], List[Tuple[str, str, str]]]] = None, node_properties: Union[bool, List[str]] = False, llm_type: Optional[str] = None, relationship_properties: Union[bool, List[str]] = False, + relationship_type: Optional[str] = None, ) -> Type[_Graph]: """ Create a simple graph model with optional constraints on node @@ -353,7 +382,13 @@ class Property(BaseModel): input_type="property", llm_type=llm_type, ) - value: str = Field(..., description="value") + value: str = Field( + ..., + description=( + "Extracted value. Any date value " + "should be formatted as yyyy-mm-dd." + ), + ) node_fields["properties"] = ( Optional[List[Property]], @@ -401,6 +436,7 @@ class Property(BaseModel): description="The type of the relationship.", input_type="relationship", llm_type=llm_type, + relationship_type=relationship_type, ), ), } @@ -426,13 +462,28 @@ class RelationshipProperty(BaseModel): input_type="property", llm_type=llm_type, ) - value: str = Field(..., description="value") + value: str = Field( + ..., + description=( + "Extracted value. Any date value " + "should be formatted as yyyy-mm-dd." + ), + ) relationship_fields["properties"] = ( Optional[List[RelationshipProperty]], Field(None, description="List of relationship properties"), ) SimpleRelationship = create_model("SimpleRelationship", **relationship_fields) # type: ignore + # Add a docstring to the dynamically created model + if relationship_type == "tuple": + SimpleRelationship.__doc__ = ( + "Your task is to extract relationships from text strictly adhering " + "to the provided schema. The relationships can only appear " + "between specific node types are presented in the schema format " + "like: (Entity1Type, RELATIONSHIP_TYPE, Entity2Type) /n" + f"Provided schema is {rel_types}" + ) class DynamicGraph(_Graph): """Represents a graph document consisting of nodes and relationships.""" @@ -600,7 +651,6 @@ def _convert_to_graph_document( argument_json["relationships"] = json.loads( argument_json["relationships"] ) - nodes, relationships = _parse_and_clean_json(argument_json) except Exception: # If we can't parse JSON return ([], []) @@ -625,6 +675,39 @@ def _convert_to_graph_document( return _format_nodes(nodes), _format_relationships(relationships) +def validate_and_get_relationship_type( + allowed_relationships: Union[List[str], List[Tuple[str, str, str]]], + allowed_nodes: Optional[List[str]], +) -> Optional[str]: + if allowed_relationships and not isinstance(allowed_relationships, list): + raise ValueError("`allowed_relationships` attribute must be a list.") + # If it's an empty list + if not allowed_relationships: + return None + # Validate list of strings + if all(isinstance(item, str) for item in allowed_relationships): + # Valid: all items are strings, no further checks needed. + return "string" + + # Validate list of 3-tuples and check if first/last elements are in allowed_nodes + if all( + isinstance(item, tuple) + and len(item) == 3 + and all(isinstance(subitem, str) for subitem in item) + and item[0] in allowed_nodes # type: ignore + and item[2] in allowed_nodes # type: ignore + for item in allowed_relationships + ): + # all items are 3-tuples, and the first/last elements are in allowed_nodes. + return "tuple" + + # If the input doesn't match any of the valid cases, raise a ValueError + raise ValueError( + "`allowed_relationships` must be list of strings or a list of 3-item tuples. " + "For tuples, the first and last elements must be in the `allowed_nodes` list." + ) + + class LLMGraphTransformer: """Transform documents into graph-based documents using a LLM. @@ -676,13 +759,18 @@ def __init__( self, llm: BaseLanguageModel, allowed_nodes: List[str] = [], - allowed_relationships: List[str] = [], + allowed_relationships: Union[List[str], List[Tuple[str, str, str]]] = [], prompt: Optional[ChatPromptTemplate] = None, strict_mode: bool = True, node_properties: Union[bool, List[str]] = False, relationship_properties: Union[bool, List[str]] = False, ignore_tool_usage: bool = False, ) -> None: + # Validate and check allowed relationships input + self._relationship_type = validate_and_get_relationship_type( + allowed_relationships, allowed_nodes + ) + self.allowed_nodes = allowed_nodes self.allowed_relationships = allowed_relationships self.strict_mode = strict_mode @@ -710,7 +798,7 @@ def __init__( "Please install it with `pip install json-repair`." ) prompt = prompt or create_unstructured_prompt( - allowed_nodes, allowed_relationships + allowed_nodes, allowed_relationships, self._relationship_type ) self.chain = prompt | llm else: @@ -725,6 +813,7 @@ def __init__( node_properties, llm_type, relationship_properties, + self._relationship_type, ) structured_llm = llm.with_structured_output(schema, include_raw=True) prompt = prompt or default_prompt @@ -792,12 +881,30 @@ def process_response( and rel.target.type.lower() in lower_allowed_nodes ] if self.allowed_relationships: - relationships = [ - rel - for rel in relationships - if rel.type.lower() - in [el.lower() for el in self.allowed_relationships] - ] + # Filter by type and direction + if self._relationship_type == "tuple": + relationships = [ + rel + for rel in relationships + if ( + ( + rel.source.type.lower(), + rel.type.lower(), + rel.target.type.lower(), + ) + in [ # type: ignore + (s_t.lower(), r_t.lower(), t_t.lower()) + for s_t, r_t, t_t in self.allowed_relationships + ] + ) + ] + else: # Filter by type only + relationships = [ + rel + for rel in relationships + if rel.type.lower() + in [el.lower() for el in self.allowed_relationships] # type: ignore + ] return GraphDocument(nodes=nodes, relationships=relationships, source=document) @@ -875,12 +982,30 @@ async def aprocess_response( and rel.target.type.lower() in lower_allowed_nodes ] if self.allowed_relationships: - relationships = [ - rel - for rel in relationships - if rel.type.lower() - in [el.lower() for el in self.allowed_relationships] - ] + # Filter by type and direction + if self._relationship_type == "tuple": + relationships = [ + rel + for rel in relationships + if ( + ( + rel.source.type.lower(), + rel.type.lower(), + rel.target.type.lower(), + ) + in [ # type: ignore + (s_t.lower(), r_t.lower(), t_t.lower()) + for s_t, r_t, t_t in self.allowed_relationships + ] + ) + ] + else: # Filter by type only + relationships = [ + rel + for rel in relationships + if rel.type.lower() + in [el.lower() for el in self.allowed_relationships] # type: ignore + ] return GraphDocument(nodes=nodes, relationships=relationships, source=document)