Skip to content

Commit

Permalink
core: Add support to accept PDF in ChatPromptTemplate.
Browse files Browse the repository at this point in the history
>
>
Co-authored-by: czhu24 <[email protected]>
  • Loading branch information
rkwan05 committed Nov 9, 2024
1 parent b509747 commit 1279507
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions libs/core/extended_testing_deps.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
jinja2>=3,<4
pypdf>=5,<6
29 changes: 29 additions & 0 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import base64
import io
from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path
Expand All @@ -22,6 +24,7 @@
SkipValidation,
model_validator,
)
from pypdf import PdfReader

from langchain_core._api import deprecated
from langchain_core.load import Serializable
Expand All @@ -48,6 +51,22 @@
from langchain_core.utils.interactive_env import is_interactive_env


# extract pdf into bytes
def extract_pdf_text(pdf_data: bytes) -> str:
# Decode the base64 back into bytes
pdf_bytes = base64.b64decode(pdf_data)
pdf_text = ""

# Read the PDF and extract text
with io.BytesIO(pdf_bytes) as pdf_file:
reader = PdfReader(pdf_file)
for page in reader.pages:
extracted_text = page.extract_text()
pdf_text += extracted_text + "\n"

return pdf_text


class BaseMessagePromptTemplate(Serializable, ABC):
"""Base class for message prompt templates."""

Expand Down Expand Up @@ -468,6 +487,10 @@ class _ImageTemplateParam(TypedDict, total=False):
image_url: Union[str, dict]


class _PdfTemplateParam(TypedDict, total=False):
pdf: Union[str, dict]


class _StringImageMessagePromptTemplate(BaseMessagePromptTemplate):
"""Human message prompt template. This is a message sent from the user."""

Expand Down Expand Up @@ -571,6 +594,12 @@ def from_template(
msg = f"Invalid image template: {tmpl}"
raise ValueError(msg)
prompt.append(img_template_obj)
elif isinstance(tmpl, dict) and "data" in tmpl:
if tmpl.get("mime_type") == "application/pdf":
pdf_template = cast(_PdfTemplateParam, tmpl)["data"]
pdf_text = extract_pdf_text(pdf_template)

prompt.append(PromptTemplate.from_template(pdf_text))
else:
msg = f"Invalid template: {tmpl}"
raise ValueError(msg)
Expand Down
30 changes: 30 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,36 @@
from tests.unit_tests.pydantic_utils import _normalize_schema


def test_create_pdf_chat_prompt() -> None:
"""Test chat prompt with pdf data as bytes."""
file_path = (
Path(__file__).parent.parent.parent.parent.parent
/ "community/tests/examples/hello.pdf"
)

with open(file_path, "rb") as file:
file_data = file.read()

pdf_data = base64.b64encode(file_data).decode("utf-8")
prompt = ChatPromptTemplate(
[
(
"human",
[
{"type": "media", "mime_type": "application/pdf", "data": pdf_data},
],
)
]
)

expected_prompt = PromptTemplate(template="Hello world!\n1\n")

assert len(prompt.messages) == 1
output_prompt = prompt.messages[0]
assert isinstance(output_prompt, HumanMessagePromptTemplate)
assert output_prompt.prompt == [expected_prompt]


@pytest.fixture
def messages() -> list[BaseMessagePromptTemplate]:
"""Create messages."""
Expand Down

0 comments on commit 1279507

Please sign in to comment.