Skip to content

Commit

Permalink
Merge pull request #58 from dgraph-io/raphael/vectors
Browse files Browse the repository at this point in the history
Add Vector notebooks and product data
  • Loading branch information
harshil-goel authored Sep 12, 2024
2 parents 57bff48 + 189157a commit d47fe8c
Show file tree
Hide file tree
Showing 6 changed files with 592 additions and 0 deletions.
3 changes: 3 additions & 0 deletions vector/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Vector datasets

This section contains datasets and script to test vector indexing and similarity search.
356 changes: 356 additions & 0 deletions vector/ann_eval.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,356 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "ZtVJnO3SzTgt"
},
"source": [
"\n",
"Notebook to explore data similarity in DGraph cluster\n",
"\n",
"\n",
"**pre-requesite**\n",
"- Dgraph\n",
" - Get a [Dgraph Cloud account](https://cloud.dgraph.io/)\n",
" - Have your account user name and password available\n",
" - Have a Dgraph cluster running in your Dgraph Cloud account\n",
" - Obtain the GraphQL Endpoint of the Dgraph cluster from the [cloud dashboard](https://cloud.dgraph.io/_/dashboard)\n",
" - Obtain an Admin API key for the Dgraph Cluster from the [settings](https://cloud.dgraph.io/_/settings?tab=api-keys) tab.\n",
"\n",
"\n",
"\n",
" The first step is to import the packages needed.\n",
"\n",
"- ``pydgraph``, the official [python client library for Dgraph Query Language](https://dgraph.io/docs/dql/clients/python/)\n",
"- ``GraphqlClient``, a GraphQL client to invoke the GraphQL API generated from your schema and the GraphQL admin API of Dgraph.\n",
"\n",
"**Make sure to update the endpoints with the correct values for your Dgraph cluster!**\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## use a local Learning Environment\n",
"\n",
"docker pull dgraph/standalone\n",
"docker run --name dgraph_learn -d -p \"8080:8080\" -p \"9080:9080\" -v <local path to /dgraph-data>:/dgraph dgraph/standalone:latest\n",
"\n",
"### load data set\n",
"cp products.rdf.gz <local path to /dgraph-data>\n",
"cp products.schema <local path to /dgraph-data>\n",
"\n",
"docker exec -it dgraph_learn dgraph live -f /dgraph/products.rdf.gz -s /dgraph/products.schema"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_qEDD3UC7uqF"
},
"outputs": [],
"source": [
"!pip install pydgraph python-graphql-client ipycytoscape\n",
"import pydgraph\n",
"import json\n",
"import base64\n",
"import getpass\n",
"import pandas as pd \n",
"\n",
"from python_graphql_client import GraphqlClient\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# copy your Dgraph cloud endpoints\n",
"# The GraphQL endpoint is found at https://cloud.dgraph.io/_/dashboard\n",
"# dgraph_graphql_endpoint = \"https://black-grass.us-east-1.aws.cloud.dgraph.io/graphql\"\n",
"# dgraph_graphql_endpoint = \"https://withered-bird.us-east-1.aws.cloud.dgraph.io/graphql\"\n",
"dgraph_graphql_endpoint = \"http://localhost:8080/graphql\"\n",
"\n",
"\n",
"# The gRPC endpoint is found at https://cloud.dgraph.io/_/settings\n",
"# dgraph_grpc = \"black-grass.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n",
"dgraph_grpc = \"withered-bird.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n",
"dgraph_grpc = \"localhost:9080\"\n",
"\n",
"# graph admin endpoint is /admin\n",
"dgraph_graphql_admin = dgraph_graphql_endpoint.replace(\"/graphql\", \"/admin\")\n",
"# graph health endpoint is /health\n",
"dgraph_graphql_health = dgraph_graphql_endpoint.replace(\"/graphql\", \"/health\")\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_qnQFABQNBXO"
},
"source": [
"Enter your credentials and test the different clients\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "E7EvVHCqXzfV"
},
"outputs": [],
"source": [
"\n",
"# Cloud credentials\n",
"# we need the cloud login credential to upload the Lambda code.\n",
"# we need the an Admin API key generated at https://cloud.dgraph.io/_/settings?tab=api-keys for DQL alter and query\n",
"\n",
"\n",
"API_KEY = getpass.getpass(\"DGRAPH API KEY?\")\n",
"\n",
"\n",
"\n",
"# DQL Client\n",
"if dgraph_grpc.find(\"cloud\") > 0:\n",
" client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,API_KEY )\n",
"else:\n",
" client_stub = pydgraph.DgraphClientStub(addr=dgraph_grpc) \n",
"\n",
"client = pydgraph.DgraphClient(client_stub)\n",
"\n",
"# GraphQL client and admin client\n",
"gql_client = GraphqlClient(endpoint=dgraph_graphql_endpoint)\n",
"headers = { \"Dg-Auth\": API_KEY }\n",
"gql_admin_client = GraphqlClient(endpoint=dgraph_graphql_admin, headers=headers)\n",
"gql_health_client = GraphqlClient(endpoint=dgraph_graphql_health)\n",
"\n",
"#\n",
"# Testing the connection\n",
"#\n",
"data = gql_health_client.execute(query=\"\")\n",
"if 'errors' in data:\n",
" raise Exception(data['errors'][0]['message'])\n",
"\n",
"print(\"Check cluster health:\", json.dumps(data, indent=2))\n",
"\n",
"#\n",
"# Testing the DQL connection\n",
"#\n",
"txn = client.txn(read_only=True)\n",
"query = \"schema{}\"\n",
"res = txn.query(query)\n",
"dqlschema = json.loads(res.json)\n",
"txn.discard()\n",
"print(\"get DQL schema - succeeded\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Deploy the GraphQL Schema\n",
"\n",
"graphql_schema = \"\"\"\n",
"type Product {\n",
" id: ID!\n",
" name: String @id @search(by: [hash,term])\n",
" embedding: [Float!] @embedding\n",
"}\n",
"\"\"\"\n",
"mutation = \"\"\"\n",
"mutation($sch: String!) {\n",
" updateGQLSchema(input: { set: { schema: $sch}})\n",
" {\n",
" gqlSchema {\n",
" schema\n",
" generatedSchema\n",
" }\n",
" }\n",
"}\n",
"\"\"\"\n",
"variables = {\"sch\": graphql_schema}\n",
"schemadata = gql_admin_client.execute(query=mutation, variables=variables)\n",
"print(\"GraphQL Schema after Update\")\n",
"print(schemadata['data']['updateGQLSchema']['gqlSchema']['schema'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reset the index\n",
"def reset_index(predicate, index):\n",
" print(f\"remove index for {predicate}\")\n",
" schema = f\"{predicate}: float32vector .\"\n",
" op = pydgraph.Operation(schema=schema)\n",
" alter = client.alter(op)\n",
" print(alter)\n",
" print(f\"create index for {predicate} {index}\")\n",
" schema = f\"{predicate}: float32vector @index({index}) .\"\n",
" op = pydgraph.Operation(schema=schema)\n",
" alter = client.alter(op)\n",
" print(alter) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"data = pd.read_csv('products_with_embedding.csv.gz', compression='gzip') \n",
"data['embedding'] = [json.loads(data.iloc[i]['embedding']) for i in range(len(data))]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2-YUuHqv61tT"
},
"outputs": [],
"source": [
"# build a ground truth similarity index (KNN)\n",
"from sklearn import metrics\n",
"embeddings = data['embedding'].tolist()\n",
"distances = metrics.pairwise_distances(embeddings, embeddings, metric='euclidean')\n",
"knn_all = distances.argsort(axis=1)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UvkyYXf3VTWj"
},
"outputs": [],
"source": [
"print(data.iloc[3]['name'])\n",
"data.iloc[knn_all[3][:10]]['name'].tolist()\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get topK ANN from Graph - Approximate Nearest Neighbors\n",
"def topK_ANN(name, k=10):\n",
" query = \"\"\"\n",
" query QuerySimilarById($name: String!, $k: Int!) {{\n",
" list:querySimilar{}ById(by: embedding, topK: $k, name: $name) {{\n",
" uid:id\n",
" name:name\n",
" }}\n",
" }}\n",
" \"\"\".format(\"Product\")\n",
" variables={\"name\":name,\"k\":k}\n",
" res = gql_client.execute(query, variables)\n",
" return pd.json_normalize(res['data']['list'])\n",
"\n",
"topK_ANN(data.iloc[3]['name'], 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# compute precision\n",
"def compute_item_precision(index, k, verbose = False):\n",
" name = data.iloc[index]['name']\n",
" if verbose:\n",
" print(\"compute_precision: \",index,name)\n",
" print(\"Ground Truth:\",name)\n",
" knn = data.iloc[knn_all[index][:k]]['name'].tolist()\n",
" if verbose:\n",
" print(\"KNN:\",knn)\n",
" knn = set(knn)\n",
" ann = topK_ANN(name,k)['name'].tolist()\n",
" if verbose:\n",
" print(\"ANN:\",ann)\n",
" ann = set(ann)\n",
" precision_at_k = len(knn.intersection(ann))/len(knn)\n",
" return precision_at_k\n",
"\n",
"# compute precision for first m products\n",
"def average_precision(first = 10, k = 10):\n",
" precision = 0.0\n",
" m = first\n",
" for i in range(first):\n",
" p = compute_item_precision(i,k) \n",
" precision += p\n",
" precision = precision/m\n",
" return precision\n",
"print (compute_item_precision(3,5))\n",
"print (average_precision(10,5))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for index in ['hnsw(metric: \"euclidian\")',\n",
" 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"100\")',\n",
" 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"400\")']:\n",
" reset_index(\"Product.embedding\",index)\n",
" print(\"index = \",index)\n",
" for k in [1,3,5,10]:\n",
" p_at_k = average_precision(100,k)\n",
" print(\"\"\"Precision@{} = {}\"\"\".format(k, p_at_k))"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Loading

0 comments on commit d47fe8c

Please sign in to comment.