diff --git a/CHANGELOG.md b/CHANGELOG.md index cd7144a1..3bc00e9a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## Next +### Added +- Made `relations` and `potential_schema` optional in `SchemaBuilder`. + ## 1.1.0 ### Added diff --git a/README.md b/README.md index 8b43b3f7..9ee88ba8 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ kg_builder = SimpleKGPipeline( # Run the pipeline on a piece of text text = ( - "The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House" + "The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House " "Atreides, an aristocratic family that rules the planet Caladan." ) asyncio.run(kg_builder.run_async(text=text)) @@ -164,7 +164,7 @@ embedder = OpenAIEmbeddings(model="text-embedding-3-large") # Generate an embedding for some text text = ( - "The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House" + "The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House " "Atreides, an aristocratic family that rules the planet Caladan." ) vector = embedder.embed_query(text) diff --git a/src/neo4j_graphrag/experimental/components/schema.py b/src/neo4j_graphrag/experimental/components/schema.py index c82a7d5f..64e908ed 100644 --- a/src/neo4j_graphrag/experimental/components/schema.py +++ b/src/neo4j_graphrag/experimental/components/schema.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Dict, List, Literal, Tuple +from typing import Any, Dict, List, Literal, Optional, Tuple from pydantic import BaseModel, ValidationError, model_validator, validate_call @@ -72,8 +72,8 @@ class SchemaConfig(DataModel): """ entities: Dict[str, Dict[str, Any]] - relations: Dict[str, Dict[str, Any]] - potential_schema: List[Tuple[str, str, str]] + relations: Optional[Dict[str, Dict[str, Any]]] + potential_schema: Optional[List[Tuple[str, str, str]]] @model_validator(mode="before") def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: @@ -81,19 +81,24 @@ def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]: relations = data.get("relations", {}).keys() potential_schema = data.get("potential_schema", []) - for entity1, relation, entity2 in potential_schema: - if entity1 not in entities: + if potential_schema: + if not relations: raise SchemaValidationError( - f"Entity '{entity1}' is not defined in the provided entities." - ) - if relation not in relations: - raise SchemaValidationError( - f"Relation '{relation}' is not defined in the provided relations." - ) - if entity2 not in entities: - raise SchemaValidationError( - f"Entity '{entity2}' is not defined in the provided entities." + "Relations must also be provided when using a potential schema." ) + for entity1, relation, entity2 in potential_schema: + if entity1 not in entities: + raise SchemaValidationError( + f"Entity '{entity1}' is not defined in the provided entities." + ) + if relation not in relations: + raise SchemaValidationError( + f"Relation '{relation}' is not defined in the provided relations." + ) + if entity2 not in entities: + raise SchemaValidationError( + f"Entity '{entity2}' is not defined in the provided entities." + ) return data @@ -160,8 +165,8 @@ class SchemaBuilder(Component): @staticmethod def create_schema_model( entities: List[SchemaEntity], - relations: List[SchemaRelation], - potential_schema: List[Tuple[str, str, str]], + relations: Optional[List[SchemaRelation]] = None, + potential_schema: Optional[List[Tuple[str, str, str]]] = None, ) -> SchemaConfig: """ Creates a SchemaConfig object from Lists of Entity and Relation objects @@ -176,9 +181,11 @@ def create_schema_model( SchemaConfig: A configured schema object. """ entity_dict = {entity.label: entity.model_dump() for entity in entities} - relation_dict = { - relation.label: relation.model_dump() for relation in relations - } + relation_dict = ( + {relation.label: relation.model_dump() for relation in relations} + if relations + else {} + ) try: return SchemaConfig( @@ -193,8 +200,8 @@ def create_schema_model( async def run( self, entities: List[SchemaEntity], - relations: List[SchemaRelation], - potential_schema: List[Tuple[str, str, str]], + relations: Optional[List[SchemaRelation]] = None, + potential_schema: Optional[List[Tuple[str, str, str]]] = None, ) -> SchemaConfig: """ Asynchronously constructs and returns a SchemaConfig object. diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 94032b75..3c4d380d 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -167,7 +167,7 @@ class ERExtractionTemplate(PromptTemplate): {{"nodes": [ {{"id": "0", "label": "Person", "properties": {{"name": "John"}} }}], "relationships": [{{"type": "KNOWS", "start_node_id": "0", "end_node_id": "1", "properties": {{"since": "2024-08-01"}} }}] }} -Use only fhe following nodes and relationships (if provided): +Use only the following nodes and relationships (if provided): {schema} Assign a unique ID (string) to each node, and reuse it to define relationships. diff --git a/tests/unit/experimental/components/test_schema.py b/tests/unit/experimental/components/test_schema.py index 5d0e3450..6ff257a1 100644 --- a/tests/unit/experimental/components/test_schema.py +++ b/tests/unit/experimental/components/test_schema.py @@ -117,6 +117,7 @@ def test_create_schema_model_valid_data( ) assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." + assert schema_instance.relations assert ( schema_instance.relations["EMPLOYED_BY"]["description"] == "Indicates employment relationship." @@ -134,6 +135,7 @@ def test_create_schema_model_valid_data( {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, ] + assert schema_instance.potential_schema assert schema_instance.potential_schema == potential_schema @@ -159,6 +161,7 @@ def test_create_schema_model_missing_description( assert schema_instance.entities["ORGANIZATION"]["description"] == "" assert schema_instance.entities["AGE"]["description"] == "" + assert schema_instance.relations assert schema_instance.relations["ORGANIZED_BY"]["description"] == "" assert schema_instance.relations["ATTENDED_BY"]["description"] == "" @@ -242,6 +245,7 @@ async def test_run_method( ) assert schema.entities["AGE"]["description"] == "Age of a person in years." + assert schema.relations assert ( schema.relations["EMPLOYED_BY"]["description"] == "Indicates employment relationship." @@ -255,6 +259,7 @@ async def test_run_method( == "Indicates attendance at an event." ) + assert schema.potential_schema assert schema.potential_schema == potential_schema @@ -327,6 +332,7 @@ def test_create_schema_model_missing_properties( schema_instance.entities["AGE"]["properties"] == [] ), "Expected empty properties for AGE" + assert schema_instance.relations assert ( schema_instance.relations["EMPLOYED_BY"]["properties"] == [] ), "Expected empty properties for EMPLOYED_BY" @@ -336,3 +342,80 @@ def test_create_schema_model_missing_properties( assert ( schema_instance.relations["ATTENDED_BY"]["properties"] == [] ), "Expected empty properties for ATTENDED_BY" + + +def test_create_schema_model_no_potential_schema( + schema_builder: SchemaBuilder, + valid_entities: list[SchemaEntity], + valid_relations: list[SchemaRelation], +) -> None: + schema_instance = schema_builder.create_schema_model( + valid_entities, valid_relations + ) + + assert ( + schema_instance.entities["PERSON"]["description"] + == "An individual human being." + ) + assert schema_instance.entities["PERSON"]["properties"] == [ + {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, + {"description": "", "name": "name", "type": "STRING"}, + ] + assert ( + schema_instance.entities["ORGANIZATION"]["description"] + == "A structured group of people with a common purpose." + ) + assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." + + assert schema_instance.relations + assert ( + schema_instance.relations["EMPLOYED_BY"]["description"] + == "Indicates employment relationship." + ) + assert ( + schema_instance.relations["ORGANIZED_BY"]["description"] + == "Indicates organization responsible for an event." + ) + assert ( + schema_instance.relations["ATTENDED_BY"]["description"] + == "Indicates attendance at an event." + ) + assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [ + {"description": "", "name": "start_time", "type": "LOCAL_DATETIME"}, + {"description": "", "name": "end_time", "type": "LOCAL_DATETIME"}, + ] + + +def test_create_schema_model_no_relations_or_potential_schema( + schema_builder: SchemaBuilder, + valid_entities: list[SchemaEntity], +) -> None: + schema_instance = schema_builder.create_schema_model(valid_entities) + + assert ( + schema_instance.entities["PERSON"]["description"] + == "An individual human being." + ) + assert schema_instance.entities["PERSON"]["properties"] == [ + {"description": "", "name": "birth date", "type": "ZONED_DATETIME"}, + {"description": "", "name": "name", "type": "STRING"}, + ] + assert ( + schema_instance.entities["ORGANIZATION"]["description"] + == "A structured group of people with a common purpose." + ) + assert schema_instance.entities["AGE"]["description"] == "Age of a person in years." + + +def test_create_schema_model_missing_relations( + schema_builder: SchemaBuilder, + valid_entities: list[SchemaEntity], + potential_schema: list[tuple[str, str, str]], +) -> None: + with pytest.raises(SchemaValidationError) as exc_info: + schema_builder.create_schema_model( + entities=valid_entities, potential_schema=potential_schema + ) + assert "Relations must also be provided when using a potential schema." in str( + exc_info.value + ), "Should fail due to missing relations"