From 6f0e9bd3f3d01b6c19bcac80b438fcd3f8e2f1e0 Mon Sep 17 00:00:00 2001 From: haranrk Date: Thu, 8 Feb 2024 05:47:32 -0500 Subject: [PATCH 1/2] add replicate as an llm provider --- pyproject.toml | 4 + src/autolabel/models/__init__.py | 2 + src/autolabel/models/replicate.py | 149 ++++++++++++++++++ src/autolabel/schema.py | 1 + .../banking/config_banking_replicate.json | 99 ++++++++++++ tests/unit/__init__.py | 1 + tests/unit/llm_test.py | 65 ++++++++ 7 files changed, 321 insertions(+) create mode 100644 src/autolabel/models/replicate.py create mode 100644 tests/assets/banking/config_banking_replicate.json diff --git a/pyproject.toml b/pyproject.toml index 7eb63a97..6a8c131f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,10 @@ google = [ "tiktoken >= 0.3.3", "google-cloud-aiplatform>=1.25.0" ] +replicate = [ + "replicate >= 0.23.1", + "transformers >= 4.25.0", +] cohere = [ "cohere>=4.11.2" ] diff --git a/src/autolabel/models/__init__.py b/src/autolabel/models/__init__.py index d28a7dde..733fcbea 100644 --- a/src/autolabel/models/__init__.py +++ b/src/autolabel/models/__init__.py @@ -13,6 +13,7 @@ from autolabel.models.palm import PaLMLLM from autolabel.models.hf_pipeline import HFPipelineLLM from autolabel.models.hf_pipeline_vision import HFPipelineMultimodal +from autolabel.models.replicate import ReplicateLLM from autolabel.models.refuel import RefuelLLM MODEL_REGISTRY = { @@ -23,6 +24,7 @@ ModelProvider.HUGGINGFACE_PIPELINE: HFPipelineLLM, ModelProvider.HUGGINGFACE_PIPELINE_VISION: HFPipelineMultimodal, ModelProvider.GOOGLE: PaLMLLM, + ModelProvider.REPLICATE: ReplicateLLM, ModelProvider.REFUEL: RefuelLLM, } diff --git a/src/autolabel/models/replicate.py b/src/autolabel/models/replicate.py new file mode 100644 index 00000000..268cf98d --- /dev/null +++ b/src/autolabel/models/replicate.py @@ -0,0 +1,149 @@ +from typing import List, Optional +from time import time +import logging +import requests + +from autolabel.models import BaseModel +from autolabel.configs import AutolabelConfig +from autolabel.cache import BaseCache +from autolabel.schema import RefuelLLMResult + + +import os + +logger = logging.getLogger(__name__) + + +class ReplicateLLM(BaseModel): + REPLICATE_MAINTAINED_MODELS = [ + "meta/llama-2-70b", + "meta/llama-2-13b", + "meta/llama-2-7b", + "meta/llama-2-70b-chat", + "meta/llama-2-13b-chat", + "meta/llama-2-7b-chat", + "mistralai/mistral-7b-v0.1", + "mistralai/mistral-7b-instruct-vo.2", + "mistralai/mixtral-8x7b-instruct-v0.1", + ] + + # Default parameters for OpenAILLM + DEFAULT_MODEL = "meta/llama-2-7b-chat" + + DEFAULT_PARAMS_COMPLETION_ENGINE = { + "max_tokens": 1000, + "temperature": 0.01, + "model_kwargs": {"logprobs": 1}, + "request_timeout": 30, + } + + # Reference: https://replicate.com/docs/billing + COST_PER_PROMPT_TOKEN = { + "meta/llama-2-70b": 0.65 / 1e6, + "meta/llama-2-13b": 0.10 / 1e6, + "meta/llama-2-7b": 0.05 / 1e6, + "meta/llama-2-70b-chat": 0.65 / 1e6, + "meta/llama-2-13b-chat": 0.10 / 1e6, + "meta/llama-2-7b-chat": 0.05 / 1e6, + "mistralai/mistral-7b-v0.1": 0.05 / 1e6, + "mistralai/mistral-7b-instruct-v0.2": 0.05 / 1e6, + "mistralai/mixtral-8x7b-instruct-v0.1": 0.30 / 1e6, + } + COST_PER_COMPLETION_TOKEN = { + "meta/llama-2-70b": 2.75 / 1e6, + "meta/llama-2-13b": 0.50 / 1e6, + "meta/llama-2-7b": 0.25 / 1e6, + "meta/llama-2-70b-chat": 2.75 / 1e6, + "meta/llama-2-13b-chat": 0.50 / 1e6, + "meta/llama-2-7b-chat": 0.25 / 1e6, + "mistralai/mistral-7b-v0.1": 0.25 / 1e6, + "mistralai/mistral-7b-instruct-v0.2": 0.25 / 1e6, + "mistralai/mixtral-8x7b-instruct-v0.1": 1.00 / 1e6, + } + + def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None: + super().__init__(config, cache) + try: + from langchain_community.llms import Replicate + from transformers import LlamaTokenizerFast + except ImportError: + raise ImportError( + "replicate is required to use the ReplicateLLM. Please install it with the following command: pip install 'refuel-autolabel[replicate]'" + ) + + if os.getenv("REPLICATE_API_TOKEN") is None: + raise ValueError("REPLICATE_API_TOKEN environment variable not set") + + # populate model name + self.model_name = config.model_name() or self.DEFAULT_MODEL + + # populate model params and initialize the LLM + model_params = config.model_params() + + self.model_params = { + **self.DEFAULT_PARAMS_COMPLETION_ENGINE, + **model_params, + } + + # get latest model version, required by langchain to process replicate generations + response = requests.get( + f"https://api.replicate.com/v1/models/{self.model_name}", + headers={"Authorization": f"Token {os.environ['REPLICATE_API_TOKEN']}"}, + ) + if response.status_code == 404: + raise ValueError(f"Model {self.model_name} not found on Replicate") + latest_model_version = response.json()["latest_version"]["id"] + + self.llm = Replicate( + model=f"{self.model_name}:{latest_model_version}", + verbose=False, + **self.model_params, + ) + + self.tokenizer = LlamaTokenizerFast.from_pretrained( + "hf-internal-testing/llama-tokenizer" + ) + + def is_model_mangaed_by_replicate(self) -> bool: + return self.model_name in self.REPLICATE_MAINTAINED_MODELS + + def _label(self, prompts: List[str]) -> RefuelLLMResult: + try: + start_time = time() + result = self.llm.generate(prompts) + generations = result.generations + end_time = time() + return RefuelLLMResult( + generations=generations, + errors=[None] * len(generations), + latencies=[end_time - start_time] * len(generations), + ) + except Exception as e: + return self._label_individually(prompts) + + def get_cost(self, prompt: str, label: Optional[str] = "") -> float: + if self.is_model_mangaed_by_replicate(): + num_prompt_toks = len(self.tokenizer.encode(prompt)) + if label: + num_label_toks = len(self.tokenizer.encode(label)) + else: + # get an upper bound + num_label_toks = self.model_params["max_tokens"] + + cost_per_prompt_token = self.COST_PER_PROMPT_TOKEN[self.model_name] + cost_per_completion_token = self.COST_PER_COMPLETION_TOKEN[self.model_name] + return (num_prompt_toks * cost_per_prompt_token) + ( + num_label_toks * cost_per_completion_token + ) + else: + # TODO - at the moment it's not possible to calculate it https://github.com/replicate/replicate-python/issues/243 + return 0 + + def returns_token_probs(self) -> bool: + return ( + self.model_name is not None + and self.model_name in self.MODELS_WITH_TOKEN_PROBS + ) + + def get_num_tokens(self, prompt: str) -> int: + return len(self.tokenizer.encode(prompt)) diff --git a/src/autolabel/schema.py b/src/autolabel/schema.py index bfc990a5..60fd342a 100644 --- a/src/autolabel/schema.py +++ b/src/autolabel/schema.py @@ -22,6 +22,7 @@ class ModelProvider(str, Enum): REFUEL = "refuel" GOOGLE = "google" COHERE = "cohere" + REPLICATE = "replicate" CUSTOM = "custom" diff --git a/tests/assets/banking/config_banking_replicate.json b/tests/assets/banking/config_banking_replicate.json new file mode 100644 index 00000000..20575e10 --- /dev/null +++ b/tests/assets/banking/config_banking_replicate.json @@ -0,0 +1,99 @@ +{ + "task_name": "BankingComplaintsClassification", + "task_type": "classification", + "dataset": { + "label_column": "label", + "delimiter": "," + }, + "model": { + "provider": "replicate", + "name": "meta/llama-2-70b-chat" + }, + "prompt": { + "task_guidelines": "You are an expert at understanding bank customers support complaints and queries.\nYour job is to correctly classify the provided input example into one of the following categories.\nCategories:\n{labels}", + "output_guidelines": "You will answer with just the the correct output label and nothing else.", + "labels": [ + "activate_my_card", + "age_limit", + "apple_pay_or_google_pay", + "atm_support", + "automatic_top_up", + "balance_not_updated_after_bank_transfer", + "balance_not_updated_after_cheque_or_cash_deposit", + "beneficiary_not_allowed", + "cancel_transfer", + "card_about_to_expire", + "card_acceptance", + "card_arrival", + "card_delivery_estimate", + "card_linking", + "card_not_working", + "card_payment_fee_charged", + "card_payment_not_recognised", + "card_payment_wrong_exchange_rate", + "card_swallowed", + "cash_withdrawal_charge", + "cash_withdrawal_not_recognised", + "change_pin", + "compromised_card", + "contactless_not_working", + "country_support", + "declined_card_payment", + "declined_cash_withdrawal", + "declined_transfer", + "direct_debit_payment_not_recognised", + "disposable_card_limits", + "edit_personal_details", + "exchange_charge", + "exchange_rate", + "exchange_via_app", + "extra_charge_on_statement", + "failed_transfer", + "fiat_currency_support", + "get_disposable_virtual_card", + "get_physical_card", + "getting_spare_card", + "getting_virtual_card", + "lost_or_stolen_card", + "lost_or_stolen_phone", + "order_physical_card", + "passcode_forgotten", + "pending_card_payment", + "pending_cash_withdrawal", + "pending_top_up", + "pending_transfer", + "pin_blocked", + "receiving_money", + "Refund_not_showing_up", + "request_refund", + "reverted_card_payment?", + "supported_cards_and_currencies", + "terminate_account", + "top_up_by_bank_transfer_charge", + "top_up_by_card_charge", + "top_up_by_cash_or_cheque", + "top_up_failed", + "top_up_limits", + "top_up_reverted", + "topping_up_by_card", + "transaction_charged_twice", + "transfer_fee_charged", + "transfer_into_account", + "transfer_not_received_by_recipient", + "transfer_timing", + "unable_to_verify_identity", + "verify_my_identity", + "verify_source_of_funds", + "verify_top_up", + "virtual_card_not_working", + "visa_or_mastercard", + "why_verify_identity", + "wrong_amount_of_cash_received", + "wrong_exchange_rate_for_cash_withdrawal" + ], + "few_shot_examples": "seed.csv", + "few_shot_selection": "semantic_similarity", + "few_shot_num": 10, + "example_template": "Input: {example}\nOutput: {label}" + } +} \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 1ef85049..b491787b 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -6,6 +6,7 @@ { "REFUEL_API_KEY": "dummy_refuel_api_key", "OPENAI_API_KEY": "dummy_open_api_key", + "REPLICATE_API_TOKEN": "dummy_replicate_api_token", "ANTHROPIC_API_KEY": "dummy_anthropic_api_key", } ) diff --git a/tests/unit/llm_test.py b/tests/unit/llm_test.py index 45c8d9e7..f50d09fe 100644 --- a/tests/unit/llm_test.py +++ b/tests/unit/llm_test.py @@ -4,11 +4,13 @@ from autolabel.models.openai import OpenAILLM from autolabel.models.openai_vision import OpenAIVisionLLM from autolabel.models.palm import PaLMLLM +from autolabel.models.replicate import ReplicateLLM from autolabel.models.refuel import RefuelLLM from langchain.schema import Generation, LLMResult from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage from pytest import approx +import pytest ################### ANTHROPIC TESTS ####################### @@ -198,6 +200,69 @@ def test_gpt4V_return_probs(): ################### OPENAI GPT 4V TESTS ####################### +################### REPLICATE TESTS ####################### +class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + +@pytest.fixture +def replicate_model(mocker): + mocker.patch( + "requests.get", + return_value=MockResponse({"latest_version": {"id": "valid_id"}}, 200), + ) + return ReplicateLLM( + config=AutolabelConfig( + config="tests/assets/banking/config_banking_replicate.json" + ) + ) + + +def test_replicate_initialization(replicate_model): + assert isinstance(replicate_model, ReplicateLLM) + + +def test_replicate_invalid_model_initialization(mocker): + config = AutolabelConfig( + config="tests/assets/banking/config_banking_replicate.json" + ) + mocker.patch( + "requests.get", + return_value=MockResponse({"error": "model not found"}, 404), + ) + + with pytest.raises(ValueError) as excinfo: + model = ReplicateLLM(config) + + assert "Model meta/llama-2-70b-chat not found on Replicate" in str(excinfo.value) + + +def test_replicate_label(mocker, replicate_model): + prompts = ["test1", "test2"] + mocker.patch( + "langchain_community.llms.Replicate.generate", + return_value=LLMResult( + generations=[[Generation(text="Answers")] for _ in prompts] + ), + ) + x = replicate_model.label(prompts) + assert [i[0].text for i in x.generations] == ["Answers", "Answers"] + + +def test_replicate_get_cost(replicate_model): + example_prompt = "TestingExamplePrompt" + curr_cost = replicate_model.get_cost(example_prompt) + assert curr_cost == approx(0.00275389, rel=1e-3) + + +################### REPLICATE TESTS ####################### + + ################### REFUEL TESTS ####################### def test_refuel_initialization(): model = RefuelLLM( From b063ec7f6636cf26a4bafc29df9c8f70ccdbe4e5 Mon Sep 17 00:00:00 2001 From: haranrk Date: Thu, 8 Feb 2024 06:26:23 -0500 Subject: [PATCH 2/2] fix typo --- src/autolabel/models/replicate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/autolabel/models/replicate.py b/src/autolabel/models/replicate.py index 268cf98d..b4cb3203 100644 --- a/src/autolabel/models/replicate.py +++ b/src/autolabel/models/replicate.py @@ -104,7 +104,7 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None: "hf-internal-testing/llama-tokenizer" ) - def is_model_mangaed_by_replicate(self) -> bool: + def is_model_managed_by_replicate(self) -> bool: return self.model_name in self.REPLICATE_MAINTAINED_MODELS def _label(self, prompts: List[str]) -> RefuelLLMResult: