diff --git a/docs/docs/integrations/graphs/memgraph.ipynb b/docs/docs/integrations/graphs/memgraph.ipynb index 4dc8d33be4b86..1040d43f9e50f 100644 --- a/docs/docs/integrations/graphs/memgraph.ipynb +++ b/docs/docs/integrations/graphs/memgraph.ipynb @@ -2,24 +2,26 @@ "cells": [ { "cell_type": "markdown", - "id": "311b3061", "metadata": {}, "source": [ "# Memgraph\n", "\n", - ">[Memgraph](https://github.com/memgraph/memgraph) is the open-source graph database, compatible with `Neo4j`.\n", - ">The database is using the `Cypher` graph query language, \n", - ">\n", - ">[Cypher](https://en.wikipedia.org/wiki/Cypher_(query_language)) is a declarative graph query language that allows for expressive and efficient data querying in a property graph.\n", - "\n", - "This notebook shows how to use LLMs to provide a natural language interface to a [Memgraph](https://github.com/memgraph/memgraph) database.\n", + "Memgraph is an open-source graph database, tuned for dynamic analytics environments and compatible with Neo4j. To query the database, Memgraph uses Cypher - the most widely adopted, fully-specified, and open query language for property graph databases.\n", "\n", + "This notebook will show you how to [query Memgraph with natural language](#natural-language-querying) and how to [construct a knowledge graph](#constructing-knowledge-graph) from your unstructured data. \n", "\n", + "But first, make sure to [set everything up](#setting-up)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "## Setting up\n", "\n", - "To complete this tutorial, you will need [Docker](https://www.docker.com/get-started/) and [Python 3.x](https://www.python.org/) installed.\n", + "To go over this guide, you will need [Docker](https://www.docker.com/get-started/) and [Python 3.x](https://www.python.org/) installed.\n", "\n", - "Ensure you have a running Memgraph instance. To quickly run Memgraph Platform (Memgraph database + MAGE library + Memgraph Lab) for the first time, do the following:\n", + "To quickly run **Memgraph Platform** (Memgraph database + MAGE library + Memgraph Lab) for the first time, do the following:\n", "\n", "On Linux/MacOS:\n", "```\n", @@ -31,89 +33,90 @@ "iwr https://windows.memgraph.com | iex\n", "```\n", "\n", - "Both commands run a script that downloads a Docker Compose file to your system, builds and starts `memgraph-mage` and `memgraph-lab` Docker services in two separate containers. \n", + "Both commands run a script that downloads a Docker Compose file to your system, builds and starts `memgraph-mage` and `memgraph-lab` Docker services in two separate containers. Now you have Memgraph up and running! Read more about the installation process on [Memgraph documentation](https://memgraph.com/docs/getting-started/install-memgraph).\n", "\n", - "Read more about the installation process on [Memgraph documentation](https://memgraph.com/docs/getting-started/install-memgraph).\n", + "To use LangChain, install and import all the necessary packages. We'll use the package manager [pip](https://pip.pypa.io/en/stable/installation/), along with the `--user` flag, to ensure proper permissions. If you've installed Python 3.4 or a later version, `pip` is included by default. You can install all the required packages using the following command:\n", "\n", - "Now you can start playing with `Memgraph`!" - ] - }, - { - "cell_type": "markdown", - "id": "45ee105e", - "metadata": {}, - "source": [ - "Begin by installing and importing all the necessary packages. We'll use the package manager called [pip](https://pip.pypa.io/en/stable/installation/), along with the `--user` flag, to ensure proper permissions. If you've installed Python 3.4 or a later version, pip is included by default. You can install all the required packages using the following command:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fd6b9672", - "metadata": {}, - "outputs": [], - "source": [ - "pip install langchain langchain-neo4j langchain-openai neo4j gqlalchemy --user" + "```\n", + "pip install langchain langchain-openai neo4j --user\n", + "```\n", + "\n", + "You can either run the provided code blocks in this notebook or use a separate Python file to experiment with Memgraph and LangChain." ] }, { "cell_type": "markdown", - "id": "ec969a02", "metadata": {}, "source": [ - "You can either run the provided code blocks in this notebook or use a separate Python file to experiment with Memgraph and LangChain." + "## Natural language querying\n", + "\n", + "Memgraph's integration with LangChain includes natural language querying. To utilized it, first do all the necessary imports. We will discuss them as they appear in the code.\n", + "\n", + "First, instantiate `MemgraphGraph`. This object holds the connection to the running Memgraph instance. Make sure to set up all the environment variables properly." ] }, { "cell_type": "code", - "execution_count": null, - "id": "8206f90d", + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", - "from gqlalchemy import Memgraph\n", + "from langchain_community.chains.graph_qa.memgraph import MemgraphQAChain\n", "from langchain_community.graphs import MemgraphGraph\n", "from langchain_core.prompts import PromptTemplate\n", - "from langchain_neo4j import GraphCypherQAChain\n", - "from langchain_openai import ChatOpenAI" + "from langchain_openai import ChatOpenAI\n", + "\n", + "url = os.environ.get(\"MEMGRAPH_URI\", \"bolt://localhost:7687\")\n", + "username = os.environ.get(\"MEMGRAPH_USERNAME\", \"\")\n", + "password = os.environ.get(\"MEMGRAPH_PASSWORD\", \"\")\n", + "\n", + "graph = MemgraphGraph(\n", + " url=url, username=username, password=password, refresh_schema=False\n", + ")" ] }, { "cell_type": "markdown", - "id": "95ba37a4", "metadata": {}, "source": [ - "We're utilizing the Python library [GQLAlchemy](https://github.com/memgraph/gqlalchemy) to establish a connection between our Memgraph database and Python script. You can establish the connection to a running Memgraph instance with the Neo4j driver as well, since it's compatible with Memgraph. To execute queries with GQLAlchemy, we can set up a Memgraph instance as follows:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b90c9cf8", - "metadata": {}, - "outputs": [], - "source": [ - "memgraph = Memgraph(host=\"127.0.0.1\", port=7687)" + "The `refresh_schema` is initially set to `False` because there is still no data in the database and we want to avoid unnecessary database calls. " ] }, { "cell_type": "markdown", - "id": "4c379d16", "metadata": {}, "source": [ - "## Populating the database\n", - "You can effortlessly populate your new, empty database using the Cypher query language. Don't worry if you don't grasp every line just yet, you can learn Cypher from the documentation [here](https://memgraph.com/docs/cypher-manual/). Running the following script will execute a seeding query on the database, giving us data about a video game, including details like the publisher, available platforms, and genres. This data will serve as a basis for our work." + "### Populating the database\n", + "\n", + "To populate the database, first make sure it's empty. The most efficient way to do that is to switch to the in-memory analytical storage mode, drop the graph and go back to the in-memory transactional mode. Learn more about Memgraph's [storage modes](https://memgraph.com/docs/fundamentals/storage-memory-usage#storage-modes).\n", + "\n", + "The data we'll add to the database is about video games of different genres available on various platforms and related to publishers." ] }, { "cell_type": "code", - "execution_count": null, - "id": "11922bdf", - "metadata": {}, - "outputs": [], - "source": [ + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Drop graph\n", + "graph.query(\"STORAGE MODE IN_MEMORY_ANALYTICAL\")\n", + "graph.query(\"DROP GRAPH\")\n", + "graph.query(\"STORAGE MODE IN_MEMORY_TRANSACTIONAL\")\n", + "\n", "# Creating and executing the seeding query\n", "query = \"\"\"\n", " MERGE (g:Game {name: \"Baldur's Gate 3\"})\n", @@ -131,576 +134,660 @@ " MERGE (g)-[:PUBLISHED_BY]->(p);\n", "\"\"\"\n", "\n", - "memgraph.execute(query)" + "graph.query(query)" ] }, { "cell_type": "markdown", - "id": "378db965", "metadata": {}, "source": [ - "## Refresh graph schema" + "Notice how the `graph` object holds the `query` method. That method executes query in Memgraph and it is also used by the `MemgraphQAChain` to query the database." ] }, { "cell_type": "markdown", - "id": "d6b37df3", "metadata": {}, "source": [ - "You're all set to instantiate the Memgraph-LangChain graph using the following script. This interface will allow us to query our database using LangChain, automatically creating the required graph schema for generating Cypher queries through LLM." + "### Refresh graph schema\n", + "\n", + "Since the new data is created in Memgraph, it is necessary to refresh the schema. The generated schema will be used by the `MemgraphQAChain` to instruct LLM to better generate Cypher queries. " ] }, { "cell_type": "code", - "execution_count": null, - "id": "f38bbe83", + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ - "graph = MemgraphGraph(url=\"bolt://localhost:7687\", username=\"\", password=\"\")" + "graph.refresh_schema()" ] }, { "cell_type": "markdown", - "id": "846c32a8", "metadata": {}, "source": [ - "If necessary, you can manually refresh the graph schema as follows." + "To familiarize yourself with the data and verify the updated graph schema, you can print it using the following statement:" ] }, { "cell_type": "code", - "execution_count": null, - "id": "b561026e", - "metadata": {}, - "outputs": [], - "source": [ - "graph.refresh_schema()" - ] - }, - { - "cell_type": "markdown", - "id": "c51b7948", - "metadata": {}, - "source": [ - "To familiarize yourself with the data and verify the updated graph schema, you can print it using the following statement." + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Node labels and properties (name and type) are:\n", + "- labels: (:Platform)\n", + " properties:\n", + " - name: string\n", + "- labels: (:Genre)\n", + " properties:\n", + " - name: string\n", + "- labels: (:Game)\n", + " properties:\n", + " - name: string\n", + "- labels: (:Publisher)\n", + " properties:\n", + " - name: string\n", + "\n", + "Nodes are connected with the following relationships:\n", + "(:Game)-[:HAS_GENRE]->(:Genre)\n", + "(:Game)-[:PUBLISHED_BY]->(:Publisher)\n", + "(:Game)-[:AVAILABLE_ON]->(:Platform)\n", + "\n" + ] + } + ], + "source": [ + "print(graph.get_schema)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Querying the database\n", + "\n", + "To interact with the OpenAI API, you must configure your API key as an environment variable. This ensures proper authorization for your requests. You can find more information on obtaining your API key [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key). To configure the API key, you can use Python [os](https://docs.python.org/3/library/os.html) package:\n", + "\n", + "```\n", + "os.environ[\"OPENAI_API_KEY\"] = \"your-key-here\"\n", + "```\n", + "\n", + "Run the above code snippet if you're running the code within the Jupyter notebook. \n", + "\n", + "Next, create `MemgraphQAChain`, which will be utilized in the question-answering process based on your graph data. The `temperature parameter` is set to zero to ensure predictable and consistent answers. You can set `verbose` parameter to `True` to receive more detailed messages regarding query generation." ] }, { "cell_type": "code", - "execution_count": null, - "id": "f2e0ec3e", + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "print(graph.schema)" - ] - }, - { - "cell_type": "markdown", - "id": "a0c2a556", - "metadata": {}, - "source": [ - "```\n", - "Node properties are the following:\n", - "Node name: 'Game', Node properties: [{'property': 'name', 'type': 'str'}]\n", - "Node name: 'Platform', Node properties: [{'property': 'name', 'type': 'str'}]\n", - "Node name: 'Genre', Node properties: [{'property': 'name', 'type': 'str'}]\n", - "Node name: 'Publisher', Node properties: [{'property': 'name', 'type': 'str'}]\n", - "\n", - "Relationship properties are the following:\n", - "\n", - "The relationships are the following:\n", - "['(:Game)-[:AVAILABLE_ON]->(:Platform)']\n", - "['(:Game)-[:HAS_GENRE]->(:Genre)']\n", - "['(:Game)-[:PUBLISHED_BY]->(:Publisher)']\n", - "```" + "chain = MemgraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " model_name=\"gpt-4-turbo\",\n", + " allow_dangerous_requests=True,\n", + ")" ] }, { "cell_type": "markdown", - "id": "44d3a1da", "metadata": {}, "source": [ - "## Querying the database" + "Now you can start asking questions!" ] }, { - "cell_type": "markdown", - "id": "8aedfd63", + "cell_type": "code", + "execution_count": 7, "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game{name: \"Baldur's Gate 3\"})-[:AVAILABLE_ON]->(platform:Platform)\n", + "RETURN platform.name\n", + "Baldur's Gate 3 is available on PlayStation 5, Mac OS, Windows, and Xbox Series X/S.\n" + ] + } + ], "source": [ - "To interact with the OpenAI API, you must configure your API key as an environment variable using the Python [os](https://docs.python.org/3/library/os.html) package. This ensures proper authorization for your requests. You can find more information on obtaining your API key [here](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key)." + "response = chain.invoke(\"Which platforms is Baldur's Gate 3 available on?\")\n", + "print(response[\"result\"])" ] }, { "cell_type": "code", - "execution_count": null, - "id": "b8385c72", + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game{name: \"Baldur's Gate 3\"})-[:AVAILABLE_ON]->(:Platform{name: \"Windows\"})\n", + "RETURN \"Yes\"\n", + "Yes, Baldur's Gate 3 is available on Windows.\n" + ] + } + ], "source": [ - "os.environ[\"OPENAI_API_KEY\"] = \"your-key-here\"" + "response = chain.invoke(\"Is Baldur's Gate 3 available on Windows?\")\n", + "print(response[\"result\"])" ] }, { "cell_type": "markdown", - "id": "5a74565a", "metadata": {}, "source": [ - "You should create the graph chain using the following script, which will be utilized in the question-answering process based on your graph data. While it defaults to GPT-3.5-turbo, you might also consider experimenting with other models like [GPT-4](https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4) for notably improved Cypher queries and outcomes. We'll utilize the OpenAI chat, utilizing the key you previously configured. We'll set the temperature to zero, ensuring predictable and consistent answers. Additionally, we'll use our Memgraph-LangChain graph and set the verbose parameter, which defaults to False, to True to receive more detailed messages regarding query generation." + "### Chain modifiers\n", + "\n", + "To modify the behavior of your chain and obtain more context or additional information, you can modify the chain's parameters.\n", + "\n", + "#### Return direct query results\n", + "The `return_direct` modifier specifies whether to return the direct results of the executed Cypher query or the processed natural language response." ] }, { "cell_type": "code", - "execution_count": null, - "id": "4a3a5f2e", - "metadata": {}, - "outputs": [], + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (g:Game {name: \"Baldur's Gate 3\"})-[:PUBLISHED_BY]->(p:Publisher)\n", + "RETURN p.name\n", + "[{'p.name': 'Larian Studios'}]\n" + ] + } + ], "source": [ - "chain = GraphCypherQAChain.from_llm(\n", + "# Return the result of querying the graph directly\n", + "chain = MemgraphQAChain.from_llm(\n", " ChatOpenAI(temperature=0),\n", " graph=graph,\n", - " verbose=True,\n", - " model_name=\"gpt-3.5-turbo\",\n", + " return_direct=True,\n", " allow_dangerous_requests=True,\n", - ")" + " model_name=\"gpt-4-turbo\",\n", + ")\n", + "\n", + "response = chain.invoke(\"Which studio published Baldur's Gate 3?\")\n", + "print(response[\"result\"])" ] }, { "cell_type": "markdown", - "id": "949de4f3", "metadata": {}, "source": [ - "Now you can start asking questions!" + "#### Return query intermediate steps\n", + "The `return_intermediate_steps` chain modifier enhances the returned response by including the intermediate steps of the query in addition to the initial query result." ] }, { "cell_type": "code", - "execution_count": null, - "id": "b7aea263", - "metadata": {}, - "outputs": [], + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game {name: \"Baldur's Gate 3\"})-[:HAS_GENRE]->(:Genre {name: \"Adventure\"})\n", + "RETURN \"Yes\"\n", + "Intermediate steps: [{'query': 'MATCH (:Game {name: \"Baldur\\'s Gate 3\"})-[:HAS_GENRE]->(:Genre {name: \"Adventure\"})\\nRETURN \"Yes\"'}, {'context': [{'\"Yes\"': 'Yes'}]}]\n", + "Final response: Yes.\n" + ] + } + ], "source": [ - "response = chain.run(\"Which platforms is Baldur's Gate 3 available on?\")\n", - "print(response)" + "# Return all the intermediate steps of query execution\n", + "chain = MemgraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " allow_dangerous_requests=True,\n", + " return_intermediate_steps=True,\n", + " model_name=\"gpt-4-turbo\",\n", + ")\n", + "\n", + "response = chain.invoke(\"Is Baldur's Gate 3 an Adventure game?\")\n", + "print(f\"Intermediate steps: {response['intermediate_steps']}\")\n", + "print(f\"Final response: {response['result']}\")" ] }, { "cell_type": "markdown", - "id": "a06a8164", "metadata": {}, "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (g:Game {name: 'Baldur\\'s Gate 3'})-[:AVAILABLE_ON]->(p:Platform)\n", - "RETURN p.name\n", - "Full Context:\n", - "[{'p.name': 'PlayStation 5'}, {'p.name': 'Mac OS'}, {'p.name': 'Windows'}, {'p.name': 'Xbox Series X/S'}]\n", + "#### Limit the number of query results\n", "\n", - "> Finished chain.\n", - "Baldur's Gate 3 is available on PlayStation 5, Mac OS, Windows, and Xbox Series X/S.\n", - "```" + "The `top_k` modifier can be used when you want to restrict the maximum number of query results." ] }, { "cell_type": "code", - "execution_count": null, - "id": "59d298d5", - "metadata": {}, - "outputs": [], - "source": [ - "response = chain.run(\"Is Baldur's Gate 3 available on Windows?\")\n", - "print(response)" - ] - }, - { - "cell_type": "markdown", - "id": "99dd783c", - "metadata": {}, + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game {name: \"Baldur's Gate 3\"})-[:HAS_GENRE]->(g:Genre)\n", + "RETURN g.name;\n", + "Adventure, Role-Playing Game\n" + ] + } + ], "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (:Game {name: 'Baldur\\'s Gate 3'})-[:AVAILABLE_ON]->(:Platform {name: 'Windows'})\n", - "RETURN true\n", - "Full Context:\n", - "[{'true': True}]\n", + "# Limit the maximum number of results returned by query\n", + "chain = MemgraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " top_k=2,\n", + " allow_dangerous_requests=True,\n", + " model_name=\"gpt-4-turbo\",\n", + ")\n", "\n", - "> Finished chain.\n", - "Yes, Baldur's Gate 3 is available on Windows.\n", - "```" - ] - }, - { - "cell_type": "markdown", - "id": "08620465", - "metadata": {}, - "source": [ - "## Chain modifiers" - ] - }, - { - "cell_type": "markdown", - "id": "6603e6c8", - "metadata": {}, - "source": [ - "To modify the behavior of your chain and obtain more context or additional information, you can modify the chain's parameters." + "response = chain.invoke(\"What genres are associated with Baldur's Gate 3?\")\n", + "print(response[\"result\"])" ] }, { "cell_type": "markdown", - "id": "8d187a83", "metadata": {}, "source": [ - "#### Return direct query results\n", - "The `return_direct` modifier specifies whether to return the direct results of the executed Cypher query or the processed natural language response." + "### Advanced querying\n", + "\n", + "As the complexity of your solution grows, you might encounter different use-cases that require careful handling. Ensuring your application's scalability is essential to maintain a smooth user flow without any hitches.\n", + "\n", + "Let's instantiate our chain once again and attempt to ask some questions that users might potentially ask." ] }, { "cell_type": "code", - "execution_count": null, - "id": "0533847d", - "metadata": {}, - "outputs": [], - "source": [ - "# Return the result of querying the graph directly\n", - "chain = GraphCypherQAChain.from_llm(\n", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game{name: \"Baldur's Gate 3\"})-[:AVAILABLE_ON]->(:Platform{name: \"PS5\"})\n", + "RETURN \"Yes\"\n", + "I don't know the answer.\n" + ] + } + ], + "source": [ + "chain = MemgraphQAChain.from_llm(\n", " ChatOpenAI(temperature=0),\n", " graph=graph,\n", - " verbose=True,\n", - " return_direct=True,\n", + " model_name=\"gpt-4-turbo\",\n", " allow_dangerous_requests=True,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "afbe96fb", - "metadata": {}, - "outputs": [], - "source": [ - "response = chain.run(\"Which studio published Baldur's Gate 3?\")\n", - "print(response)" + ")\n", + "\n", + "response = chain.invoke(\"Is Baldur's Gate 3 available on PS5?\")\n", + "print(response[\"result\"])" ] }, { "cell_type": "markdown", - "id": "94b32b6e", "metadata": {}, "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (:Game {name: 'Baldur\\'s Gate 3'})-[:PUBLISHED_BY]->(p:Publisher)\n", - "RETURN p.name\n", - "\n", - "> Finished chain.\n", - "[{'p.name': 'Larian Studios'}]\n", - "```" + "The generated Cypher query looks fine, but we didn't receive any information in response. This illustrates a common challenge when working with LLMs - the misalignment between how users phrase queries and how data is stored. In this case, the difference between user perception and the actual data storage can cause mismatches. Prompt refinement, the process of honing the model's prompts to better grasp these distinctions, is an efficient solution that tackles this issue. Through prompt refinement, the model gains increased proficiency in generating precise and pertinent queries, leading to the successful retrieval of the desired data." ] }, { "cell_type": "markdown", - "id": "5c97ab3a", "metadata": {}, "source": [ - "#### Return query intermediate steps\n", - "The `return_intermediate_steps` chain modifier enhances the returned response by including the intermediate steps of the query in addition to the initial query result." + "#### Prompt refinement\n", + "\n", + "To address this, we can adjust the initial Cypher prompt of the QA chain. This involves adding guidance to the LLM on how users can refer to specific platforms, such as PS5 in our case. We achieve this using the LangChain [PromptTemplate](/docs/how_to#prompt-templates), creating a modified initial prompt. This modified prompt is then supplied as an argument to our refined `MemgraphQAChain` instance." ] }, { "cell_type": "code", - "execution_count": null, - "id": "82f673c8", - "metadata": {}, - "outputs": [], - "source": [ - "# Return all the intermediate steps of query execution\n", - "chain = GraphCypherQAChain.from_llm(\n", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Game{name: \"Baldur's Gate 3\"})-[:AVAILABLE_ON]->(:Platform{name: \"PlayStation 5\"})\n", + "RETURN \"Yes\"\n", + "Yes, Baldur's Gate 3 is available on PS5.\n" + ] + } + ], + "source": [ + "MEMGRAPH_GENERATION_TEMPLATE = \"\"\"Your task is to directly translate natural language inquiry into precise and executable Cypher query for Memgraph database. \n", + "You will utilize a provided database schema to understand the structure, nodes and relationships within the Memgraph database.\n", + "Instructions: \n", + "- Use provided node and relationship labels and property names from the\n", + "schema which describes the database's structure. Upon receiving a user\n", + "question, synthesize the schema to craft a precise Cypher query that\n", + "directly corresponds to the user's intent. \n", + "- Generate valid executable Cypher queries on top of Memgraph database. \n", + "Any explanation, context, or additional information that is not a part \n", + "of the Cypher query syntax should be omitted entirely. \n", + "- Use Memgraph MAGE procedures instead of Neo4j APOC procedures. \n", + "- Do not include any explanations or apologies in your responses. \n", + "- Do not include any text except the generated Cypher statement.\n", + "- For queries that ask for information or functionalities outside the direct\n", + "generation of Cypher queries, use the Cypher query format to communicate\n", + "limitations or capabilities. For example: RETURN \"I am designed to generate\n", + "Cypher queries based on the provided schema only.\"\n", + "Schema: \n", + "{schema}\n", + "\n", + "With all the above information and instructions, generate Cypher query for the\n", + "user question. \n", + "If the user asks about PS5, Play Station 5 or PS 5, that is the platform called PlayStation 5.\n", + "\n", + "The question is:\n", + "{question}\"\"\"\n", + "\n", + "MEMGRAPH_GENERATION_PROMPT = PromptTemplate(\n", + " input_variables=[\"schema\", \"question\"], template=MEMGRAPH_GENERATION_TEMPLATE\n", + ")\n", + "\n", + "chain = MemgraphQAChain.from_llm(\n", " ChatOpenAI(temperature=0),\n", + " cypher_prompt=MEMGRAPH_GENERATION_PROMPT,\n", " graph=graph,\n", - " verbose=True,\n", - " return_intermediate_steps=True,\n", + " model_name=\"gpt-4-turbo\",\n", " allow_dangerous_requests=True,\n", - ")" + ")\n", + "\n", + "response = chain.invoke(\"Is Baldur's Gate 3 available on PS5?\")\n", + "print(response[\"result\"])" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d87e0976", + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = chain(\"Is Baldur's Gate 3 an Adventure game?\")\n", - "print(f\"Intermediate steps: {response['intermediate_steps']}\")\n", - "print(f\"Final response: {response['result']}\")" + "Now, with the revised initial Cypher prompt that includes guidance on platform naming, we are obtaining accurate and relevant results that align more closely with user queries. \n", + "\n", + "This approach allows for further improvement of your QA chain. You can effortlessly integrate extra prompt refinement data into your chain, thereby enhancing the overall user experience of your app." ] }, { "cell_type": "markdown", - "id": "df12b3da", "metadata": {}, "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (g:Game {name: 'Baldur\\'s Gate 3'})-[:HAS_GENRE]->(genre:Genre {name: 'Adventure'})\n", - "RETURN g, genre\n", - "Full Context:\n", - "[{'g': {'name': \"Baldur's Gate 3\"}, 'genre': {'name': 'Adventure'}}]\n", + "## Constructing knowledge graph\n", + "\n", + "Transforming unstructured data to structured is not an easy or straightforward task. This guide will show how LLMs can be utilized to help us there and how to construct a knowledge graph in Memgraph. After knowledge graph is created, you can use it for your GraphRAG application.\n", "\n", - "> Finished chain.\n", - "Intermediate steps: [{'query': \"MATCH (g:Game {name: 'Baldur\\\\'s Gate 3'})-[:HAS_GENRE]->(genre:Genre {name: 'Adventure'})\\nRETURN g, genre\"}, {'context': [{'g': {'name': \"Baldur's Gate 3\"}, 'genre': {'name': 'Adventure'}}]}]\n", - "Final response: Yes, Baldur's Gate 3 is an Adventure game.\n", - "```" + "The steps of constructing a knowledge graph from the text are:\n", + "\n", + "- [Extracting structured information from text](#extracting-structured-information-from-text): LLM is used to extract structured graph information from text in a form of nodes and relationships.\n", + "- [Storing into Memgraph](#storing-into-memgraph): Storing the extracted structured graph information into Memgraph." ] }, { "cell_type": "markdown", - "id": "41124485", "metadata": {}, "source": [ - "#### Limit the number of query results\n", - "The `top_k` modifier can be used when you want to restrict the maximum number of query results." + "### Extracting structured information from text\n", + "\n", + "Besides all the imports in the [setup section](#setting-up), import `LLMGraphTransformer` and `Document` which will be used to extract structured information from text." ] }, { "cell_type": "code", - "execution_count": null, - "id": "7340fc87", + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "# Limit the maximum number of results returned by query\n", - "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0),\n", - " graph=graph,\n", - " verbose=True,\n", - " top_k=2,\n", - " allow_dangerous_requests=True,\n", - ")" + "from langchain_core.documents import Document\n", + "from langchain_experimental.graph_transformers import LLMGraphTransformer" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "3a17cdc6", + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = chain.run(\"What genres are associated with Baldur's Gate 3?\")\n", - "print(response)" + "Below is an example text about Charles Darwin ([source](https://en.wikipedia.org/wiki/Charles_Darwin)) from which knowledge graph will be constructed." ] }, { - "cell_type": "markdown", - "id": "dcff33ed", + "cell_type": "code", + "execution_count": 15, "metadata": {}, + "outputs": [], "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (:Game {name: 'Baldur\\'s Gate 3'})-[:HAS_GENRE]->(g:Genre)\n", - "RETURN g.name\n", - "Full Context:\n", - "[{'g.name': 'Adventure'}, {'g.name': 'Role-Playing Game'}]\n", - "\n", - "> Finished chain.\n", - "Baldur's Gate 3 is associated with the genres Adventure and Role-Playing Game.\n", - "```" + "text = \"\"\"\n", + " Charles Robert Darwin was an English naturalist, geologist, and biologist,\n", + " widely known for his contributions to evolutionary biology. His proposition that\n", + " all species of life have descended from a common ancestor is now generally\n", + " accepted and considered a fundamental scientific concept. In a joint\n", + " publication with Alfred Russel Wallace, he introduced his scientific theory that\n", + " this branching pattern of evolution resulted from a process he called natural\n", + " selection, in which the struggle for existence has a similar effect to the\n", + " artificial selection involved in selective breeding. Darwin has been\n", + " described as one of the most influential figures in human history and was\n", + " honoured by burial in Westminster Abbey.\n", + "\"\"\"" ] }, { "cell_type": "markdown", - "id": "2eb524a1", "metadata": {}, "source": [ - "# Advanced querying" + "The next step is to initialize `LLMGraphTransformer` from the desired LLM and convert the document to the graph structure. " ] }, { - "cell_type": "markdown", - "id": "113be997", + "cell_type": "code", + "execution_count": 16, "metadata": {}, + "outputs": [], "source": [ - "As the complexity of your solution grows, you might encounter different use-cases that require careful handling. Ensuring your application's scalability is essential to maintain a smooth user flow without any hitches." + "llm = ChatOpenAI(temperature=0, model_name=\"gpt-4-turbo\")\n", + "llm_transformer = LLMGraphTransformer(llm=llm)\n", + "documents = [Document(page_content=text)]\n", + "graph_documents = llm_transformer.convert_to_graph_documents(documents)" ] }, { "cell_type": "markdown", - "id": "e0b2db17", "metadata": {}, "source": [ - "Let's instantiate our chain once again and attempt to ask some questions that users might potentially ask." + "Under the hood, LLM extracts important entities from the text and returns them as a list of nodes and relationships. Here's how it looks like:" ] }, { "cell_type": "code", - "execution_count": null, - "id": "fc544d0b", + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[GraphDocument(nodes=[Node(id='Charles Robert Darwin', type='Person', properties={}), Node(id='English', type='Nationality', properties={}), Node(id='Naturalist', type='Profession', properties={}), Node(id='Geologist', type='Profession', properties={}), Node(id='Biologist', type='Profession', properties={}), Node(id='Evolutionary Biology', type='Field', properties={}), Node(id='Common Ancestor', type='Concept', properties={}), Node(id='Scientific Concept', type='Concept', properties={}), Node(id='Alfred Russel Wallace', type='Person', properties={}), Node(id='Natural Selection', type='Concept', properties={}), Node(id='Selective Breeding', type='Concept', properties={}), Node(id='Westminster Abbey', type='Location', properties={})], relationships=[Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='English', type='Nationality', properties={}), type='NATIONALITY', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Naturalist', type='Profession', properties={}), type='PROFESSION', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Geologist', type='Profession', properties={}), type='PROFESSION', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Biologist', type='Profession', properties={}), type='PROFESSION', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Evolutionary Biology', type='Field', properties={}), type='CONTRIBUTION', properties={}), Relationship(source=Node(id='Common Ancestor', type='Concept', properties={}), target=Node(id='Scientific Concept', type='Concept', properties={}), type='BASIS', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Alfred Russel Wallace', type='Person', properties={}), type='COLLABORATION', properties={}), Relationship(source=Node(id='Natural Selection', type='Concept', properties={}), target=Node(id='Selective Breeding', type='Concept', properties={}), type='COMPARISON', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Westminster Abbey', type='Location', properties={}), type='BURIAL', properties={})], source=Document(metadata={}, page_content='\\n Charles Robert Darwin was an English naturalist, geologist, and biologist,\\n widely known for his contributions to evolutionary biology. His proposition that\\n all species of life have descended from a common ancestor is now generally\\n accepted and considered a fundamental scientific concept. In a joint\\n publication with Alfred Russel Wallace, he introduced his scientific theory that\\n this branching pattern of evolution resulted from a process he called natural\\n selection, in which the struggle for existence has a similar effect to the\\n artificial selection involved in selective breeding. Darwin has been\\n described as one of the most influential figures in human history and was\\n honoured by burial in Westminster Abbey.\\n'))]\n" + ] + } + ], "source": [ - "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0),\n", - " graph=graph,\n", - " verbose=True,\n", - " model_name=\"gpt-3.5-turbo\",\n", - " allow_dangerous_requests=True,\n", - ")" + "print(graph_documents)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "e2abde2d", + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "response = chain.run(\"Is Baldur's Gate 3 available on PS5?\")\n", - "print(response)" + "### Storing into Memgraph\n", + "\n", + "Once you have the data ready in a format of `GraphDocument`, that is, nodes and relationships, you can use `add_graph_documents` method to import it into Memgraph. That method transforms the list of `graph_documents` into appropriate Cypher queries that need to be executed in Memgraph. Once that's done, a knowledge graph is stored in Memgraph. " ] }, { - "cell_type": "markdown", - "id": "cf22dc48", + "cell_type": "code", + "execution_count": 18, "metadata": {}, + "outputs": [], "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (g:Game {name: 'Baldur\\'s Gate 3'})-[:AVAILABLE_ON]->(p:Platform {name: 'PS5'})\n", - "RETURN g.name, p.name\n", - "Full Context:\n", - "[]\n", + "# Empty the database\n", + "graph.query(\"STORAGE MODE IN_MEMORY_ANALYTICAL\")\n", + "graph.query(\"DROP GRAPH\")\n", + "graph.query(\"STORAGE MODE IN_MEMORY_TRANSACTIONAL\")\n", "\n", - "> Finished chain.\n", - "I'm sorry, but I don't have the information to answer your question.\n", - "```" + "# Create KG\n", + "graph.add_graph_documents(graph_documents)" ] }, { "cell_type": "markdown", - "id": "293aa1c9", "metadata": {}, "source": [ - "The generated Cypher query looks fine, but we didn't receive any information in response. This illustrates a common challenge when working with LLMs - the misalignment between how users phrase queries and how data is stored. In this case, the difference between user perception and the actual data storage can cause mismatches. Prompt refinement, the process of honing the model's prompts to better grasp these distinctions, is an efficient solution that tackles this issue. Through prompt refinement, the model gains increased proficiency in generating precise and pertinent queries, leading to the successful retrieval of the desired data." + "Here is how the graph looks like in Memgraph Lab (check on `localhost:3000`):\n", + "\n", + "![memgraph-kg](../../../static/img/memgraph_kg.png)\n", + "\n", + "In case you tried this out and got a different graph, that is expected behavior. The graph construction process is non-deterministic, since LLM which is used to generate nodes and relationships from unstructured data in non-deterministic.\n", + "\n", + "### Additional options\n", + "\n", + "Additionally, you have the flexibility to define specific types of nodes and relationships for extraction according to your requirements." ] }, { - "cell_type": "markdown", - "id": "a87b2f1b", - "metadata": {}, - "source": [ - "### Prompt refinement" + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Nodes:[Node(id='Charles Robert Darwin', type='Person', properties={}), Node(id='English', type='Nationality', properties={}), Node(id='Evolutionary Biology', type='Concept', properties={}), Node(id='Natural Selection', type='Concept', properties={}), Node(id='Alfred Russel Wallace', type='Person', properties={})]\n", + "Relationships:[Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='English', type='Nationality', properties={}), type='NATIONALITY', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Evolutionary Biology', type='Concept', properties={}), type='INVOLVED_IN', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Natural Selection', type='Concept', properties={}), type='INVOLVED_IN', properties={}), Relationship(source=Node(id='Charles Robert Darwin', type='Person', properties={}), target=Node(id='Alfred Russel Wallace', type='Person', properties={}), type='COLLABORATES_WITH', properties={})]\n" + ] + } + ], + "source": [ + "llm_transformer_filtered = LLMGraphTransformer(\n", + " llm=llm,\n", + " allowed_nodes=[\"Person\", \"Nationality\", \"Concept\"],\n", + " allowed_relationships=[\"NATIONALITY\", \"INVOLVED_IN\", \"COLLABORATES_WITH\"],\n", + ")\n", + "graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(\n", + " documents\n", + ")\n", + "\n", + "print(f\"Nodes:{graph_documents_filtered[0].nodes}\")\n", + "print(f\"Relationships:{graph_documents_filtered[0].relationships}\")" ] }, { "cell_type": "markdown", - "id": "8edb9976", "metadata": {}, "source": [ - "To address this, we can adjust the initial Cypher prompt of the QA chain. This involves adding guidance to the LLM on how users can refer to specific platforms, such as PS5 in our case. We achieve this using the LangChain [PromptTemplate](/docs/how_to#prompt-templates), creating a modified initial prompt. This modified prompt is then supplied as an argument to our refined Memgraph-LangChain instance." + "Here's how the graph would like in such case:\n", + "\n", + "![memgraph-kg-2](../../../static/img/memgraph_kg_2.png)\n", + "\n", + "Your graph can also have `__Entity__` labels on all nodes which will be indexed for faster retrieval. " ] }, { "cell_type": "code", - "execution_count": null, - "id": "312dad05", + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ - "CYPHER_GENERATION_TEMPLATE = \"\"\"\n", - "Task:Generate Cypher statement to query a graph database.\n", - "Instructions:\n", - "Use only the provided relationship types and properties in the schema.\n", - "Do not use any other relationship types or properties that are not provided.\n", - "Schema:\n", - "{schema}\n", - "Note: Do not include any explanations or apologies in your responses.\n", - "Do not respond to any questions that might ask anything else than for you to construct a Cypher statement.\n", - "Do not include any text except the generated Cypher statement.\n", - "If the user asks about PS5, Play Station 5 or PS 5, that is the platform called PlayStation 5.\n", - "\n", - "The question is:\n", - "{question}\n", - "\"\"\"\n", + "# Drop graph\n", + "graph.query(\"STORAGE MODE IN_MEMORY_ANALYTICAL\")\n", + "graph.query(\"DROP GRAPH\")\n", + "graph.query(\"STORAGE MODE IN_MEMORY_TRANSACTIONAL\")\n", "\n", - "CYPHER_GENERATION_PROMPT = PromptTemplate(\n", - " input_variables=[\"schema\", \"question\"], template=CYPHER_GENERATION_TEMPLATE\n", - ")" + "# Store to Memgraph with Entity label\n", + "graph.add_graph_documents(graph_documents, baseEntityLabel=True)" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "2c297245", + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "chain = GraphCypherQAChain.from_llm(\n", - " ChatOpenAI(temperature=0),\n", - " cypher_prompt=CYPHER_GENERATION_PROMPT,\n", - " graph=graph,\n", - " verbose=True,\n", - " model_name=\"gpt-3.5-turbo\",\n", - " allow_dangerous_requests=True,\n", - ")" + "Here's how the graph would look like:\n", + "\n", + "![memgraph-kg-3](../../../static/img/memgraph_kg_3.png)\n", + "\n", + "There is also an option to include the source of the information that's obtained in the graph. To do that, set `include_source` to `True` and then the source document is stored and it is linked to the nodes in the graph using the `MENTIONS` relationship." ] }, { "cell_type": "code", - "execution_count": null, - "id": "7efb11a0", + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "response = chain.run(\"Is Baldur's Gate 3 available on PS5?\")\n", - "print(response)" - ] - }, - { - "cell_type": "markdown", - "id": "289db07f", - "metadata": {}, - "source": [ - "```\n", - "> Entering new GraphCypherQAChain chain...\n", - "Generated Cypher:\n", - "MATCH (g:Game {name: 'Baldur\\'s Gate 3'})-[:AVAILABLE_ON]->(p:Platform {name: 'PlayStation 5'})\n", - "RETURN g.name, p.name\n", - "Full Context:\n", - "[{'g.name': \"Baldur's Gate 3\", 'p.name': 'PlayStation 5'}]\n", + "# Drop graph\n", + "graph.query(\"STORAGE MODE IN_MEMORY_ANALYTICAL\")\n", + "graph.query(\"DROP GRAPH\")\n", + "graph.query(\"STORAGE MODE IN_MEMORY_TRANSACTIONAL\")\n", "\n", - "> Finished chain.\n", - "Yes, Baldur's Gate 3 is available on PlayStation 5.\n", - "```" + "# Store to Memgraph with source included\n", + "graph.add_graph_documents(graph_documents, include_source=True)" ] }, { "cell_type": "markdown", - "id": "84b5f6af", "metadata": {}, "source": [ - "Now, with the revised initial Cypher prompt that includes guidance on platform naming, we are obtaining accurate and relevant results that align more closely with user queries. " + "The constructed graph would look like this:\n", + "\n", + "![memgraph-kg-4](../../../static/img/memgraph_kg_4.png)\n", + "\n", + "Notice how the content of the source is stored and `id` property is generated since the document didn't have any `id`.\n", + "You can combine having both `__Entity__` label and document source. Still, be aware that both take up memory, especially source included due to long strings for content.\n", + "\n", + "In the end, you can query the knowledge graph, as explained in the section before:" ] }, { - "cell_type": "markdown", - "id": "a21108ad", - "metadata": {}, - "source": [ - "This approach allows for further improvement of your QA chain. You can effortlessly integrate extra prompt refinement data into your chain, thereby enhancing the overall user experience of your app." + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MATCH (:Person {id: \"Charles Robert Darwin\"})-[:COLLABORATION]->(collaborator)\n", + "RETURN collaborator;\n", + "Alfred Russel Wallace\n" + ] + } + ], + "source": [ + "chain = MemgraphQAChain.from_llm(\n", + " ChatOpenAI(temperature=0),\n", + " graph=graph,\n", + " model_name=\"gpt-4-turbo\",\n", + " allow_dangerous_requests=True,\n", + ")\n", + "print(chain.invoke(\"Who Charles Robert Darwin collaborated with?\")[\"result\"])" ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "langchain", "language": "python", "name": "python3" }, @@ -714,9 +801,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.19" } }, "nbformat": 4, - "nbformat_minor": 5 + "nbformat_minor": 2 } diff --git a/docs/static/img/memgraph_kg.png b/docs/static/img/memgraph_kg.png new file mode 100644 index 0000000000000..3f518a18710ea Binary files /dev/null and b/docs/static/img/memgraph_kg.png differ diff --git a/docs/static/img/memgraph_kg_2.png b/docs/static/img/memgraph_kg_2.png new file mode 100644 index 0000000000000..3b69ad9a9048f Binary files /dev/null and b/docs/static/img/memgraph_kg_2.png differ diff --git a/docs/static/img/memgraph_kg_3.png b/docs/static/img/memgraph_kg_3.png new file mode 100644 index 0000000000000..ba6ac94489a11 Binary files /dev/null and b/docs/static/img/memgraph_kg_3.png differ diff --git a/docs/static/img/memgraph_kg_4.png b/docs/static/img/memgraph_kg_4.png new file mode 100644 index 0000000000000..e731941c38514 Binary files /dev/null and b/docs/static/img/memgraph_kg_4.png differ diff --git a/libs/community/langchain_community/chains/graph_qa/memgraph.py b/libs/community/langchain_community/chains/graph_qa/memgraph.py new file mode 100644 index 0000000000000..02fa66992a2fd --- /dev/null +++ b/libs/community/langchain_community/chains/graph_qa/memgraph.py @@ -0,0 +1,316 @@ +"""Question answering over a graph.""" + +from __future__ import annotations + +import re +from typing import Any, Dict, List, Optional, Union + +from langchain.chains.base import Chain +from langchain_core.callbacks import CallbackManagerForChainRun +from langchain_core.language_models import BaseLanguageModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import ( + BasePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + MessagesPlaceholder, +) +from langchain_core.runnables import Runnable +from pydantic import Field + +from langchain_community.chains.graph_qa.prompts import ( + MEMGRAPH_GENERATION_PROMPT, + MEMGRAPH_QA_PROMPT, +) +from langchain_community.graphs.memgraph_graph import MemgraphGraph + +INTERMEDIATE_STEPS_KEY = "intermediate_steps" + +FUNCTION_RESPONSE_SYSTEM = """You are an assistant that helps to form nice and human +understandable answers based on the provided information from tools. +Do not add any other information that wasn't present in the tools, and use +very concise style in interpreting results! +""" + + +def extract_cypher(text: str) -> str: + """Extract Cypher code from a text. + + Args: + text: Text to extract Cypher code from. + + Returns: + Cypher code extracted from the text. + """ + # The pattern to find Cypher code enclosed in triple backticks + pattern = r"```(.*?)```" + + # Find all matches in the input text + matches = re.findall(pattern, text, re.DOTALL) + + return matches[0] if matches else text + + +def get_function_response( + question: str, context: List[Dict[str, Any]] +) -> List[BaseMessage]: + TOOL_ID = "call_H7fABDuzEau48T10Qn0Lsh0D" + messages = [ + AIMessage( + content="", + additional_kwargs={ + "tool_calls": [ + { + "id": TOOL_ID, + "function": { + "arguments": '{"question":"' + question + '"}', + "name": "GetInformation", + }, + "type": "function", + } + ] + }, + ), + ToolMessage(content=str(context), tool_call_id=TOOL_ID), + ] + return messages + + +class MemgraphQAChain(Chain): + """Chain for question-answering against a graph by generating Cypher statements. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + graph: MemgraphGraph = Field(exclude=True) + cypher_generation_chain: Runnable + qa_chain: Runnable + graph_schema: str + input_key: str = "query" #: :meta private: + output_key: str = "result" #: :meta private: + top_k: int = 10 + """Number of results to return from the query""" + return_intermediate_steps: bool = False + """Whether or not to return the intermediate steps along with the final answer.""" + return_direct: bool = False + """Optional cypher validation tool""" + use_function_response: bool = False + """Whether to wrap the database context as tool/function response""" + allow_dangerous_requests: bool = False + """Forced user opt-in to acknowledge that the chain can make dangerous requests. + + *Security note*: Make sure that the database connection uses credentials + that are narrowly-scoped to only include necessary permissions. + Failure to do so may result in data corruption or loss, since the calling + code may attempt commands that would result in deletion, mutation + of data if appropriately prompted or reading sensitive data if such + data is present in the database. + The best way to guard against such negative outcomes is to (as appropriate) + limit the permissions granted to the credentials used with this tool. + + See https://python.langchain.com/docs/security for more information. + """ + + def __init__(self, **kwargs: Any) -> None: + """Initialize the chain.""" + super().__init__(**kwargs) + if self.allow_dangerous_requests is not True: + raise ValueError( + "In order to use this chain, you must acknowledge that it can make " + "dangerous requests by setting `allow_dangerous_requests` to `True`." + "You must narrowly scope the permissions of the database connection " + "to only include necessary permissions. Failure to do so may result " + "in data corruption or loss or reading sensitive data if such data is " + "present in the database." + "Only use this chain if you understand the risks and have taken the " + "necessary precautions. " + "See https://python.langchain.com/docs/security for more information." + ) + + @property + def input_keys(self) -> List[str]: + """Return the input keys. + + :meta private: + """ + return [self.input_key] + + @property + def output_keys(self) -> List[str]: + """Return the output keys. + + :meta private: + """ + _output_keys = [self.output_key] + return _output_keys + + @property + def _chain_type(self) -> str: + return "graph_cypher_chain" + + @classmethod + def from_llm( + cls, + llm: Optional[BaseLanguageModel] = None, + *, + qa_prompt: Optional[BasePromptTemplate] = None, + cypher_prompt: Optional[BasePromptTemplate] = None, + cypher_llm: Optional[BaseLanguageModel] = None, + qa_llm: Optional[Union[BaseLanguageModel, Any]] = None, + qa_llm_kwargs: Optional[Dict[str, Any]] = None, + cypher_llm_kwargs: Optional[Dict[str, Any]] = None, + use_function_response: bool = False, + function_response_system: str = FUNCTION_RESPONSE_SYSTEM, + **kwargs: Any, + ) -> MemgraphQAChain: + """Initialize from LLM.""" + + if not cypher_llm and not llm: + raise ValueError("Either `llm` or `cypher_llm` parameters must be provided") + if not qa_llm and not llm: + raise ValueError("Either `llm` or `qa_llm` parameters must be provided") + if cypher_llm and qa_llm and llm: + raise ValueError( + "You can specify up to two of 'cypher_llm', 'qa_llm'" + ", and 'llm', but not all three simultaneously." + ) + if cypher_prompt and cypher_llm_kwargs: + raise ValueError( + "Specifying cypher_prompt and cypher_llm_kwargs together is" + " not allowed. Please pass prompt via cypher_llm_kwargs." + ) + if qa_prompt and qa_llm_kwargs: + raise ValueError( + "Specifying qa_prompt and qa_llm_kwargs together is" + " not allowed. Please pass prompt via qa_llm_kwargs." + ) + use_qa_llm_kwargs = qa_llm_kwargs if qa_llm_kwargs is not None else {} + use_cypher_llm_kwargs = ( + cypher_llm_kwargs if cypher_llm_kwargs is not None else {} + ) + if "prompt" not in use_qa_llm_kwargs: + use_qa_llm_kwargs["prompt"] = ( + qa_prompt if qa_prompt is not None else MEMGRAPH_QA_PROMPT + ) + if "prompt" not in use_cypher_llm_kwargs: + use_cypher_llm_kwargs["prompt"] = ( + cypher_prompt + if cypher_prompt is not None + else MEMGRAPH_GENERATION_PROMPT + ) + + qa_llm = qa_llm or llm + if use_function_response: + try: + qa_llm.bind_tools({}) # type: ignore[union-attr] + response_prompt = ChatPromptTemplate.from_messages( + [ + SystemMessage(content=function_response_system), + HumanMessagePromptTemplate.from_template("{question}"), + MessagesPlaceholder(variable_name="function_response"), + ] + ) + qa_chain = response_prompt | qa_llm | StrOutputParser() # type: ignore + except (NotImplementedError, AttributeError): + raise ValueError("Provided LLM does not support native tools/functions") + else: + qa_chain = use_qa_llm_kwargs["prompt"] | qa_llm | StrOutputParser() # type: ignore + + prompt = use_cypher_llm_kwargs["prompt"] + llm_to_use = cypher_llm if cypher_llm is not None else llm + + if prompt is not None and llm_to_use is not None: + cypher_generation_chain = prompt | llm_to_use | StrOutputParser() # type: ignore[arg-type] + else: + raise ValueError( + "Missing required components for the cypher generation chain: " + "'prompt' or 'llm'" + ) + + graph_schema = kwargs["graph"].get_schema + + return cls( + graph_schema=graph_schema, + qa_chain=qa_chain, + cypher_generation_chain=cypher_generation_chain, + use_function_response=use_function_response, + **kwargs, + ) + + def _call( + self, + inputs: Dict[str, Any], + run_manager: Optional[CallbackManagerForChainRun] = None, + ) -> Dict[str, Any]: + """Generate Cypher statement, use it to look up in db and answer question.""" + _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() + callbacks = _run_manager.get_child() + question = inputs[self.input_key] + args = { + "question": question, + "schema": self.graph_schema, + } + args.update(inputs) + + intermediate_steps: List = [] + + generated_cypher = self.cypher_generation_chain.invoke( + args, callbacks=callbacks + ) + # Extract Cypher code if it is wrapped in backticks + generated_cypher = extract_cypher(generated_cypher) + + _run_manager.on_text("Generated Cypher:", end="\n", verbose=self.verbose) + _run_manager.on_text( + generated_cypher, color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"query": generated_cypher}) + + # Retrieve and limit the number of results + # Generated Cypher be null if query corrector identifies invalid schema + if generated_cypher: + context = self.graph.query(generated_cypher)[: self.top_k] + else: + context = [] + + if self.return_direct: + result = context + else: + _run_manager.on_text("Full Context:", end="\n", verbose=self.verbose) + _run_manager.on_text( + str(context), color="green", end="\n", verbose=self.verbose + ) + + intermediate_steps.append({"context": context}) + if self.use_function_response: + function_response = get_function_response(question, context) + result = self.qa_chain.invoke( # type: ignore + {"question": question, "function_response": function_response}, + ) + else: + result = self.qa_chain.invoke( # type: ignore + {"question": question, "context": context}, + callbacks=callbacks, + ) + + chain_result: Dict[str, Any] = {"result": result} + if self.return_intermediate_steps: + chain_result[INTERMEDIATE_STEPS_KEY] = intermediate_steps + + return chain_result diff --git a/libs/community/langchain_community/chains/graph_qa/prompts.py b/libs/community/langchain_community/chains/graph_qa/prompts.py index ec4d9a4c750a2..9077da3e00e3d 100644 --- a/libs/community/langchain_community/chains/graph_qa/prompts.py +++ b/libs/community/langchain_community/chains/graph_qa/prompts.py @@ -411,3 +411,58 @@ input_variables=["schema", "question", "extra_instructions"], template=NEPTUNE_OPENCYPHER_GENERATION_SIMPLE_TEMPLATE, ) + +MEMGRAPH_GENERATION_TEMPLATE = """Your task is to directly translate natural language inquiry into precise and executable Cypher query for Memgraph database. +You will utilize a provided database schema to understand the structure, nodes and relationships within the Memgraph database. +Instructions: +- Use provided node and relationship labels and property names from the +schema which describes the database's structure. Upon receiving a user +question, synthesize the schema to craft a precise Cypher query that +directly corresponds to the user's intent. +- Generate valid executable Cypher queries on top of Memgraph database. +Any explanation, context, or additional information that is not a part +of the Cypher query syntax should be omitted entirely. +- Use Memgraph MAGE procedures instead of Neo4j APOC procedures. +- Do not include any explanations or apologies in your responses. +- Do not include any text except the generated Cypher statement. +- For queries that ask for information or functionalities outside the direct +generation of Cypher queries, use the Cypher query format to communicate +limitations or capabilities. For example: RETURN "I am designed to generate +Cypher queries based on the provided schema only." +Schema: +{schema} + +With all the above information and instructions, generate Cypher query for the +user question. + +The question is: +{question}""" + +MEMGRAPH_GENERATION_PROMPT = PromptTemplate( + input_variables=["schema", "question"], template=MEMGRAPH_GENERATION_TEMPLATE +) + + +MEMGRAPH_QA_TEMPLATE = """Your task is to form nice and human +understandable answers. The information part contains the provided +information that you must use to construct an answer. +The provided information is authoritative, you must never doubt it or try to +use your internal knowledge to correct it. Make the answer sound as a +response to the question. Do not mention that you based the result on the +given information. Here is an example: + +Question: Which managers own Neo4j stocks? +Context:[manager:CTL LLC, manager:JANE STREET GROUP LLC] +Helpful Answer: CTL LLC, JANE STREET GROUP LLC owns Neo4j stocks. + +Follow this example when generating answers. If the provided information is +empty, say that you don't know the answer. + +Information: +{context} + +Question: {question} +Helpful Answer:""" +MEMGRAPH_QA_PROMPT = PromptTemplate( + input_variables=["context", "question"], template=MEMGRAPH_QA_TEMPLATE +) diff --git a/libs/community/langchain_community/graphs/memgraph_graph.py b/libs/community/langchain_community/graphs/memgraph_graph.py index 34e9f7145bb22..fa829a0e5db81 100644 --- a/libs/community/langchain_community/graphs/memgraph_graph.py +++ b/libs/community/langchain_community/graphs/memgraph_graph.py @@ -1,15 +1,272 @@ -from langchain_community.graphs.neo4j_graph import Neo4jGraph +import logging +from hashlib import md5 +from typing import Any, Dict, List, Optional + +from langchain_core.utils import get_from_dict_or_env + +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_community.graphs.graph_store import GraphStore + +logger = logging.getLogger(__name__) + + +BASE_ENTITY_LABEL = "__Entity__" SCHEMA_QUERY = """ -CALL llm_util.schema("raw") -YIELD * -RETURN * +SHOW SCHEMA INFO +""" + +NODE_PROPERTIES_QUERY = """ +CALL schema.node_type_properties() +YIELD nodeType AS label, propertyName AS property, propertyTypes AS type +WITH label AS nodeLabels, collect({key: property, types: type}) AS properties +RETURN {labels: nodeLabels, properties: properties} AS output +""" + +REL_QUERY = """ +MATCH (n)-[e]->(m) +WITH DISTINCT + labels(n) AS start_node_labels, + type(e) AS rel_type, + labels(m) AS end_node_labels, + e, + keys(e) AS properties +UNWIND CASE WHEN size(properties) > 0 THEN properties ELSE [null] END AS prop +WITH + start_node_labels, + rel_type, + end_node_labels, + CASE WHEN prop IS NULL THEN [] ELSE [prop, valueType(e[prop])] END AS property_info +RETURN + start_node_labels, + rel_type, + end_node_labels, + COLLECT(DISTINCT CASE + WHEN property_info <> [] + THEN property_info + ELSE null END) AS properties_info +""" + +NODE_IMPORT_QUERY = """ +UNWIND $data AS row +CALL merge.node(row.label, row.properties, {}, {}) +YIELD node +RETURN distinct 'done' AS result +""" + +REL_NODES_IMPORT_QUERY = """ +UNWIND $data AS row +MERGE (source {id: row.source_id}) +MERGE (target {id: row.target_id}) +RETURN distinct 'done' AS result +""" + +REL_IMPORT_QUERY = """ +UNWIND $data AS row +MATCH (source {id: row.source_id}) +MATCH (target {id: row.target_id}) +WITH source, target, row +CALL merge.relationship(source, row.type, {}, {}, target, {}) +YIELD rel +RETURN distinct 'done' AS result +""" + +INCLUDE_DOCS_QUERY = """ +MERGE (d:Document {id:$document.metadata.id}) +SET d.content = $document.page_content +SET d += $document.metadata +RETURN distinct 'done' AS result +""" + +INCLUDE_DOCS_SOURCE_QUERY = """ +UNWIND $data AS row +MATCH (source {id: row.source_id}), (d:Document {id: $document.metadata.id}) +MERGE (d)-[:MENTIONS]->(source) +RETURN distinct 'done' AS result +""" + +NODE_PROPS_TEXT = """ +Node labels and properties (name and type) are: """ +REL_PROPS_TEXT = """ +Relationship labels and properties are: +""" + +REL_TEXT = """ +Nodes are connected with the following relationships: +""" + + +def get_schema_subset(data: Dict[str, Any]) -> Dict[str, Any]: + return { + "edges": [ + { + "end_node_labels": edge["end_node_labels"], + "properties": [ + { + "key": prop["key"], + "types": [ + {"type": type_item["type"].lower()} + for type_item in prop["types"] + ], + } + for prop in edge["properties"] + ], + "start_node_labels": edge["start_node_labels"], + "type": edge["type"], + } + for edge in data["edges"] + ], + "nodes": [ + { + "labels": node["labels"], + "properties": [ + { + "key": prop["key"], + "types": [ + {"type": type_item["type"].lower()} + for type_item in prop["types"] + ], + } + for prop in node["properties"] + ], + } + for node in data["nodes"] + ], + } + + +def get_reformated_schema( + nodes: List[Dict[str, Any]], rels: List[Dict[str, Any]] +) -> Dict[str, Any]: + return { + "edges": [ + { + "end_node_labels": rel["end_node_labels"], + "properties": [ + {"key": prop[0], "types": [{"type": prop[1].lower()}]} + for prop in rel["properties_info"] + ], + "start_node_labels": rel["start_node_labels"], + "type": rel["rel_type"], + } + for rel in rels + ], + "nodes": [ + { + "labels": [_remove_backticks(node["labels"])[1:]], + "properties": [ + { + "key": prop["key"], + "types": [ + {"type": type_item.lower()} for type_item in prop["types"] + ], + } + for prop in node["properties"] + if node["properties"][0]["key"] != "" + ], + } + for node in nodes + ], + } + + +def transform_schema_to_text(schema: Dict[str, Any]) -> str: + node_props_data = "" + rel_props_data = "" + rel_data = "" + + for node in schema["nodes"]: + node_props_data += f"- labels: (:{':'.join(node['labels'])})\n" + if node["properties"] == []: + continue + node_props_data += " properties:\n" + for prop in node["properties"]: + prop_types_str = " or ".join( + {prop_types["type"] for prop_types in prop["types"]} + ) + node_props_data += f" - {prop['key']}: {prop_types_str}\n" + + for rel in schema["edges"]: + rel_type = rel["type"] + start_labels = ":".join(rel["start_node_labels"]) + end_labels = ":".join(rel["end_node_labels"]) + rel_data += f"(:{start_labels})-[:{rel_type}]->(:{end_labels})\n" + + if rel["properties"] == []: + continue + + rel_props_data += f"- labels: {rel_type}\n properties:\n" + for prop in rel["properties"]: + prop_types_str = " or ".join( + {prop_types["type"].lower() for prop_types in prop["types"]} + ) + rel_props_data += f" - {prop['key']}: {prop_types_str}\n" + + return "".join( + [ + NODE_PROPS_TEXT + node_props_data if node_props_data else "", + REL_PROPS_TEXT + rel_props_data if rel_props_data else "", + REL_TEXT + rel_data if rel_data else "", + ] + ) + + +def _remove_backticks(text: str) -> str: + return text.replace("`", "") + + +def _transform_nodes(nodes: list[Node], baseEntityLabel: bool) -> List[dict]: + transformed_nodes = [] + for node in nodes: + properties_dict = node.properties | {"id": node.id} + label = ( + [_remove_backticks(node.type), BASE_ENTITY_LABEL] + if baseEntityLabel + else [_remove_backticks(node.type)] + ) + node_dict = {"label": label, "properties": properties_dict} + transformed_nodes.append(node_dict) + return transformed_nodes + + +def _transform_relationships( + relationships: list[Relationship], baseEntityLabel: bool +) -> List[dict]: + transformed_relationships = [] + for rel in relationships: + rel_dict = { + "type": _remove_backticks(rel.type), + "source_label": ( + [BASE_ENTITY_LABEL] + if baseEntityLabel + else [_remove_backticks(rel.source.type)] + ), + "source_id": rel.source.id, + "target_label": ( + [BASE_ENTITY_LABEL] + if baseEntityLabel + else [_remove_backticks(rel.target.type)] + ), + "target_id": rel.target.id, + } + transformed_relationships.append(rel_dict) + return transformed_relationships -class MemgraphGraph(Neo4jGraph): + +class MemgraphGraph(GraphStore): """Memgraph wrapper for graph operations. + Parameters: + url (Optional[str]): The URL of the Memgraph database server. + username (Optional[str]): The username for database authentication. + password (Optional[str]): The password for database authentication. + database (str): The name of the database to connect to. Default is 'memgraph'. + refresh_schema (bool): A flag whether to refresh schema information + at initialization. Default is True. + driver_config (Dict): Configuration passed to Neo4j Driver. + *Security note*: Make sure that the database connection uses credentials that are narrowly-scoped to only include necessary permissions. Failure to do so may result in data corruption or loss, since the calling @@ -23,49 +280,247 @@ class MemgraphGraph(Neo4jGraph): """ def __init__( - self, url: str, username: str, password: str, *, database: str = "memgraph" + self, + url: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + refresh_schema: bool = True, + *, + driver_config: Optional[Dict] = None, ) -> None: """Create a new Memgraph graph wrapper instance.""" - super().__init__(url, username, password, database=database) + try: + import neo4j + except ImportError: + raise ImportError( + "Could not import neo4j python package. " + "Please install it with `pip install neo4j`." + ) + + url = get_from_dict_or_env({"url": url}, "url", "MEMGRAPH_URI") + + # if username and password are "", assume auth is disabled + if username == "" and password == "": + auth = None + else: + username = get_from_dict_or_env( + {"username": username}, + "username", + "MEMGRAPH_USERNAME", + ) + password = get_from_dict_or_env( + {"password": password}, + "password", + "MEMGRAPH_PASSWORD", + ) + auth = (username, password) + database = get_from_dict_or_env( + {"database": database}, "database", "MEMGRAPH_DATABASE", "memgraph" + ) + + self._driver = neo4j.GraphDatabase.driver( + url, auth=auth, **(driver_config or {}) + ) + + self._database = database + self.schema: str = "" + self.structured_schema: Dict[str, Any] = {} + + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Memgraph database. " + "Please ensure that the username and password are correct" + ) + + # Set schema + if refresh_schema: + try: + self.refresh_schema() + except neo4j.exceptions.ClientError as e: + raise e + + def close(self) -> None: + if self._driver: + logger.info("Closing the driver connection.") + self._driver.close() + self._driver = None + + @property + def get_schema(self) -> str: + """Returns the schema of the Graph database""" + return self.schema + + @property + def get_structured_schema(self) -> Dict[str, Any]: + """Returns the structured schema of the Graph database""" + return self.structured_schema + + def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: + """Query the graph. + + Args: + query (str): The Cypher query to execute. + params (dict): The parameters to pass to the query. + + Returns: + List[Dict[str, Any]]: The list of dictionaries containing the query results. + """ + from neo4j.exceptions import Neo4jError + + try: + data, _, _ = self._driver.execute_query( + query, + database_=self._database, + parameters_=params, + ) + json_data = [r.data() for r in data] + return json_data + except Neo4jError as e: + if not ( + ( + ( # isCallInTransactionError + e.code == "Neo.DatabaseError.Statement.ExecutionFailed" + or e.code + == "Neo.DatabaseError.Transaction.TransactionStartFailed" + ) + and "in an implicit transaction" in e.message + ) + or ( # isPeriodicCommitError + e.code == "Neo.ClientError.Statement.SemanticError" + and ( + "in an open transaction is not possible" in e.message + or "tried to execute in an explicit transaction" in e.message + ) + ) + or ( + e.code == "Memgraph.ClientError.MemgraphError.MemgraphError" + and ("in multicommand transactions" in e.message) + ) + or ( + e.code == "Memgraph.ClientError.MemgraphError.MemgraphError" + and "SchemaInfo disabled" in e.message + ) + ): + raise + + # fallback to allow implicit transactions + with self._driver.session(database=self._database) as session: + data = session.run(query, params) + json_data = [r.data() for r in data] + return json_data def refresh_schema(self) -> None: """ Refreshes the Memgraph graph schema information. """ + import ast + + from neo4j.exceptions import Neo4jError + + # leave schema empty if db is empty + if self.query("MATCH (n) RETURN n LIMIT 1") == []: + return - db_structured_schema = self.query(SCHEMA_QUERY)[0].get("schema") - assert db_structured_schema is not None - self.structured_schema = db_structured_schema + # first try with SHOW SCHEMA INFO + try: + result = self.query(SCHEMA_QUERY)[0].get("schema") + if result is not None and isinstance(result, (str, ast.AST)): + schema_result = ast.literal_eval(result) + else: + schema_result = result + assert schema_result is not None + structured_schema = get_schema_subset(schema_result) + self.structured_schema = structured_schema + self.schema = transform_schema_to_text(structured_schema) + return + except Neo4jError as e: + if ( + e.code == "Memgraph.ClientError.MemgraphError.MemgraphError" + and "SchemaInfo disabled" in e.message + ): + logger.info( + "Schema generation with SHOW SCHEMA INFO query failed. " + "Set --schema-info-enabled=true to use SHOW SCHEMA INFO query. " + "Falling back to alternative queries." + ) - # Format node properties - formatted_node_props = [] + # fallback on Cypher without SHOW SCHEMA INFO + nodes = [query["output"] for query in self.query(NODE_PROPERTIES_QUERY)] + rels = self.query(REL_QUERY) - for node_name, properties in db_structured_schema["node_props"].items(): - formatted_node_props.append( - f"Node name: '{node_name}', Node properties: {properties}" + structured_schema = get_reformated_schema(nodes, rels) + self.structured_schema = structured_schema + self.schema = transform_schema_to_text(structured_schema) + + def add_graph_documents( + self, + graph_documents: List[GraphDocument], + include_source: bool = False, + baseEntityLabel: bool = False, + ) -> None: + """ + Take GraphDocument as input as uses it to construct a graph in Memgraph. + + Parameters: + - graph_documents (List[GraphDocument]): A list of GraphDocument objects + that contain the nodes and relationships to be added to the graph. Each + GraphDocument should encapsulate the structure of part of the graph, + including nodes, relationships, and the source document information. + - include_source (bool, optional): If True, stores the source document + and links it to nodes in the graph using the MENTIONS relationship. + This is useful for tracing back the origin of data. Merges source + documents based on the `id` property from the source document metadata + if available; otherwise it calculates the MD5 hash of `page_content` + for merging process. Defaults to False. + - baseEntityLabel (bool, optional): If True, each newly created node + gets a secondary __Entity__ label, which is indexed and improves import + speed and performance. Defaults to False. + """ + + if baseEntityLabel: + self.query( + f"CREATE CONSTRAINT ON (b:{BASE_ENTITY_LABEL}) " + "ASSERT b.id IS UNIQUE;" ) + self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL}(id);") + self.query(f"CREATE INDEX ON :{BASE_ENTITY_LABEL};") + + for document in graph_documents: + if include_source: + if not document.source.metadata.get("id"): + document.source.metadata["id"] = md5( + document.source.page_content.encode("utf-8") + ).hexdigest() - # Format relationship properties - formatted_rel_props = [] - for rel_name, properties in db_structured_schema["rel_props"].items(): - formatted_rel_props.append( - f"Relationship name: '{rel_name}', " - f"Relationship properties: {properties}" + self.query(INCLUDE_DOCS_QUERY, {"document": document.source.__dict__}) + + self.query( + NODE_IMPORT_QUERY, + {"data": _transform_nodes(document.nodes, baseEntityLabel)}, ) - # Format relationships - formatted_rels = [ - f"(:{rel['start']})-[:{rel['type']}]->(:{rel['end']})" - for rel in db_structured_schema["relationships"] - ] + rel_data = _transform_relationships(document.relationships, baseEntityLabel) + self.query( + REL_NODES_IMPORT_QUERY, + {"data": rel_data}, + ) + self.query( + REL_IMPORT_QUERY, + {"data": rel_data}, + ) - self.schema = "\n".join( - [ - "Node properties are the following:", - *formatted_node_props, - "Relationship properties are the following:", - *formatted_rel_props, - "The relationships are the following:", - *formatted_rels, - ] - ) + if include_source: + self.query( + INCLUDE_DOCS_SOURCE_QUERY, + {"data": rel_data, "document": document.source.__dict__}, + ) + self.refresh_schema() diff --git a/libs/community/tests/integration_tests/graphs/test_memgraph.py b/libs/community/tests/integration_tests/graphs/test_memgraph.py index 663f974d3f106..c005eec55ef3a 100644 --- a/libs/community/tests/integration_tests/graphs/test_memgraph.py +++ b/libs/community/tests/integration_tests/graphs/test_memgraph.py @@ -1,24 +1,44 @@ import os +from langchain_core.documents import Document + from langchain_community.graphs import MemgraphGraph +from langchain_community.graphs.graph_document import GraphDocument, Node, Relationship +from langchain_community.graphs.memgraph_graph import NODE_PROPERTIES_QUERY, REL_QUERY + +test_data = [ + GraphDocument( + nodes=[Node(id="foo", type="foo"), Node(id="bar", type="bar")], + relationships=[ + Relationship( + source=Node(id="foo", type="foo"), + target=Node(id="bar", type="bar"), + type="REL", + ) + ], + source=Document(page_content="source document"), + ) +] def test_cypher_return_correct_schema() -> None: """Test that chain returns direct results.""" + url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687") username = os.environ.get("MEMGRAPH_USERNAME", "") password = os.environ.get("MEMGRAPH_PASSWORD", "") + assert url is not None assert username is not None assert password is not None - graph = MemgraphGraph( - url=url, - username=username, - password=password, - ) - # Delete all nodes in the graph - graph.query("MATCH (n) DETACH DELETE n") + graph = MemgraphGraph(url=url, username=username, password=password) + + # Drop graph + graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL") + graph.query("DROP GRAPH") + graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + # Create two nodes and a relationship graph.query( """ @@ -31,32 +51,123 @@ def test_cypher_return_correct_schema() -> None: ) # Refresh schema information graph.refresh_schema() - relationships = graph.query( - "CALL llm_util.schema('raw') YIELD schema " - "WITH schema.relationships AS relationships " - "UNWIND relationships AS relationship " - "RETURN relationship['start'] AS start, " - "relationship['type'] AS type, " - "relationship['end'] AS end " - "ORDER BY start, type, end;" - ) - node_props = graph.query( - "CALL llm_util.schema('raw') YIELD schema " - "WITH schema.node_props AS nodes " - "WITH nodes['LabelA'] AS properties " - "UNWIND properties AS property " - "RETURN property['property'] AS prop, " - "property['type'] AS type " - "ORDER BY prop ASC;" - ) + node_properties = graph.query(NODE_PROPERTIES_QUERY) + relationships = graph.query(REL_QUERY) + + expected_node_properties = [ + { + "output": { + "labels": ":`LabelA`", + "properties": [{"key": "property_a", "types": ["String"]}], + } + }, + {"output": {"labels": ":`LabelB`", "properties": [{"key": "", "types": []}]}}, + {"output": {"labels": ":`LabelC`", "properties": [{"key": "", "types": []}]}}, + ] expected_relationships = [ - {"start": "LabelA", "type": "REL_TYPE", "end": "LabelB"}, - {"start": "LabelA", "type": "REL_TYPE", "end": "LabelC"}, + { + "start_node_labels": ["LabelA"], + "rel_type": "REL_TYPE", + "end_node_labels": ["LabelC"], + "properties_info": [["rel_prop", "STRING"]], + }, + { + "start_node_labels": ["LabelA"], + "rel_type": "REL_TYPE", + "end_node_labels": ["LabelB"], + "properties_info": [], + }, ] - expected_node_props = [{"prop": "property_a", "type": "str"}] + graph.close() + assert node_properties == expected_node_properties assert relationships == expected_relationships - assert node_props == expected_node_props + + +def test_add_graph_documents() -> None: + """Test that Memgraph correctly imports graph document.""" + url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687") + username = os.environ.get("MEMGRAPH_USERNAME", "") + password = os.environ.get("MEMGRAPH_PASSWORD", "") + + assert url is not None + assert username is not None + assert password is not None + + graph = MemgraphGraph( + url=url, username=username, password=password, refresh_schema=False + ) + # Drop graph + graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL") + graph.query("DROP GRAPH") + graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + # Create KG + graph.add_graph_documents(test_data) + output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count") + # Close the connection + graph.close() + assert output == [{"label": ["bar"], "count": 1}, {"label": ["foo"], "count": 1}] + + +def test_add_graph_documents_base_entity() -> None: + """Test that Memgraph correctly imports graph document with Entity label.""" + url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687") + username = os.environ.get("MEMGRAPH_USERNAME", "") + password = os.environ.get("MEMGRAPH_PASSWORD", "") + + assert url is not None + assert username is not None + assert password is not None + + graph = MemgraphGraph( + url=url, username=username, password=password, refresh_schema=False + ) + # Drop graph + graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL") + graph.query("DROP GRAPH") + graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + # Create KG + graph.add_graph_documents(test_data, baseEntityLabel=True) + output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count") + + # Close the connection + graph.close() + + assert output == [ + {"label": ["__Entity__", "bar"], "count": 1}, + {"label": ["__Entity__", "foo"], "count": 1}, + ] + + +def test_add_graph_documents_include_source() -> None: + """Test that Memgraph correctly imports graph document with source included.""" + url = os.environ.get("MEMGRAPH_URI", "bolt://localhost:7687") + username = os.environ.get("MEMGRAPH_USERNAME", "") + password = os.environ.get("MEMGRAPH_PASSWORD", "") + + assert url is not None + assert username is not None + assert password is not None + + graph = MemgraphGraph( + url=url, username=username, password=password, refresh_schema=False + ) + # Drop graph + graph.query("STORAGE MODE IN_MEMORY_ANALYTICAL") + graph.query("DROP GRAPH") + graph.query("STORAGE MODE IN_MEMORY_TRANSACTIONAL") + # Create KG + graph.add_graph_documents(test_data, include_source=True) + output = graph.query("MATCH (n) RETURN labels(n) AS label, count(*) AS count") + + # Close the connection + graph.close() + + assert output == [ + {"label": ["bar"], "count": 1}, + {"label": ["foo"], "count": 1}, + {"label": ["Document"], "count": 1}, + ]