-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for external vector integration via Weaviate
- Loading branch information
Showing
16 changed files
with
18,092 additions
and
8 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from .weaviate import WeaviateNeo4jRetriever | ||
|
||
__all__ = ["WeaviateNeo4jRetriever"] |
54 changes: 54 additions & 0 deletions
54
src/neo4j_genai/retrievers/external/weaviate/examples/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
### Start services locally | ||
|
||
This is a manual task you need to do in the terminal. | ||
|
||
```bash | ||
docker run \ | ||
--name testweaviate \ | ||
--rm \ | ||
-p8080:8080 -p 50051:50051 \ | ||
cr.weaviate.io/semitechnologies/weaviate:1.24.11 | ||
|
||
docker run \ | ||
--name testneo4j \ | ||
--rm \ | ||
-p7474:7474 -p7687:7687 \ | ||
--env NEO4J_ACCEPT_LICENSE_AGREEMENT=eval \ | ||
--env NEO4J_AUTH=neo4j/password \ | ||
neo4j:enterprise | ||
``` | ||
|
||
To run Weaviate with OpenAI Vectorizer enabled | ||
|
||
```bash | ||
docker run \ | ||
--name testweaviate \ | ||
--rm \ | ||
-p8080:8080 -p 50051:50051 \ | ||
--env ENABLE_MODULES=text2vec-openai \ | ||
--env DEFAULT_VECTORIZER_MODULE=text2vec-openai \ | ||
cr.weaviate.io/semitechnologies/weaviate:1.24.11 | ||
``` | ||
|
||
### Write data (once) | ||
|
||
Run this from the project root to write data to both dbs | ||
|
||
``` | ||
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/populate_dbs.py | ||
``` | ||
|
||
### Search | ||
|
||
To run the text search examples you'd need to create a `.env` file and add a variable named `OPENAI_API_KEY=<your-api-key>` inside. | ||
|
||
``` | ||
# search by vector | ||
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/vector_search.py | ||
# search by text, with embeddings generated locally (via embedder argument) | ||
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/text_search_local_embedder.py | ||
# search by text, with embeddings generated on the Weaviate side, via configured vectorizer | ||
poetry run python src/neo4j_genai/retrievers/external/weaviate/examples/text_search_remote_embedder.py | ||
``` |
161 changes: 161 additions & 0 deletions
161
src/neo4j_genai/retrievers/external/weaviate/examples/populate_dbs.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import weaviate.classes as wvc | ||
import weaviate | ||
from neo4j import GraphDatabase | ||
import os.path | ||
import json | ||
import hashlib | ||
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | ||
NEO4J_URL = "neo4j://localhost:7687" | ||
NEO4J_AUTH = ("neo4j", "password") | ||
|
||
|
||
def main(): | ||
neo4j_objects, w_question_objs = build_data_objects() | ||
with GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) as neo4j_driver: | ||
with weaviate.connect_to_local() as w_client: | ||
w_client.collections.create( | ||
"Question", | ||
vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai( | ||
model="ada", model_version="002" | ||
), | ||
vector_index_config=wvc.config.Configure.VectorIndex.hnsw( | ||
distance_metric=wvc.config.VectorDistances.COSINE # select prefered distance metric | ||
), | ||
properties=[ | ||
wvc.config.Property( | ||
name="neo4j_id", data_type=wvc.config.DataType.TEXT | ||
), | ||
], | ||
) | ||
|
||
# Populate Weaviate | ||
populate_weaviate(w_client, w_question_objs) | ||
|
||
# Populate Neo4j | ||
populate_neo4j(neo4j_driver, neo4j_objects) | ||
|
||
|
||
def populate_neo4j(neo4j_driver, neo4j_objs): | ||
question_nodes = list( | ||
filter(lambda x: x["label"] == "Question", neo4j_objs["nodes"]) | ||
) | ||
answer_nodes = list(filter(lambda x: x["label"] == "Answer", neo4j_objs["nodes"])) | ||
category_nodes = list( | ||
filter(lambda x: x["label"] == "Category", neo4j_objs["nodes"]) | ||
) | ||
belongs_to_relationships = list( | ||
filter(lambda x: x["type"] == "BELONGS_TO", neo4j_objs["relationships"]) | ||
) | ||
has_answer_relationships = list( | ||
filter(lambda x: x["type"] == "HAS_ANSWER", neo4j_objs["relationships"]) | ||
) | ||
question_nodes_cypher = "UNWIND $nodes as node MERGE (n:Question {id: node.properties.id}) ON CREATE SET n = node.properties" | ||
answer_nodes_cypher = "UNWIND $nodes as node MERGE (n:Answer {id: node.properties.id}) ON CREATE SET n = node.properties" | ||
category_nodes_cypher = ( | ||
"UNWIND $nodes as node MERGE (n:Category {id: node.id}) ON CREATE SET n = node" | ||
) | ||
belongs_to_relationships_cypher = "UNWIND $relationships as rel MATCH (q:Question {id: rel.start_node_id}), (c:Category {id: rel.end_node_id}) MERGE (q)-[r:BELONGS_TO]->(c)" | ||
has_answer_relationships_cypher = "UNWIND $relationships as rel MATCH (q:Question {id: rel.start_node_id}), (a:Answer {id: rel.end_node_id}) MERGE (q)-[r:HAS_ANSWER]->(a)" | ||
neo4j_driver.execute_query(question_nodes_cypher, {"nodes": question_nodes}) | ||
neo4j_driver.execute_query(answer_nodes_cypher, {"nodes": answer_nodes}) | ||
neo4j_driver.execute_query(category_nodes_cypher, {"nodes": category_nodes}) | ||
neo4j_driver.execute_query( | ||
belongs_to_relationships_cypher, {"relationships": belongs_to_relationships} | ||
) | ||
neo4j_driver.execute_query( | ||
has_answer_relationships_cypher, {"relationships": has_answer_relationships} | ||
) | ||
|
||
|
||
def populate_weaviate(w_client, w_question_objs): | ||
questions = w_client.collections.get("Question") | ||
questions.data.insert_many(w_question_objs) | ||
|
||
|
||
def build_data_objects(): | ||
# read file from disk | ||
# this file is from https://github.com/weaviate-tutorials/quickstart/tree/main/data | ||
# MIT License | ||
file_name = os.path.join( | ||
BASE_DIR, | ||
"../../../../../tests/e2e/data/jeopardy_tiny_with_vectors_all-OpenAI-ada-002.json", | ||
) | ||
with open(file_name, "r") as f: | ||
data = json.load(f) | ||
|
||
w_question_objs = list() | ||
neo4j_objs = {"nodes": [], "relationships": []} | ||
|
||
# only unique categories and ids for them | ||
unique_categories_list = list(set([c["Category"] for c in data])) | ||
unique_categories = [ | ||
{"label": "Category", "name": c, "id": c} for c in unique_categories_list | ||
] | ||
neo4j_objs["nodes"] += unique_categories | ||
|
||
for i, d in enumerate(data): | ||
id = hashlib.md5(d["Question"].encode()).hexdigest() | ||
neo4j_objs["nodes"].append( | ||
{ | ||
"label": "Question", | ||
"properties": { | ||
"id": f"question_{id}", | ||
"question": d["Question"], | ||
}, | ||
} | ||
) | ||
neo4j_objs["nodes"].append( | ||
{ | ||
"label": "Answer", | ||
"properties": { | ||
"id": f"answer_{id}", | ||
"answer": d["Answer"], | ||
}, | ||
} | ||
) | ||
neo4j_objs["relationships"].append( | ||
{ | ||
"start_node_id": f"question_{id}", | ||
"end_node_id": f"answer_{id}", | ||
"type": "HAS_ANSWER", | ||
"properties": {}, | ||
} | ||
) | ||
neo4j_objs["relationships"].append( | ||
{ | ||
"start_node_id": f"question_{id}", | ||
"end_node_id": d["Category"], | ||
"type": "BELONGS_TO", | ||
"properties": {}, | ||
} | ||
) | ||
w_question_objs.append( | ||
wvc.data.DataObject( | ||
properties={ | ||
"neo4j_id": f"question_{id}", | ||
}, | ||
vector=d["vector"], | ||
) | ||
) | ||
|
||
return neo4j_objs, w_question_objs | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
48 changes: 48 additions & 0 deletions
48
src/neo4j_genai/retrievers/external/weaviate/examples/text_search_local_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) "Neo4j" | ||
# Neo4j Sweden AB [https://neo4j.com] | ||
# # | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# # | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# # | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
from neo4j_genai.retrievers.external.weaviate import WeaviateNeo4jRetriever | ||
from neo4j import GraphDatabase | ||
import weaviate | ||
from langchain_openai import OpenAIEmbeddings | ||
|
||
NEO4J_URL = "neo4j://localhost:7687" | ||
NEO4J_AUTH = ("neo4j", "password") | ||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | ||
|
||
|
||
def main(): | ||
neo4j_driver = GraphDatabase.driver(NEO4J_URL, auth=NEO4J_AUTH) | ||
with weaviate.connect_to_local() as w_client: | ||
embedder = OpenAIEmbeddings( | ||
api_key=OPENAI_API_KEY, model="text-embedding-ada-002" | ||
) | ||
retriever = WeaviateNeo4jRetriever( | ||
driver=neo4j_driver, | ||
client=w_client, | ||
collection="Question", | ||
id_property_external="neo4j_id", | ||
id_property_neo4j="id", | ||
embedder=embedder, | ||
) | ||
|
||
res = retriever.search(query_text="biology", top_k=2) | ||
print(res) | ||
neo4j_driver.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.