Skip to content

Commit

Permalink
core: allow artifact in create_retriever_tool
Browse files Browse the repository at this point in the history
  • Loading branch information
ianchi committed Dec 24, 2024
1 parent cb4e6ac commit eef8b40
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions libs/core/langchain_core/tools/retriever.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Optional
from typing import Any, Literal, Optional, Union

from pydantic import BaseModel, Field

Expand All @@ -28,11 +28,16 @@ def _get_relevant_documents(
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[dict[str, Any]]]]:
docs = retriever.invoke(query, config={"callbacks": callbacks})
return document_separator.join(
content = document_separator.join(
format_document(doc, document_prompt) for doc in docs
)
if response_format == "content_and_artifact":
return (content, [doc.model_dump() for doc in docs])

return content


async def _aget_relevant_documents(
Expand All @@ -41,12 +46,18 @@ async def _aget_relevant_documents(
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Union[str, tuple[str, list[dict[str, Any]]]]:
docs = await retriever.ainvoke(query, config={"callbacks": callbacks})
return document_separator.join(
content = document_separator.join(
[await aformat_document(doc, document_prompt) for doc in docs]
)

if response_format == "content_and_artifact":
return (content, [doc.model_dump() for doc in docs])

return content


def create_retriever_tool(
retriever: BaseRetriever,
Expand All @@ -55,6 +66,7 @@ def create_retriever_tool(
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
response_format: Literal["content", "content_and_artifact"] = "content",
) -> Tool:
"""Create a tool to do retrieval of documents.
Expand All @@ -76,17 +88,20 @@ def create_retriever_tool(
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
response_format=response_format,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
response_format=response_format,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
response_format=response_format
)

0 comments on commit eef8b40

Please sign in to comment.