Skip to content

Commit

Permalink
Makes relations and potential_schema optional in SchemaBuilder (neo4j…
Browse files Browse the repository at this point in the history
…#184)

* Fixed small typo in README

* Fixed small typo in KG creation prompt

* Made relations and potential schema optional in schema component

* Updated unit tests

* Updated changelog
  • Loading branch information
alexthomas93 authored Oct 16, 2024
1 parent 20b374d commit 06d9889
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## Next

### Added
- Made `relations` and `potential_schema` optional in `SchemaBuilder`.

## 1.1.0

### Added
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ kg_builder = SimpleKGPipeline(

# Run the pipeline on a piece of text
text = (
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House"
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House "
"Atreides, an aristocratic family that rules the planet Caladan."
)
asyncio.run(kg_builder.run_async(text=text))
Expand Down Expand Up @@ -164,7 +164,7 @@ embedder = OpenAIEmbeddings(model="text-embedding-3-large")

# Generate an embedding for some text
text = (
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House"
"The son of Duke Leto Atreides and the Lady Jessica, Paul is the heir of House "
"Atreides, an aristocratic family that rules the planet Caladan."
)
vector = embedder.embed_query(text)
Expand Down
49 changes: 28 additions & 21 deletions src/neo4j_graphrag/experimental/components/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Dict, List, Literal, Tuple
from typing import Any, Dict, List, Literal, Optional, Tuple

from pydantic import BaseModel, ValidationError, model_validator, validate_call

Expand Down Expand Up @@ -72,28 +72,33 @@ class SchemaConfig(DataModel):
"""

entities: Dict[str, Dict[str, Any]]
relations: Dict[str, Dict[str, Any]]
potential_schema: List[Tuple[str, str, str]]
relations: Optional[Dict[str, Dict[str, Any]]]
potential_schema: Optional[List[Tuple[str, str, str]]]

@model_validator(mode="before")
def check_schema(cls, data: Dict[str, Any]) -> Dict[str, Any]:
entities = data.get("entities", {}).keys()
relations = data.get("relations", {}).keys()
potential_schema = data.get("potential_schema", [])

for entity1, relation, entity2 in potential_schema:
if entity1 not in entities:
if potential_schema:
if not relations:
raise SchemaValidationError(
f"Entity '{entity1}' is not defined in the provided entities."
)
if relation not in relations:
raise SchemaValidationError(
f"Relation '{relation}' is not defined in the provided relations."
)
if entity2 not in entities:
raise SchemaValidationError(
f"Entity '{entity2}' is not defined in the provided entities."
"Relations must also be provided when using a potential schema."
)
for entity1, relation, entity2 in potential_schema:
if entity1 not in entities:
raise SchemaValidationError(
f"Entity '{entity1}' is not defined in the provided entities."
)
if relation not in relations:
raise SchemaValidationError(
f"Relation '{relation}' is not defined in the provided relations."
)
if entity2 not in entities:
raise SchemaValidationError(
f"Entity '{entity2}' is not defined in the provided entities."
)

return data

Expand Down Expand Up @@ -160,8 +165,8 @@ class SchemaBuilder(Component):
@staticmethod
def create_schema_model(
entities: List[SchemaEntity],
relations: List[SchemaRelation],
potential_schema: List[Tuple[str, str, str]],
relations: Optional[List[SchemaRelation]] = None,
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
) -> SchemaConfig:
"""
Creates a SchemaConfig object from Lists of Entity and Relation objects
Expand All @@ -176,9 +181,11 @@ def create_schema_model(
SchemaConfig: A configured schema object.
"""
entity_dict = {entity.label: entity.model_dump() for entity in entities}
relation_dict = {
relation.label: relation.model_dump() for relation in relations
}
relation_dict = (
{relation.label: relation.model_dump() for relation in relations}
if relations
else {}
)

try:
return SchemaConfig(
Expand All @@ -193,8 +200,8 @@ def create_schema_model(
async def run(
self,
entities: List[SchemaEntity],
relations: List[SchemaRelation],
potential_schema: List[Tuple[str, str, str]],
relations: Optional[List[SchemaRelation]] = None,
potential_schema: Optional[List[Tuple[str, str, str]]] = None,
) -> SchemaConfig:
"""
Asynchronously constructs and returns a SchemaConfig object.
Expand Down
2 changes: 1 addition & 1 deletion src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class ERExtractionTemplate(PromptTemplate):
{{"nodes": [ {{"id": "0", "label": "Person", "properties": {{"name": "John"}} }}],
"relationships": [{{"type": "KNOWS", "start_node_id": "0", "end_node_id": "1", "properties": {{"since": "2024-08-01"}} }}] }}
Use only fhe following nodes and relationships (if provided):
Use only the following nodes and relationships (if provided):
{schema}
Assign a unique ID (string) to each node, and reuse it to define relationships.
Expand Down
83 changes: 83 additions & 0 deletions tests/unit/experimental/components/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_create_schema_model_valid_data(
)
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."

assert schema_instance.relations
assert (
schema_instance.relations["EMPLOYED_BY"]["description"]
== "Indicates employment relationship."
Expand All @@ -134,6 +135,7 @@ def test_create_schema_model_valid_data(
{"description": "", "name": "end_time", "type": "LOCAL_DATETIME"},
]

assert schema_instance.potential_schema
assert schema_instance.potential_schema == potential_schema


Expand All @@ -159,6 +161,7 @@ def test_create_schema_model_missing_description(

assert schema_instance.entities["ORGANIZATION"]["description"] == ""
assert schema_instance.entities["AGE"]["description"] == ""
assert schema_instance.relations
assert schema_instance.relations["ORGANIZED_BY"]["description"] == ""
assert schema_instance.relations["ATTENDED_BY"]["description"] == ""

Expand Down Expand Up @@ -242,6 +245,7 @@ async def test_run_method(
)
assert schema.entities["AGE"]["description"] == "Age of a person in years."

assert schema.relations
assert (
schema.relations["EMPLOYED_BY"]["description"]
== "Indicates employment relationship."
Expand All @@ -255,6 +259,7 @@ async def test_run_method(
== "Indicates attendance at an event."
)

assert schema.potential_schema
assert schema.potential_schema == potential_schema


Expand Down Expand Up @@ -327,6 +332,7 @@ def test_create_schema_model_missing_properties(
schema_instance.entities["AGE"]["properties"] == []
), "Expected empty properties for AGE"

assert schema_instance.relations
assert (
schema_instance.relations["EMPLOYED_BY"]["properties"] == []
), "Expected empty properties for EMPLOYED_BY"
Expand All @@ -336,3 +342,80 @@ def test_create_schema_model_missing_properties(
assert (
schema_instance.relations["ATTENDED_BY"]["properties"] == []
), "Expected empty properties for ATTENDED_BY"


def test_create_schema_model_no_potential_schema(
schema_builder: SchemaBuilder,
valid_entities: list[SchemaEntity],
valid_relations: list[SchemaRelation],
) -> None:
schema_instance = schema_builder.create_schema_model(
valid_entities, valid_relations
)

assert (
schema_instance.entities["PERSON"]["description"]
== "An individual human being."
)
assert schema_instance.entities["PERSON"]["properties"] == [
{"description": "", "name": "birth date", "type": "ZONED_DATETIME"},
{"description": "", "name": "name", "type": "STRING"},
]
assert (
schema_instance.entities["ORGANIZATION"]["description"]
== "A structured group of people with a common purpose."
)
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."

assert schema_instance.relations
assert (
schema_instance.relations["EMPLOYED_BY"]["description"]
== "Indicates employment relationship."
)
assert (
schema_instance.relations["ORGANIZED_BY"]["description"]
== "Indicates organization responsible for an event."
)
assert (
schema_instance.relations["ATTENDED_BY"]["description"]
== "Indicates attendance at an event."
)
assert schema_instance.relations["EMPLOYED_BY"]["properties"] == [
{"description": "", "name": "start_time", "type": "LOCAL_DATETIME"},
{"description": "", "name": "end_time", "type": "LOCAL_DATETIME"},
]


def test_create_schema_model_no_relations_or_potential_schema(
schema_builder: SchemaBuilder,
valid_entities: list[SchemaEntity],
) -> None:
schema_instance = schema_builder.create_schema_model(valid_entities)

assert (
schema_instance.entities["PERSON"]["description"]
== "An individual human being."
)
assert schema_instance.entities["PERSON"]["properties"] == [
{"description": "", "name": "birth date", "type": "ZONED_DATETIME"},
{"description": "", "name": "name", "type": "STRING"},
]
assert (
schema_instance.entities["ORGANIZATION"]["description"]
== "A structured group of people with a common purpose."
)
assert schema_instance.entities["AGE"]["description"] == "Age of a person in years."


def test_create_schema_model_missing_relations(
schema_builder: SchemaBuilder,
valid_entities: list[SchemaEntity],
potential_schema: list[tuple[str, str, str]],
) -> None:
with pytest.raises(SchemaValidationError) as exc_info:
schema_builder.create_schema_model(
entities=valid_entities, potential_schema=potential_schema
)
assert "Relations must also be provided when using a potential schema." in str(
exc_info.value
), "Should fail due to missing relations"

0 comments on commit 06d9889

Please sign in to comment.