diff --git a/vector/README.md b/vector/README.md new file mode 100644 index 0000000..efa5445 --- /dev/null +++ b/vector/README.md @@ -0,0 +1,3 @@ +# Vector datasets + +This section contains datasets and script to test vector indexing and similarity search. diff --git a/vector/ann_eval.ipynb b/vector/ann_eval.ipynb new file mode 100644 index 0000000..fa3cb96 --- /dev/null +++ b/vector/ann_eval.ipynb @@ -0,0 +1,356 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZtVJnO3SzTgt" + }, + "source": [ + "\n", + "Notebook to explore data similarity in DGraph cluster\n", + "\n", + "\n", + "**pre-requesite**\n", + "- Dgraph\n", + " - Get a [Dgraph Cloud account](https://cloud.dgraph.io/)\n", + " - Have your account user name and password available\n", + " - Have a Dgraph cluster running in your Dgraph Cloud account\n", + " - Obtain the GraphQL Endpoint of the Dgraph cluster from the [cloud dashboard](https://cloud.dgraph.io/_/dashboard)\n", + " - Obtain an Admin API key for the Dgraph Cluster from the [settings](https://cloud.dgraph.io/_/settings?tab=api-keys) tab.\n", + "\n", + "\n", + "\n", + " The first step is to import the packages needed.\n", + "\n", + "- ``pydgraph``, the official [python client library for Dgraph Query Language](https://dgraph.io/docs/dql/clients/python/)\n", + "- ``GraphqlClient``, a GraphQL client to invoke the GraphQL API generated from your schema and the GraphQL admin API of Dgraph.\n", + "\n", + "**Make sure to update the endpoints with the correct values for your Dgraph cluster!**\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## use a local Learning Environment\n", + "\n", + "docker pull dgraph/standalone\n", + "docker run --name dgraph_learn -d -p \"8080:8080\" -p \"9080:9080\" -v :/dgraph dgraph/standalone:latest\n", + "\n", + "### load data set\n", + "cp products.rdf.gz \n", + "cp products.schema \n", + "\n", + "docker exec -it dgraph_learn dgraph live -f /dgraph/products.rdf.gz -s /dgraph/products.schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_qEDD3UC7uqF" + }, + "outputs": [], + "source": [ + "!pip install pydgraph python-graphql-client ipycytoscape\n", + "import pydgraph\n", + "import json\n", + "import base64\n", + "import getpass\n", + "import pandas as pd \n", + "\n", + "from python_graphql_client import GraphqlClient\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# copy your Dgraph cloud endpoints\n", + "# The GraphQL endpoint is found at https://cloud.dgraph.io/_/dashboard\n", + "# dgraph_graphql_endpoint = \"https://black-grass.us-east-1.aws.cloud.dgraph.io/graphql\"\n", + "# dgraph_graphql_endpoint = \"https://withered-bird.us-east-1.aws.cloud.dgraph.io/graphql\"\n", + "dgraph_graphql_endpoint = \"http://localhost:8080/graphql\"\n", + "\n", + "\n", + "# The gRPC endpoint is found at https://cloud.dgraph.io/_/settings\n", + "# dgraph_grpc = \"black-grass.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n", + "dgraph_grpc = \"withered-bird.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n", + "dgraph_grpc = \"localhost:9080\"\n", + "\n", + "# graph admin endpoint is /admin\n", + "dgraph_graphql_admin = dgraph_graphql_endpoint.replace(\"/graphql\", \"/admin\")\n", + "# graph health endpoint is /health\n", + "dgraph_graphql_health = dgraph_graphql_endpoint.replace(\"/graphql\", \"/health\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_qnQFABQNBXO" + }, + "source": [ + "Enter your credentials and test the different clients\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E7EvVHCqXzfV" + }, + "outputs": [], + "source": [ + "\n", + "# Cloud credentials\n", + "# we need the cloud login credential to upload the Lambda code.\n", + "# we need the an Admin API key generated at https://cloud.dgraph.io/_/settings?tab=api-keys for DQL alter and query\n", + "\n", + "\n", + "API_KEY = getpass.getpass(\"DGRAPH API KEY?\")\n", + "\n", + "\n", + "\n", + "# DQL Client\n", + "if dgraph_grpc.find(\"cloud\") > 0:\n", + " client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,API_KEY )\n", + "else:\n", + " client_stub = pydgraph.DgraphClientStub(addr=dgraph_grpc) \n", + "\n", + "client = pydgraph.DgraphClient(client_stub)\n", + "\n", + "# GraphQL client and admin client\n", + "gql_client = GraphqlClient(endpoint=dgraph_graphql_endpoint)\n", + "headers = { \"Dg-Auth\": API_KEY }\n", + "gql_admin_client = GraphqlClient(endpoint=dgraph_graphql_admin, headers=headers)\n", + "gql_health_client = GraphqlClient(endpoint=dgraph_graphql_health)\n", + "\n", + "#\n", + "# Testing the connection\n", + "#\n", + "data = gql_health_client.execute(query=\"\")\n", + "if 'errors' in data:\n", + " raise Exception(data['errors'][0]['message'])\n", + "\n", + "print(\"Check cluster health:\", json.dumps(data, indent=2))\n", + "\n", + "#\n", + "# Testing the DQL connection\n", + "#\n", + "txn = client.txn(read_only=True)\n", + "query = \"schema{}\"\n", + "res = txn.query(query)\n", + "dqlschema = json.loads(res.json)\n", + "txn.discard()\n", + "print(\"get DQL schema - succeeded\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Deploy the GraphQL Schema\n", + "\n", + "graphql_schema = \"\"\"\n", + "type Product {\n", + " id: ID!\n", + " name: String @id @search(by: [hash,term])\n", + " embedding: [Float!] @embedding\n", + "}\n", + "\"\"\"\n", + "mutation = \"\"\"\n", + "mutation($sch: String!) {\n", + " updateGQLSchema(input: { set: { schema: $sch}})\n", + " {\n", + " gqlSchema {\n", + " schema\n", + " generatedSchema\n", + " }\n", + " }\n", + "}\n", + "\"\"\"\n", + "variables = {\"sch\": graphql_schema}\n", + "schemadata = gql_admin_client.execute(query=mutation, variables=variables)\n", + "print(\"GraphQL Schema after Update\")\n", + "print(schemadata['data']['updateGQLSchema']['gqlSchema']['schema'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# reset the index\n", + "def reset_index(predicate, index):\n", + " print(f\"remove index for {predicate}\")\n", + " schema = f\"{predicate}: float32vector .\"\n", + " op = pydgraph.Operation(schema=schema)\n", + " alter = client.alter(op)\n", + " print(alter)\n", + " print(f\"create index for {predicate} {index}\")\n", + " schema = f\"{predicate}: float32vector @index({index}) .\"\n", + " op = pydgraph.Operation(schema=schema)\n", + " alter = client.alter(op)\n", + " print(alter) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.read_csv('products_with_embedding.csv.gz', compression='gzip') \n", + "data['embedding'] = [json.loads(data.iloc[i]['embedding']) for i in range(len(data))]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2-YUuHqv61tT" + }, + "outputs": [], + "source": [ + "# build a ground truth similarity index (KNN)\n", + "from sklearn import metrics\n", + "embeddings = data['embedding'].tolist()\n", + "distances = metrics.pairwise_distances(embeddings, embeddings, metric='euclidean')\n", + "knn_all = distances.argsort(axis=1)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "UvkyYXf3VTWj" + }, + "outputs": [], + "source": [ + "print(data.iloc[3]['name'])\n", + "data.iloc[knn_all[3][:10]]['name'].tolist()\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get topK ANN from Graph - Approximate Nearest Neighbors\n", + "def topK_ANN(name, k=10):\n", + " query = \"\"\"\n", + " query QuerySimilarById($name: String!, $k: Int!) {{\n", + " list:querySimilar{}ById(by: embedding, topK: $k, name: $name) {{\n", + " uid:id\n", + " name:name\n", + " }}\n", + " }}\n", + " \"\"\".format(\"Product\")\n", + " variables={\"name\":name,\"k\":k}\n", + " res = gql_client.execute(query, variables)\n", + " return pd.json_normalize(res['data']['list'])\n", + "\n", + "topK_ANN(data.iloc[3]['name'], 10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# compute precision\n", + "def compute_item_precision(index, k, verbose = False):\n", + " name = data.iloc[index]['name']\n", + " if verbose:\n", + " print(\"compute_precision: \",index,name)\n", + " print(\"Ground Truth:\",name)\n", + " knn = data.iloc[knn_all[index][:k]]['name'].tolist()\n", + " if verbose:\n", + " print(\"KNN:\",knn)\n", + " knn = set(knn)\n", + " ann = topK_ANN(name,k)['name'].tolist()\n", + " if verbose:\n", + " print(\"ANN:\",ann)\n", + " ann = set(ann)\n", + " precision_at_k = len(knn.intersection(ann))/len(knn)\n", + " return precision_at_k\n", + "\n", + "# compute precision for first m products\n", + "def average_precision(first = 10, k = 10):\n", + " precision = 0.0\n", + " m = first\n", + " for i in range(first):\n", + " p = compute_item_precision(i,k) \n", + " precision += p\n", + " precision = precision/m\n", + " return precision\n", + "print (compute_item_precision(3,5))\n", + "print (average_precision(10,5))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for index in ['hnsw(metric: \"euclidian\")',\n", + " 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"100\")',\n", + " 'hnsw(metric: \"euclidian\", maxLevels: \"3\", efSearch: \"40\", efConstruction: \"400\")']:\n", + " reset_index(\"Product.embedding\",index)\n", + " print(\"index = \",index)\n", + " for k in [1,3,5,10]:\n", + " p_at_k = average_precision(100,k)\n", + " print(\"\"\"Precision@{} = {}\"\"\".format(k, p_at_k))" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/vector/fetch_embeddings.ipynb b/vector/fetch_embeddings.ipynb new file mode 100644 index 0000000..6a2f4b3 --- /dev/null +++ b/vector/fetch_embeddings.ipynb @@ -0,0 +1,225 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZtVJnO3SzTgt" + }, + "source": [ + "\n", + "Notebook to explore data similarity in DGraph cluster\n", + "\n", + "\n", + "**pre-requesite**\n", + "- Dgraph\n", + " - Get a [Dgraph Cloud account](https://cloud.dgraph.io/)\n", + " - Have your account user name and password available\n", + " - Have a Dgraph cluster running in your Dgraph Cloud account\n", + " - Obtain the GraphQL Endpoint of the Dgraph cluster from the [cloud dashboard](https://cloud.dgraph.io/_/dashboard)\n", + " - Obtain an Admin API key for the Dgraph Cluster from the [settings](https://cloud.dgraph.io/_/settings?tab=api-keys) tab.\n", + "\n", + "\n", + "\n", + " The first step is to import the packages needed.\n", + "\n", + "- ``pydgraph``, the official [python client library for Dgraph Query Language](https://dgraph.io/docs/dql/clients/python/)\n", + "- ``GraphqlClient``, a GraphQL client to invoke the GraphQL API generated from your schema and the GraphQL admin API of Dgraph.\n", + "\n", + "**Make sure to update the endpoints with the correct values for your Dgraph cluster!**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_qEDD3UC7uqF" + }, + "outputs": [], + "source": [ + "!pip install pydgraph python-graphql-client ipycytoscape\n", + "import sys\n", + "import pydgraph\n", + "import json\n", + "import base64\n", + "import getpass\n", + "import pandas as pd \n", + "from python_graphql_client import GraphqlClient\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "# copy your Dgraph cloud endpoints\n", + "# The GraphQL endpoint is found at https://cloud.dgraph.io/_/dashboard\n", + "# dgraph_graphql_endpoint = \"https://xyz.us-east-1.aws.cloud.dgraph.io/graphql\"\n", + "dgraph_graphql_endpoint = \"http://localhost:8080/graphql\"\n", + "\n", + "\n", + "# The gRPC endpoint is found at https://cloud.dgraph.io/_/settings\n", + "# dgraph_grpc = \"xyz.grpc.us-east-1.aws.cloud.dgraph.io:443\"\n", + "dgraph_grpc = \"localhost:9080\"\n", + "\n", + "# graph admin endpoint is /admin\n", + "dgraph_graphql_admin = dgraph_graphql_endpoint.replace(\"/graphql\", \"/admin\")\n", + "# graph health endpoint is /health\n", + "dgraph_graphql_health = dgraph_graphql_endpoint.replace(\"/graphql\", \"/health\")\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_qnQFABQNBXO" + }, + "source": [ + "Enter your credentials and test the clients\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "E7EvVHCqXzfV" + }, + "outputs": [], + "source": [ + "\n", + "# Cloud credentials\n", + "# we need the cloud login credential to upload the Lambda code.\n", + "# we need the an Admin API key generated at https://cloud.dgraph.io/_/settings?tab=api-keys for DQL alter and query\n", + "\n", + "\n", + "API_KEY = getpass.getpass(\"DGRAPH API KEY?\")\n", + "\n", + "# DQL Client\n", + "if dgraph_grpc.find(\"cloud\") > 0:\n", + " client_stub = pydgraph.DgraphClientStub.from_cloud(dgraph_grpc,API_KEY )\n", + "else:\n", + " client_stub = pydgraph.DgraphClientStub(addr=dgraph_grpc) \n", + "\n", + "client = pydgraph.DgraphClient(client_stub)\n", + "\n", + "# GraphQL client and admin client\n", + "gql_client = GraphqlClient(endpoint=dgraph_graphql_endpoint)\n", + "headers = { \"Dg-Auth\": API_KEY }\n", + "gql_admin_client = GraphqlClient(endpoint=dgraph_graphql_admin, headers=headers)\n", + "gql_health_client = GraphqlClient(endpoint=dgraph_graphql_health)\n", + "\n", + "\n", + "#\n", + "# Testing the connection to the Dgraph cluster\n", + "#\n", + "data = gql_health_client.execute(query=\"\")\n", + "if 'errors' in data:\n", + " raise Exception(data['errors'][0]['message'])\n", + "\n", + "print(\"Check cluster health:\", json.dumps(data, indent=2))\n", + "\n", + "#\n", + "# Testing the DQL connection\n", + "#\n", + "txn = client.txn(read_only=True)\n", + "query = \"schema{}\"\n", + "res = txn.query(query)\n", + "dqlschema = json.loads(res.json)\n", + "txn.discard()\n", + "print(\"get DQL schema - succeeded\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XeaFJyVJ8LlE" + }, + "outputs": [], + "source": [ + "# function to get embeddings from existing Dgraph cluster\n", + "# update the query to match your schema, entity and embedding predicate\n", + "\n", + "\n", + "def fetchEmbeddings(entity, emdedding_predicate, name_predicate):\n", + " # for example fetchEmbeddings(\"Product\",\"Product.embedding\",\"Product.title\")\n", + " # returns a dataframe with uid, name, and embedding\n", + " txn = client.txn(read_only=True)\n", + " after = \"\"\n", + " data = pd.DataFrame(columns=['uid', 'name', 'embedding'])\n", + " while True:\n", + " print(\"fetching products after\", after)\n", + " query = \"\"\"\n", + " query queryItems() {{\n", + " items(func:type({1}) {0},first:500) @filter(has({2})) {{\n", + " uid\n", + " name: {3}\n", + " embedding:{2}\n", + " }}\n", + " }}\n", + " \"\"\".format(after,entity,emdedding_predicate,name_predicate)\n", + "\n", + " variables = {}\n", + "\n", + " res = txn.query(query, variables=variables)\n", + " jdata = json.loads(res.json)\n", + " print(len(jdata['items']))\n", + " if jdata['items'] == None or len(jdata['items']) == 0:\n", + " break\n", + " data = pd.concat([ data, pd.json_normalize(jdata['items']) ],ignore_index=True)\n", + " after = \"\"\",after: {}\"\"\".format(jdata['items'][-1]['uid'])\n", + " print(after)\n", + " return data\n", + "\n", + "\n", + "def dataframe_to_rdf(data, filehandle = sys.stdout):\n", + " for _, row in data.iterrows():\n", + " rdf= \"\"\n", + " rdf += \"<_:{}> \\\"{}\\\" .\\n\".format(row['uid'],row['embedding'])\n", + " rdf += \"<_:{}> \\\"{}\\\" .\\n\".format(row['uid'],row['name'])\n", + " rdf += \"<_:{}> \\\"Product\\\" .\\n\".format(row['uid'])\n", + " filehandle.write(rdf)\n", + " return" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import gzip\n", + "data = fetchEmbeddings(\"Product\",\"Product.embedding\",\"Product.name\")\n", + "data.to_csv(\"products_with_embedding.csv.gz\",index=False,compression='gzip',header=True)\n", + "# gzip file must use wt for write text\n", + "with gzip.open(\"./products.rdf.gz\",\"wt\") as f:\n", + " dataframe_to_rdf(data, f)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/vector/products.rdf.gz b/vector/products.rdf.gz new file mode 100644 index 0000000..9806f88 --- /dev/null +++ b/vector/products.rdf.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e922ea1a6999301b6d9ed1d3242fbd2443c27be42b12079902e36448ee9929ff +size 6348109 diff --git a/vector/products.schema b/vector/products.schema new file mode 100644 index 0000000..db1ed75 --- /dev/null +++ b/vector/products.schema @@ -0,0 +1,2 @@ +Product.embedding: float32vector . +Product.name: string @index(hash, term) @upsert . \ No newline at end of file diff --git a/vector/products_with_embedding.csv.gz b/vector/products_with_embedding.csv.gz new file mode 100644 index 0000000..387c54b --- /dev/null +++ b/vector/products_with_embedding.csv.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7fe302596c6ac2eb1bb25d75ccdc87e4fc350afe2c440fb51d149a0fc289cb1d +size 6297038