From c04c85db4d416ffeb1b6d61fdfada5f7b6be6f89 Mon Sep 17 00:00:00 2001 From: Adrian Panella Date: Mon, 23 Dec 2024 22:52:28 -0300 Subject: [PATCH] core: allow artifact in create_retriever_tool --- libs/core/langchain_core/tools/retriever.py | 25 ++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/libs/core/langchain_core/tools/retriever.py b/libs/core/langchain_core/tools/retriever.py index b59b49a3d317a..23b4af29593bd 100644 --- a/libs/core/langchain_core/tools/retriever.py +++ b/libs/core/langchain_core/tools/retriever.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import partial -from typing import Optional +from typing import Any, Literal, Optional, Union from pydantic import BaseModel, Field @@ -28,11 +28,16 @@ def _get_relevant_documents( document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, -) -> str: + response_format: Literal["content", "content_and_artifact"] = "content", +) -> Union[str, tuple[str, list[dict[str, Any]]]]: docs = retriever.invoke(query, config={"callbacks": callbacks}) - return document_separator.join( + content = document_separator.join( format_document(doc, document_prompt) for doc in docs ) + if response_format == "content_and_artifact": + return (content, [doc.model_dump() for doc in docs]) + + return content async def _aget_relevant_documents( @@ -41,12 +46,18 @@ async def _aget_relevant_documents( document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, -) -> str: + response_format: Literal["content", "content_and_artifact"] = "content", +) -> Union[str, tuple[str, list[dict[str, Any]]]]: docs = await retriever.ainvoke(query, config={"callbacks": callbacks}) - return document_separator.join( + content = document_separator.join( [await aformat_document(doc, document_prompt) for doc in docs] ) + if response_format == "content_and_artifact": + return (content, [doc.model_dump() for doc in docs]) + + return content + def create_retriever_tool( retriever: BaseRetriever, @@ -55,6 +66,7 @@ def create_retriever_tool( *, document_prompt: Optional[BasePromptTemplate] = None, document_separator: str = "\n\n", + response_format: Literal["content", "content_and_artifact"] = "content", ) -> Tool: """Create a tool to do retrieval of documents. @@ -76,12 +88,14 @@ def create_retriever_tool( retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, + response_format=response_format, ) afunc = partial( _aget_relevant_documents, retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, + response_format=response_format, ) return Tool( name=name, @@ -89,4 +103,5 @@ def create_retriever_tool( func=func, coroutine=afunc, args_schema=RetrieverInput, + response_format=response_format )