diff --git a/cookbook/docugami_xml_kg_rag.ipynb b/cookbook/docugami_xml_kg_rag.ipynb index 85c53190a5387..43383a47493f6 100644 --- a/cookbook/docugami_xml_kg_rag.ipynb +++ b/cookbook/docugami_xml_kg_rag.ipynb @@ -34,12 +34,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "5740fc70-c513-4ff4-9d72-cfc098f85fef", "metadata": {}, "outputs": [], "source": [ - "! pip install langchain docugami==0.0.4 dgml-utils==0.2.0 pydantic langchainhub chromadb --upgrade --quiet" + "! pip install langchain docugami==0.0.8 dgml-utils==0.3.0 pydantic langchainhub chromadb hnswlib --upgrade --quiet" ] }, { @@ -76,98 +76,7 @@ }, { "cell_type": "code", - "execution_count": 45, - "id": "fc0767d4-9155-4591-855c-ef2e14e0e10f", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "from pathlib import Path\n", - "from pprint import pprint\n", - "from time import sleep\n", - "from typing import Dict, List\n", - "\n", - "import requests\n", - "from docugami import Docugami\n", - "from docugami.types import Document as DocugamiDocument\n", - "\n", - "api_key = os.environ.get(\"DOCUGAMI_API_KEY\")\n", - "if not api_key:\n", - " raise Exception(\"Please set Docugami API key environment variable\")\n", - "\n", - "client = Docugami()\n", - "\n", - "\n", - "def upload_files(local_paths: List[str], docset_name: str) -> List[DocugamiDocument]:\n", - " docset_list_response = client.docsets.list(name=docset_name)\n", - " if docset_list_response and docset_list_response.docsets:\n", - " # Docset already exists with this name\n", - " docset_id = docset_list_response.docsets[0]\n", - " else:\n", - " dg_docset = client.docsets.create(name=docset_name)\n", - " docset_id = dg_docset.id\n", - "\n", - " document_list_response = client.documents.list(limit=int(1e5))\n", - " dg_docs: List[DocugamiDocument] = []\n", - " if document_list_response and document_list_response.documents:\n", - " new_names = [Path(f).name for f in local_paths]\n", - "\n", - " dg_docs = [\n", - " d\n", - " for d in document_list_response.documents\n", - " if Path(d.name).name in new_names\n", - " ]\n", - " existing_names = [Path(d.name).name for d in dg_docs]\n", - "\n", - " # Upload any files not previously uploaded\n", - " for f in local_paths:\n", - " if Path(f).name not in existing_names:\n", - " dg_docs.append(\n", - " client.documents.contents.upload(\n", - " file=Path(f).absolute(),\n", - " docset_id=docset_id,\n", - " )\n", - " )\n", - " return dg_docs\n", - "\n", - "\n", - "def wait_for_xml(dg_docs: List[DocugamiDocument]) -> dict[str, str]:\n", - " dgml_paths: dict[str, str] = {}\n", - " while len(dgml_paths) < len(dg_docs):\n", - " for doc in dg_docs:\n", - " doc = client.documents.retrieve(doc.id) # update with latest\n", - " current_status = doc.status\n", - " if current_status == \"Error\":\n", - " raise Exception(\n", - " \"Document could not be processed, please confirm it is not a zero length, corrupt or password protected file\"\n", - " )\n", - " elif current_status == \"Ready\":\n", - " dgml_url = doc.docset.url + f\"/documents/{doc.id}/dgml\"\n", - " headers = {\"Authorization\": f\"Bearer {api_key}\"}\n", - " dgml_response = requests.get(dgml_url, headers=headers)\n", - " if not dgml_response.ok:\n", - " raise Exception(\n", - " f\"Could not download DGML artifact {dgml_url}: {dgml_response.status_code}\"\n", - " )\n", - " dgml_contents = dgml_response.text\n", - " with tempfile.NamedTemporaryFile(delete=False, mode=\"w\") as temp_file:\n", - " temp_file.write(dgml_contents)\n", - " temp_file_path = temp_file.name\n", - " dgml_paths[doc.name] = temp_file_path\n", - "\n", - " print(f\"{len(dgml_paths)} docs done processing out of {len(dg_docs)}...\")\n", - "\n", - " if len(dgml_paths) == len(dg_docs):\n", - " # done\n", - " return dgml_paths\n", - " else:\n", - " sleep(30) # try again in a bit" - ] - }, - { - "cell_type": "code", - "execution_count": 46, + "execution_count": 3, "id": "ce0b2b21-7623-46e7-ae2c-3a9f67e8b9b9", "metadata": {}, "outputs": [ @@ -175,18 +84,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "6 docs done processing out of 6...\n", - "{'Report_CEN23LA277_192541.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmpel3o0rpg',\n", - " 'Report_CEN23LA338_192753.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmpgugb9ut1',\n", - " 'Report_CEN23LA363_192876.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmp3_gf2sky',\n", - " 'Report_CEN23LA394_192995.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmpwmfgoxkl',\n", - " 'Report_ERA23LA114_106615.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmptibrz2yu',\n", - " 'Report_WPR23LA254_192532.pdf': '/var/folders/0h/6cchx4k528bdj8cfcsdm0dqr0000gn/T/tmpvazrbbsi'}\n" + "{'Report_CEN23LA277_192541.pdf': '/tmp/tmpa0c77x46',\n", + " 'Report_CEN23LA338_192753.pdf': '/tmp/tmpaftfld2w',\n", + " 'Report_CEN23LA363_192876.pdf': '/tmp/tmpn7gp6be2',\n", + " 'Report_CEN23LA394_192995.pdf': '/tmp/tmp9udymprf',\n", + " 'Report_ERA23LA114_106615.pdf': '/tmp/tmpxdjbh4r_',\n", + " 'Report_WPR23LA254_192532.pdf': '/tmp/tmpz6h75a0h'}\n" ] } ], "source": [ - "#### START DOCSET INFO (please change)\n", + "from pprint import pprint\n", + "\n", + "from docugami import Docugami\n", + "from docugami.lib.upload import upload_to_named_docset, wait_for_dgml\n", + "\n", + "#### START DOCSET INFO (please change this values as needed)\n", "DOCSET_NAME = \"NTSB Aviation Incident Reports\"\n", "FILE_PATHS = [\n", " \"/Users/tjaffri/ntsb/Report_CEN23LA277_192541.pdf\",\n", @@ -197,13 +110,15 @@ " \"/Users/tjaffri/ntsb/Report_WPR23LA254_192532.pdf\",\n", "]\n", "\n", - "assert (\n", - " len(FILE_PATHS) > 5\n", - ") # Please specify ~6 (or more!) similar files to process together as a document set\n", + "# Note: Please specify ~6 (or more!) similar files to process together as a document set\n", + "# This is currently a requirement for Docugami to automatically detect motifs\n", + "# across the document set to generate a semantic XML Knowledge Graph.\n", + "assert len(FILE_PATHS) > 5, \"Please provide at least 6 files\"\n", "#### END DOCSET INFO\n", "\n", - "dg_docs = upload_files(FILE_PATHS, DOCSET_NAME)\n", - "dgml_paths = wait_for_xml(dg_docs)\n", + "dg_client = Docugami()\n", + "dg_docs = upload_to_named_docset(dg_client, FILE_PATHS, DOCSET_NAME)\n", + "dgml_paths = wait_for_dgml(dg_client, dg_docs)\n", "\n", "pprint(dgml_paths)" ] @@ -228,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 4, "id": "05fcdd57-090f-44bf-a1fb-2c3609c80e34", "metadata": {}, "outputs": [ @@ -237,13 +152,13 @@ "output_type": "stream", "text": [ "found 30 chunks, here are the first few\n", - "Aviation Investigation Final Report\n", - "
Location: Elbert, Colorado Accident Number: CEN23LA277
Date & Time: June 26, 2023, 11:00 Local Registration: N23161
Aircraft: Piper J3C-50 Aircraft Damage: Substantial
Defining Event: Nose over/nose down Injuries: 1 Minor
Flight Conducted Under: Part 91: General aviation - Personal
\n", + "Aviation Investigation Final Report\n", + "
Location: Elbert, Colorado Accident Number: CEN23LA277
Date & Time: June 26, 2023, 11:00 Local Registration: N23161
Aircraft: Piper J3C-50 Aircraft Damage: Substantial
Defining Event: Nose over/nose down Injuries: 1 Minor
Flight Conducted Under: Part 91: General aviation - Personal
\n", "Analysis\n", - " The pilot reported that, as the tail lifted during takeoff, the airplane veered left. He attempted to correct with full right rudder and full brakes. However, the airplane subsequently nosed over resulting in substantial damage to the fuselage, lift struts, rudder, and vertical stabilizer. \n", + " The pilot reported that, as the tail lifted during takeoff, the airplane veered left. He attempted to correct with full right rudder and full brakes. However, the airplane subsequently nosed over resulting in substantial damage to the fuselage, lift struts, rudder, and vertical stabilizer. \n", " The pilot reported that there were no preaccident mechanical malfunctions or anomalies with the airplane that would have precluded normal operation. \n", " At about the time of the accident, wind was from 180° at 5 knots. The pilot decided to depart on runway 35 due to the prevailing airport traffic. He stated that departing with “more favorable wind conditions” may have prevented the accident. \n", - "Probable Cause and Findings\n", + "Probable Cause and Findings \n", " The National Transportation Safety Board determines the probable cause(s) of this accident to be: \n", " The pilot's loss of directional control during takeoff and subsequent excessive use of brakes which resulted in a nose-over. Contributing to the accident was his decision to takeoff downwind. \n", "Page 1 of 5 \n" @@ -251,6 +166,8 @@ } ], "source": [ + "from pathlib import Path\n", + "\n", "from dgml_utils.segmentation import get_chunks_str\n", "\n", "# Here we just read the first file, you can do the same for others\n", @@ -283,7 +200,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 5, "id": "8a4b49e0-de78-4790-a930-ad7cf324697a", "metadata": {}, "outputs": [ @@ -343,7 +260,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 6, "id": "7b697d30-1e94-47f0-87e8-f81d4b180da2", "metadata": {}, "outputs": [ @@ -353,12 +270,14 @@ "39" ] }, - "execution_count": 109, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "import requests\n", + "\n", "# Download XML from known URL\n", "dgml = requests.get(\n", " \"https://raw.githubusercontent.com/docugami/dgml-utils/main/python/tests/test_data/article/Jane%20Doe.xml\"\n", @@ -369,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 7, "id": "14714576-6e1d-499b-bcc8-39140bb2fd78", "metadata": {}, "outputs": [ @@ -379,7 +298,7 @@ "{'h1': 9, 'div': 12, 'p': 3, 'lim h1': 9, 'lim': 1, 'table': 1, 'h1 div': 4}" ] }, - "execution_count": 98, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -400,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 8, "id": "5462f29e-fd59-4e0e-9493-ea3b560e523e", "metadata": {}, "outputs": [ @@ -433,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 9, "id": "2b4ece00-2e43-4254-adc9-66dbb79139a6", "metadata": {}, "outputs": [ @@ -471,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 10, "id": "08350119-aa22-4ec1-8f65-b1316a0d4123", "metadata": {}, "outputs": [ @@ -499,7 +418,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 11, "id": "bcac8294-c54a-4b6e-af9d-3911a69620b2", "metadata": {}, "outputs": [ @@ -546,7 +465,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 12, "id": "8e275736-3408-4d7a-990e-4362c88e81f8", "metadata": {}, "outputs": [], @@ -577,7 +496,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 13, "id": "1b12536a-1303-41ad-9948-4eb5a5f32614", "metadata": {}, "outputs": [], @@ -594,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 14, "id": "8d8b567c-b442-4bf0-b639-04bd89effc62", "metadata": {}, "outputs": [], @@ -619,7 +538,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 17, "id": "346c3a02-8fea-4f75-a69e-fc9542b99dbc", "metadata": {}, "outputs": [], @@ -681,7 +600,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 18, "id": "f2489de4-51e3-48b4-bbcd-ed9171deadf3", "metadata": {}, "outputs": [], @@ -725,10 +644,17 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 19, "id": "636e992f-823b-496b-a082-8b4fcd479de5", "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Number of requested results 4 is greater than number of elements in index 1, updating n_results = 1\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -770,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 20, "id": "0e4a2f43-dd48-4ae3-8e27-7e87d169965f", "metadata": {}, "outputs": [ @@ -780,7 +706,7 @@ "669" ] }, - "execution_count": 121, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -795,7 +721,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 21, "id": "56b78fb3-603d-4343-ae72-be54a3c5dd72", "metadata": {}, "outputs": [ @@ -820,7 +746,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 22, "id": "d3cc5ba9-8553-4eda-a5d1-b799751186af", "metadata": {}, "outputs": [], @@ -832,7 +758,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 23, "id": "d7c73faf-74cb-400d-8059-b69e2493de38", "metadata": {}, "outputs": [], @@ -844,7 +770,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 24, "id": "4c553722-be42-42ce-83b8-76a17f323f1c", "metadata": {}, "outputs": [], @@ -854,7 +780,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 25, "id": "65dce40b-f1c3-494a-949e-69a9c9544ddb", "metadata": {}, "outputs": [ @@ -864,7 +790,7 @@ "'The number of training tokens for LLaMA2 is 2.0T for all parameter sizes.'" ] }, - "execution_count": 128, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -959,14 +885,37 @@ " \n", " \n", "\n", - "``" + "```" ] }, { "cell_type": "markdown", - "id": "0879349e-7298-4f2c-b246-f1142e97a8e5", + "id": "867f8e11-384c-4aa1-8b3e-c59fb8d5fd7d", + "metadata": {}, + "source": [ + "Finally, you can ask other questions that rely on more subtle parsing of the table, e.g.:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d38f1459-7d2b-40df-8dcd-e747f85eb144", "metadata": {}, - "source": [] + "outputs": [ + { + "data": { + "text/plain": [ + "'The learning rate for LLaMA2 was 3.0 × 10−4 for the 7B and 13B models, and 1.5 × 10−4 for the 34B and 70B models.'" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llama2_chain.invoke(\"What was the learning rate for LLaMA2?\")" + ] } ], "metadata": {