Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summary Snack on 4k Context #18

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 89 additions & 52 deletions recipes/Summarize/Summarize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"source": [
"## Select your model\n",
"\n",
"Select a Granite Code model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.\n",
"Select a Granite model from the [`ibm-granite`](https://replicate.com/ibm-granite) org on Replicate. Here we use the Replicate Langchain client to connect to the model.\n",
"\n",
"To get set up with Replicate, see [Getting Started with Replicate](https://github.com/ibm-granite-community/granite-kitchen/blob/main/recipes/Getting_Started/Getting_Started_with_Replicate.ipynb).\n",
"\n",
Expand Down Expand Up @@ -145,7 +145,7 @@
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"print(\"Your model uses the tokenizer \" + type(tokenizer).__name__)\n",
"\n",
"print(f\"Your document has {len(tokenizer(contents, return_tensors='pt')['input_ids'][0])} tokens. \")"
"print(f\"Your document has {len(tokenizer.tokenize(contents))} tokens. \")"
]
},
{
Expand All @@ -168,7 +168,7 @@
"outputs": [],
"source": [
"prompt = f\"\"\"\n",
"Summarize the following text from \"Walden\" by Henry David Thoreau:\n",
"Summarize the following text:\n",
"{contents}\n",
"\"\"\"\n",
"\n",
Expand All @@ -193,7 +193,7 @@
"source": [
"## Summary of Summaries\n",
"\n",
"Here we use an iterative summarization technique to adapt to the context length of the model."
"Here we use a hierarchical abstractive summarization technique to adapt to the context length of the model. Our approach is naïve, in that it takes equal-width chunks and groups of chunks from the document. A more sophisticated approach would be to create a document hierarchical structure that accounts for the document's structure and features, and groups text passages by topic or section. "
]
},
{
Expand All @@ -202,7 +202,9 @@
"source": [
"### Chunk the text\n",
"\n",
"Divide the full text into smaller passages for separate processing."
"Divide the full text into smaller passages for separate processing. The `chunk_size` (given in tokens) must account for the size of both the messages (input) and the completions (output). The resulting chunk size may sometimes exceed the `chunk_size` provided, so we give it additional headroom. \n",
"\n",
"The `chunk_overlap` parameter allows us to overlap chunks by a certain number of tokens, to help preserve coherence between chunks."
]
},
{
Expand All @@ -215,14 +217,16 @@
"from langchain.docstore.document import Document\n",
"\n",
"excerpt_length = 20000\n",
"doc = Document(page_content=book_contents[:excerpt_length], metadata={\"source\": \"local\"})\n",
"print(f\"The text is {len(doc.page_content)} chars\")\n",
"text = book_contents # [:excerpt_length]\n",
"print(f\"The text is {len(tokenizer.tokenize(text))} tokens.\")\n",
"\n",
"# Split the documents into chunks\n",
"chunk_char_limit = 1000\n",
"text_splitter = TokenTextSplitter.from_huggingface_tokenizer(tokenizer=tokenizer, chunk_size=chunk_char_limit, chunk_overlap=50)\n",
"chunks = text_splitter.split_documents([doc])\n",
"print(\"Chunk count: \" + str(len(chunks)))"
"chunk_token_limit = 3000 # In tokens: 3000 message + 512 completion + ~350 padding < 4000 context length\n",
"text_splitter = TokenTextSplitter.from_huggingface_tokenizer(tokenizer=tokenizer, chunk_size=chunk_token_limit, chunk_overlap=0)\n",
"chunks = text_splitter.split_text(text)\n",
"\n",
"print(\"Chunk count: \" + str(len(chunks)))\n",
"print(\"Max chunk length: \" + str(max([len(tokenizer.tokenize(chunk)) for chunk in chunks])))"
]
},
{
Expand All @@ -240,29 +244,40 @@
"metadata": {},
"outputs": [],
"source": [
"summaries = []\n",
"\n",
"for i, chunk in enumerate(chunks):\n",
" prompt = f\"\"\"\n",
" Summarize the following text from \"Walden\" by Henry David Thoreau:\n",
" {chunk}\n",
" \"\"\"\n",
" output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 10000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 200, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
" summary = f\"Summary {i+1}:\\n{output}\\n\\n\"\n",
" summaries.append(summary)\n",
" print(summary)\n",
"\n",
"print(\"Summary count: \" + str(len(summaries)))\n"
"def summarize(texts, prompt_template, min, max):\n",
" summaries = []\n",
" for i, text in enumerate(texts):\n",
" print(f\"{i + 1}. Input size: {len(tokenizer.tokenize(text))} tokens\")\n",
" prompt = prompt_template.format(text=text)\n",
" output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 2000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 200, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
" print(f\"{i + 1}. Output size: {len(tokenizer.tokenize(output))} tokens\")\n",
" summary = f\"Summary {i+1}:\\n{output}\\n\\n\"\n",
" summaries.append(summary)\n",
" print(summary)\n",
"\n",
" print(\"Summary count: \" + str(len(summaries)))\n",
" summary_contents = \"\\n\\n\".join(summaries)\n",
" print(f\"Total: {len(tokenizer.tokenize(summary_contents))} tokens\")\n",
"\n",
" return summaries\n",
"\n",
"\n",
"prompt = \"\"\"\n",
" Summarize the following text using only the information found in the text:\n",
" {text}\n",
" \"\"\"\n",
"\n",
"summaries_lvl_1 = summarize(chunks, prompt, 200, 2000)\n"
]
},
{
Expand All @@ -280,30 +295,52 @@
"metadata": {},
"outputs": [],
"source": [
"summary_contents = \"\\n\\n\".join(summaries)\n",
"print(len(summary_contents))\n",
"def group_array(arr, n):\n",
" # Calculate the size of each chunk\n",
" avg_len = len(arr) // n\n",
" remainder = len(arr) % n\n",
" result = []\n",
" start = 0\n",
"\n",
"prompt = f\"\"\"\n",
"The text of \"Walden\", by Henry David Thoreau, was summarized in separate passages; those passage summaries are provided below. \n",
" for i in range(n):\n",
" # Distribute the remainder elements across the first chunks\n",
" end = start + avg_len + (1 if i < remainder else 0)\n",
" result.append(arr[start:end])\n",
" start = end\n",
"\n",
" return result\n",
"\n",
"summary_groups = group_array(summaries_lvl_1, 4)\n",
"texts_lvl_2 = [\"\\n\\n\".join(summary_group) for summary_group in summary_groups]\n",
"\n",
"prompt = \"\"\"\n",
"A text was summarized in separate passages; those passage summaries are provided below. \n",
"\n",
"{summary_contents}\n",
"{text}\n",
"\n",
"From these summaries, compose a single lengthy, unified summary of the original text.\n",
"From these summaries alone, compose a single, unified summary of the text.\n",
"\"\"\"\n",
"\n",
"output = model.invoke(\n",
" prompt,\n",
" model_kwargs={\n",
" \"max_tokens\": 100000, # Set the maximum number of tokens to generate as output.\n",
" \"min_tokens\": 5000, # Set the minimum number of tokens to generate as output.\n",
" \"temperature\": 0.75,\n",
" \"system_prompt\": \"You are a helpful assistant.\",\n",
" \"presence_penalty\": 0,\n",
" \"frequency_penalty\": 0\n",
" }\n",
" )\n",
"summaries_lvl_2 = summarize(texts_lvl_2, prompt, 500, 1000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create the Final Summary\n",
"\n",
"print(output)"
"Generate a single summary from the passage summaries generated above."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"texts_lvl_3 = [\"\\n\\n\".join(summaries_lvl_2)]\n",
"final_summary = summarize(texts_lvl_3, prompt, 500, 1000)[0]"
]
}
],
Expand Down
Loading