From 3e456581983e9eb764431b80ac8424a25140749c Mon Sep 17 00:00:00 2001 From: Shroominic Date: Sun, 10 Nov 2024 15:47:26 +0800 Subject: [PATCH] add unit-tests back by mocking build_client --- .../chat_models/outlines.py | 47 +++++----- .../langchain_community/llms/outlines.py | 47 +++++----- .../unit_tests/chat_models/test_outlines.py | 90 ++++++++++++++++++ .../tests/unit_tests/llms/test_outlines.py | 91 +++++++++++++++++++ 4 files changed, 227 insertions(+), 48 deletions(-) create mode 100644 libs/community/tests/unit_tests/chat_models/test_outlines.py create mode 100644 libs/community/tests/unit_tests/llms/test_outlines.py diff --git a/libs/community/langchain_community/chat_models/outlines.py b/libs/community/langchain_community/chat_models/outlines.py index 2f98eb0d568fc..00aa05adfe1e6 100644 --- a/libs/community/langchain_community/chat_models/outlines.py +++ b/libs/community/langchain_community/chat_models/outlines.py @@ -199,17 +199,6 @@ class ChatOutlines(BaseChatModel): @model_validator(mode="after") def validate_environment(self) -> "ChatOutlines": """Validate that outlines is installed and create a model instance.""" - try: - import outlines.models as models - except ImportError: - raise ImportError( - "Could not import the Outlines library. " - "Please install it with `pip install outlines`." - ) - - model: str = self.model - backend: str = self.backend - num_constraints = sum( [ bool(self.regex), @@ -223,6 +212,16 @@ def validate_environment(self) -> "ChatOutlines": "Either none or exactly one of regex, type_constraints, " "json_schema, or grammar can be provided." ) + return self.build_client() + + def build_client(self) -> "ChatOutlines": + try: + import outlines.models as models + except ImportError: + raise ImportError( + "Could not import the Outlines library. " + "Please install it with `pip install outlines`." + ) def check_packages_installed( packages: List[Union[str, Tuple[str, str]]], @@ -240,38 +239,40 @@ def check_packages_installed( f" pip install {' '.join(missing_packages)}" ) - if backend == "llamacpp": + if self.backend == "llamacpp": check_packages_installed([("llama-cpp-python", "llama_cpp")]) - if ".gguf" in model: - creator, repo_name, file_name = model.split("/", 2) + if ".gguf" in self.model: + creator, repo_name, file_name = self.model.split("/", 2) repo_id = f"{creator}/{repo_name}" else: raise ValueError("GGUF file_name must be provided for llama.cpp.") self.client = models.llamacpp(repo_id, file_name, **self.model_kwargs) - elif backend == "transformers": + elif self.backend == "transformers": check_packages_installed(["transformers", "torch", "datasets"]) - self.client = models.transformers(model_name=model, **self.model_kwargs) - elif backend == "transformers_vision": + self.client = models.transformers( + model_name=self.model, **self.model_kwargs + ) + elif self.backend == "transformers_vision": if hasattr(models, "transformers_vision"): from transformers import LlavaNextForConditionalGeneration self.client = models.transformers_vision( - model, + self.model, model_class=LlavaNextForConditionalGeneration, **self.model_kwargs, ) else: raise ValueError("transformers_vision backend is not supported") - elif backend == "vllm": + elif self.backend == "vllm": if platform.system() == "Darwin": raise ValueError("vLLM backend is not supported on macOS.") check_packages_installed(["vllm"]) - self.client = models.vllm(model, **self.model_kwargs) - elif backend == "mlxlm": + self.client = models.vllm(self.model, **self.model_kwargs) + elif self.backend == "mlxlm": check_packages_installed(["mlx"]) - self.client = models.mlxlm(model, **self.model_kwargs) + self.client = models.mlxlm(self.model, **self.model_kwargs) else: - raise ValueError(f"Unsupported backend: {backend}") + raise ValueError(f"Unsupported backend: {self.backend}") return self @property diff --git a/libs/community/langchain_community/llms/outlines.py b/libs/community/langchain_community/llms/outlines.py index cf38e6ccf3195..df03150c11490 100644 --- a/libs/community/langchain_community/llms/outlines.py +++ b/libs/community/langchain_community/llms/outlines.py @@ -142,17 +142,6 @@ class Outlines(LLM): @model_validator(mode="after") def validate_environment(self) -> "Outlines": """Validate that outlines is installed and create a model instance.""" - try: - import outlines.models as models - except ImportError: - raise ImportError( - "Could not import the Outlines library. " - "Please install it with `pip install outlines`." - ) - - model: str = self.model - backend: str = self.backend - num_constraints = sum( [ bool(self.regex), @@ -166,6 +155,16 @@ def validate_environment(self) -> "Outlines": "Either none or exactly one of regex, type_constraints, " "json_schema, or grammar can be provided." ) + return self.build_client() + + def build_client(self) -> "Outlines": + try: + import outlines.models as models + except ImportError: + raise ImportError( + "Could not import the Outlines library. " + "Please install it with `pip install outlines`." + ) def check_packages_installed( packages: List[Union[str, Tuple[str, str]]], @@ -183,18 +182,18 @@ def check_packages_installed( f" pip install {' '.join(missing_packages)}" ) - if backend == "llamacpp": - if ".gguf" in model: - creator, repo_name, file_name = model.split("/", 2) + if self.backend == "llamacpp": + if ".gguf" in self.model: + creator, repo_name, file_name = self.model.split("/", 2) repo_id = f"{creator}/{repo_name}" else: # todo add auto-file-selection if no file is given raise ValueError("GGUF file_name must be provided for llama.cpp.") check_packages_installed([("llama-cpp-python", "llama_cpp")]) self.client = models.llamacpp(repo_id, file_name, **self.model_kwargs) - elif backend == "transformers": + elif self.backend == "transformers": check_packages_installed(["transformers", "torch", "datasets"]) - self.client = models.transformers(model, **self.model_kwargs) - elif backend == "transformers_vision": + self.client = models.transformers(self.model, **self.model_kwargs) + elif self.backend == "transformers_vision": check_packages_installed( ["transformers", "datasets", "torchvision", "PIL", "flash_attn"] ) @@ -206,22 +205,20 @@ def check_packages_installed( "please install the correct outlines version." ) self.client = models.transformers_vision( - model, + self.model, model_class=LlavaNextForConditionalGeneration, **self.model_kwargs, ) - elif backend == "vllm": + elif self.backend == "vllm": if platform.system() == "Darwin": raise ValueError("vLLM backend is not supported on macOS.") check_packages_installed(["vllm"]) - self.client = models.vllm(model, **self.model_kwargs) - elif backend == "mlxlm": + self.client = models.vllm(self.model, **self.model_kwargs) + elif self.backend == "mlxlm": check_packages_installed(["mlx"]) - self.client = models.mlxlm(model, **self.model_kwargs) + self.client = models.mlxlm(self.model, **self.model_kwargs) else: - raise ValueError(f"Unsupported backend: {backend}") - - return self + raise ValueError(f"Unsupported backend: {self.backend}") @property def _llm_type(self) -> str: diff --git a/libs/community/tests/unit_tests/chat_models/test_outlines.py b/libs/community/tests/unit_tests/chat_models/test_outlines.py new file mode 100644 index 0000000000000..84f6f22355a96 --- /dev/null +++ b/libs/community/tests/unit_tests/chat_models/test_outlines.py @@ -0,0 +1,90 @@ +import pytest +from pydantic import BaseModel, Field + +from langchain_community.chat_models.outlines import ChatOutlines + + +def test_chat_outlines_initialization(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + + chat = ChatOutlines( + model="microsoft/Phi-3-mini-4k-instruct", + max_tokens=42, + stop=["\n"], + ) + assert chat.model == "microsoft/Phi-3-mini-4k-instruct" + assert chat.max_tokens == 42 + assert chat.backend == "transformers" + assert chat.stop == ["\n"] + + +def test_chat_outlines_backend_llamacpp(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + chat = ChatOutlines( + model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf", + backend="llamacpp", + ) + assert chat.backend == "llamacpp" + + +def test_chat_outlines_backend_vllm(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm") + assert chat.backend == "vllm" + + +def test_chat_outlines_backend_mlxlm(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm") + assert chat.backend == "mlxlm" + + +def test_chat_outlines_with_regex(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + regex = r"\d{3}-\d{3}-\d{4}" + chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex) + assert chat.regex == regex + + +def test_chat_outlines_with_type_constraints(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int) + assert chat.type_constraints == int # noqa + + +def test_chat_outlines_with_json_schema(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + + class TestSchema(BaseModel): + name: str = Field(description="A person's name") + age: int = Field(description="A person's age") + + chat = ChatOutlines( + model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema + ) + assert chat.json_schema == TestSchema + + +def test_chat_outlines_with_grammar(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + + grammar = """ +?start: expression +?expression: term (("+" | "-") term)* +?term: factor (("*" | "/") factor)* +?factor: NUMBER | "-" factor | "(" expression ")" +%import common.NUMBER + """ + chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar) + assert chat.grammar == grammar + + +def test_raise_for_multiple_output_constraints(monkeypatch): + monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self) + + with pytest.raises(ValueError): + ChatOutlines( + model="microsoft/Phi-3-mini-4k-instruct", + type_constraints=int, + regex=r"\d{3}-\d{3}-\d{4}", + ) diff --git a/libs/community/tests/unit_tests/llms/test_outlines.py b/libs/community/tests/unit_tests/llms/test_outlines.py new file mode 100644 index 0000000000000..dc6294608e695 --- /dev/null +++ b/libs/community/tests/unit_tests/llms/test_outlines.py @@ -0,0 +1,91 @@ +import pytest + +from langchain_community.llms.outlines import Outlines + + +def test_outlines_initialization(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + + llm = Outlines( + model="microsoft/Phi-3-mini-4k-instruct", + max_tokens=42, + stop=["\n"], + ) + assert llm.model == "microsoft/Phi-3-mini-4k-instruct" + assert llm.max_tokens == 42 + assert llm.backend == "transformers" + assert llm.stop == ["\n"] + + +def test_outlines_backend_llamacpp(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + llm = Outlines( + model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf", + backend="llamacpp", + ) + assert llm.backend == "llamacpp" + + +def test_outlines_backend_vllm(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm") + assert llm.backend == "vllm" + + +def test_outlines_backend_mlxlm(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm") + assert llm.backend == "mlxlm" + + +def test_outlines_with_regex(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + regex = r"\d{3}-\d{3}-\d{4}" + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex) + assert llm.regex == regex + + +def test_outlines_with_type_constraints(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int) + assert llm.type_constraints == int # noqa + + +def test_outlines_with_json_schema(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + from pydantic import BaseModel, Field + + class TestSchema(BaseModel): + name: str = Field(description="A person's name") + age: int = Field(description="A person's age") + + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema) + assert llm.json_schema == TestSchema + + +def test_outlines_with_grammar(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + grammar = """ + ?start: expression + ?expression: term (("+" | "-") term)* + ?term: factor (("*" | "/") factor)* + ?factor: NUMBER | "-" factor | "(" expression ")" + %import common.NUMBER + """ + llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar) + assert llm.grammar == grammar + + +def test_raise_for_multiple_output_constraints(monkeypatch): + monkeypatch.setattr(Outlines, "build_client", lambda self: self) + with pytest.raises(ValueError): + Outlines( + model="microsoft/Phi-3-mini-4k-instruct", + type_constraints=int, + regex=r"\d{3}-\d{3}-\d{4}", + ) + Outlines( + model="microsoft/Phi-3-mini-4k-instruct", + type_constraints=int, + regex=r"\d{3}-\d{3}-\d{4}", + )