From 196a7665c798e4f51fe708969360b068b5dda168 Mon Sep 17 00:00:00 2001 From: rkwan05 Date: Thu, 7 Nov 2024 19:44:50 -0500 Subject: [PATCH] core: Add support to accept PDF in ChatPromptTemplate. > > Co-authored-by: czhu24 --- libs/core/extended_testing_deps.txt | 1 + libs/core/langchain_core/prompts/chat.py | 29 ++++++++++++++++++ .../tests/unit_tests/prompts/test_chat.py | 30 +++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/libs/core/extended_testing_deps.txt b/libs/core/extended_testing_deps.txt index 5ad9c8930daf9..eb408cd10ad8a 100644 --- a/libs/core/extended_testing_deps.txt +++ b/libs/core/extended_testing_deps.txt @@ -1 +1,2 @@ jinja2>=3,<4 +pypdf>=5,<6 \ No newline at end of file diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 1629962ba1333..7b367248861d1 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -2,6 +2,8 @@ from __future__ import annotations +import base64 +import io from abc import ABC, abstractmethod from collections.abc import Sequence from pathlib import Path @@ -22,6 +24,7 @@ SkipValidation, model_validator, ) +from pypdf import PdfReader from langchain_core._api import deprecated from langchain_core.load import Serializable @@ -48,6 +51,22 @@ from langchain_core.utils.interactive_env import is_interactive_env +# extract pdf into bytes +def extract_pdf_text(pdf_data: bytes) -> str: + # Decode the base64 back into bytes + pdf_bytes = base64.b64decode(pdf_data) + pdf_text = "" + + # Read the PDF and extract text + with io.BytesIO(pdf_bytes) as pdf_file: + reader = PdfReader(pdf_file) + for page in reader.pages: + extracted_text = page.extract_text() + pdf_text += extracted_text + "\n" + + return pdf_text + + class BaseMessagePromptTemplate(Serializable, ABC): """Base class for message prompt templates.""" @@ -468,6 +487,10 @@ class _ImageTemplateParam(TypedDict, total=False): image_url: Union[str, dict] +class _PdfTemplateParam(TypedDict, total=False): + pdf: Union[str, dict] + + class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate): """Human message prompt template. This is a message sent from the user.""" @@ -571,6 +594,12 @@ def from_template( msg = f"Invalid image template: {tmpl}" raise ValueError(msg) prompt.append(img_template_obj) + elif isinstance(tmpl, dict) and "data" in tmpl: + if tmpl.get("mime_type") == "application/pdf": + pdf_template = cast(_PdfTemplateParam, tmpl)["data"] + pdf_text = extract_pdf_text(pdf_template) + + prompt.append(PromptTemplate.from_template(pdf_text)) else: msg = f"Invalid template: {tmpl}" raise ValueError(msg) diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index 6249aa6f47893..7861d99bc7bbb 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -37,6 +37,36 @@ from tests.unit_tests.pydantic_utils import _normalize_schema +def test_create_pdf_chat_prompt() -> None: + """Test chat prompt with pdf data as bytes.""" + file_path = ( + Path(__file__).parent.parent.parent.parent.parent + / "community/tests/examples/hello.pdf" + ) + + with open(file_path, "rb") as file: + file_data = file.read() + + pdf_data = base64.b64encode(file_data).decode("utf-8") + prompt = ChatPromptTemplate( + [ + ( + "human", + [ + {"type": "media", "mime_type": "application/pdf", "data": pdf_data}, + ], + ) + ] + ) + + expected_prompt = PromptTemplate(template="Hello world!\n1\n") + + assert len(prompt.messages) == 1 + output_prompt = prompt.messages[0] + assert isinstance(output_prompt, HumanMessagePromptTemplate) + assert output_prompt.prompt == [expected_prompt] + + @pytest.fixture def messages() -> list[BaseMessagePromptTemplate]: """Create messages."""