-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #58 from dgraph-io/raphael/vectors
Add Vector notebooks and product data
- Loading branch information
Showing
6 changed files
with
592 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Vector datasets | ||
|
||
This section contains datasets and script to test vector indexing and similarity search. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,356 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "ZtVJnO3SzTgt" | ||
}, | ||
"source": [ | ||
"\n", | ||
"Notebook to explore data similarity in DGraph cluster\n", | ||
"\n", | ||
"\n", | ||
"**pre-requesite**\n", | ||
"- Dgraph\n", | ||
" - Get a [Dgraph Cloud account](https://cloud.dgraph.io/)\n", | ||
" - Have your account user name and password available\n", | ||
" - Have a Dgraph cluster running in your Dgraph Cloud account\n", | ||
" - Obtain the GraphQL Endpoint of the Dgraph cluster from the [cloud dashboard](https://cloud.dgraph.io/_/dashboard)\n", | ||
" - Obtain an Admin API key for the Dgraph Cluster from the [settings](https://cloud.dgraph.io/_/settings?tab=api-keys) tab.\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
" The first step is to import the packages needed.\n", | ||
"\n", | ||
"- ``pydgraph``, the official [python client library for Dgraph Query Language](https://dgraph.io/docs/dql/clients/python/)\n", | ||
"- ``GraphqlClient``, a GraphQL client to invoke the GraphQL API generated from your schema and the GraphQL admin API of Dgraph.\n", | ||
"\n", | ||
"**Make sure to update the endpoints with the correct values for your Dgraph cluster!**\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## use a local Learning Environment\n", | ||
"\n", | ||
"docker pull dgraph/standalone\n", | ||
"docker run --name dgraph_learn -d -p \"8080:8080\" -p \"9080:9080\" -v <local path to /dgraph-data>:/dgraph dgraph/standalone:latest\n", | ||
"\n", | ||
"### load data set\n", | ||
"cp products.rdf.gz <local path to /dgraph-data>\n", | ||
"cp products.schema <local path to /dgraph-data>\n", | ||
"\n", | ||
"docker exec -it dgraph_learn dgraph live -f /dgraph/products.rdf.gz -s /dgraph/products.schema" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "_qEDD3UC7uqF" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"!pip install pydgraph python-graphql-client ipycytoscape\n", | ||
"import pydgraph\n", | ||
"import json\n", | ||
"import base64\n", | ||
"import getpass\n", | ||
"import pandas as pd \n", | ||
"\n", | ||
"from python_graphql_client import GraphqlClient\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# copy your Dgraph cloud endpoints\n", | ||
"# The GraphQL endpoint is found at https://cloud.dgraph.io/_/dashboard\n", | ||
"# dgraph_graphql_endpoint = \"https://black-grass.us-east-1.aws.cloud.dgraph.io/graphql\"\n", | ||
"# dgraph_graphql_endpoint = \"https://withered-bird.us-east-1.aws.cloud.dgraph.io/graphql\"\n", | ||
"dgraph_graphql_endpoint = \"http://localhost:8080/graphql\"\n", | ||
"\n", | ||
"\n", | ||
"# The gRPC endpoint is found at https://cloud.dgraph.io/_/settings\n", | ||
"# dgraph_grpc = \"black-grass.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n", | ||
"dgraph_grpc = \"withered-bird.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n", | ||
"dgraph_grpc = \"localhost:9080\"\n", | ||
"\n", | ||
"# graph admin endpoint is /admin\n", | ||
"dgraph_graphql_admin = dgraph_graphql_endpoint.replace(\"/graphql\", \"/admin\")\n", | ||
"# graph health endpoint is /health\n", | ||
"dgraph_graphql_health = dgraph_graphql_endpoint.replace(\"/graphql\", \"/health\")\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": { | ||
"id": "_qnQFABQNBXO" | ||
}, | ||
"source": [ | ||
"Enter your credentials and test the different clients\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "E7EvVHCqXzfV" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"# Cloud credentials\n", | ||
"# we need the cloud login credential to upload the Lambda code.\n", | ||
"# we need the an Admin API key generated at https://cloud.dgraph.io/_/settings?tab=api-keys for DQL alter and query\n", | ||
"\n", | ||
"\n", | ||
"API_KEY = getpass.getpass(\"DGRAPH API KEY?\")\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"# DQL Client\n", | ||
"if dgraph_grpc.find(\"cloud\") > 0:\n", | ||
" client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,API_KEY )\n", | ||
"else:\n", | ||
" client_stub = pydgraph.DgraphClientStub(addr=dgraph_grpc) \n", | ||
"\n", | ||
"client = pydgraph.DgraphClient(client_stub)\n", | ||
"\n", | ||
"# GraphQL client and admin client\n", | ||
"gql_client = GraphqlClient(endpoint=dgraph_graphql_endpoint)\n", | ||
"headers = { \"Dg-Auth\": API_KEY }\n", | ||
"gql_admin_client = GraphqlClient(endpoint=dgraph_graphql_admin, headers=headers)\n", | ||
"gql_health_client = GraphqlClient(endpoint=dgraph_graphql_health)\n", | ||
"\n", | ||
"#\n", | ||
"# Testing the connection\n", | ||
"#\n", | ||
"data = gql_health_client.execute(query=\"\")\n", | ||
"if 'errors' in data:\n", | ||
" raise Exception(data['errors'][0]['message'])\n", | ||
"\n", | ||
"print(\"Check cluster health:\", json.dumps(data, indent=2))\n", | ||
"\n", | ||
"#\n", | ||
"# Testing the DQL connection\n", | ||
"#\n", | ||
"txn = client.txn(read_only=True)\n", | ||
"query = \"schema{}\"\n", | ||
"res = txn.query(query)\n", | ||
"dqlschema = json.loads(res.json)\n", | ||
"txn.discard()\n", | ||
"print(\"get DQL schema - succeeded\")\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Deploy the GraphQL Schema\n", | ||
"\n", | ||
"graphql_schema = \"\"\"\n", | ||
"type Product {\n", | ||
" id: ID!\n", | ||
" name: String @id @search(by: [hash,term])\n", | ||
" embedding: [Float!] @embedding\n", | ||
"}\n", | ||
"\"\"\"\n", | ||
"mutation = \"\"\"\n", | ||
"mutation($sch: String!) {\n", | ||
" updateGQLSchema(input: { set: { schema: $sch}})\n", | ||
" {\n", | ||
" gqlSchema {\n", | ||
" schema\n", | ||
" generatedSchema\n", | ||
" }\n", | ||
" }\n", | ||
"}\n", | ||
"\"\"\"\n", | ||
"variables = {\"sch\": graphql_schema}\n", | ||
"schemadata = gql_admin_client.execute(query=mutation, variables=variables)\n", | ||
"print(\"GraphQL Schema after Update\")\n", | ||
"print(schemadata['data']['updateGQLSchema']['gqlSchema']['schema'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# reset the index\n", | ||
"def reset_index(predicate, index):\n", | ||
" print(f\"remove index for {predicate}\")\n", | ||
" schema = f\"{predicate}: float32vector .\"\n", | ||
" op = pydgraph.Operation(schema=schema)\n", | ||
" alter = client.alter(op)\n", | ||
" print(alter)\n", | ||
" print(f\"create index for {predicate} {index}\")\n", | ||
" schema = f\"{predicate}: float32vector @index({index}) .\"\n", | ||
" op = pydgraph.Operation(schema=schema)\n", | ||
" alter = client.alter(op)\n", | ||
" print(alter) " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"data = pd.read_csv('products_with_embedding.csv.gz', compression='gzip') \n", | ||
"data['embedding'] = [json.loads(data.iloc[i]['embedding']) for i in range(len(data))]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "2-YUuHqv61tT" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# build a ground truth similarity index (KNN)\n", | ||
"from sklearn import metrics\n", | ||
"embeddings = data['embedding'].tolist()\n", | ||
"distances = metrics.pairwise_distances(embeddings, embeddings, metric='euclidean')\n", | ||
"knn_all = distances.argsort(axis=1)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"id": "UvkyYXf3VTWj" | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(data.iloc[3]['name'])\n", | ||
"data.iloc[knn_all[3][:10]]['name'].tolist()\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# get topK ANN from Graph - Approximate Nearest Neighbors\n", | ||
"def topK_ANN(name, k=10):\n", | ||
" query = \"\"\"\n", | ||
" query QuerySimilarById($name: String!, $k: Int!) {{\n", | ||
" list:querySimilar{}ById(by: embedding, topK: $k, name: $name) {{\n", | ||
" uid:id\n", | ||
" name:name\n", | ||
" }}\n", | ||
" }}\n", | ||
" \"\"\".format(\"Product\")\n", | ||
" variables={\"name\":name,\"k\":k}\n", | ||
" res = gql_client.execute(query, variables)\n", | ||
" return pd.json_normalize(res['data']['list'])\n", | ||
"\n", | ||
"topK_ANN(data.iloc[3]['name'], 10)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# compute precision\n", | ||
"def compute_item_precision(index, k, verbose = False):\n", | ||
" name = data.iloc[index]['name']\n", | ||
" if verbose:\n", | ||
" print(\"compute_precision: \",index,name)\n", | ||
" print(\"Ground Truth:\",name)\n", | ||
" knn = data.iloc[knn_all[index][:k]]['name'].tolist()\n", | ||
" if verbose:\n", | ||
" print(\"KNN:\",knn)\n", | ||
" knn = set(knn)\n", | ||
" ann = topK_ANN(name,k)['name'].tolist()\n", | ||
" if verbose:\n", | ||
" print(\"ANN:\",ann)\n", | ||
" ann = set(ann)\n", | ||
" precision_at_k = len(knn.intersection(ann))/len(knn)\n", | ||
" return precision_at_k\n", | ||
"\n", | ||
"# compute precision for first m products\n", | ||
"def average_precision(first = 10, k = 10):\n", | ||
" precision = 0.0\n", | ||
" m = first\n", | ||
" for i in range(first):\n", | ||
" p = compute_item_precision(i,k) \n", | ||
" precision += p\n", | ||
" precision = precision/m\n", | ||
" return precision\n", | ||
"print (compute_item_precision(3,5))\n", | ||
"print (average_precision(10,5))\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"for index in ['hnsw(metric: \"euclidian\")',\n", | ||
" 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"100\")',\n", | ||
" 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"400\")']:\n", | ||
" reset_index(\"Product.embedding\",index)\n", | ||
" print(\"index = \",index)\n", | ||
" for k in [1,3,5,10]:\n", | ||
" p_at_k = average_precision(100,k)\n", | ||
" print(\"\"\"Precision@{} = {}\"\"\".format(k, p_at_k))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
Oops, something went wrong.