Skip to content

Commit

Permalink
Add support for external vector integration via Weaviate
Browse files Browse the repository at this point in the history
  • Loading branch information
oskarhane committed May 16, 2024
1 parent e58d68e commit 35136f5
Show file tree
Hide file tree
Showing 16 changed files with 18,092 additions and 8 deletions.
332 changes: 330 additions & 2 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pre-commit = { version = "^3.6.2", python = "^3.9" }
coverage = "^7.4.3"
ruff = "^0.3.0"
langchain-openai = "^0.1.1"
weaviate-client = "^4.6.1"

[build-system]
requires = ["poetry-core"]
Expand Down
14 changes: 14 additions & 0 deletions src/neo4j_genai/neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@
"YIELD node, score"
)

MATCH_QUERY = (
"UNWIND $match_params AS match_param "
"WITH match_param[0] AS match_id_value, match_param[1] AS score "
"MATCH (node) "
"WHERE {match_query}"
)


def _get_hybrid_query() -> str:
return (
Expand Down Expand Up @@ -117,6 +124,7 @@ def get_search_query(
embedding_node_property: Optional[str] = None,
embedding_dimension: Optional[int] = None,
filters: Optional[dict[str, Any]] = None,
match_query: Optional[str] = None,
) -> tuple[str, dict[str, Any]]:
"""Build the search query, including pre-filtering if needed, and return clause.
Expand All @@ -130,6 +138,7 @@ def get_search_query(
embedding_node_property (str): the name of the property holding the embeddings
embedding_dimension (int): the dimension of the embeddings
filters (dict[str, Any]): filters used to pre-filter the nodes before vector search
match_query: Optional[str]: the query to use to match the search results
Returns:
tuple[str, dict[str, Any]]: query and parameters
Expand All @@ -144,6 +153,11 @@ def get_search_query(
query, params = _get_vector_query(
filters, node_label, embedding_node_property, embedding_dimension
)
elif search_type == SearchType.MATCH:
if not match_query:
raise ValueError("Match query is required for MATCH search type")
query = MATCH_QUERY.format(match_query=match_query)
params = {}
else:
raise ValueError(f"Search type is not supported: {search_type}")
query_tail = _get_query_tail(
Expand Down
37 changes: 35 additions & 2 deletions src/neo4j_genai/retrievers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any

from typing import Optional, Any
import neo4j


Expand Down Expand Up @@ -75,3 +74,37 @@ def _fetch_index_infos(self):
self._node_label = result["labels"][0]
self._embedding_node_property = result["properties"][0]
self._embedding_dimension = result["dimensions"]


class ExternalRetriever(ABC):
"""
Abstract class for External Vector Stores
"""

def __init__(self):
pass

@property
@abstractmethod
def id_property_external(self):
pass

@property
@abstractmethod
def id_property_neo4j(self):
pass

@abstractmethod
def search(
self,
query_vector: Optional[list[float]] = None,
query_text: Optional[str] = None,
top_k: int = 5,
) -> list[neo4j.Record]:
"""
Returns:
list[neo4j.Record]: List of Neo4j Records
"""
pass
18 changes: 18 additions & 0 deletions src/neo4j_genai/retrievers/external/weaviate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .weaviate import WeaviateNeo4jRetriever

__all__ = ["WeaviateNeo4jRetriever"]
54 changes: 54 additions & 0 deletions src/neo4j_genai/retrievers/external/weaviate/examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
### Start services locally

This is a manual task you need to do in the terminal.

```bash
docker run \
--name testweaviate \
--rm \
-p8080:8080 -p 50051:50051 \
cr.weaviate.io/semitechnologies/weaviate:1.24.11

docker run \
--name testneo4j \
--rm \
-p7474:7474 -p7687:7687 \
--env NEO4J_ACCEPT_LICENSE_AGREEMENT=eval \
--env NEO4J_AUTH=neo4j/password \
neo4j:enterprise
```

To run Weaviate with OpenAI Vectorizer enabled

```bash
docker run \
--name testweaviate \
--rm \
-p8080:8080 -p 50051:50051 \
--env ENABLE_MODULES=text2vec-openai \
--env DEFAULT_VECTORIZER_MODULE=text2vec-openai \
cr.weaviate.io/semitechnologies/weaviate:1.24.11
```

### Write data (once)

Run this from the project root to write data to both dbs

```
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/populate_dbs.py
```

### Search

To run the text search examples you'd need to create a `.env` file and add a variable named `OPENAI_API_KEY=<your-api-key>` inside.

```
# search by vector
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/vector_search.py
# search by text, with embeddings generated locally (via embedder argument)
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/text_search_local_embedder.py
# search by text, with embeddings generated on the Weaviate side, via configured vectorizer
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/text_search_remote_embedder.py
```
161 changes: 161 additions & 0 deletions src/neo4j_genai/retrievers/external/weaviate/examples/populate_dbs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import weaviate.classes as wvc
import weaviate
from neo4j import GraphDatabase
import os.path
import json
import hashlib

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
NEO4J_URL = "neo4j://localhost:7687"
NEO4J_AUTH = ("neo4j", "password")


def main():
neo4j_objects, w_question_objs = build_data_objects()
with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver:
with weaviate.connect_to_local() as w_client:
w_client.collections.create(
"Question",
vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(
model="ada", model_version="002"
),
vector_index_config=wvc.config.Configure.VectorIndex.hnsw(
distance_metric=wvc.config.VectorDistances.COSINE # select prefered distance metric
),
properties=[
wvc.config.Property(
name="neo4j_id", data_type=wvc.config.DataType.TEXT
),
],
)

# Populate Weaviate
populate_weaviate(w_client, w_question_objs)

# Populate Neo4j
populate_neo4j(neo4j_driver, neo4j_objects)


def populate_neo4j(neo4j_driver, neo4j_objs):
question_nodes = list(
filter(lambda x: x["label"] == "Question", neo4j_objs["nodes"])
)
answer_nodes = list(filter(lambda x: x["label"] == "Answer", neo4j_objs["nodes"]))
category_nodes = list(
filter(lambda x: x["label"] == "Category", neo4j_objs["nodes"])
)
belongs_to_relationships = list(
filter(lambda x: x["type"] == "BELONGS_TO", neo4j_objs["relationships"])
)
has_answer_relationships = list(
filter(lambda x: x["type"] == "HAS_ANSWER", neo4j_objs["relationships"])
)
question_nodes_cypher = "UNWIND $nodes as node MERGE (n:Question {id: node.properties.id}) ON CREATE SET n = node.properties"
answer_nodes_cypher = "UNWIND $nodes as node MERGE (n:Answer {id: node.properties.id}) ON CREATE SET n = node.properties"
category_nodes_cypher = (
"UNWIND $nodes as node MERGE (n:Category {id: node.id}) ON CREATE SET n = node"
)
belongs_to_relationships_cypher = "UNWIND $relationships as rel MATCH (q:Question {id: rel.start_node_id}), (c:Category {id: rel.end_node_id}) MERGE (q)-[r:BELONGS_TO]->(c)"
has_answer_relationships_cypher = "UNWIND $relationships as rel MATCH (q:Question {id: rel.start_node_id}), (a:Answer {id: rel.end_node_id}) MERGE (q)-[r:HAS_ANSWER]->(a)"
neo4j_driver.execute_query(question_nodes_cypher, {"nodes": question_nodes})
neo4j_driver.execute_query(answer_nodes_cypher, {"nodes": answer_nodes})
neo4j_driver.execute_query(category_nodes_cypher, {"nodes": category_nodes})
neo4j_driver.execute_query(
belongs_to_relationships_cypher, {"relationships": belongs_to_relationships}
)
neo4j_driver.execute_query(
has_answer_relationships_cypher, {"relationships": has_answer_relationships}
)


def populate_weaviate(w_client, w_question_objs):
questions = w_client.collections.get("Question")
questions.data.insert_many(w_question_objs)


def build_data_objects():
# read file from disk
# this file is from https://github.com/weaviate-tutorials/quickstart/tree/main/data
# MIT License
file_name = os.path.join(
BASE_DIR,
"../../../../../tests/e2e/data/jeopardy_tiny_with_vectors_all-OpenAI-ada-002.json",
)
with open(file_name, "r") as f:
data = json.load(f)

w_question_objs = list()
neo4j_objs = {"nodes": [], "relationships": []}

# only unique categories and ids for them
unique_categories_list = list(set([c["Category"] for c in data]))
unique_categories = [
{"label": "Category", "name": c, "id": c} for c in unique_categories_list
]
neo4j_objs["nodes"] += unique_categories

for i, d in enumerate(data):
id = hashlib.md5(d["Question"].encode()).hexdigest()
neo4j_objs["nodes"].append(
{
"label": "Question",
"properties": {
"id": f"question_{id}",
"question": d["Question"],
},
}
)
neo4j_objs["nodes"].append(
{
"label": "Answer",
"properties": {
"id": f"answer_{id}",
"answer": d["Answer"],
},
}
)
neo4j_objs["relationships"].append(
{
"start_node_id": f"question_{id}",
"end_node_id": f"answer_{id}",
"type": "HAS_ANSWER",
"properties": {},
}
)
neo4j_objs["relationships"].append(
{
"start_node_id": f"question_{id}",
"end_node_id": d["Category"],
"type": "BELONGS_TO",
"properties": {},
}
)
w_question_objs.append(
wvc.data.DataObject(
properties={
"neo4j_id": f"question_{id}",
},
vector=d["vector"],
)
)

return neo4j_objs, w_question_objs


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
# #
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# #
# https://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from neo4j_genai.retrievers.external.weaviate import WeaviateNeo4jRetriever
from neo4j import GraphDatabase
import weaviate
from langchain_openai import OpenAIEmbeddings

NEO4J_URL = "neo4j://localhost:7687"
NEO4J_AUTH = ("neo4j", "password")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")


def main():
neo4j_driver = GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH)
with weaviate.connect_to_local() as w_client:
embedder = OpenAIEmbeddings(
api_key=OPENAI_API_KEY, model="text-embedding-ada-002"
)
retriever = WeaviateNeo4jRetriever(
driver=neo4j_driver,
client=w_client,
collection="Question",
id_property_external="neo4j_id",
id_property_neo4j="id",
embedder=embedder,
)

res = retriever.search(query_text="biology", top_k=2)
print(res)
neo4j_driver.close()


if __name__ == "__main__":
main()
Loading

0 comments on commit 35136f5

Please sign in to comment.