Skip to content

Commit

Permalink
Summary Snack on 4k Context (#18)
Browse files Browse the repository at this point in the history
* Ensure model receives inputs less than 4k tokens

---------

Signed-off-by: Fayvor Love <[email protected]>
  • Loading branch information
fayvor authored and adampingel committed Oct 18, 2024
1 parent fdeeaf2 commit cf68a20
Showing 1 changed file with 89 additions and 52 deletions.
141 changes: 89 additions & 52 deletions recipes/Summarize/Summarize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"source": [
"## Select your model\n",
"\n",
"Select a Granite Code model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.\n",
"Select a Granite model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.\n",
"\n",
"To get set up with Replicate, see [Getting Started with Replicate](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Getting_Started/Getting_Started_with_Replicate.ipynb).\n",
"\n",
Expand Down Expand Up @@ -145,7 +145,7 @@
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"print(\"Your model uses the tokenizer \" + type(tokenizer).__name__)\n",
"\n",
"print(f\"Your document has {len(tokenizer(contents, return_tensors='pt')['input_ids'][0])} tokens. \")"
"print(f\"Your document has {len(tokenizer.tokenize(contents))} tokens. \")"
]
},
{
Expand All @@ -168,7 +168,7 @@
"outputs": [],
"source": [
"prompt = f\"\"\"\n",
"Summarize the following text from \"Walden\" by Henry David Thoreau:\n",
"Summarize the following text:\n",
"{contents}\n",
"\"\"\"\n",
"\n",
Expand All @@ -193,7 +193,7 @@
"source": [
"## Summary of Summaries\n",
"\n",
"Here we use an iterative summarization technique to adapt to the context length of the model."
"Here we use a hierarchical abstractive summarization technique to adapt to the context length of the model. Our approach is naïve, in that it takes equal-width chunks and groups of chunks from the document. A more sophisticated approach would be to create a document hierarchical structure that accounts for the document's structure and features, and groups text passages by topic or section. "
]
},
{
Expand All @@ -202,7 +202,9 @@
"source": [
"### Chunk the text\n",
"\n",
"Divide the full text into smaller passages for separate processing."
"Divide the full text into smaller passages for separate processing. The `chunk_size` (given in tokens) must account for the size of both the messages (input) and the completions (output). The resulting chunk size may sometimes exceed the `chunk_size` provided, so we give it additional headroom. \n",
"\n",
"The `chunk_overlap` parameter allows us to overlap chunks by a certain number of tokens, to help preserve coherence between chunks."
]
},
{
Expand All @@ -215,14 +217,16 @@
"from langchain.docstore.document import Document\n",
"\n",
"excerpt_length = 20000\n",
"doc = Document(page_content=book_contents[:excerpt_length], metadata={\"source\": \"local\"})\n",
"print(f\"The text is {len(doc.page_content)} chars\")\n",
"text = book_contents # [:excerpt_length]\n",
"print(f\"The text is {len(tokenizer.tokenize(text))} tokens.\")\n",
"\n",
"# Split the documents into chunks\n",
"chunk_char_limit = 1000\n",
"text_splitter = TokenTextSplitter.from_huggingface_tokenizer(tokenizer=tokenizer, chunk_size=chunk_char_limit, chunk_overlap=50)\n",
"chunks = text_splitter.split_documents([doc])\n",
"print(\"Chunk count: \" + str(len(chunks)))"
"chunk_token_limit = 3000 # In tokens: 3000 message + 512 completion + ~350 padding < 4000 context length\n",
"text_splitter = TokenTextSplitter.from_huggingface_tokenizer(tokenizer=tokenizer, chunk_size=chunk_token_limit, chunk_overlap=0)\n",
"chunks = text_splitter.split_text(text)\n",
"\n",
"print(\"Chunk count: \" + str(len(chunks)))\n",
"print(\"Max chunk length: \" + str(max([len(tokenizer.tokenize(chunk)) for chunk in chunks])))"
]
},
{
Expand All @@ -240,29 +244,40 @@
"metadata": {},
"outputs": [],
"source": [
"summaries = []\n",
"\n",
"for i, chunk in enumerate(chunks):\n",
" prompt = f\"\"\"\n",
" Summarize the following text from \"Walden\" by Henry David Thoreau:\n",
" {chunk}\n",
" \"\"\"\n",
" output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 10000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 200, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
" summary = f\"Summary {i+1}:\\n{output}\\n\\n\"\n",
" summaries.append(summary)\n",
" print(summary)\n",
"\n",
"print(\"Summary count: \" + str(len(summaries)))\n"
"def summarize(texts, prompt_template, min, max):\n",
" summaries = []\n",
" for i, text in enumerate(texts):\n",
" print(f\"{i + 1}. Input size: {len(tokenizer.tokenize(text))} tokens\")\n",
" prompt = prompt_template.format(text=text)\n",
" output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 2000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 200, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
" print(f\"{i + 1}. Output size: {len(tokenizer.tokenize(output))} tokens\")\n",
" summary = f\"Summary {i+1}:\\n{output}\\n\\n\"\n",
" summaries.append(summary)\n",
" print(summary)\n",
"\n",
" print(\"Summary count: \" + str(len(summaries)))\n",
" summary_contents = \"\\n\\n\".join(summaries)\n",
" print(f\"Total: {len(tokenizer.tokenize(summary_contents))} tokens\")\n",
"\n",
" return summaries\n",
"\n",
"\n",
"prompt = \"\"\"\n",
" Summarize the following text using only the information found in the text:\n",
" {text}\n",
" \"\"\"\n",
"\n",
"summaries_lvl_1 = summarize(chunks, prompt, 200, 2000)\n"
]
},
{
Expand All @@ -280,30 +295,52 @@
"metadata": {},
"outputs": [],
"source": [
"summary_contents = \"\\n\\n\".join(summaries)\n",
"print(len(summary_contents))\n",
"def group_array(arr, n):\n",
" # Calculate the size of each chunk\n",
" avg_len = len(arr) // n\n",
" remainder = len(arr) % n\n",
" result = []\n",
" start = 0\n",
"\n",
"prompt = f\"\"\"\n",
"The text of \"Walden\", by Henry David Thoreau, was summarized in separate passages; those passage summaries are provided below. \n",
" for i in range(n):\n",
" # Distribute the remainder elements across the first chunks\n",
" end = start + avg_len + (1 if i < remainder else 0)\n",
" result.append(arr[start:end])\n",
" start = end\n",
"\n",
" return result\n",
"\n",
"summary_groups = group_array(summaries_lvl_1, 4)\n",
"texts_lvl_2 = [\"\\n\\n\".join(summary_group) for summary_group in summary_groups]\n",
"\n",
"prompt = \"\"\"\n",
"A text was summarized in separate passages; those passage summaries are provided below. \n",
"\n",
"{summary_contents}\n",
"{text}\n",
"\n",
"From these summaries, compose a single lengthy, unified summary of the original text.\n",
"From these summaries alone, compose a single, unified summary of the text.\n",
"\"\"\"\n",
"\n",
"output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 100000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 5000, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
"summaries_lvl_2 = summarize(texts_lvl_2, prompt, 500, 1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the Final Summary\n",
"\n",
"print(output)"
"Generate a single summary from the passage summaries generated above."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"texts_lvl_3 = [\"\\n\\n\".join(summaries_lvl_2)]\n",
"final_summary = summarize(texts_lvl_3, prompt, 500, 1000)[0]"
]
}
],
Expand Down

0 comments on commit cf68a20

Please sign in to comment.