Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use image prompt template #106

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"</table>"
],
"text/plain": [
"Registry(tasks=[RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_cached_docs at 0x1066098a0>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1266289a0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x126628a40>, 'hyde': <function _chroma_hyde_retriever_factory at 0x126628ae0>}, architecture_factories={'conversational-retrieval-qa': <function default_response_chain at 0x10fea2660>}), RetrievalTask(name='Semi-structured Reports', dataset_id='https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d', description=\"Questions and answers based on PDFs containing tables and charts.\\n\\nThe task provides the raw documents as well as factory methods to easily index them\\nand create a retriever.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_docs at 0x126629620>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x1266296c0>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x126629760>, 'hyde': <function _chroma_hyde_retriever_factory at 0x126629800>}, architecture_factories={}), RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})])"
"Registry(tasks=[RetrievalTask(name='LangChain Docs Q&A', dataset_id='https://smith.langchain.com/public/452ccafc-18e1-4314-885b-edd735f17b9d/d', description=\"Questions and answers based on a snapshot of the LangChain python docs.\\n\\nThe environment provides the documents and the retriever information.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_cached_docs at 0x128811430>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x12a852550>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x12a8525e0>, 'hyde': <function _chroma_hyde_retriever_factory at 0x12a852670>}, architecture_factories={'conversational-retrieval-qa': <function default_response_chain at 0x1288115e0>}), RetrievalTask(name='Semi-structured Reports', dataset_id='https://smith.langchain.com/public/c47d9617-ab99-4d6e-a6e6-92b8daf85a7d/d', description=\"Questions and answers based on PDFs containing tables and charts.\\n\\nThe task provides the raw documents as well as factory methods to easily index them\\nand create a retriever.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\nWe also measure the faithfulness of the model's response relative to the retrieved documents (if any).\\n\", get_docs=<function load_docs at 0x12a852b80>, retriever_factories={'basic': <function _chroma_retriever_factory at 0x12a852c10>, 'parent-doc': <function _chroma_parent_document_retriever_factory at 0x12a852ca0>, 'hyde': <function _chroma_hyde_retriever_factory at 0x12a852d30>}, architecture_factories={}), RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})])"
]
},
"execution_count": 1,
Expand All @@ -130,7 +130,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "02ba8543-2c77-4b92-ae32-651b1699d0df",
"metadata": {
"tags": []
Expand Down Expand Up @@ -161,7 +161,7 @@
"RetrievalTask(name='Multi-modal slide decks', dataset_id='https://smith.langchain.com/public/40afc8e7-9d7e-44ed-8971-2cae1eb59731/d', description='This public dataset is a work-in-progress and will be extended over time.\\n \\nQuestions and answers based on slide decks containing visual tables and charts.\\n\\nEach example is composed of a question and reference answer.\\n\\nSuccess is measured based on the accuracy of the answer relative to the reference answer.\\n', get_docs={}, retriever_factories={}, architecture_factories={})"
]
},
"execution_count": 3,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -225,7 +225,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"id": "f0d56161-3672-4c23-9653-deefb3e340a2",
"metadata": {
"tags": []
Expand All @@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "119c6cb3-71e0-414e-b8ae-275c6b07cbef",
"metadata": {
"tags": []
Expand Down Expand Up @@ -319,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"id": "a28a9c74-835a-43b2-be25-fc26f1daca0f",
"metadata": {
"tags": []
Expand Down Expand Up @@ -447,7 +447,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 5,
"id": "288a3818-b7f2-430c-b624-0c7d5317dc52",
"metadata": {
"tags": []
Expand Down Expand Up @@ -501,42 +501,53 @@
},
{
"cell_type": "code",
"execution_count": 10,
"id": "72dca949-a255-4543-a290-f898ffb50962",
"metadata": {
"tags": []
},
"execution_count": 9,
"id": "1974c726-820e-4109-b5ef-7cf69eefd2ff",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage\n",
"from langchain_core.prompts.chat import ChatPromptTemplate\n",
"\n",
"\n",
"def image_summarize(img_base64, prompt):\n",
"def image_summarize(img_base64):\n",
" \"\"\"\n",
" Make image summary\n",
"\n",
" :param img_base64: Base64 encoded string for image\n",
" :param prompt: Text prompt for summarizatiomn\n",
" :return: Image summarization prompt\n",
"\n",
" \"\"\"\n",
" chat = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
"\n",
" msg = chat.invoke(\n",
" # Prompt\n",
" prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
" These summaries will be embedded and used to retrieve the raw image. \\\n",
" Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
"\n",
" summarization_prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": prompt},\n",
" (\n",
" \"system\",\n",
" \"You are an analyst tasked with summarizing images. \\n\"\n",
" \"You will be give an image to summarize.\\n\",\n",
" ),\n",
" (\n",
" \"human\",\n",
" [\n",
" {\"type\": \"text\", \"text\": \"{prompt}\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{img_base64}\"},\n",
" \"image_url\": \"data:image/jpeg;base64,{img}\",\n",
" },\n",
" ]\n",
" )\n",
" ],\n",
" ),\n",
" ]\n",
" )\n",
" return msg.content\n",
"\n",
" llm = ChatOpenAI(model=\"gpt-4-vision-preview\", max_tokens=1024)\n",
" chain = summarization_prompt | llm\n",
" summary = chain.invoke({\"prompt\": prompt, \"img\": img_base64})\n",
" return summary\n",
"\n",
"\n",
"def generate_img_summaries(img_base64_list):\n",
Expand All @@ -551,15 +562,10 @@
" image_summaries = []\n",
" processed_images = []\n",
"\n",
" # Prompt\n",
" prompt = \"\"\"You are an assistant tasked with summarizing images for retrieval. \\\n",
" These summaries will be embedded and used to retrieve the raw image. \\\n",
" Give a concise summary of the image that is well optimized for retrieval.\"\"\"\n",
"\n",
" # Apply summarization to images\n",
" for i, base64_image in enumerate(img_base64_list):\n",
" try:\n",
" image_summaries.append(image_summarize(base64_image, prompt))\n",
" image_summaries.append(image_summarize(base64_image))\n",
" processed_images.append(base64_image)\n",
" except:\n",
" print(f\"BadRequestError with image {i+1}\")\n",
Expand All @@ -568,7 +574,7 @@
"\n",
"\n",
"# Image summaries\n",
"image_summaries, images_base_64_processed = generate_img_summaries(images_base_64)"
"# image_summaries, images_base_64_processed = generate_img_summaries(images_base_64[0:1])"
]
},
{
Expand Down Expand Up @@ -663,7 +669,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 36,
"id": "b032d4e0-63d5-4e7f-bafc-e0550b5c2db0",
"metadata": {
"tags": []
Expand Down Expand Up @@ -696,25 +702,41 @@
" :param num_images: Number of images to include in the prompt.\n",
" :return: A list containing message objects for each image and the text prompt.\n",
" \"\"\"\n",
" messages = []\n",
" if data_dict[\"context\"][\"images\"]:\n",
" for image in data_dict[\"context\"][\"images\"][:num_images]:\n",
" image_message = {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{image}\"},\n",
" }\n",
" messages.append(image_message)\n",
" text_message = {\n",
" \"type\": \"text\",\n",
" \"text\": (\n",
" \"You are an analyst tasked with answering questions about visual content.\\n\"\n",
" \"You will be give a set of image(s) from a slide deck / presentation.\\n\"\n",
" \"Use this information to answer the user question. \\n\"\n",
" f\"User-provided question: {data_dict['question']}\\n\\n\"\n",
"\n",
" # Base template\n",
" template_messages = [\n",
" (\n",
" \"system\",\n",
" \"You are an analyst tasked with answering questions about visual content. \\n\"\n",
" \"You will be given a set of image(s) from a slide deck / presentation.\\n\",\n",
" ),\n",
" (\n",
" \"human\",\n",
" [\n",
" {\n",
" \"type\": \"text\",\n",
" \"text\": \"Answer the question using the images. Question: {question}\",\n",
" }\n",
" ],\n",
" ),\n",
" }\n",
" messages.append(text_message)\n",
" return [HumanMessage(content=messages)]\n",
" ]\n",
"\n",
" # Add images\n",
" images = data_dict[\"context\"][\"images\"]\n",
" for i in range(min(num_images, len(images))):\n",
" image_message = {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": f\"data:image/jpeg;base64,{images[i]}\"},\n",
" }\n",
" template_messages[1][1].append(image_message)\n",
"\n",
" # Format\n",
" rag_prompt = ChatPromptTemplate.from_messages(template_messages)\n",
" rag_prompt_formatted = rag_prompt.format_messages(\n",
" question=data_dict[\"question\"],\n",
" )\n",
"\n",
" return rag_prompt_formatted\n",
"\n",
"\n",
"def multi_modal_rag_chain(retriever):\n",
Expand Down Expand Up @@ -744,6 +766,27 @@
"chain_multimodal_rag_mmembd = multi_modal_rag_chain(retriever_mmembd)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "2b912808-ee59-4d56-b591-a4bd7a91b96e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"The Total Addressable Market (TAM) for Datadog's observability is projected to grow over time as follows:\\n\\n- In 2022, the TAM is $41 billion.\\n- In 2023, it increases to $45 billion.\\n- In 2024, the TAM is expected to be $51 billion.\\n- In 2025, it is projected to reach $56 billion.\\n- By 2026, the TAM is forecasted to grow to $62 billion.\\n\\nThese figures are presented in billions of dollars and indicate a steady growth in the TAM for Datadog's observability services.\""
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chain_multimodal_rag_mmembd.invoke(\"What is the TAM for Datadog over time?\")"
]
},
{
"cell_type": "markdown",
"id": "f7c5f379-317c-4a2e-9190-61f6dfbbc77d",
Expand Down