Skip to content

Commit

Permalink
langchain[patch]: return formatted SPARQL query on demand (langchain-…
Browse files Browse the repository at this point in the history
…ai#11263)

- **Description:** Added the `return_sparql_query` feature to the
`GraphSparqlQAChain` class, allowing users to get the formatted SPARQL
query along with the chain's result.
  - **Issue:** NA
  - **Dependencies:** None

Note: I've ensured that the PR passes linting and testing by running
make format, make lint, and make test locally.

I have added a test for the integration (which relies on network access)
and I have added an example to the notebook showing its use.
  • Loading branch information
reidfalconer authored and al1p-R committed Feb 27, 2024
1 parent 830b39a commit 9430b11
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 11 deletions.
106 changes: 96 additions & 10 deletions docs/docs/use_cases/graph/graph_sparql_qa.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"id": "62812aad",
"metadata": {
"pycharm": {
Expand All @@ -37,7 +37,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 4,
"id": "0928915d",
"metadata": {
"pycharm": {
Expand Down Expand Up @@ -74,7 +74,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"id": "4e3de44f",
"metadata": {
"pycharm": {
Expand All @@ -88,7 +88,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 6,
"id": "1fe76ccd",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -121,7 +121,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "7476ce98",
"metadata": {
"pycharm": {
Expand Down Expand Up @@ -250,7 +250,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 27,
"id": "f874171b",
"metadata": {},
"outputs": [
Expand All @@ -277,13 +277,99 @@
")\n",
"graph.query(query)"
]
},
{
"cell_type": "markdown",
"id": "eb00a625-a6c9-4766-b3f0-eaed024851c9",
"metadata": {},
"source": [
"## Return SQARQL query\n",
"You can return the SPARQL query step from the Sparql QA Chain using the `return_sparql_query` parameter"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f13e2865-176a-4417-95e6-db818b214d08",
"metadata": {},
"outputs": [],
"source": [
"chain = GraphSparqlQAChain.from_llm(\n",
" ChatOpenAI(temperature=0), graph=graph, verbose=True, return_sparql_query=True\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "4f4d47b6-4202-4e74-8c88-aeaac5344c04",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n",
"\u001b[1m> Entering new GraphSparqlQAChain chain...\u001b[0m\n",
"Identified intent:\n",
"\u001b[32;1m\u001b[1;3mSELECT\u001b[0m\n",
"Generated SPARQL:\n",
"\u001b[32;1m\u001b[1;3mPREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\u001b[0m\n",
"Full Context:\n",
"\u001b[32;1m\u001b[1;3m[]\u001b[0m\n",
"\n",
"\u001b[1m> Finished chain.\u001b[0m\n",
"SQARQL query: PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\n",
"Final answer: Tim Berners-Lee's work homepage is http://www.w3.org/People/Berners-Lee/.\n"
]
}
],
"source": [
"result = chain(\"What is Tim Berners-Lee's work homepage?\")\n",
"print(f\"SQARQL query: {result['sparql_query']}\")\n",
"print(f\"Final answer: {result['result']}\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "be3d9ff7-dc00-47d0-857d-fd40437a3f22",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PREFIX foaf: <http://xmlns.com/foaf/0.1/>\n",
"SELECT ?workHomepage\n",
"WHERE {\n",
" ?person foaf:name \"Tim Berners-Lee\" .\n",
" ?person foaf:workplaceHomepage ?workHomepage .\n",
"}\n"
]
}
],
"source": [
"print(result[\"sparql_query\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lc",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "lc"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -295,9 +381,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}
16 changes: 15 additions & 1 deletion libs/langchain/langchain/chains/graph_qa/sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,25 @@ class GraphSparqlQAChain(Chain):
sparql_generation_update_chain: LLMChain
sparql_intent_chain: LLMChain
qa_chain: LLMChain
return_sparql_query: bool = False
input_key: str = "query" #: :meta private:
output_key: str = "result" #: :meta private:
sparql_query_key: str = "sparql_query" #: :meta private:

@property
def input_keys(self) -> List[str]:
"""Return the input keys.
:meta private:
"""
return [self.input_key]

@property
def output_keys(self) -> List[str]:
"""Return the output keys.
:meta private:
"""
_output_keys = [self.output_key]
return _output_keys

Expand Down Expand Up @@ -135,4 +145,8 @@ def _call(
res = "Successfully inserted triples into the graph."
else:
raise ValueError("Unsupported SPARQL query type.")
return {self.output_key: res}

chain_result: Dict[str, Any] = {self.output_key: res}
if self.return_sparql_query:
chain_result[self.sparql_query_key] = generated_sparql
return chain_result
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,29 @@ def test_sparql_insert() -> None:
os.remove(_local_copy)
except OSError:
pass


def test_sparql_select_return_query() -> None:
"""
Test for generating and executing simple SPARQL SELECT query
and returning the generated SPARQL query.
"""
berners_lee_card = "http://www.w3.org/People/Berners-Lee/card"

graph = RdfGraph(
source_file=berners_lee_card,
standard="rdf",
)

chain = GraphSparqlQAChain.from_llm(
OpenAI(temperature=0), graph=graph, return_sparql_query=True
)
output = chain("What is Tim Berners-Lee's work homepage?")

# Verify the expected answer
expected_output = (
" The work homepage of Tim Berners-Lee is "
"http://www.w3.org/People/Berners-Lee/."
)
assert output["result"] == expected_output
assert "sparql_query" in output

0 comments on commit 9430b11

Please sign in to comment.