From c004d7b9f7765755dc10e5e04b039be94e877781 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Thu, 31 Aug 2023 15:29:27 +0200 Subject: [PATCH 01/27] llama-2-chat integration [wip] --- docs/extras/integrations/chat/llama_2.ipynb | 309 ++++++++++++++++++ .../langchain/chat_models/__init__.py | 2 + .../langchain/chat_models/llama_2.py | 96 ++++++ 3 files changed, 407 insertions(+) create mode 100644 docs/extras/integrations/chat/llama_2.ipynb create mode 100644 libs/langchain/langchain/chat_models/llama_2.py diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb new file mode 100644 index 0000000000000..757091031c8bc --- /dev/null +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -0,0 +1,309 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO:\n", + "# temporary workaround:\n", + "# remove cell before PR\n", + "import os\n", + "import sys\n", + "os.getcwd()\n", + "os.chdir(\"/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/\")\n", + "os.getcwd()\n", + "\n", + "# sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/')\n", + "sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/')\n" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "# HF imports\n", + "import torch\n", + "from transformers import (\n", + " AutoModelForCausalLM,\n", + " AutoTokenizer,\n", + " BitsAndBytesConfig,\n", + " HfArgumentParser,\n", + " pipeline,\n", + ")\n", + "\n", + "# LangChain imports\n", + "from langchain.chat_models import ChatLlama2\n", + "from langchain.prompts.chat import (\n", + " ChatPromptTemplate,\n", + " SystemMessagePromptTemplate,\n", + " AIMessagePromptTemplate,\n", + " HumanMessagePromptTemplate,\n", + ")\n", + "from langchain.schema import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "env: HUGGINGFACE_TOKEN=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\n" + ] + } + ], + "source": [ + "%env HUGGINGFACE_TOKEN=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" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n", + "Token is valid (permission: read).\n", + "Your token has been saved to /home/eryk.mazus/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ], + "source": [ + "!huggingface-cli login --token $HUGGINGFACE_TOKEN" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"meta-llama/Llama-2-7b-chat-hf\"" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [], + "source": [ + "bnb_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_quant_type=\"nf4\",\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.30it/s]\n" + ] + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "model_4bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map=\"auto\")" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [], + "source": [ + "pipe = pipeline(\n", + " \"text-generation\",\n", + " model=model_4bit,\n", + " tokenizer=tokenizer,\n", + " torch_dtype=torch.float16,\n", + " device_map=\"auto\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'meta-llama/Llama-2-7b-chat-hf'" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pipe.model.name_or_path" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [], + "source": [ + "# pipeline generation kwargs\n", + "pipeline_kwargs = {\n", + " \"do_sample\": True,\n", + " \"top_p\": 0.95,\n", + " \"temperature\": 0.7,\n", + " \"eos_token_id\": tokenizer.eos_token_id,\n", + " \"max_length\": 200, \n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "chat = ChatLlama2(pipeline=pipe)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'ChatLlama2' object has no attribute '_pipeline'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[56], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m chat\u001b[39m.\u001b[39;49mpipeline\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:32\u001b[0m, in \u001b[0;36mChatLlama2.pipeline\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mpipeline\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m 31\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Getter for the pipeline.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_pipeline\n", + "\u001b[0;31mAttributeError\u001b[0m: 'ChatLlama2' object has no attribute '_pipeline'" + ] + } + ], + "source": [ + "chat.pipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[INST] <>\n", + "You are a helpful assistant that translates English to French.\n", + "<>\n", + "\n", + "Translate this sentence from English to French. I love programming. [/INST] \n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'ChatLlama2' object has no attribute '_pipeline'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[52], line 9\u001b[0m\n\u001b[1;32m 1\u001b[0m messages \u001b[39m=\u001b[39m [\n\u001b[1;32m 2\u001b[0m SystemMessage(\n\u001b[1;32m 3\u001b[0m content\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are a helpful assistant that translates English to French.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7\u001b[0m ),\n\u001b[1;32m 8\u001b[0m ]\n\u001b[0;32m----> 9\u001b[0m chat(messages)\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:551\u001b[0m, in \u001b[0;36mBaseChatModel.__call__\u001b[0;34m(self, messages, stop, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 544\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 545\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 546\u001b[0m messages: List[BaseMessage],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any,\n\u001b[1;32m 550\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m BaseMessage:\n\u001b[0;32m--> 551\u001b[0m generation \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgenerate(\n\u001b[1;32m 552\u001b[0m [messages], stop\u001b[39m=\u001b[39;49mstop, callbacks\u001b[39m=\u001b[39;49mcallbacks, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 553\u001b[0m )\u001b[39m.\u001b[39mgenerations[\u001b[39m0\u001b[39m][\u001b[39m0\u001b[39m]\n\u001b[1;32m 554\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(generation, ChatGeneration):\n\u001b[1;32m 555\u001b[0m \u001b[39mreturn\u001b[39;00m generation\u001b[39m.\u001b[39mmessage\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:309\u001b[0m, in \u001b[0;36mBaseChatModel.generate\u001b[0;34m(self, messages, stop, callbacks, tags, metadata, **kwargs)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mif\u001b[39;00m run_managers:\n\u001b[1;32m 308\u001b[0m run_managers[i]\u001b[39m.\u001b[39mon_llm_error(e)\n\u001b[0;32m--> 309\u001b[0m \u001b[39mraise\u001b[39;00m e\n\u001b[1;32m 310\u001b[0m flattened_outputs \u001b[39m=\u001b[39m [\n\u001b[1;32m 311\u001b[0m LLMResult(generations\u001b[39m=\u001b[39m[res\u001b[39m.\u001b[39mgenerations], llm_output\u001b[39m=\u001b[39mres\u001b[39m.\u001b[39mllm_output)\n\u001b[1;32m 312\u001b[0m \u001b[39mfor\u001b[39;00m res \u001b[39min\u001b[39;00m results\n\u001b[1;32m 313\u001b[0m ]\n\u001b[1;32m 314\u001b[0m llm_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_combine_llm_outputs([res\u001b[39m.\u001b[39mllm_output \u001b[39mfor\u001b[39;00m res \u001b[39min\u001b[39;00m results])\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:299\u001b[0m, in \u001b[0;36mBaseChatModel.generate\u001b[0;34m(self, messages, stop, callbacks, tags, metadata, **kwargs)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[39mfor\u001b[39;00m i, m \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(messages):\n\u001b[1;32m 297\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 298\u001b[0m results\u001b[39m.\u001b[39mappend(\n\u001b[0;32m--> 299\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_generate_with_cache(\n\u001b[1;32m 300\u001b[0m m,\n\u001b[1;32m 301\u001b[0m stop\u001b[39m=\u001b[39;49mstop,\n\u001b[1;32m 302\u001b[0m run_manager\u001b[39m=\u001b[39;49mrun_managers[i] \u001b[39mif\u001b[39;49;00m run_managers \u001b[39melse\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 303\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 304\u001b[0m )\n\u001b[1;32m 305\u001b[0m )\n\u001b[1;32m 306\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 307\u001b[0m \u001b[39mif\u001b[39;00m run_managers:\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:446\u001b[0m, in \u001b[0;36mBaseChatModel._generate_with_cache\u001b[0;34m(self, messages, stop, run_manager, **kwargs)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 443\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mAsked to cache, but no cache found at `langchain.cache`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 444\u001b[0m )\n\u001b[1;32m 445\u001b[0m \u001b[39mif\u001b[39;00m new_arg_supported:\n\u001b[0;32m--> 446\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_generate(\n\u001b[1;32m 447\u001b[0m messages, stop\u001b[39m=\u001b[39;49mstop, run_manager\u001b[39m=\u001b[39;49mrun_manager, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 448\u001b[0m )\n\u001b[1;32m 449\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 450\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_generate(messages, stop\u001b[39m=\u001b[39mstop, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:86\u001b[0m, in \u001b[0;36mChatLlama2._generate\u001b[0;34m(self, messages, stop, run_manager, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mreturn_full_text\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 84\u001b[0m \u001b[39m# num_return_sequences ? ~ is it possible to pass multiple conversations ?\u001b[39;00m\n\u001b[0;32m---> 86\u001b[0m response \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpipeline(prompt, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpipeline_params)[\u001b[39m\"\u001b[39m\u001b[39mgenerated_text\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 87\u001b[0m \u001b[39mprint\u001b[39m(response)\n\u001b[1;32m 88\u001b[0m \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:32\u001b[0m, in \u001b[0;36mChatLlama2.pipeline\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mpipeline\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m 31\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Getter for the pipeline.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_pipeline\n", + "\u001b[0;31mAttributeError\u001b[0m: 'ChatLlama2' object has no attribute '_pipeline'" + ] + } + ], + "source": [ + "messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful assistant that translates English to French.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"Translate this sentence from English to French. I love programming.\"\n", + " ),\n", + "]\n", + "chat(messages)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.12 ('langchain_venv': venv)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "3372ef96e068313d34c91eab0f20d815c93d37110de821968e5d598f73bfb74c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index ee21a2377eb4d..c50fa87d1c7c1 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -31,6 +31,7 @@ from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.vertexai import ChatVertexAI +from langchain.chat_models.llama_2 import ChatLlama2 __all__ = [ "ChatOpenAI", @@ -41,6 +42,7 @@ "ChatGooglePalm", "ChatMLflowAIGateway", "ChatOllama", + "ChatLlama2", "ChatVertexAI", "JinaChat", "HumanInputChatModel", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py new file mode 100644 index 0000000000000..55fb5d181d42a --- /dev/null +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -0,0 +1,96 @@ +from typing import List, Any, Optional + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import ChatResult +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>", "<>" + + +class ChatLlama2(BaseChatModel): + _pipeline: Any #: :meta private: + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "llama-2-chat-hf" + + @property + def pipeline(self) -> Any: + """Getter for the pipeline.""" + return self._pipeline + + @pipeline.setter + def pipeline(self, value: Any): + """Setter for the pipeline.""" + if not hasattr(value, "task") or value.task != "text-generation": + raise ValueError("The pipeline task should be 'text-generation'.") + + valid_models = ( + "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-13b-chat-hf", + "meta-llama/Llama-2-70b-chat-hf", + ) + + if not hasattr(value, "model") or value.model.name_or_path not in valid_models: + raise ValueError(f"The pipeline model name or path should be one of {valid_models}.") + + self._pipeline = value + + def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: + """ https://huggingface.co/blog/llama2 """ + prompt = "" + + for i, message in enumerate(messages): + if i != 0 and isinstance(message, SystemMessage): + raise ... + elif i == 0 and isinstance(message, SystemMessage): + prompt += f"{B_INST} {B_SYS}\n{message.content}\n{E_SYS}\n\n" + elif isinstance(message, HumanMessage) and i > 0: + prompt += f"{message.content} {E_INST} " + elif i == 0 and isinstance(message, HumanMessage): + prompt += f"{B_INST} {message.content} {E_INST} " + elif isinstance(message, AIMessage): + prompt += f"{message.content} {B_INST} " + + return prompt + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + prompt = self._format_messages_as_text(messages) + # TODO: remove + print(prompt) + + pipeline_params = kwargs + # ensure that: + kwargs["return_text"] = True + kwargs["return_full_text"] = False + # num_return_sequences ? ~ is it possible to pass multiple conversations ? + + response = self.pipeline(prompt, **pipeline_params)["generated_text"] + print(response) + ... + return response + +# TODO: +# fix problem with getter +# correct output from _generate +# try to add stopping criteria +# handle batch requests +# handle ChatMessage, AIMessageChunk ? From 3337d6a4f40b5a605d66c98488abdb16e4324505 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Thu, 31 Aug 2023 17:51:17 +0200 Subject: [PATCH 02/27] fixing problem with private pipeline --- docs/extras/integrations/chat/llama_2.ipynb | 111 +++++++----------- .../langchain/chat_models/llama_2.py | 29 ++--- 2 files changed, 55 insertions(+), 85 deletions(-) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index 757091031c8bc..9be8e3afb761b 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 40, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -21,18 +21,9 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" @@ -40,9 +31,18 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "# HF imports\n", "import torch\n", @@ -67,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -84,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -104,7 +104,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -127,14 +127,14 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.30it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.29it/s]\n" ] } ], @@ -145,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -169,7 +169,7 @@ "'meta-llama/Llama-2-7b-chat-hf'" ] }, - "execution_count": 49, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -180,7 +180,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -196,69 +196,42 @@ }, { "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "chat = ChatLlama2(pipeline=pipe)" - ] - }, - { - "cell_type": "code", - "execution_count": 56, + "execution_count": 12, "metadata": {}, "outputs": [ { "ename": "AttributeError", - "evalue": "'ChatLlama2' object has no attribute '_pipeline'", + "evalue": "'NoneType' object has no attribute 'get'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[56], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m chat\u001b[39m.\u001b[39;49mpipeline\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:32\u001b[0m, in \u001b[0;36mChatLlama2.pipeline\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mpipeline\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m 31\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Getter for the pipeline.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_pipeline\n", - "\u001b[0;31mAttributeError\u001b[0m: 'ChatLlama2' object has no attribute '_pipeline'" + "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m chat \u001b[39m=\u001b[39m ChatLlama2(pipeline\u001b[39m=\u001b[39;49mpipe)\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/load/serializable.py:74\u001b[0m, in \u001b[0;36mSerializable.__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m---> 74\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 75\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_lc_kwargs \u001b[39m=\u001b[39m kwargs\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/pydantic/v1/main.py:339\u001b[0m, in \u001b[0;36mBaseModel.__init__\u001b[0;34m(__pydantic_self__, **data)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[39mCreate a new model by parsing and validating input data from keyword arguments.\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[39mRaises ValidationError if the input data cannot be parsed to form a valid model.\u001b[39;00m\n\u001b[1;32m 337\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 338\u001b[0m \u001b[39m# Uses something other than `self` the first arg to allow \"self\" as a settable attribute\u001b[39;00m\n\u001b[0;32m--> 339\u001b[0m values, fields_set, validation_error \u001b[39m=\u001b[39m validate_model(__pydantic_self__\u001b[39m.\u001b[39;49m\u001b[39m__class__\u001b[39;49m, data)\n\u001b[1;32m 340\u001b[0m \u001b[39mif\u001b[39;00m validation_error:\n\u001b[1;32m 341\u001b[0m \u001b[39mraise\u001b[39;00m validation_error\n", + "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/pydantic/v1/main.py:1055\u001b[0m, in \u001b[0;36mvalidate_model\u001b[0;34m(model, input_data, cls)\u001b[0m\n\u001b[1;32m 1052\u001b[0m \u001b[39mreturn\u001b[39;00m {}, \u001b[39mset\u001b[39m(), ValidationError([ErrorWrapper(exc, loc\u001b[39m=\u001b[39mROOT_KEY)], cls_)\n\u001b[1;32m 1054\u001b[0m \u001b[39mfor\u001b[39;00m name, field \u001b[39min\u001b[39;00m model\u001b[39m.\u001b[39m__fields__\u001b[39m.\u001b[39mitems():\n\u001b[0;32m-> 1055\u001b[0m value \u001b[39m=\u001b[39m input_data\u001b[39m.\u001b[39;49mget(field\u001b[39m.\u001b[39malias, _missing)\n\u001b[1;32m 1056\u001b[0m using_name \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 1057\u001b[0m \u001b[39mif\u001b[39;00m value \u001b[39mis\u001b[39;00m _missing \u001b[39mand\u001b[39;00m config\u001b[39m.\u001b[39mallow_population_by_field_name \u001b[39mand\u001b[39;00m field\u001b[39m.\u001b[39malt_alias:\n", + "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'get'" ] } ], + "source": [ + "chat = ChatLlama2(pipeline=pipe)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "chat.pipeline" ] }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[INST] <>\n", - "You are a helpful assistant that translates English to French.\n", - "<>\n", - "\n", - "Translate this sentence from English to French. I love programming. [/INST] \n" - ] - }, - { - "ename": "AttributeError", - "evalue": "'ChatLlama2' object has no attribute '_pipeline'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[52], line 9\u001b[0m\n\u001b[1;32m 1\u001b[0m messages \u001b[39m=\u001b[39m [\n\u001b[1;32m 2\u001b[0m SystemMessage(\n\u001b[1;32m 3\u001b[0m content\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou are a helpful assistant that translates English to French.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 7\u001b[0m ),\n\u001b[1;32m 8\u001b[0m ]\n\u001b[0;32m----> 9\u001b[0m chat(messages)\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:551\u001b[0m, in \u001b[0;36mBaseChatModel.__call__\u001b[0;34m(self, messages, stop, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 544\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\n\u001b[1;32m 545\u001b[0m \u001b[39mself\u001b[39m,\n\u001b[1;32m 546\u001b[0m messages: List[BaseMessage],\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any,\n\u001b[1;32m 550\u001b[0m ) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m BaseMessage:\n\u001b[0;32m--> 551\u001b[0m generation \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mgenerate(\n\u001b[1;32m 552\u001b[0m [messages], stop\u001b[39m=\u001b[39;49mstop, callbacks\u001b[39m=\u001b[39;49mcallbacks, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 553\u001b[0m )\u001b[39m.\u001b[39mgenerations[\u001b[39m0\u001b[39m][\u001b[39m0\u001b[39m]\n\u001b[1;32m 554\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(generation, ChatGeneration):\n\u001b[1;32m 555\u001b[0m \u001b[39mreturn\u001b[39;00m generation\u001b[39m.\u001b[39mmessage\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:309\u001b[0m, in \u001b[0;36mBaseChatModel.generate\u001b[0;34m(self, messages, stop, callbacks, tags, metadata, **kwargs)\u001b[0m\n\u001b[1;32m 307\u001b[0m \u001b[39mif\u001b[39;00m run_managers:\n\u001b[1;32m 308\u001b[0m run_managers[i]\u001b[39m.\u001b[39mon_llm_error(e)\n\u001b[0;32m--> 309\u001b[0m \u001b[39mraise\u001b[39;00m e\n\u001b[1;32m 310\u001b[0m flattened_outputs \u001b[39m=\u001b[39m [\n\u001b[1;32m 311\u001b[0m LLMResult(generations\u001b[39m=\u001b[39m[res\u001b[39m.\u001b[39mgenerations], llm_output\u001b[39m=\u001b[39mres\u001b[39m.\u001b[39mllm_output)\n\u001b[1;32m 312\u001b[0m \u001b[39mfor\u001b[39;00m res \u001b[39min\u001b[39;00m results\n\u001b[1;32m 313\u001b[0m ]\n\u001b[1;32m 314\u001b[0m llm_output \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_combine_llm_outputs([res\u001b[39m.\u001b[39mllm_output \u001b[39mfor\u001b[39;00m res \u001b[39min\u001b[39;00m results])\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:299\u001b[0m, in \u001b[0;36mBaseChatModel.generate\u001b[0;34m(self, messages, stop, callbacks, tags, metadata, **kwargs)\u001b[0m\n\u001b[1;32m 296\u001b[0m \u001b[39mfor\u001b[39;00m i, m \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(messages):\n\u001b[1;32m 297\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 298\u001b[0m results\u001b[39m.\u001b[39mappend(\n\u001b[0;32m--> 299\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_generate_with_cache(\n\u001b[1;32m 300\u001b[0m m,\n\u001b[1;32m 301\u001b[0m stop\u001b[39m=\u001b[39;49mstop,\n\u001b[1;32m 302\u001b[0m run_manager\u001b[39m=\u001b[39;49mrun_managers[i] \u001b[39mif\u001b[39;49;00m run_managers \u001b[39melse\u001b[39;49;00m \u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 303\u001b[0m \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs,\n\u001b[1;32m 304\u001b[0m )\n\u001b[1;32m 305\u001b[0m )\n\u001b[1;32m 306\u001b[0m \u001b[39mexcept\u001b[39;00m (\u001b[39mKeyboardInterrupt\u001b[39;00m, \u001b[39mException\u001b[39;00m) \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 307\u001b[0m \u001b[39mif\u001b[39;00m run_managers:\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/base.py:446\u001b[0m, in \u001b[0;36mBaseChatModel._generate_with_cache\u001b[0;34m(self, messages, stop, run_manager, **kwargs)\u001b[0m\n\u001b[1;32m 442\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 443\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mAsked to cache, but no cache found at `langchain.cache`.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 444\u001b[0m )\n\u001b[1;32m 445\u001b[0m \u001b[39mif\u001b[39;00m new_arg_supported:\n\u001b[0;32m--> 446\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_generate(\n\u001b[1;32m 447\u001b[0m messages, stop\u001b[39m=\u001b[39;49mstop, run_manager\u001b[39m=\u001b[39;49mrun_manager, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs\n\u001b[1;32m 448\u001b[0m )\n\u001b[1;32m 449\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 450\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_generate(messages, stop\u001b[39m=\u001b[39mstop, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs)\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:86\u001b[0m, in \u001b[0;36mChatLlama2._generate\u001b[0;34m(self, messages, stop, run_manager, **kwargs)\u001b[0m\n\u001b[1;32m 83\u001b[0m kwargs[\u001b[39m\"\u001b[39m\u001b[39mreturn_full_text\u001b[39m\u001b[39m\"\u001b[39m] \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 84\u001b[0m \u001b[39m# num_return_sequences ? ~ is it possible to pass multiple conversations ?\u001b[39;00m\n\u001b[0;32m---> 86\u001b[0m response \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpipeline(prompt, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mpipeline_params)[\u001b[39m\"\u001b[39m\u001b[39mgenerated_text\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m 87\u001b[0m \u001b[39mprint\u001b[39m(response)\n\u001b[1;32m 88\u001b[0m \u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\u001b[39m.\u001b[39m\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/chat_models/llama_2.py:32\u001b[0m, in \u001b[0;36mChatLlama2.pipeline\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[39m@property\u001b[39m\n\u001b[1;32m 30\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mpipeline\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[1;32m 31\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"Getter for the pipeline.\"\"\"\u001b[39;00m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_pipeline\n", - "\u001b[0;31mAttributeError\u001b[0m: 'ChatLlama2' object has no attribute '_pipeline'" - ] - } - ], + "outputs": [], "source": [ "messages = [\n", " SystemMessage(\n", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 55fb5d181d42a..0ee2a3dc483d9 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -1,4 +1,4 @@ -from typing import List, Any, Optional +from typing import List, Any, Optional, Dict from langchain.chat_models.base import BaseChatModel from langchain.schema import ChatResult @@ -14,27 +14,24 @@ CallbackManagerForLLMRun, ) +from langchain.pydantic_v1 import Field, root_validator +from transformers.pipelines import TextGenerationPipeline + B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>", "<>" class ChatLlama2(BaseChatModel): - _pipeline: Any #: :meta private: + pipeline: TextGenerationPipeline @property def _llm_type(self) -> str: """Return type of chat model.""" return "llama-2-chat-hf" - @property - def pipeline(self) -> Any: - """Getter for the pipeline.""" - return self._pipeline - - @pipeline.setter - def pipeline(self, value: Any): - """Setter for the pipeline.""" - if not hasattr(value, "task") or value.task != "text-generation": + @root_validator(pre=True) + def validate_environment(cls, values: Dict) -> Dict: + if not hasattr(values["pipeline"], "task") or values["pipeline"].task != "text-generation": raise ValueError("The pipeline task should be 'text-generation'.") valid_models = ( @@ -43,10 +40,10 @@ def pipeline(self, value: Any): "meta-llama/Llama-2-70b-chat-hf", ) - if not hasattr(value, "model") or value.model.name_or_path not in valid_models: + if not hasattr(values["pipeline"], "model") or values["pipeline"].model.name_or_path not in valid_models: raise ValueError(f"The pipeline model name or path should be one of {valid_models}.") - - self._pipeline = value + + return values def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: """ https://huggingface.co/blog/llama2 """ @@ -79,11 +76,11 @@ def _generate( pipeline_params = kwargs # ensure that: - kwargs["return_text"] = True + # kwargs["return_text"] = True kwargs["return_full_text"] = False # num_return_sequences ? ~ is it possible to pass multiple conversations ? - response = self.pipeline(prompt, **pipeline_params)["generated_text"] + response = self.pipeline(prompt, **pipeline_params)[0]['generated_text'] print(response) ... return response From 26e10fcd1c46b61c053b35950785b3e91f60e00b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Fri, 1 Sep 2023 09:26:39 +0200 Subject: [PATCH 03/27] _generate returns ChatResult --- docs/extras/integrations/chat/llama_2.ipynb | 43 +++++++------------ .../langchain/chat_models/llama_2.py | 23 +++++----- 2 files changed, 27 insertions(+), 39 deletions(-) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index 9be8e3afb761b..b36ff60bd2355 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -134,7 +134,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.29it/s]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.36it/s]\n" ] } ], @@ -198,40 +198,27 @@ "cell_type": "code", "execution_count": 12, "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'NoneType' object has no attribute 'get'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[12], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m chat \u001b[39m=\u001b[39m ChatLlama2(pipeline\u001b[39m=\u001b[39;49mpipe)\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/langchain/load/serializable.py:74\u001b[0m, in \u001b[0;36mSerializable.__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__init__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m---> 74\u001b[0m \u001b[39msuper\u001b[39;49m()\u001b[39m.\u001b[39;49m\u001b[39m__init__\u001b[39;49m(\u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 75\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_lc_kwargs \u001b[39m=\u001b[39m kwargs\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/pydantic/v1/main.py:339\u001b[0m, in \u001b[0;36mBaseModel.__init__\u001b[0;34m(__pydantic_self__, **data)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[39mCreate a new model by parsing and validating input data from keyword arguments.\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[39mRaises ValidationError if the input data cannot be parsed to form a valid model.\u001b[39;00m\n\u001b[1;32m 337\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 338\u001b[0m \u001b[39m# Uses something other than `self` the first arg to allow \"self\" as a settable attribute\u001b[39;00m\n\u001b[0;32m--> 339\u001b[0m values, fields_set, validation_error \u001b[39m=\u001b[39m validate_model(__pydantic_self__\u001b[39m.\u001b[39;49m\u001b[39m__class__\u001b[39;49m, data)\n\u001b[1;32m 340\u001b[0m \u001b[39mif\u001b[39;00m validation_error:\n\u001b[1;32m 341\u001b[0m \u001b[39mraise\u001b[39;00m validation_error\n", - "File \u001b[0;32m/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/pydantic/v1/main.py:1055\u001b[0m, in \u001b[0;36mvalidate_model\u001b[0;34m(model, input_data, cls)\u001b[0m\n\u001b[1;32m 1052\u001b[0m \u001b[39mreturn\u001b[39;00m {}, \u001b[39mset\u001b[39m(), ValidationError([ErrorWrapper(exc, loc\u001b[39m=\u001b[39mROOT_KEY)], cls_)\n\u001b[1;32m 1054\u001b[0m \u001b[39mfor\u001b[39;00m name, field \u001b[39min\u001b[39;00m model\u001b[39m.\u001b[39m__fields__\u001b[39m.\u001b[39mitems():\n\u001b[0;32m-> 1055\u001b[0m value \u001b[39m=\u001b[39m input_data\u001b[39m.\u001b[39;49mget(field\u001b[39m.\u001b[39malias, _missing)\n\u001b[1;32m 1056\u001b[0m using_name \u001b[39m=\u001b[39m \u001b[39mFalse\u001b[39;00m\n\u001b[1;32m 1057\u001b[0m \u001b[39mif\u001b[39;00m value \u001b[39mis\u001b[39;00m _missing \u001b[39mand\u001b[39;00m config\u001b[39m.\u001b[39mallow_population_by_field_name \u001b[39mand\u001b[39;00m field\u001b[39m.\u001b[39malt_alias:\n", - "\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'get'" - ] - } - ], - "source": [ - "chat = ChatLlama2(pipeline=pipe)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, "outputs": [], "source": [ - "chat.pipeline" + "chat = ChatLlama2(pipeline=pipe)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Sure! Here\\'s the translation of \"I love programming\" from English to French:\\nJe adore le programming.\\n\\nI hope that helps! Let me know if you have any other sentences you\\'d like me to translate.', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "messages = [\n", " SystemMessage(\n", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 0ee2a3dc483d9..94953100fde5d 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -10,6 +10,7 @@ HumanMessage, SystemMessage, ) +from langchain.schema.output import ChatGeneration from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) @@ -71,23 +72,23 @@ def _generate( **kwargs: Any, ) -> ChatResult: prompt = self._format_messages_as_text(messages) - # TODO: remove - print(prompt) pipeline_params = kwargs - # ensure that: - # kwargs["return_text"] = True + # make sure that `return_full_text` is set to False + # otherwise, pipeline will return prompt + generation kwargs["return_full_text"] = False - # num_return_sequences ? ~ is it possible to pass multiple conversations ? response = self.pipeline(prompt, **pipeline_params)[0]['generated_text'] - print(response) - ... - return response + chat_generation = ChatGeneration( + message=AIMessage(content=response), + ) + return ChatResult(generations=[chat_generation]) + # TODO: -# fix problem with getter -# correct output from _generate +# generation kwargs # try to add stopping criteria # handle batch requests -# handle ChatMessage, AIMessageChunk ? + # num_return_sequences ? ~ is it possible to pass multiple conversations ? +# handle ChatMessage +# AIMessageChunk ? From 2aba4cd204219e14dff424493cc2121538eb3386 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Fri, 1 Sep 2023 10:53:54 +0200 Subject: [PATCH 04/27] batch calls example in notebook, handling chat messages --- docs/extras/integrations/chat/llama_2.ipynb | 160 ++++++++++-------- .../langchain/chat_models/llama_2.py | 31 ++-- 2 files changed, 106 insertions(+), 85 deletions(-) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index b36ff60bd2355..d826368fec7d5 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -19,6 +19,13 @@ "sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/')\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# LLama Chat" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -33,18 +40,9 @@ "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/ml-team/homes/eryk.mazus/langchain/langchain_venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ - "# HF imports\n", + "# Hugging Face imports:\n", "import torch\n", "from transformers import (\n", " AutoModelForCausalLM,\n", @@ -54,7 +52,7 @@ " pipeline,\n", ")\n", "\n", - "# LangChain imports\n", + "# LangChain imports:\n", "from langchain.chat_models import ChatLlama2\n", "from langchain.prompts.chat import (\n", " ChatPromptTemplate,\n", @@ -66,45 +64,40 @@ ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "env: HUGGINGFACE_TOKEN=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\n" - ] - } - ], "source": [ - "%env HUGGINGFACE_TOKEN=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" + "This notebook assumes that you were granted with access to the Llama 2 models in the Hugging Face models hub. To use the model locally, you need to be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with a Hugging Face account." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n", - "Token is valid (permission: read).\n", - "Your token has been saved to /home/eryk.mazus/.cache/huggingface/token\n", - "Login successful\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ac4d7153742a44fb91506b16684e952d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
Dict: return values def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: - """ https://huggingface.co/blog/llama2 """ + """ + # TODO: docstring + https://huggingface.co/blog/llama2 + """ prompt = "" for i, message in enumerate(messages): - if i != 0 and isinstance(message, SystemMessage): - raise ... - elif i == 0 and isinstance(message, SystemMessage): + if isinstance(message, SystemMessage) and i != 0: + raise ValueError("SystemMessage can only appear as the first message in the list.") + elif isinstance(message, SystemMessage) and i == 0: prompt += f"{B_INST} {B_SYS}\n{message.content}\n{E_SYS}\n\n" elif isinstance(message, HumanMessage) and i > 0: prompt += f"{message.content} {E_INST} " - elif i == 0 and isinstance(message, HumanMessage): + elif isinstance(message, HumanMessage) and i == 0: prompt += f"{B_INST} {message.content} {E_INST} " elif isinstance(message, AIMessage): prompt += f"{message.content} {B_INST} " + elif isinstance(message, ChatMessage) and i == 0: + prompt += f"{B_INST} {message.role.capitalize()}: {message.content} {E_INST} " + elif isinstance(message, ChatMessage) and i > 0: + prompt += f"{message.role.capitalize()}: {message.content} {E_INST} " return prompt @@ -73,12 +79,12 @@ def _generate( ) -> ChatResult: prompt = self._format_messages_as_text(messages) - pipeline_params = kwargs # make sure that `return_full_text` is set to False # otherwise, pipeline will return prompt + generation kwargs["return_full_text"] = False + kwargs["num_return_sequences"] = 1 - response = self.pipeline(prompt, **pipeline_params)[0]['generated_text'] + response = self.pipeline(prompt, **kwargs)[0]['generated_text'] chat_generation = ChatGeneration( message=AIMessage(content=response), ) @@ -86,9 +92,6 @@ def _generate( # TODO: -# generation kwargs -# try to add stopping criteria -# handle batch requests - # num_return_sequences ? ~ is it possible to pass multiple conversations ? -# handle ChatMessage -# AIMessageChunk ? +# try adding stopping criteria +# tests for prompt generation +# streaming ? From 82f71f1a4e2dcf506c51b1f4b6e994cae4efeafd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Fri, 1 Sep 2023 13:26:48 +0200 Subject: [PATCH 05/27] _format_messages_as_text docstring --- .../langchain/chat_models/llama_2.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 89174f8ee67d1..5fcf783cd8401 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -46,9 +46,22 @@ def validate_environment(cls, values: Dict) -> Dict: return values def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: - """ - # TODO: docstring - https://huggingface.co/blog/llama2 + """ + Transform List of Chat Messages to text following Meta's prompt guidelines. + + Prompt template with System Message: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` + + Prompt template without System Message: + ``` + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + ``` """ prompt = "" @@ -94,4 +107,3 @@ def _generate( # TODO: # try adding stopping criteria # tests for prompt generation -# streaming ? From 933f5f4ed139d2dd435a255931fbd98ec95df8fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Fri, 1 Sep 2023 13:47:03 +0200 Subject: [PATCH 06/27] you can pass stop words now --- docs/extras/integrations/chat/llama_2.ipynb | 58 ++++++++++--------- .../langchain/chat_models/llama_2.py | 41 +++++++++++-- 2 files changed, 67 insertions(+), 32 deletions(-) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index d826368fec7d5..f55bce94e121e 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -78,7 +78,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "ac4d7153742a44fb91506b16684e952d", + "model_id": "0caf2986239049f094b991683930f3b9", "version_major": 2, "version_minor": 0 }, @@ -126,7 +126,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "23396232e97c4d36a341a6f93c8d83d5", + "model_id": "d9ad8b146c464fa7a43a6541e6a61cd3", "version_major": 2, "version_minor": 0 }, @@ -192,20 +192,9 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content=' Sure! Here is the translation of \"I love programming\" from English to French:\\nJe adore le programming.\\n\\nI hope this helps! Let me know if you have any other sentences you would like me to translate.', additional_kwargs={}, example=False)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "messages = [\n", " SystemMessage(\n", @@ -218,6 +207,30 @@ "chat(messages, **pipeline_kwargs)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Single calls with stop words" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful assistant.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"Tell me the history of AI.\"\n", + " ),\n", + "]\n", + "chat(messages, stop=[\"Artificial\"], **pipeline_kwargs)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -227,20 +240,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programming\" is \"le programming\" in French.\\n\\nSo, you love programming? \"J\\'aime le programming\" in French.', generation_info=None, message=AIMessage(content=' Great! \"Programming\" is \"le programming\" in French.\\n\\nSo, you love programming? \"J\\'aime le programming\" in French.', additional_kwargs={}, example=False))], [ChatGeneration(text=\" Je suis ravi d'être votre assistant de traduction pour l'anglais à français. Vous pouvez maintenant dire ce que vous voulez en français, et je serai heureux de le traduire pour vous.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et qui a des applications nombreuses dans différents domaines, tels que la robotique, l'apprentissage automatique, la reconnaissance faciale et de la parole, et bien plus encore.\\nPourriez-vous me dire ce que vous voulez savoir sur l'intelligence artificielle? Je suis là pour vous aider et vous fournir des informations précises et utiles.\", generation_info=None, message=AIMessage(content=\" Je suis ravi d'être votre assistant de traduction pour l'anglais à français. Vous pouvez maintenant dire ce que vous voulez en français, et je serai heureux de le traduire pour vous.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et qui a des applications nombreuses dans différents domaines, tels que la robotique, l'apprentissage automatique, la reconnaissance faciale et de la parole, et bien plus encore.\\nPourriez-vous me dire ce que vous voulez savoir sur l'intelligence artificielle? Je suis là pour vous aider et vous fournir des informations précises et utiles.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('ee799572-e9cb-4c32-be0d-8eca0f19aa01')), RunInfo(run_id=UUID('94372bd7-e03c-48ba-8be2-aaad5a125868'))])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "batch_messages = [\n", " [\n", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 5fcf783cd8401..5c5b0b6d0d402 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -1,5 +1,6 @@ -from typing import List, Any, Optional, Dict +from typing import List, Any, Optional, Dict, Union +import torch from langchain.chat_models.base import BaseChatModel from langchain.schema import ChatResult from langchain.schema.messages import ( @@ -16,6 +17,7 @@ from langchain.pydantic_v1 import Field, root_validator from transformers.pipelines import TextGenerationPipeline +from transformers import StoppingCriteria, StoppingCriteriaList B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>", "<>" @@ -97,13 +99,44 @@ def _generate( kwargs["return_full_text"] = False kwargs["num_return_sequences"] = 1 - response = self.pipeline(prompt, **kwargs)[0]['generated_text'] + if stop: + class StoppingCriteriaSub(StoppingCriteria): + """ Subclass of StoppingCriteria to allow for custom stopping criteria """ + def __init__(self, stops: Optional[List] = None, device: Union[torch.device, str, None] = None): + super().__init__() + stops = stops or [] + if device: + self.stops = [stop.to(device) for stop in stops] + else: + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Dict) -> bool: + for stop_id in self.stops: + if (input_ids[0][-torch.numel(stop_id) :] == stop_id).all(): + return True + return False + + stopping_criteria_tokenized = [ + self.pipeline.tokenizer(stopping_criterion, return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze() + for stopping_criterion in stop + ] + + stopping_criteria = StoppingCriteriaList( + [ + StoppingCriteriaSub( + stops=stopping_criteria_tokenized, device="cuda:0", + ) + ] + ) + else: + stopping_criteria = None + + + response = self.pipeline(prompt, stopping_criteria=stopping_criteria, **kwargs)[0]['generated_text'] chat_generation = ChatGeneration( message=AIMessage(content=response), ) return ChatResult(generations=[chat_generation]) - # TODO: -# try adding stopping criteria # tests for prompt generation From 36519e01a2b08ab7446ba2a5872ad190944c138d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Fri, 1 Sep 2023 14:55:33 +0200 Subject: [PATCH 07/27] format_messages_as_text test --- docs/extras/integrations/chat/llama_2.ipynb | 49 ++++++++++++++++--- .../langchain/chat_models/llama_2.py | 8 ++- .../unit_tests/chat_models/text_llama_2.py | 36 ++++++++++++++ 3 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 libs/langchain/tests/unit_tests/chat_models/text_llama_2.py diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index f55bce94e121e..dcd485ca7a729 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -78,7 +78,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "0caf2986239049f094b991683930f3b9", + "model_id": "55b402029a06429fb287dc2485b4a881", "version_major": 2, "version_minor": 0 }, @@ -126,7 +126,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d9ad8b146c464fa7a43a6541e6a61cd3", + "model_id": "95a4c12cb61541ba9c843741e190c3be", "version_major": 2, "version_minor": 0 }, @@ -192,9 +192,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Sure! Here is the translation of \"I love programming\" from English to French:\\n\\nJe adore le programming.\\n\\nI hope this helps! Let me know if you have any other questions.', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "messages = [\n", " SystemMessage(\n", @@ -216,9 +227,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=' Of course! The history of Artificial', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "messages = [\n", " SystemMessage(\n", @@ -240,9 +262,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nWould you like me to translate something else for you?', generation_info=None, message=AIMessage(content=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nWould you like me to translate something else for you?', additional_kwargs={}, example=False))], [ChatGeneration(text=' \"Je suis heureux que vous aimiez l\\'intelligence artificielle.\" (I am happy that you like artificial intelligence.)', generation_info=None, message=AIMessage(content=' \"Je suis heureux que vous aimiez l\\'intelligence artificielle.\" (I am happy that you like artificial intelligence.)', additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('da8a7be7-24b3-473a-a904-7f822cb26a13')), RunInfo(run_id=UUID('ec7217bd-94e6-4871-b056-6918234a39c1'))])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "batch_messages = [\n", " [\n", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 5c5b0b6d0d402..1c9e060314ead 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -47,7 +47,8 @@ def validate_environment(cls, values: Dict) -> Dict: return values - def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: + @staticmethod + def format_messages_as_text(messages: List[BaseMessage]) -> str: """ Transform List of Chat Messages to text following Meta's prompt guidelines. @@ -92,7 +93,7 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - prompt = self._format_messages_as_text(messages) + prompt = self.format_messages_as_text(messages) # make sure that `return_full_text` is set to False # otherwise, pipeline will return prompt + generation @@ -137,6 +138,3 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwa message=AIMessage(content=response), ) return ChatResult(generations=[chat_generation]) - -# TODO: -# tests for prompt generation diff --git a/libs/langchain/tests/unit_tests/chat_models/text_llama_2.py b/libs/langchain/tests/unit_tests/chat_models/text_llama_2.py new file mode 100644 index 0000000000000..d6a5c62269aec --- /dev/null +++ b/libs/langchain/tests/unit_tests/chat_models/text_llama_2.py @@ -0,0 +1,36 @@ +"""Test Llama-2 Chat model.""" + +from langchain.chat_models.llama_2 import ChatLlama2 +from langchain.schema.messages import ( + AIMessage, + HumanMessage, + SystemMessage, +) + +def test_format_messages_as_text_with_system() -> None: + messages = [ + SystemMessage(content="System Prompt."), + HumanMessage(content="Human Message."), + AIMessage(content="AI response."), + HumanMessage(content="Second Human Message."), + ] + + ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] " + + messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) + assert messages_as_str == ground_truth + + +def test_format_messages_as_text_without_system() -> None: + messages = [ + HumanMessage(content="Human Message."), + AIMessage(content="AI response."), + HumanMessage(content="Second Human Message."), + AIMessage(content="Second AI response."), + ] + + ground_truth = "[INST] Human Message. [/INST] AI response. [INST] Second Human Message. [/INST] Second AI response. [INST] " + + messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) + assert messages_as_str == ground_truth + \ No newline at end of file From 9a92e08a5b147f05a56a01dff1f48bc4901cd254 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Fri, 1 Sep 2023 15:02:36 +0200 Subject: [PATCH 08/27] formatter --- .../langchain/chat_models/__init__.py | 2 +- .../langchain/chat_models/llama_2.py | 71 +++++++++++++------ .../{text_llama_2.py => test_llama_2.py} | 2 +- 3 files changed, 50 insertions(+), 25 deletions(-) rename libs/langchain/tests/unit_tests/chat_models/{text_llama_2.py => test_llama_2.py} (99%) diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index c50fa87d1c7c1..f8bf68f877d47 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -26,12 +26,12 @@ from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.litellm import ChatLiteLLM +from langchain.chat_models.llama_2 import ChatLlama2 from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain.chat_models.ollama import ChatOllama from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.vertexai import ChatVertexAI -from langchain.chat_models.llama_2 import ChatLlama2 __all__ = [ "ChatOpenAI", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 1c9e060314ead..1df82d459bed0 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -1,7 +1,14 @@ -from typing import List, Any, Optional, Dict, Union +from typing import Any, Dict, List, Optional, Union import torch +from transformers import StoppingCriteria, StoppingCriteriaList +from transformers.pipelines import TextGenerationPipeline + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) from langchain.chat_models.base import BaseChatModel +from langchain.pydantic_v1 import Field, root_validator from langchain.schema import ChatResult from langchain.schema.messages import ( AIMessage, @@ -11,13 +18,6 @@ SystemMessage, ) from langchain.schema.output import ChatGeneration -from langchain.callbacks.manager import ( - CallbackManagerForLLMRun, -) - -from langchain.pydantic_v1 import Field, root_validator -from transformers.pipelines import TextGenerationPipeline -from transformers import StoppingCriteria, StoppingCriteriaList B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>", "<>" @@ -33,7 +33,10 @@ def _llm_type(self) -> str: @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: - if not hasattr(values["pipeline"], "task") or values["pipeline"].task != "text-generation": + if ( + not hasattr(values["pipeline"], "task") + or values["pipeline"].task != "text-generation" + ): raise ValueError("The pipeline task should be 'text-generation'.") valid_models = ( @@ -42,9 +45,14 @@ def validate_environment(cls, values: Dict) -> Dict: "meta-llama/Llama-2-70b-chat-hf", ) - if not hasattr(values["pipeline"], "model") or values["pipeline"].model.name_or_path not in valid_models: - raise ValueError(f"The pipeline model name or path should be one of {valid_models}.") - + if ( + not hasattr(values["pipeline"], "model") + or values["pipeline"].model.name_or_path not in valid_models + ): + raise ValueError( + f"The pipeline model name or path should be one of {valid_models}." + ) + return values @staticmethod @@ -64,13 +72,15 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: Prompt template without System Message: ``` [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] - ``` + ``` """ prompt = "" for i, message in enumerate(messages): if isinstance(message, SystemMessage) and i != 0: - raise ValueError("SystemMessage can only appear as the first message in the list.") + raise ValueError( + "SystemMessage can only appear as the first message in the list." + ) elif isinstance(message, SystemMessage) and i == 0: prompt += f"{B_INST} {B_SYS}\n{message.content}\n{E_SYS}\n\n" elif isinstance(message, HumanMessage) and i > 0: @@ -83,7 +93,7 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: prompt += f"{B_INST} {message.role.capitalize()}: {message.content} {E_INST} " elif isinstance(message, ChatMessage) and i > 0: prompt += f"{message.role.capitalize()}: {message.content} {E_INST} " - + return prompt def _generate( @@ -101,9 +111,15 @@ def _generate( kwargs["num_return_sequences"] = 1 if stop: + class StoppingCriteriaSub(StoppingCriteria): - """ Subclass of StoppingCriteria to allow for custom stopping criteria """ - def __init__(self, stops: Optional[List] = None, device: Union[torch.device, str, None] = None): + """Subclass of StoppingCriteria to allow for custom stopping criteria""" + + def __init__( + self, + stops: Optional[List] = None, + device: Union[torch.device, str, None] = None, + ): super().__init__() stops = stops or [] if device: @@ -111,29 +127,38 @@ def __init__(self, stops: Optional[List] = None, device: Union[torch.device, str else: self.stops = stops - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Dict) -> bool: + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs: Dict, + ) -> bool: for stop_id in self.stops: if (input_ids[0][-torch.numel(stop_id) :] == stop_id).all(): return True return False stopping_criteria_tokenized = [ - self.pipeline.tokenizer(stopping_criterion, return_tensors="pt", add_special_tokens=False)["input_ids"].squeeze() + self.pipeline.tokenizer( + stopping_criterion, return_tensors="pt", add_special_tokens=False + )["input_ids"].squeeze() for stopping_criterion in stop ] - + stopping_criteria = StoppingCriteriaList( [ StoppingCriteriaSub( - stops=stopping_criteria_tokenized, device="cuda:0", + stops=stopping_criteria_tokenized, + device="cuda:0", ) ] ) else: stopping_criteria = None - - response = self.pipeline(prompt, stopping_criteria=stopping_criteria, **kwargs)[0]['generated_text'] + response = self.pipeline(prompt, stopping_criteria=stopping_criteria, **kwargs)[ + 0 + ]["generated_text"] chat_generation = ChatGeneration( message=AIMessage(content=response), ) diff --git a/libs/langchain/tests/unit_tests/chat_models/text_llama_2.py b/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py similarity index 99% rename from libs/langchain/tests/unit_tests/chat_models/text_llama_2.py rename to libs/langchain/tests/unit_tests/chat_models/test_llama_2.py index d6a5c62269aec..6926f88e20a36 100644 --- a/libs/langchain/tests/unit_tests/chat_models/text_llama_2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py @@ -7,6 +7,7 @@ SystemMessage, ) + def test_format_messages_as_text_with_system() -> None: messages = [ SystemMessage(content="System Prompt."), @@ -33,4 +34,3 @@ def test_format_messages_as_text_without_system() -> None: messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth - \ No newline at end of file From cc53251128ee1aea7a90786959f418e372965b2c Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Fri, 1 Sep 2023 15:25:38 +0200 Subject: [PATCH 09/27] fix lint issues --- libs/langchain/langchain/chat_models/llama_2.py | 11 +++++++---- .../tests/unit_tests/chat_models/test_llama_2.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/llama_2.py index 1df82d459bed0..4ac4790793c1a 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/llama_2.py @@ -8,7 +8,7 @@ CallbackManagerForLLMRun, ) from langchain.chat_models.base import BaseChatModel -from langchain.pydantic_v1 import Field, root_validator +from langchain.pydantic_v1 import root_validator from langchain.schema import ChatResult from langchain.schema.messages import ( AIMessage, @@ -66,12 +66,14 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: {{ system_prompt }} <> - {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} + [INST] {{ user_msg_2 }} [/INST] ``` Prompt template without System Message: ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} + [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2}} ``` """ prompt = "" @@ -90,7 +92,8 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: elif isinstance(message, AIMessage): prompt += f"{message.content} {B_INST} " elif isinstance(message, ChatMessage) and i == 0: - prompt += f"{B_INST} {message.role.capitalize()}: {message.content} {E_INST} " + prompt += f"{B_INST} {message.role.capitalize()}:\ +{message.content} {E_INST} " elif isinstance(message, ChatMessage) and i > 0: prompt += f"{message.role.capitalize()}: {message.content} {E_INST} " diff --git a/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py b/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py index 6926f88e20a36..ee571d1109380 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py @@ -16,7 +16,7 @@ def test_format_messages_as_text_with_system() -> None: HumanMessage(content="Second Human Message."), ] - ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] " + ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] " # noqa: E501 messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth @@ -30,7 +30,7 @@ def test_format_messages_as_text_without_system() -> None: AIMessage(content="Second AI response."), ] - ground_truth = "[INST] Human Message. [/INST] AI response. [INST] Second Human Message. [/INST] Second AI response. [INST] " + ground_truth = "[INST] Human Message. [/INST] AI response. [INST] Second Human Message. [/INST] Second AI response. [INST] " # noqa: E501 messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth From 6a0cd87477e17288f22cdd0ebfce8170d19f0c23 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Fri, 1 Sep 2023 16:41:34 +0200 Subject: [PATCH 10/27] removal of redundant notebook cell --- docs/extras/integrations/chat/llama_2.ipynb | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/llama_2.ipynb index dcd485ca7a729..a386a6d16f7c7 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/llama_2.ipynb @@ -1,24 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# TODO:\n", - "# temporary workaround:\n", - "# remove cell before PR\n", - "import os\n", - "import sys\n", - "os.getcwd()\n", - "os.chdir(\"/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/\")\n", - "os.getcwd()\n", - "\n", - "# sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/')\n", - "sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/')\n" - ] - }, { "cell_type": "markdown", "metadata": {}, From cc578b525f3d97323c3e1b08f5289af7585d53e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Mon, 4 Sep 2023 10:17:18 +0200 Subject: [PATCH 11/27] refactor: update naming to indicate Hugging Face usage --- ...llama_2.ipynb => huggingface_llama2.ipynb} | 23 +++++++------------ .../langchain/chat_models/__init__.py | 4 ++-- .../{llama_2.py => huggingface_llama2.py} | 8 +++---- ..._llama_2.py => test_huggingface_llama2.py} | 8 +++---- 4 files changed, 17 insertions(+), 26 deletions(-) rename docs/extras/integrations/chat/{llama_2.ipynb => huggingface_llama2.ipynb} (72%) rename libs/langchain/langchain/chat_models/{llama_2.py => huggingface_llama2.py} (96%) rename libs/langchain/tests/unit_tests/chat_models/{test_llama_2.py => test_huggingface_llama2.py} (80%) diff --git a/docs/extras/integrations/chat/llama_2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb similarity index 72% rename from docs/extras/integrations/chat/llama_2.ipynb rename to docs/extras/integrations/chat/huggingface_llama2.ipynb index a386a6d16f7c7..5474071a9562d 100644 --- a/docs/extras/integrations/chat/llama_2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# LLama Chat" + "# Llama-2-Chat Model from Hugging Face" ] }, { @@ -29,18 +29,11 @@ " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " BitsAndBytesConfig,\n", - " HfArgumentParser,\n", " pipeline,\n", ")\n", "\n", "# LangChain imports:\n", - "from langchain.chat_models import ChatLlama2\n", - "from langchain.prompts.chat import (\n", - " ChatPromptTemplate,\n", - " SystemMessagePromptTemplate,\n", - " AIMessagePromptTemplate,\n", - " HumanMessagePromptTemplate,\n", - ")\n", + "from langchain.chat_models import ChatLlama2Hf\n", "from langchain.schema import AIMessage, HumanMessage, SystemMessage" ] }, @@ -59,7 +52,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "55b402029a06429fb287dc2485b4a881", + "model_id": "f01ae0ae007e4e6e885c8bf6b05eb813", "version_major": 2, "version_minor": 0 }, @@ -107,7 +100,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "95a4c12cb61541ba9c843741e190c3be", + "model_id": "e6b0cbeb58c942f2b320eb771787b773", "version_major": 2, "version_minor": 0 }, @@ -145,7 +138,7 @@ "metadata": {}, "outputs": [], "source": [ - "chat = ChatLlama2(pipeline=pipe)" + "chat = ChatLlama2Hf(pipeline=pipe)" ] }, { @@ -179,7 +172,7 @@ { "data": { "text/plain": [ - "AIMessage(content=' Sure! Here is the translation of \"I love programming\" from English to French:\\n\\nJe adore le programming.\\n\\nI hope this helps! Let me know if you have any other questions.', additional_kwargs={}, example=False)" + "AIMessage(content=' Sure, I\\'d be happy to help! The French translation of \"I love programming\" is \"J\\'aime le programming.\"', additional_kwargs={}, example=False)" ] }, "execution_count": 11, @@ -214,7 +207,7 @@ { "data": { "text/plain": [ - "AIMessage(content=' Of course! The history of Artificial', additional_kwargs={}, example=False)" + "AIMessage(content=\" Of course, I'd be happy to help! The history of Artificial\", additional_kwargs={}, example=False)" ] }, "execution_count": 12, @@ -249,7 +242,7 @@ { "data": { "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nWould you like me to translate something else for you?', generation_info=None, message=AIMessage(content=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nWould you like me to translate something else for you?', additional_kwargs={}, example=False))], [ChatGeneration(text=' \"Je suis heureux que vous aimiez l\\'intelligence artificielle.\" (I am happy that you like artificial intelligence.)', generation_info=None, message=AIMessage(content=' \"Je suis heureux que vous aimiez l\\'intelligence artificielle.\" (I am happy that you like artificial intelligence.)', additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('da8a7be7-24b3-473a-a904-7f822cb26a13')), RunInfo(run_id=UUID('ec7217bd-94e6-4871-b056-6918234a39c1'))])" + "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nSo, you love programmation? (You love programming)', generation_info=None, message=AIMessage(content=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nSo, you love programmation? (You love programming)', additional_kwargs={}, example=False))], [ChatGeneration(text=\" Bonjour! Je suis heureux d'être votre assistante de traduction pour les phrases en anglais et en français.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et en constante évolution. Les avancées dans le domaine de l'IA ont des applications nombreuses dans différents secteurs, tels que la robotique, les systems embarqués, les réseaux de communication, les applications mobiles, les sistèmes de recommendation, les centres de données, etc.\\nVous pouvez me poser des questions ou des questions sur l'IA, et je ferai de mon mieux pour vous fournir des informations précises et utiles.\", generation_info=None, message=AIMessage(content=\" Bonjour! Je suis heureux d'être votre assistante de traduction pour les phrases en anglais et en français.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et en constante évolution. Les avancées dans le domaine de l'IA ont des applications nombreuses dans différents secteurs, tels que la robotique, les systems embarqués, les réseaux de communication, les applications mobiles, les sistèmes de recommendation, les centres de données, etc.\\nVous pouvez me poser des questions ou des questions sur l'IA, et je ferai de mon mieux pour vous fournir des informations précises et utiles.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('92316e4b-4ced-4e30-b015-9a35ea990476')), RunInfo(run_id=UUID('f91df63e-c604-4bd2-9da2-8fe3a3112981'))])" ] }, "execution_count": 13, diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index f8bf68f877d47..a3d56fe71a2d2 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -26,7 +26,7 @@ from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.litellm import ChatLiteLLM -from langchain.chat_models.llama_2 import ChatLlama2 +from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain.chat_models.ollama import ChatOllama from langchain.chat_models.openai import ChatOpenAI @@ -42,7 +42,7 @@ "ChatGooglePalm", "ChatMLflowAIGateway", "ChatOllama", - "ChatLlama2", + "ChatLlama2Hf", "ChatVertexAI", "JinaChat", "HumanInputChatModel", diff --git a/libs/langchain/langchain/chat_models/llama_2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py similarity index 96% rename from libs/langchain/langchain/chat_models/llama_2.py rename to libs/langchain/langchain/chat_models/huggingface_llama2.py index 4ac4790793c1a..0db8414ba25a7 100644 --- a/libs/langchain/langchain/chat_models/llama_2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -23,7 +23,7 @@ B_SYS, E_SYS = "<>", "<>" -class ChatLlama2(BaseChatModel): +class ChatLlama2Hf(BaseChatModel): pipeline: TextGenerationPipeline @property @@ -66,14 +66,12 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: {{ system_prompt }} <> - {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - [INST] {{ user_msg_2 }} [/INST] + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] ``` Prompt template without System Message: ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2}} + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2}} ``` """ prompt = "" diff --git a/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py similarity index 80% rename from libs/langchain/tests/unit_tests/chat_models/test_llama_2.py rename to libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index ee571d1109380..90fe828b13ffc 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_llama_2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -1,6 +1,6 @@ -"""Test Llama-2 Chat model.""" +"""Test Hugging Face Llama-2 Chat model.""" -from langchain.chat_models.llama_2 import ChatLlama2 +from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf from langchain.schema.messages import ( AIMessage, HumanMessage, @@ -18,7 +18,7 @@ def test_format_messages_as_text_with_system() -> None: ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] " # noqa: E501 - messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) + messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth @@ -32,5 +32,5 @@ def test_format_messages_as_text_without_system() -> None: ground_truth = "[INST] Human Message. [/INST] AI response. [INST] Second Human Message. [/INST] Second AI response. [INST] " # noqa: E501 - messages_as_str = ChatLlama2.format_messages_as_text(messages=messages) + messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth From cb33d483d662e22ef85c43c5f18f58a8c5fc3f0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Mon, 4 Sep 2023 13:38:11 +0200 Subject: [PATCH 12/27] small refactor --- docs/extras/integrations/chat/huggingface_llama2.ipynb | 10 +++++----- .../langchain/chat_models/huggingface_llama2.py | 10 +++------- .../unit_tests/chat_models/test_huggingface_llama2.py | 3 ++- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 5474071a9562d..22d883605d5c4 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -52,7 +52,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "f01ae0ae007e4e6e885c8bf6b05eb813", + "model_id": "fdfe7c3faf0c40d0bac0fd22fd9ebd38", "version_major": 2, "version_minor": 0 }, @@ -100,7 +100,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e6b0cbeb58c942f2b320eb771787b773", + "model_id": "4c3b0838de6140019682c5e5d17f1f37", "version_major": 2, "version_minor": 0 }, @@ -172,7 +172,7 @@ { "data": { "text/plain": [ - "AIMessage(content=' Sure, I\\'d be happy to help! The French translation of \"I love programming\" is \"J\\'aime le programming.\"', additional_kwargs={}, example=False)" + "AIMessage(content=' Sure! Here\\'s the translation of \"I love programming\" from English to French:\\nJe adore le programming.\\n\\nI hope that helps! Let me know if you have any other sentences you\\'d like me to translate.', additional_kwargs={}, example=False)" ] }, "execution_count": 11, @@ -207,7 +207,7 @@ { "data": { "text/plain": [ - "AIMessage(content=\" Of course, I'd be happy to help! The history of Artificial\", additional_kwargs={}, example=False)" + "AIMessage(content=\" Of course, I'd be happy to help! Artificial\", additional_kwargs={}, example=False)" ] }, "execution_count": 12, @@ -242,7 +242,7 @@ { "data": { "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nSo, you love programmation? (You love programming)', generation_info=None, message=AIMessage(content=' Great! \"Programmation\" is the French word for \"programming\".\\n\\nSo, you love programmation? (You love programming)', additional_kwargs={}, example=False))], [ChatGeneration(text=\" Bonjour! Je suis heureux d'être votre assistante de traduction pour les phrases en anglais et en français.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et en constante évolution. Les avancées dans le domaine de l'IA ont des applications nombreuses dans différents secteurs, tels que la robotique, les systems embarqués, les réseaux de communication, les applications mobiles, les sistèmes de recommendation, les centres de données, etc.\\nVous pouvez me poser des questions ou des questions sur l'IA, et je ferai de mon mieux pour vous fournir des informations précises et utiles.\", generation_info=None, message=AIMessage(content=\" Bonjour! Je suis heureux d'être votre assistante de traduction pour les phrases en anglais et en français.\\nMerci de me dire que vous aimez l'intelligence artificielle. C'est un sujet très intéressant et en constante évolution. Les avancées dans le domaine de l'IA ont des applications nombreuses dans différents secteurs, tels que la robotique, les systems embarqués, les réseaux de communication, les applications mobiles, les sistèmes de recommendation, les centres de données, etc.\\nVous pouvez me poser des questions ou des questions sur l'IA, et je ferai de mon mieux pour vous fournir des informations précises et utiles.\", additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('92316e4b-4ced-4e30-b015-9a35ea990476')), RunInfo(run_id=UUID('f91df63e-c604-4bd2-9da2-8fe3a3112981'))])" + "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programming\" in English can be translated to \"le programming\" in French.\\n\\nSo, you love programming? \"Aimez-vous le programming\" in French.', generation_info=None, message=AIMessage(content=' Great! \"Programming\" in English can be translated to \"le programming\" in French.\\n\\nSo, you love programming? \"Aimez-vous le programming\" in French.', additional_kwargs={}, example=False))], [ChatGeneration(text=' Bonjour! Je suis heureux de vous aider avec la translation de votre phrase en français.\\n\\nVous aimez l\\'intelligence artificielle.\\n\\n(Note: I used the phrase \"Bonjour!\" to greet you in French, as it is a common way to start a conversation in France. \"Je suis heureux de vous aider\" means \"I am happy to help you\" in French.)', generation_info=None, message=AIMessage(content=' Bonjour! Je suis heureux de vous aider avec la translation de votre phrase en français.\\n\\nVous aimez l\\'intelligence artificielle.\\n\\n(Note: I used the phrase \"Bonjour!\" to greet you in French, as it is a common way to start a conversation in France. \"Je suis heureux de vous aider\" means \"I am happy to help you\" in French.)', additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('844a9416-e941-478b-ac61-82b5c7c15fb7')), RunInfo(run_id=UUID('57636053-2e2c-41a4-89ad-7ef70d38fe12'))])" ] }, "execution_count": 13, diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 0db8414ba25a7..0c8be7e2ca9c3 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -13,7 +13,6 @@ from langchain.schema.messages import ( AIMessage, BaseMessage, - ChatMessage, HumanMessage, SystemMessage, ) @@ -71,7 +70,7 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: Prompt template without System Message: ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] {{ model_answer_2}} + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] ``` """ prompt = "" @@ -89,11 +88,8 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: prompt += f"{B_INST} {message.content} {E_INST} " elif isinstance(message, AIMessage): prompt += f"{message.content} {B_INST} " - elif isinstance(message, ChatMessage) and i == 0: - prompt += f"{B_INST} {message.role.capitalize()}:\ -{message.content} {E_INST} " - elif isinstance(message, ChatMessage) and i > 0: - prompt += f"{message.role.capitalize()}: {message.content} {E_INST} " + else: + raise ValueError(f"Unsupported Message type:") return prompt diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index 90fe828b13ffc..681378e2cbadf 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -14,9 +14,10 @@ def test_format_messages_as_text_with_system() -> None: HumanMessage(content="Human Message."), AIMessage(content="AI response."), HumanMessage(content="Second Human Message."), + AIMessage(content="AI response."), ] - ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] " # noqa: E501 + ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] AI response. [INST] " # noqa: E501 messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth From 440571d58c5c541819e4f0ce4fc4684769ea5606 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Mon, 4 Sep 2023 13:49:18 +0200 Subject: [PATCH 13/27] fix lint errors, running formatter --- libs/langchain/langchain/chat_models/__init__.py | 2 +- libs/langchain/langchain/chat_models/huggingface_llama2.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index a3d56fe71a2d2..44e5d15dd76de 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -23,10 +23,10 @@ from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm +from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.litellm import ChatLiteLLM -from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway from langchain.chat_models.ollama import ChatOllama from langchain.chat_models.openai import ChatOpenAI diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 0c8be7e2ca9c3..59b930c447cdc 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -65,12 +65,12 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: {{ system_prompt }} <> - {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} ``` Prompt template without System Message: ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} [INST] {{ user_msg_2 }} [/INST] + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} ``` """ prompt = "" @@ -89,7 +89,7 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: elif isinstance(message, AIMessage): prompt += f"{message.content} {B_INST} " else: - raise ValueError(f"Unsupported Message type:") + raise ValueError(f"Unsupported Message type: {type(message)}") return prompt From f860326c532dea1555450acbf39fe29218b33290 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Mon, 4 Sep 2023 15:00:30 +0200 Subject: [PATCH 14/27] moving stopping criteria class out of function, correct typing --- .../chat_models/huggingface_llama2.py | 54 +++++++++---------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 59b930c447cdc..fae5028bdee55 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -22,6 +22,33 @@ B_SYS, E_SYS = "<>", "<>" +class StoppingCriteriaSub(StoppingCriteria): + """Subclass of StoppingCriteria to allow for custom stopping criteria""" + + def __init__( + self, + stops: Optional[List[torch.Tensor]] = None, + device: Union[torch.device, str, None] = None, + ): + super().__init__() + stops = stops or [] + if device: + self.stops = [stop.to(device) for stop in stops] + else: + self.stops = stops + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs: Dict, + ) -> bool: + for stop_id in self.stops: + if (input_ids[0][-torch.numel(stop_id) :] == stop_id).all(): + return True + return False + + class ChatLlama2Hf(BaseChatModel): pipeline: TextGenerationPipeline @@ -108,33 +135,6 @@ def _generate( kwargs["num_return_sequences"] = 1 if stop: - - class StoppingCriteriaSub(StoppingCriteria): - """Subclass of StoppingCriteria to allow for custom stopping criteria""" - - def __init__( - self, - stops: Optional[List] = None, - device: Union[torch.device, str, None] = None, - ): - super().__init__() - stops = stops or [] - if device: - self.stops = [stop.to(device) for stop in stops] - else: - self.stops = stops - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - **kwargs: Dict, - ) -> bool: - for stop_id in self.stops: - if (input_ids[0][-torch.numel(stop_id) :] == stop_id).all(): - return True - return False - stopping_criteria_tokenized = [ self.pipeline.tokenizer( stopping_criterion, return_tensors="pt", add_special_tokens=False From 9b6e79d5e6c5992a0df41514d076f3bbe00d63f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Mon, 4 Sep 2023 17:21:05 +0200 Subject: [PATCH 15/27] code review suggestions --- .../chat/huggingface_llama2.ipynb | 114 +++++++++--------- .../chat_models/huggingface_llama2.py | 55 +++++---- .../chat_models/test_huggingface_llama2.py | 41 ++++++- 3 files changed, 124 insertions(+), 86 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 22d883605d5c4..71b6fde1a9565 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -9,17 +9,7 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 3, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -46,13 +36,13 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fdfe7c3faf0c40d0bac0fd22fd9ebd38", + "model_id": "c054dde5caa04223ab6dfa3588dc3418", "version_major": 2, "version_minor": 0 }, @@ -71,7 +61,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -80,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -94,13 +84,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4c3b0838de6140019682c5e5d17f1f37", + "model_id": "a6d5f47650404075b9d885fba799cfe6", "version_major": 2, "version_minor": 0 }, @@ -119,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -134,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -143,7 +133,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -166,18 +156,18 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 39, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "AIMessage(content=' Sure! Here\\'s the translation of \"I love programming\" from English to French:\\nJe adore le programming.\\n\\nI hope that helps! Let me know if you have any other sentences you\\'d like me to translate.', additional_kwargs={}, example=False)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + " Sure, I'd be happy to help! Here is the translation of \"I love programming\" from English to French:\n", + "Je aime le programming.\n", + "\n", + "I hope this helps! Let me know if you have any other questions.\n" + ] } ], "source": [ @@ -189,7 +179,8 @@ " content=\"Translate this sentence from English to French. I love programming.\"\n", " ),\n", "]\n", - "chat(messages, **pipeline_kwargs)" + "result = chat(messages, **pipeline_kwargs)\n", + "print(result.content)" ] }, { @@ -201,18 +192,15 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 40, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "AIMessage(content=\" Of course, I'd be happy to help! Artificial\", additional_kwargs={}, example=False)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + " Of course! Artificial\n" + ] } ], "source": [ @@ -224,7 +212,8 @@ " content=\"Tell me the history of AI.\"\n", " ),\n", "]\n", - "chat(messages, stop=[\"Artificial\"], **pipeline_kwargs)" + "result = chat(messages, stop=[\"Artificial\", \"Inteligence\"], **pipeline_kwargs)\n", + "print(result.content)" ] }, { @@ -236,20 +225,9 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 41, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "LLMResult(generations=[[ChatGeneration(text=' Great! \"Programming\" in English can be translated to \"le programming\" in French.\\n\\nSo, you love programming? \"Aimez-vous le programming\" in French.', generation_info=None, message=AIMessage(content=' Great! \"Programming\" in English can be translated to \"le programming\" in French.\\n\\nSo, you love programming? \"Aimez-vous le programming\" in French.', additional_kwargs={}, example=False))], [ChatGeneration(text=' Bonjour! Je suis heureux de vous aider avec la translation de votre phrase en français.\\n\\nVous aimez l\\'intelligence artificielle.\\n\\n(Note: I used the phrase \"Bonjour!\" to greet you in French, as it is a common way to start a conversation in France. \"Je suis heureux de vous aider\" means \"I am happy to help you\" in French.)', generation_info=None, message=AIMessage(content=' Bonjour! Je suis heureux de vous aider avec la translation de votre phrase en français.\\n\\nVous aimez l\\'intelligence artificielle.\\n\\n(Note: I used the phrase \"Bonjour!\" to greet you in French, as it is a common way to start a conversation in France. \"Je suis heureux de vous aider\" means \"I am happy to help you\" in French.)', additional_kwargs={}, example=False))]], llm_output={}, run=[RunInfo(run_id=UUID('844a9416-e941-478b-ac61-82b5c7c15fb7')), RunInfo(run_id=UUID('57636053-2e2c-41a4-89ad-7ef70d38fe12'))])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "batch_messages = [\n", " [\n", @@ -261,8 +239,34 @@ " HumanMessage(content=\"I love artificial intelligence.\")\n", " ],\n", "]\n", - "result = chat.generate(batch_messages)\n", - "result" + "result = chat.generate(batch_messages)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Response #0:\n", + " Great! \"Programmation\" is the French word for \"programming\". So, \"Je adore le programming.\" (I love programming.)\n", + "\n", + "Response #1:\n", + " Bonjour! Je suis heureux d'être votre assistant de traduction pour l'anglais à français.\n", + "\n", + "You said: \"I love artificial intelligence.\"\n", + "\n", + "In French: \"Je suis ravi d'artificial intelligence.\" (Note: \"artificial\" should be pronounced \"artificial\" in French, not \"ar-ti-fi-cial\")\n", + "\n" + ] + } + ], + "source": [ + "for i, generation in enumerate(result.generations):\n", + " print(f\"Response #{i}:\\n{generation[0].text}\", end=\"\\n\\n\")" ] } ], diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index fae5028bdee55..93bfb45215f9f 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -1,6 +1,13 @@ from typing import Any, Dict, List, Optional, Union -import torch +try: + import torch +except ImportError: + raise ImportError( + "torch package not found, please install it with " "`pip install torch`" + ) + +from enum import Enum from transformers import StoppingCriteria, StoppingCriteriaList from transformers.pipelines import TextGenerationPipeline @@ -18,8 +25,21 @@ ) from langchain.schema.output import ChatGeneration -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>", "<>" + +class InstructionTokens(Enum): + def __str__(self) -> str: + return self.value + + B_INST = "[INST]" + E_INST = "[/INST]" + + +class SystemTokens(Enum): + def __str__(self) -> str: + return self.value + + B_SYS = "<>" + E_SYS = "<>" class StoppingCriteriaSub(StoppingCriteria): @@ -44,7 +64,7 @@ def __call__( **kwargs: Dict, ) -> bool: for stop_id in self.stops: - if (input_ids[0][-torch.numel(stop_id) :] == stop_id).all(): + if (input_ids[0][-len(stop_id) :] == stop_id).all(): return True return False @@ -65,20 +85,6 @@ def validate_environment(cls, values: Dict) -> Dict: ): raise ValueError("The pipeline task should be 'text-generation'.") - valid_models = ( - "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-13b-chat-hf", - "meta-llama/Llama-2-70b-chat-hf", - ) - - if ( - not hasattr(values["pipeline"], "model") - or values["pipeline"].model.name_or_path not in valid_models - ): - raise ValueError( - f"The pipeline model name or path should be one of {valid_models}." - ) - return values @staticmethod @@ -108,13 +114,13 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: "SystemMessage can only appear as the first message in the list." ) elif isinstance(message, SystemMessage) and i == 0: - prompt += f"{B_INST} {B_SYS}\n{message.content}\n{E_SYS}\n\n" + prompt += f"{InstructionTokens.B_INST} {SystemTokens.B_SYS}\n{message.content}\n{SystemTokens.E_SYS}\n\n" elif isinstance(message, HumanMessage) and i > 0: - prompt += f"{message.content} {E_INST} " + prompt += f"{message.content} {InstructionTokens.E_INST} " elif isinstance(message, HumanMessage) and i == 0: - prompt += f"{B_INST} {message.content} {E_INST} " + prompt += f"{InstructionTokens.B_INST} {message.content} {InstructionTokens.E_INST} " elif isinstance(message, AIMessage): - prompt += f"{message.content} {B_INST} " + prompt += f"{message.content} {InstructionTokens.B_INST} " else: raise ValueError(f"Unsupported Message type: {type(message)}") @@ -153,9 +159,8 @@ def _generate( else: stopping_criteria = None - response = self.pipeline(prompt, stopping_criteria=stopping_criteria, **kwargs)[ - 0 - ]["generated_text"] + response = self.pipeline(prompt, stopping_criteria=stopping_criteria, **kwargs) + response = response[0]["generated_text"] chat_generation = ChatGeneration( message=AIMessage(content=response), ) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index 681378e2cbadf..606588c60d279 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -1,13 +1,16 @@ """Test Hugging Face Llama-2 Chat model.""" -from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf +from langchain.chat_models.huggingface_llama2 import ( + InstructionTokens, + SystemTokens, + ChatLlama2Hf +) from langchain.schema.messages import ( AIMessage, HumanMessage, SystemMessage, ) - def test_format_messages_as_text_with_system() -> None: messages = [ SystemMessage(content="System Prompt."), @@ -17,10 +20,23 @@ def test_format_messages_as_text_with_system() -> None: AIMessage(content="AI response."), ] - ground_truth = "[INST] <>\nSystem Prompt.\n<>\n\nHuman Message. [/INST] AI response. [INST] Second Human Message. [/INST] AI response. [INST] " # noqa: E501 + assert InstructionTokens.B_INST == "[INST]" + assert InstructionTokens.E_INST == "[/INST]" + assert SystemTokens.B_SYS == "<>" + assert SystemTokens.E_SYS == "<>" + + ground_truth = ( + "[INST] <>\nSystem Prompt.\n<>\n\n" + "Human Message. [/INST] AI response. " + "[INST] Second Human Message. [/INST] " + "AI response. [INST] " + ) messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) - assert messages_as_str == ground_truth + assert messages_as_str == ground_truth, ( + f"Prediction:\n```{messages_as_str}\n" + "```\nExpected:\n```{ground_truth}\n```" + ) def test_format_messages_as_text_without_system() -> None: @@ -31,7 +47,20 @@ def test_format_messages_as_text_without_system() -> None: AIMessage(content="Second AI response."), ] - ground_truth = "[INST] Human Message. [/INST] AI response. [INST] Second Human Message. [/INST] Second AI response. [INST] " # noqa: E501 + assert InstructionTokens.B_INST == "[INST]" + assert InstructionTokens.E_INST == "[/INST]" + assert SystemTokens.B_SYS == "<>" + assert SystemTokens.E_SYS == "<>" + + ground_truth = ( + "[INST] Human Message. [/INST] " + "AI response. [INST] " + "Second Human Message. [/INST] " + "Second AI response. [INST] " + ) messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) - assert messages_as_str == ground_truth + assert messages_as_str == ground_truth, ( + f"Prediction:\n```{messages_as_str}\n" + "```\nExpected:\n```{ground_truth}\n```" + ) From 900d55c09a8579871be5f3b84502ab244e255ffa Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Mon, 4 Sep 2023 17:25:20 +0200 Subject: [PATCH 16/27] run formatter and lint --- .../langchain/chat_models/huggingface_llama2.py | 12 ++++++++++-- .../chat_models/test_huggingface_llama2.py | 9 ++++----- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 93bfb45215f9f..28d5379875755 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -8,6 +8,7 @@ ) from enum import Enum + from transformers import StoppingCriteria, StoppingCriteriaList from transformers.pipelines import TextGenerationPipeline @@ -114,11 +115,18 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: "SystemMessage can only appear as the first message in the list." ) elif isinstance(message, SystemMessage) and i == 0: - prompt += f"{InstructionTokens.B_INST} {SystemTokens.B_SYS}\n{message.content}\n{SystemTokens.E_SYS}\n\n" + prompt += ( + f"{InstructionTokens.B_INST} " + f"{SystemTokens.B_SYS}\n{message.content}\n" + f"{SystemTokens.E_SYS}\n\n" + ) elif isinstance(message, HumanMessage) and i > 0: prompt += f"{message.content} {InstructionTokens.E_INST} " elif isinstance(message, HumanMessage) and i == 0: - prompt += f"{InstructionTokens.B_INST} {message.content} {InstructionTokens.E_INST} " + prompt += ( + f"{InstructionTokens.B_INST} " + f"{message.content} {InstructionTokens.E_INST} " + ) elif isinstance(message, AIMessage): prompt += f"{message.content} {InstructionTokens.B_INST} " else: diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index 606588c60d279..42ea7e03280a0 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -1,9 +1,9 @@ """Test Hugging Face Llama-2 Chat model.""" from langchain.chat_models.huggingface_llama2 import ( + ChatLlama2Hf, InstructionTokens, SystemTokens, - ChatLlama2Hf ) from langchain.schema.messages import ( AIMessage, @@ -11,6 +11,7 @@ SystemMessage, ) + def test_format_messages_as_text_with_system() -> None: messages = [ SystemMessage(content="System Prompt."), @@ -34,8 +35,7 @@ def test_format_messages_as_text_with_system() -> None: messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" - "```\nExpected:\n```{ground_truth}\n```" + f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" ) @@ -61,6 +61,5 @@ def test_format_messages_as_text_without_system() -> None: messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" - "```\nExpected:\n```{ground_truth}\n```" + f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" ) From 3b9c03089eceb6a19a3503862f562975f76cbb09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Tue, 5 Sep 2023 10:26:07 +0200 Subject: [PATCH 17/27] StoppingCriteria are correctly placed on the same device as pipeline --- .../chat/huggingface_llama2.ipynb | 58 ++++++---- .../chat_models/huggingface_llama2.py | 100 ++++++++++-------- .../chat_models/test_huggingface_llama2.py | 9 +- 3 files changed, 97 insertions(+), 70 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 71b6fde1a9565..16bb7c35e3be6 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -1,5 +1,19 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "os.getcwd()\n", + "os.chdir(\"/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/\")\n", + "os.getcwd()\n", + "sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/')" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -9,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -36,13 +50,13 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c054dde5caa04223ab6dfa3588dc3418", + "model_id": "4b4b6e0544c94c248f85fbc6103efbbb", "version_major": 2, "version_minor": 0 }, @@ -61,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -84,13 +98,13 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a6d5f47650404075b9d885fba799cfe6", + "model_id": "5d912b06f1954ab287ea0f73828422a6", "version_major": 2, "version_minor": 0 }, @@ -109,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -124,7 +138,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -156,17 +170,18 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Sure, I'd be happy to help! Here is the translation of \"I love programming\" from English to French:\n", - "Je aime le programming.\n", + " Sure! Here is the translation of \"I love programming\" from English to French:\n", + "\n", + "Je suis passionné par le programming.\n", "\n", - "I hope this helps! Let me know if you have any other questions.\n" + "I hope this helps! Let me know if you have any other sentences you would like me to translate.\n" ] } ], @@ -192,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 24, "metadata": {}, "outputs": [ { @@ -225,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -252,14 +267,11 @@ "output_type": "stream", "text": [ "Response #0:\n", - " Great! \"Programmation\" is the French word for \"programming\". So, \"Je adore le programming.\" (I love programming.)\n", + " Great! \"Programming\" is translated to French as \"programmation\". So, \"I love programmation\" would be the correct way to express your sentiment in French.\n", "\n", "Response #1:\n", - " Bonjour! Je suis heureux d'être votre assistant de traduction pour l'anglais à français.\n", - "\n", - "You said: \"I love artificial intelligence.\"\n", - "\n", - "In French: \"Je suis ravi d'artificial intelligence.\" (Note: \"artificial\" should be pronounced \"artificial\" in French, not \"ar-ti-fi-cial\")\n", + " Bonjour! Je suis heureux d'être votre assistant de traduction. Vous aimez l'intelligence artificielle, n'est-ce pas? (You love artificial intelligence, don't you?)\n", + "In French, the phrase \"intelligence artificielle\" can be translated to \"artificial intelligence\" in English.\n", "\n" ] } diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 93bfb45215f9f..2b10410944010 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -1,11 +1,5 @@ -from typing import Any, Dict, List, Optional, Union - -try: - import torch -except ImportError: - raise ImportError( - "torch package not found, please install it with " "`pip install torch`" - ) +import importlib.util +from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING from enum import Enum from transformers import StoppingCriteria, StoppingCriteriaList @@ -25,6 +19,9 @@ ) from langchain.schema.output import ChatGeneration +if TYPE_CHECKING: + import torch + class InstructionTokens(Enum): def __str__(self) -> str: @@ -42,33 +39,6 @@ def __str__(self) -> str: E_SYS = "<>" -class StoppingCriteriaSub(StoppingCriteria): - """Subclass of StoppingCriteria to allow for custom stopping criteria""" - - def __init__( - self, - stops: Optional[List[torch.Tensor]] = None, - device: Union[torch.device, str, None] = None, - ): - super().__init__() - stops = stops or [] - if device: - self.stops = [stop.to(device) for stop in stops] - else: - self.stops = stops - - def __call__( - self, - input_ids: torch.LongTensor, - scores: torch.FloatTensor, - **kwargs: Dict, - ) -> bool: - for stop_id in self.stops: - if (input_ids[0][-len(stop_id) :] == stop_id).all(): - return True - return False - - class ChatLlama2Hf(BaseChatModel): pipeline: TextGenerationPipeline @@ -140,21 +110,67 @@ def _generate( kwargs["return_full_text"] = False kwargs["num_return_sequences"] = 1 + if importlib.util.find_spec("torch") is not None: + import torch + + device = self.pipeline.device.type + if device == "cuda": + # in the multi-gpu case, stopping criteria tokens + # need to be on the same device: + device = f"{device}:{self.pipeline.device.index}" + + class StoppingCriteriaSub(StoppingCriteria): + """ + A subclass of StoppingCriteria, used for defining custom stopping criteria + for the generation process, apart from the standard End Of Sentence (EOS) + token generation. + + This class allows for generation to be halted based on a list of specified + token sequences, which might signify the end of a meaningful segment + or passage within the generated text. + """ + + def __init__( + self, + stops: Optional[List[torch.Tensor]] = None, + device: Union[torch.device, str, None] = None, + ): + """ + Args: + stops: A list of tensor sequences with individual, tokenized stopping words. + device: The device (e.g., 'cpu', 'cuda', 'cuda:0') on which to keep the + stopping words tokens + """ + super().__init__() + stops = stops or [] + if device: + self.stops = [stop.to(device) for stop in stops] + else: + self.stops = stops + + def __call__( + self, + input_ids: torch.LongTensor, + scores: torch.FloatTensor, + **kwargs: Dict, + ) -> bool: + for stop_id in self.stops: + if (input_ids[0][-len(stop_id) :] == stop_id).all(): + return True + return False + if stop: stopping_criteria_tokenized = [ self.pipeline.tokenizer( stopping_criterion, return_tensors="pt", add_special_tokens=False - )["input_ids"].squeeze() + )["input_ids"] + .squeeze() + .to(device) for stopping_criterion in stop ] stopping_criteria = StoppingCriteriaList( - [ - StoppingCriteriaSub( - stops=stopping_criteria_tokenized, - device="cuda:0", - ) - ] + [StoppingCriteriaSub(stops=stopping_criteria_tokenized, device=device)] ) else: stopping_criteria = None diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index 606588c60d279..02dc631a4a3e8 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -3,7 +3,7 @@ from langchain.chat_models.huggingface_llama2 import ( InstructionTokens, SystemTokens, - ChatLlama2Hf + ChatLlama2Hf, ) from langchain.schema.messages import ( AIMessage, @@ -11,6 +11,7 @@ SystemMessage, ) + def test_format_messages_as_text_with_system() -> None: messages = [ SystemMessage(content="System Prompt."), @@ -34,8 +35,7 @@ def test_format_messages_as_text_with_system() -> None: messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" - "```\nExpected:\n```{ground_truth}\n```" + f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" ) @@ -61,6 +61,5 @@ def test_format_messages_as_text_without_system() -> None: messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" - "```\nExpected:\n```{ground_truth}\n```" + f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" ) From 3885de5dff84aedc8ad9c284db76711152ac55e6 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 5 Sep 2023 10:37:54 +0200 Subject: [PATCH 18/27] run formatter, lint --- .../langchain/chat_models/huggingface_llama2.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 3e72916561304..9f5b7b3ecec78 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -1,7 +1,6 @@ import importlib.util -from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING - from enum import Enum +from typing import Any, Dict, List, Optional, Union from transformers import StoppingCriteria, StoppingCriteriaList from transformers.pipelines import TextGenerationPipeline @@ -20,9 +19,6 @@ ) from langchain.schema.output import ChatGeneration -if TYPE_CHECKING: - import torch - class InstructionTokens(Enum): def __str__(self) -> str: @@ -145,9 +141,10 @@ def __init__( ): """ Args: - stops: A list of tensor sequences with individual, tokenized stopping words. - device: The device (e.g., 'cpu', 'cuda', 'cuda:0') on which to keep the - stopping words tokens + stops: A list of tensor sequences with individual, + tokenized stopping words. + device: The device (e.g., 'cpu', 'cuda', 'cuda:0') + on which to keep the stopping words tokens """ super().__init__() stops = stops or [] From 814131d79becf129a0c6e19b44fbff9e71d7c0a6 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 5 Sep 2023 10:39:30 +0200 Subject: [PATCH 19/27] removal of the redundant notebook cell --- .../integrations/chat/huggingface_llama2.ipynb | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 16bb7c35e3be6..7bf68a1d12f7a 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -1,19 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "os.getcwd()\n", - "os.chdir(\"/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/\")\n", - "os.getcwd()\n", - "sys.path.append('/mnt/ml-team/homes/eryk.mazus/langchain/libs/langchain/')" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -303,7 +289,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "3372ef96e068313d34c91eab0f20d815c93d37110de821968e5d598f73bfb74c" + "hash": "d1d3a3c58a58885896c5459933a599607cdbb9917d7e1ad7516c8786c51f2dd2" } } }, From bd6e2fe1d9480bd01c18f88a20c8001c6d5d54f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Tue, 5 Sep 2023 11:40:17 +0200 Subject: [PATCH 20/27] moved StoppingCriteria import to method --- libs/langchain/langchain/chat_models/huggingface_llama2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 9f5b7b3ecec78..70d4e03db402e 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from transformers import StoppingCriteria, StoppingCriteriaList from transformers.pipelines import TextGenerationPipeline from langchain.callbacks.manager import ( @@ -117,6 +116,9 @@ def _generate( if importlib.util.find_spec("torch") is not None: import torch + if importlib.util.find_spec("transformers") is not None: + from transformers import StoppingCriteria, StoppingCriteriaList + device = self.pipeline.device.type if device == "cuda": # in the multi-gpu case, stopping criteria tokens From 35b5e09557354d241c52dde8edf9b4b8d75bbdbd Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 5 Sep 2023 11:52:34 +0200 Subject: [PATCH 21/27] fixing type annotation --- libs/langchain/langchain/chat_models/huggingface_llama2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index 70d4e03db402e..b479d4b7246cb 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -2,8 +2,6 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from transformers.pipelines import TextGenerationPipeline - from langchain.callbacks.manager import ( CallbackManagerForLLMRun, ) @@ -36,7 +34,7 @@ def __str__(self) -> str: class ChatLlama2Hf(BaseChatModel): - pipeline: TextGenerationPipeline + pipeline: Any @property def _llm_type(self) -> str: From 1228bfc0234d4f6656a993754bd6a56d44b0340c Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 5 Sep 2023 12:01:09 +0200 Subject: [PATCH 22/27] fixing Enum tests --- .../chat_models/test_huggingface_llama2.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py index 42ea7e03280a0..c18bc6efa7c0a 100644 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py @@ -21,10 +21,10 @@ def test_format_messages_as_text_with_system() -> None: AIMessage(content="AI response."), ] - assert InstructionTokens.B_INST == "[INST]" - assert InstructionTokens.E_INST == "[/INST]" - assert SystemTokens.B_SYS == "<>" - assert SystemTokens.E_SYS == "<>" + assert str(InstructionTokens.B_INST) == "[INST]" + assert str(InstructionTokens.E_INST) == "[/INST]" + assert str(SystemTokens.B_SYS) == "<>" + assert str(SystemTokens.E_SYS) == "<>" ground_truth = ( "[INST] <>\nSystem Prompt.\n<>\n\n" @@ -47,10 +47,10 @@ def test_format_messages_as_text_without_system() -> None: AIMessage(content="Second AI response."), ] - assert InstructionTokens.B_INST == "[INST]" - assert InstructionTokens.E_INST == "[/INST]" - assert SystemTokens.B_SYS == "<>" - assert SystemTokens.E_SYS == "<>" + assert str(InstructionTokens.B_INST) == "[INST]" + assert str(InstructionTokens.E_INST) == "[/INST]" + assert str(SystemTokens.B_SYS) == "<>" + assert str(SystemTokens.E_SYS) == "<>" ground_truth = ( "[INST] Human Message. [/INST] " From 926c02fb14674617584de6e6f5dc5c4060da3cdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eryk=20Mazu=C5=9B?= Date: Tue, 5 Sep 2023 14:22:03 +0200 Subject: [PATCH 23/27] Editing the huggingface llama 2 notebook --- .../chat/huggingface_llama2.ipynb | 153 +++++++++++++++--- 1 file changed, 130 insertions(+), 23 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 7bf68a1d12f7a..0a96d89a98b4f 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -4,12 +4,31 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Llama-2-Chat Model from Hugging Face" + "# Llama-2-Chat Models from Hugging Face" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook illustrates Hugging Face version of Llama 2 Chat models used as LangChain's Chat model.\n", + "\n", + "Using the Hugging Face Llama 2 Chat models integration has several benefits:\n", + "- **Flexibility in setting model inference parameters:** For instance, it is possible to define the HF Pipeline with the `meta-llama/Llama-2-7b-chat-hf` model and `BitsAndBytesConfig` that loads the model in 4-bits, and run it locally on a GPU with less than 6GB of VRAM.\n", + "- **Automatic prompt formatting that adheres to Meta's guidelines:** The `ChatLlama2Hf` class includes the `format_messages_as_text method`, which converts LangChain's messages into prompts that comply with [Meta's guidelines](https://huggingface.co/blog/llama2) for interacting with Llama 2 models.\n", + "- **Customization:** By overwriting static method `format_messages_as_text` one can also use other instruction-tuned models with LangChain's \"chat messages\" interface." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -31,18 +50,34 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook assumes that you were granted with access to the Llama 2 models in the Hugging Face models hub. To use the model locally, you need to be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with a Hugging Face account." + "This notebook assumes that you were granted with access to the Llama 2 models in the Hugging Face models hub. To use the model locally, you need to be [logged in](https://huggingface.co/docs/huggingface_hub/quick-start#login) with a Hugging Face account. \n", + "\n", + "To log in using CLI run the following command in your terminal:\n", + "```\n", + "huggingface-cli login\n", + "```\n", + "or using an environment variable\n", + "```\n", + "huggingface-cli login --token $HUGGINGFACE_TOKEN\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also log in programmatically in notebook:" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4b4b6e0544c94c248f85fbc6103efbbb", + "model_id": "b8848cee932844e0a4b1bf3a1e9ae8b1", "version_major": 2, "version_minor": 0 }, @@ -59,18 +94,53 @@ "login()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Hugging Face Pipeline instance:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following section loads the 7b version of the Llama 2 Chat model and uses the `bitandbytes` library to load a model in 4bit using NF4 quantization with double quantization and compute dtype bfloat16, which speeds up the underlying matrix multiplications.\n", + "\n", + "More information about these techniques can be found at: [link](https://huggingface.co/blog/4bit-transformers-bitsandbytes)" + ] + }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "model_name = \"meta-llama/Llama-2-7b-chat-hf\"" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To load the model in 4bit, make sure that the `accelerate`, `transformers` and `bitsandbytes` libraries are installed:" + ] + }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install -q -U bitsandbytes\n", + "# !pip install -q -U git+https://github.com/huggingface/transformers.git\n", + "# !pip install -q -U git+https://github.com/huggingface/peft.git\n", + "# !pip install -q -U git+https://github.com/huggingface/accelerate.git" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -84,13 +154,13 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "5d912b06f1954ab287ea0f73828422a6", + "model_id": "d6dfde9985d24ee0bcc9509b2d300728", "version_major": 2, "version_minor": 0 }, @@ -109,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -122,18 +192,32 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Defining LangChain Llama-2-Chat model:" + ] + }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "chat = ChatLlama2Hf(pipeline=pipe)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Besides defining arguments for `Pipeline` initialization, we can also control the generation process, by enabling sampling, chaning temperature or defining maximum length of single generation." + ] + }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -143,7 +227,7 @@ " \"top_p\": 0.95,\n", " \"temperature\": 0.7,\n", " \"eos_token_id\": tokenizer.eos_token_id,\n", - " \"max_length\": 256, \n", + " \"max_length\": 512, \n", "}" ] }, @@ -154,9 +238,16 @@ "### Single calls:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can get chat completions by passing one or more messages to the chat model. The response will be a message:" + ] + }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -165,7 +256,7 @@ "text": [ " Sure! Here is the translation of \"I love programming\" from English to French:\n", "\n", - "Je suis passionné par le programming.\n", + "Je adore le programming.\n", "\n", "I hope this helps! Let me know if you have any other sentences you would like me to translate.\n" ] @@ -191,16 +282,23 @@ "### Single calls with stop words" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By utilizing Hugging Face [Stopping Criteria](https://huggingface.co/docs/transformers/internal/generation_utils#transformers.StoppingCriteria) under the hood, we can provide phrases that, if generated by the model, will cause the generation process to stop." + ] + }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Of course! Artificial\n" + " Of course, I'd be happy to help! The history of Artificial\n" ] } ], @@ -224,9 +322,16 @@ "### Batch calls:" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also go one step further and generate completions for multiple sets of messages:" + ] + }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -245,7 +350,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -253,11 +358,13 @@ "output_type": "stream", "text": [ "Response #0:\n", - " Great! \"Programming\" is translated to French as \"programmation\". So, \"I love programmation\" would be the correct way to express your sentiment in French.\n", + " Great! \"Programmation\" is the French word for \"programming\". So, \"J'aime le programmation\" (pronounced \"zhem-ahy lah-pree-moh\").\n", + "Would you like me to translate anything else for you?\n", "\n", "Response #1:\n", - " Bonjour! Je suis heureux d'être votre assistant de traduction. Vous aimez l'intelligence artificielle, n'est-ce pas? (You love artificial intelligence, don't you?)\n", - "In French, the phrase \"intelligence artificielle\" can be translated to \"artificial intelligence\" in English.\n", + " Bonjour! Je suis heureux d'être votre assistant de traduction pour l'anglais à français.\n", + "Vous aimez l'intelligence artificielle? C'est une technologie très intéressante, n'est-ce pas? Elle permet de automatiser de nombreux processus et de faire des choses qui auraient autrement été très difficiles ou impossibles.\n", + "Si vous avez d'autres questions ou des phrases à traduire, n'hésitez pas à me demander de l'aide. Je suis là pour vous aider!\n", "\n" ] } @@ -289,7 +396,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "d1d3a3c58a58885896c5459933a599607cdbb9917d7e1ad7516c8786c51f2dd2" + "hash": "3372ef96e068313d34c91eab0f20d815c93d37110de821968e5d598f73bfb74c" } } }, From 964b5792b6790af92fc06ffafee11abe8be1e5d0 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Wed, 13 Sep 2023 10:59:15 +0200 Subject: [PATCH 24/27] typos, better name for customg Stopping Critieria subclass --- docs/extras/integrations/chat/huggingface_llama2.ipynb | 6 +++--- .../langchain/chat_models/huggingface_llama2.py | 10 ++++++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_llama2.ipynb index 0a96d89a98b4f..ffca2842327e8 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_llama2.ipynb @@ -105,7 +105,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The following section loads the 7b version of the Llama 2 Chat model and uses the `bitandbytes` library to load a model in 4bit using NF4 quantization with double quantization and compute dtype bfloat16, which speeds up the underlying matrix multiplications.\n", + "The following section loads the 7b version of the Llama 2 Chat model and uses the `bitsandbytes` library to load a model in 4bit using NF4 quantization with double quantization and compute dtype bfloat16, which speeds up the underlying matrix multiplications.\n", "\n", "More information about these techniques can be found at: [link](https://huggingface.co/blog/4bit-transformers-bitsandbytes)" ] @@ -196,7 +196,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Defining LangChain Llama-2-Chat model:" + "Initializing a LangChain Llama-2-Chat instance" ] }, { @@ -396,7 +396,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "3372ef96e068313d34c91eab0f20d815c93d37110de821968e5d598f73bfb74c" + "hash": "d1d3a3c58a58885896c5459933a599607cdbb9917d7e1ad7516c8786c51f2dd2" } } }, diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_llama2.py index b479d4b7246cb..5821c84bd8695 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_llama2.py @@ -69,6 +69,8 @@ def format_messages_as_text(messages: List[BaseMessage]) -> str: ``` [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} ``` + Source: + https://github.com/facebookresearch/llama-recipes/blob/df77625e48c3994aef19702fb331215f7fb83494/docs/inference.md?plain=1#L124 """ prompt = "" @@ -123,7 +125,7 @@ def _generate( # need to be on the same device: device = f"{device}:{self.pipeline.device.index}" - class StoppingCriteriaSub(StoppingCriteria): + class CustomStoppingCriteria(StoppingCriteria): """ A subclass of StoppingCriteria, used for defining custom stopping criteria for the generation process, apart from the standard End Of Sentence (EOS) @@ -175,7 +177,11 @@ def __call__( ] stopping_criteria = StoppingCriteriaList( - [StoppingCriteriaSub(stops=stopping_criteria_tokenized, device=device)] + [ + CustomStoppingCriteria( + stops=stopping_criteria_tokenized, device=device + ) + ] ) else: stopping_criteria = None From 87597a10fac62ae342f7d137b22429a530849f28 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Wed, 13 Sep 2023 15:50:47 +0200 Subject: [PATCH 25/27] Generic Hugging Face Pipeline Chat Model --- ...lama2.ipynb => huggingface_pipeline.ipynb} | 70 ++++----- .../langchain/chat_models/__init__.py | 4 +- ...face_llama2.py => huggingface_pipeline.py} | 137 +++++++++--------- .../chat_models/test_huggingface_llama2.py | 65 --------- 4 files changed, 108 insertions(+), 168 deletions(-) rename docs/extras/integrations/chat/{huggingface_llama2.ipynb => huggingface_pipeline.ipynb} (80%) rename libs/langchain/langchain/chat_models/{huggingface_llama2.py => huggingface_pipeline.py} (82%) delete mode 100644 libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py diff --git a/docs/extras/integrations/chat/huggingface_llama2.ipynb b/docs/extras/integrations/chat/huggingface_pipeline.ipynb similarity index 80% rename from docs/extras/integrations/chat/huggingface_llama2.ipynb rename to docs/extras/integrations/chat/huggingface_pipeline.ipynb index ffca2842327e8..3319f727379f9 100644 --- a/docs/extras/integrations/chat/huggingface_llama2.ipynb +++ b/docs/extras/integrations/chat/huggingface_pipeline.ipynb @@ -4,31 +4,33 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Llama-2-Chat Models from Hugging Face" + "# Hugging Face Pipelines as LangChain Chat Model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook illustrates Hugging Face version of Llama 2 Chat models used as LangChain's Chat model.\n", - "\n", - "Using the Hugging Face Llama 2 Chat models integration has several benefits:\n", - "- **Flexibility in setting model inference parameters:** For instance, it is possible to define the HF Pipeline with the `meta-llama/Llama-2-7b-chat-hf` model and `BitsAndBytesConfig` that loads the model in 4-bits, and run it locally on a GPU with less than 6GB of VRAM.\n", - "- **Automatic prompt formatting that adheres to Meta's guidelines:** The `ChatLlama2Hf` class includes the `format_messages_as_text method`, which converts LangChain's messages into prompts that comply with [Meta's guidelines](https://huggingface.co/blog/llama2) for interacting with Llama 2 models.\n", - "- **Customization:** By overwriting static method `format_messages_as_text` one can also use other instruction-tuned models with LangChain's \"chat messages\" interface." + "This notebook demonstrates the use of Hugging Face models as LangChain Chat models. We support Llama 2 Chat models out of the box because the prompt templates for instruction-tuned models differ from model to model. To handle any other Hugging Face model, simply create a class that inherits from the `ChatHuggingFacePipeline` class and implement a custom `format_messages_as_text` that parses the List of Messages to string." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Llama-2-Chat" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Setup" + "### Setup" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +44,7 @@ ")\n", "\n", "# LangChain imports:\n", - "from langchain.chat_models import ChatLlama2Hf\n", + "from langchain.chat_models import ChatHFLlama2Pipeline\n", "from langchain.schema import AIMessage, HumanMessage, SystemMessage" ] }, @@ -71,13 +73,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b8848cee932844e0a4b1bf3a1e9ae8b1", + "model_id": "6608d88ce05f441aaef3e640732bfd3e", "version_major": 2, "version_minor": 0 }, @@ -98,7 +100,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Creating Hugging Face Pipeline instance:" + "### Creating Hugging Face Pipeline instance:" ] }, { @@ -112,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -128,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -154,13 +156,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "d6dfde9985d24ee0bcc9509b2d300728", + "model_id": "8f5fbc7100b445f98d363702e53692fd", "version_major": 2, "version_minor": 0 }, @@ -179,7 +181,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -201,11 +203,11 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "chat = ChatLlama2Hf(pipeline=pipe)" + "chat = ChatHFLlama2Pipeline(pipeline=pipe)" ] }, { @@ -217,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -247,18 +249,17 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Sure! Here is the translation of \"I love programming\" from English to French:\n", - "\n", + " Sure, I'd be happy to help! Here's the translation of \"I love programming\" from English to French:\n", "Je adore le programming.\n", "\n", - "I hope this helps! Let me know if you have any other sentences you would like me to translate.\n" + "I hope that helps! Let me know if you have any other sentences you'd like me to translate.\n" ] } ], @@ -291,14 +292,14 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - " Of course, I'd be happy to help! The history of Artificial\n" + " Of course! Artificial\n" ] } ], @@ -331,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -350,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -358,13 +359,12 @@ "output_type": "stream", "text": [ "Response #0:\n", - " Great! \"Programmation\" is the French word for \"programming\". So, \"J'aime le programmation\" (pronounced \"zhem-ahy lah-pree-moh\").\n", - "Would you like me to translate anything else for you?\n", + " Great! \"Programmation\" is the French word for \"programming\".\n", + "\n", + "So, you love programmation? (programme)\n", "\n", "Response #1:\n", - " Bonjour! Je suis heureux d'être votre assistant de traduction pour l'anglais à français.\n", - "Vous aimez l'intelligence artificielle? C'est une technologie très intéressante, n'est-ce pas? Elle permet de automatiser de nombreux processus et de faire des choses qui auraient autrement été très difficiles ou impossibles.\n", - "Si vous avez d'autres questions ou des phrases à traduire, n'hésitez pas à me demander de l'aide. Je suis là pour vous aider!\n", + " \"Je suis heureux que vous aimiez l'intelligence artificielle.\" (I am happy that you love artificial intelligence.)\n", "\n" ] } diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index e25509ddac172..357258c3ef9c3 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -24,7 +24,7 @@ from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.google_palm import ChatGooglePalm -from langchain.chat_models.huggingface_llama2 import ChatLlama2Hf +from langchain.chat_models.huggingface_pipeline import ChatHFLlama2Pipeline from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.litellm import ChatLiteLLM @@ -44,7 +44,7 @@ "ChatGooglePalm", "ChatMLflowAIGateway", "ChatOllama", - "ChatLlama2Hf", + "ChatHFLlama2Pipeline", "ChatVertexAI", "JinaChat", "HumanInputChatModel", diff --git a/libs/langchain/langchain/chat_models/huggingface_llama2.py b/libs/langchain/langchain/chat_models/huggingface_pipeline.py similarity index 82% rename from libs/langchain/langchain/chat_models/huggingface_llama2.py rename to libs/langchain/langchain/chat_models/huggingface_pipeline.py index 5821c84bd8695..f7d55f9406b8d 100644 --- a/libs/langchain/langchain/chat_models/huggingface_llama2.py +++ b/libs/langchain/langchain/chat_models/huggingface_pipeline.py @@ -1,4 +1,5 @@ import importlib.util +from abc import ABC, abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -17,29 +18,18 @@ from langchain.schema.output import ChatGeneration -class InstructionTokens(Enum): - def __str__(self) -> str: - return self.value - - B_INST = "[INST]" - E_INST = "[/INST]" - - -class SystemTokens(Enum): - def __str__(self) -> str: - return self.value - - B_SYS = "<>" - E_SYS = "<>" - - -class ChatLlama2Hf(BaseChatModel): +class ChatHuggingFacePipeline(BaseChatModel, ABC): pipeline: Any @property def _llm_type(self) -> str: """Return type of chat model.""" - return "llama-2-chat-hf" + return "huggingface_pipeline_chat" + + @abstractmethod + def format_messages_as_text(self, messages: List[BaseMessage]) -> str: + """Method for parsing the list of LangChain Messages into string""" + ... @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: @@ -51,54 +41,6 @@ def validate_environment(cls, values: Dict) -> Dict: return values - @staticmethod - def format_messages_as_text(messages: List[BaseMessage]) -> str: - """ - Transform List of Chat Messages to text following Meta's prompt guidelines. - - Prompt template with System Message: - ``` - [INST] <> - {{ system_prompt }} - <> - - {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - ``` - - Prompt template without System Message: - ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - ``` - Source: - https://github.com/facebookresearch/llama-recipes/blob/df77625e48c3994aef19702fb331215f7fb83494/docs/inference.md?plain=1#L124 - """ - prompt = "" - - for i, message in enumerate(messages): - if isinstance(message, SystemMessage) and i != 0: - raise ValueError( - "SystemMessage can only appear as the first message in the list." - ) - elif isinstance(message, SystemMessage) and i == 0: - prompt += ( - f"{InstructionTokens.B_INST} " - f"{SystemTokens.B_SYS}\n{message.content}\n" - f"{SystemTokens.E_SYS}\n\n" - ) - elif isinstance(message, HumanMessage) and i > 0: - prompt += f"{message.content} {InstructionTokens.E_INST} " - elif isinstance(message, HumanMessage) and i == 0: - prompt += ( - f"{InstructionTokens.B_INST} " - f"{message.content} {InstructionTokens.E_INST} " - ) - elif isinstance(message, AIMessage): - prompt += f"{message.content} {InstructionTokens.B_INST} " - else: - raise ValueError(f"Unsupported Message type: {type(message)}") - - return prompt - def _generate( self, messages: List[BaseMessage], @@ -192,3 +134,66 @@ def __call__( message=AIMessage(content=response), ) return ChatResult(generations=[chat_generation]) + + +class ChatHFLlama2Pipeline(ChatHuggingFacePipeline): + class InstructionTokens(Enum): + def __str__(self) -> str: + return self.value + + B_INST = "[INST]" + E_INST = "[/INST]" + + class SystemTokens(Enum): + def __str__(self) -> str: + return self.value + + B_SYS = "<>" + E_SYS = "<>" + + def format_messages_as_text(self, messages: List[BaseMessage]) -> str: + """ + Transform List of Chat Messages to text following Meta's prompt guidelines. + + Prompt template with System Message: + ``` + [INST] <> + {{ system_prompt }} + <> + + {{ user_msg_1 }} [/INST] {{ model_answer_1 }} + ``` + + Prompt template without System Message: + ``` + [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} + ``` + Source: + https://github.com/facebookresearch/llama-recipes/blob/df77625e48c3994aef19702fb331215f7fb83494/docs/inference.md?plain=1#L124 + """ + prompt = "" + + for i, message in enumerate(messages): + if isinstance(message, SystemMessage) and i != 0: + raise ValueError( + "SystemMessage can only appear as the first message in the list." + ) + elif isinstance(message, SystemMessage) and i == 0: + prompt += ( + f"{self.InstructionTokens.B_INST} " + f"{self.SystemTokens.B_SYS}\n{message.content}\n" + f"{self.SystemTokens.E_SYS}\n\n" + ) + elif isinstance(message, HumanMessage) and i > 0: + prompt += f"{message.content} {self.InstructionTokens.E_INST} " + elif isinstance(message, HumanMessage) and i == 0: + prompt += ( + f"{self.InstructionTokens.B_INST} " + f"{message.content} {self.InstructionTokens.E_INST} " + ) + elif isinstance(message, AIMessage): + prompt += f"{message.content} {self.InstructionTokens.B_INST} " + else: + raise ValueError(f"Unsupported Message type: {type(message)}") + + return prompt diff --git a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py b/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py deleted file mode 100644 index c18bc6efa7c0a..0000000000000 --- a/libs/langchain/tests/unit_tests/chat_models/test_huggingface_llama2.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Test Hugging Face Llama-2 Chat model.""" - -from langchain.chat_models.huggingface_llama2 import ( - ChatLlama2Hf, - InstructionTokens, - SystemTokens, -) -from langchain.schema.messages import ( - AIMessage, - HumanMessage, - SystemMessage, -) - - -def test_format_messages_as_text_with_system() -> None: - messages = [ - SystemMessage(content="System Prompt."), - HumanMessage(content="Human Message."), - AIMessage(content="AI response."), - HumanMessage(content="Second Human Message."), - AIMessage(content="AI response."), - ] - - assert str(InstructionTokens.B_INST) == "[INST]" - assert str(InstructionTokens.E_INST) == "[/INST]" - assert str(SystemTokens.B_SYS) == "<>" - assert str(SystemTokens.E_SYS) == "<>" - - ground_truth = ( - "[INST] <>\nSystem Prompt.\n<>\n\n" - "Human Message. [/INST] AI response. " - "[INST] Second Human Message. [/INST] " - "AI response. [INST] " - ) - - messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) - assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" - ) - - -def test_format_messages_as_text_without_system() -> None: - messages = [ - HumanMessage(content="Human Message."), - AIMessage(content="AI response."), - HumanMessage(content="Second Human Message."), - AIMessage(content="Second AI response."), - ] - - assert str(InstructionTokens.B_INST) == "[INST]" - assert str(InstructionTokens.E_INST) == "[/INST]" - assert str(SystemTokens.B_SYS) == "<>" - assert str(SystemTokens.E_SYS) == "<>" - - ground_truth = ( - "[INST] Human Message. [/INST] " - "AI response. [INST] " - "Second Human Message. [/INST] " - "Second AI response. [INST] " - ) - - messages_as_str = ChatLlama2Hf.format_messages_as_text(messages=messages) - assert messages_as_str == ground_truth, ( - f"Prediction:\n```{messages_as_str}\n" "```\nExpected:\n```{ground_truth}\n```" - ) From d9e9ef38152474f0ed761eef7967e7ae864b33f9 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 10 Oct 2023 12:31:35 +0200 Subject: [PATCH 26/27] simplifying HF Chat Model, by making use of HF Chat Templates --- .../chat/huggingface_pipeline.ipynb | 22 ++-- .../langchain/chat_models/__init__.py | 4 +- .../chat_models/huggingface_pipeline.py | 105 ++++++------------ 3 files changed, 45 insertions(+), 86 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_pipeline.ipynb b/docs/extras/integrations/chat/huggingface_pipeline.ipynb index 3319f727379f9..731a352ed9573 100644 --- a/docs/extras/integrations/chat/huggingface_pipeline.ipynb +++ b/docs/extras/integrations/chat/huggingface_pipeline.ipynb @@ -11,21 +11,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "This notebook demonstrates the use of Hugging Face models as LangChain Chat models. We support Llama 2 Chat models out of the box because the prompt templates for instruction-tuned models differ from model to model. To handle any other Hugging Face model, simply create a class that inherits from the `ChatHuggingFacePipeline` class and implement a custom `format_messages_as_text` that parses the List of Messages to string." + "This notebook demonstrates how to use Hugging Face models as LangChain Chat models, using the Llama 2 Chat model as an example. We use the Hugging Face tokenizer's 'apply_chat_template' method to handle different instruction tuned models with different prompting templates. If you want to change the prompt templateing behavior, you can find instructions in the Hugging Face [guide](https://huggingface.co/docs/transformers/main/en/chat_templating)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Llama-2-Chat" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setup" + "## Setup" ] }, { @@ -44,7 +37,7 @@ ")\n", "\n", "# LangChain imports:\n", - "from langchain.chat_models import ChatHFLlama2Pipeline\n", + "from langchain.chat_models import ChatHuggingFacePipeline\n", "from langchain.schema import AIMessage, HumanMessage, SystemMessage" ] }, @@ -100,7 +93,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Creating Hugging Face Pipeline instance:" + "## Creating Hugging Face Pipeline instance:" ] }, { @@ -176,6 +169,9 @@ ], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "# disabling the default System Message of the Llama model \n", + "tokenizer.use_default_system_prompt = False\n", + "\n", "model_4bit = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, device_map=\"auto\")" ] }, @@ -198,7 +194,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Initializing a LangChain Llama-2-Chat instance" + "Initializing the Chat Model instance" ] }, { @@ -207,7 +203,7 @@ "metadata": {}, "outputs": [], "source": [ - "chat = ChatHFLlama2Pipeline(pipeline=pipe)" + "chat = ChatHuggingFacePipeline(pipeline=pipe)" ] }, { diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index b3436c8c83ddf..670cd8e9a7b22 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -27,7 +27,7 @@ from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fireworks import ChatFireworks from langchain.chat_models.google_palm import ChatGooglePalm -from langchain.chat_models.huggingface_pipeline import ChatHFLlama2Pipeline +from langchain.chat_models.huggingface_pipeline import ChatHuggingFacePipeline from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.javelin_ai_gateway import ChatJavelinAIGateway from langchain.chat_models.jinachat import JinaChat @@ -51,7 +51,7 @@ "ChatGooglePalm", "ChatMLflowAIGateway", "ChatOllama", - "ChatHFLlama2Pipeline", + "ChatHuggingFacePipeline", "ChatVertexAI", "JinaChat", "HumanInputChatModel", diff --git a/libs/langchain/langchain/chat_models/huggingface_pipeline.py b/libs/langchain/langchain/chat_models/huggingface_pipeline.py index f7d55f9406b8d..5cf54539f0770 100644 --- a/libs/langchain/langchain/chat_models/huggingface_pipeline.py +++ b/libs/langchain/langchain/chat_models/huggingface_pipeline.py @@ -1,6 +1,4 @@ import importlib.util -from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, List, Optional, Union from langchain.callbacks.manager import ( @@ -18,7 +16,7 @@ from langchain.schema.output import ChatGeneration -class ChatHuggingFacePipeline(BaseChatModel, ABC): +class ChatHuggingFacePipeline(BaseChatModel): pipeline: Any @property @@ -26,10 +24,30 @@ def _llm_type(self) -> str: """Return type of chat model.""" return "huggingface_pipeline_chat" - @abstractmethod - def format_messages_as_text(self, messages: List[BaseMessage]) -> str: - """Method for parsing the list of LangChain Messages into string""" - ... + @staticmethod + def convert_lc_messages_to_hf_messages( + messages: List[BaseMessage], + ) -> List[Dict[str, str]]: + """ + Method for converting the list of LangChain Messages into + format required by Hugging Face. + """ + output = [] + + for message in messages: + if isinstance(message, SystemMessage): + output.append({"role": "system", "content": message.content}) + elif isinstance(message, HumanMessage): + output.append({"role": "user", "content": message.content}) + elif isinstance(message, AIMessage): + output.append({"role": "assistant", "content": message.content}) + else: + raise ValueError( + f"Unexpected message type: {type(message)}. " + "Expected one of [SystemMessage, HumanMessage, AIMessage]." + ) + + return output @root_validator(pre=True) def validate_environment(cls, values: Dict) -> Dict: @@ -39,6 +57,13 @@ def validate_environment(cls, values: Dict) -> Dict: ): raise ValueError("The pipeline task should be 'text-generation'.") + if not hasattr(values["pipeline"], "apply_chat_template"): + raise ValueError( + "Your transformers module might be outdated. " + "Please update it to ensure that tokenizer has the " + "'apply_chat_template' method." + ) + return values def _generate( @@ -48,7 +73,8 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - prompt = self.format_messages_as_text(messages) + chat = self.convert_lc_messages_to_hf_messages(messages) + prompt = self.pipeline.tokenizer.apply_chat_template(chat, tokenize=False) # make sure that `return_full_text` is set to False # otherwise, pipeline will return prompt + generation @@ -134,66 +160,3 @@ def __call__( message=AIMessage(content=response), ) return ChatResult(generations=[chat_generation]) - - -class ChatHFLlama2Pipeline(ChatHuggingFacePipeline): - class InstructionTokens(Enum): - def __str__(self) -> str: - return self.value - - B_INST = "[INST]" - E_INST = "[/INST]" - - class SystemTokens(Enum): - def __str__(self) -> str: - return self.value - - B_SYS = "<>" - E_SYS = "<>" - - def format_messages_as_text(self, messages: List[BaseMessage]) -> str: - """ - Transform List of Chat Messages to text following Meta's prompt guidelines. - - Prompt template with System Message: - ``` - [INST] <> - {{ system_prompt }} - <> - - {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - ``` - - Prompt template without System Message: - ``` - [INST] {{ user_msg_1 }} [/INST] {{ model_answer_1 }} - ``` - Source: - https://github.com/facebookresearch/llama-recipes/blob/df77625e48c3994aef19702fb331215f7fb83494/docs/inference.md?plain=1#L124 - """ - prompt = "" - - for i, message in enumerate(messages): - if isinstance(message, SystemMessage) and i != 0: - raise ValueError( - "SystemMessage can only appear as the first message in the list." - ) - elif isinstance(message, SystemMessage) and i == 0: - prompt += ( - f"{self.InstructionTokens.B_INST} " - f"{self.SystemTokens.B_SYS}\n{message.content}\n" - f"{self.SystemTokens.E_SYS}\n\n" - ) - elif isinstance(message, HumanMessage) and i > 0: - prompt += f"{message.content} {self.InstructionTokens.E_INST} " - elif isinstance(message, HumanMessage) and i == 0: - prompt += ( - f"{self.InstructionTokens.B_INST} " - f"{message.content} {self.InstructionTokens.E_INST} " - ) - elif isinstance(message, AIMessage): - prompt += f"{message.content} {self.InstructionTokens.B_INST} " - else: - raise ValueError(f"Unsupported Message type: {type(message)}") - - return prompt From cf203b93dc108ee7462e28cb4485fba594ecdaf4 Mon Sep 17 00:00:00 2001 From: eryk-dsai Date: Tue, 10 Oct 2023 13:00:14 +0200 Subject: [PATCH 27/27] removing incorrect check from validate_environment method --- .../chat/huggingface_pipeline.ipynb | 32 ------------------- .../chat_models/huggingface_pipeline.py | 7 ---- 2 files changed, 39 deletions(-) diff --git a/docs/extras/integrations/chat/huggingface_pipeline.ipynb b/docs/extras/integrations/chat/huggingface_pipeline.ipynb index 731a352ed9573..faa90baa9df2b 100644 --- a/docs/extras/integrations/chat/huggingface_pipeline.ipynb +++ b/docs/extras/integrations/chat/huggingface_pipeline.ipynb @@ -57,38 +57,6 @@ "```" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can also log in programmatically in notebook:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6608d88ce05f441aaef3e640732bfd3e", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "VBox(children=(HTML(value='
Dict: ): raise ValueError("The pipeline task should be 'text-generation'.") - if not hasattr(values["pipeline"], "apply_chat_template"): - raise ValueError( - "Your transformers module might be outdated. " - "Please update it to ensure that tokenizer has the " - "'apply_chat_template' method." - ) - return values def _generate(