Skip to content

Commit

Permalink
add unit-tests back by mocking build_client
Browse files Browse the repository at this point in the history
  • Loading branch information
shroominic committed Nov 10, 2024
1 parent 0b8a0aa commit 3e45658
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 48 deletions.
47 changes: 24 additions & 23 deletions libs/community/langchain_community/chat_models/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,17 +199,6 @@ class ChatOutlines(BaseChatModel):
@model_validator(mode="after")
def validate_environment(self) -> "ChatOutlines":
"""Validate that outlines is installed and create a model instance."""
try:
import outlines.models as models
except ImportError:
raise ImportError(
"Could not import the Outlines library. "
"Please install it with `pip install outlines`."
)

model: str = self.model
backend: str = self.backend

num_constraints = sum(
[
bool(self.regex),
Expand All @@ -223,6 +212,16 @@ def validate_environment(self) -> "ChatOutlines":
"Either none or exactly one of regex, type_constraints, "
"json_schema, or grammar can be provided."
)
return self.build_client()

def build_client(self) -> "ChatOutlines":
try:
import outlines.models as models
except ImportError:
raise ImportError(
"Could not import the Outlines library. "
"Please install it with `pip install outlines`."
)

def check_packages_installed(
packages: List[Union[str, Tuple[str, str]]],
Expand All @@ -240,38 +239,40 @@ def check_packages_installed(
f" pip install {' '.join(missing_packages)}"
)

if backend == "llamacpp":
if self.backend == "llamacpp":
check_packages_installed([("llama-cpp-python", "llama_cpp")])
if ".gguf" in model:
creator, repo_name, file_name = model.split("/", 2)
if ".gguf" in self.model:
creator, repo_name, file_name = self.model.split("/", 2)
repo_id = f"{creator}/{repo_name}"
else:
raise ValueError("GGUF file_name must be provided for llama.cpp.")
self.client = models.llamacpp(repo_id, file_name, **self.model_kwargs)
elif backend == "transformers":
elif self.backend == "transformers":
check_packages_installed(["transformers", "torch", "datasets"])
self.client = models.transformers(model_name=model, **self.model_kwargs)
elif backend == "transformers_vision":
self.client = models.transformers(
model_name=self.model, **self.model_kwargs
)
elif self.backend == "transformers_vision":
if hasattr(models, "transformers_vision"):
from transformers import LlavaNextForConditionalGeneration

self.client = models.transformers_vision(
model,
self.model,
model_class=LlavaNextForConditionalGeneration,
**self.model_kwargs,
)
else:
raise ValueError("transformers_vision backend is not supported")
elif backend == "vllm":
elif self.backend == "vllm":
if platform.system() == "Darwin":
raise ValueError("vLLM backend is not supported on macOS.")
check_packages_installed(["vllm"])
self.client = models.vllm(model, **self.model_kwargs)
elif backend == "mlxlm":
self.client = models.vllm(self.model, **self.model_kwargs)
elif self.backend == "mlxlm":
check_packages_installed(["mlx"])
self.client = models.mlxlm(model, **self.model_kwargs)
self.client = models.mlxlm(self.model, **self.model_kwargs)
else:
raise ValueError(f"Unsupported backend: {backend}")
raise ValueError(f"Unsupported backend: {self.backend}")
return self

@property
Expand Down
47 changes: 22 additions & 25 deletions libs/community/langchain_community/llms/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,6 @@ class Outlines(LLM):
@model_validator(mode="after")
def validate_environment(self) -> "Outlines":
"""Validate that outlines is installed and create a model instance."""
try:
import outlines.models as models
except ImportError:
raise ImportError(
"Could not import the Outlines library. "
"Please install it with `pip install outlines`."
)

model: str = self.model
backend: str = self.backend

num_constraints = sum(
[
bool(self.regex),
Expand All @@ -166,6 +155,16 @@ def validate_environment(self) -> "Outlines":
"Either none or exactly one of regex, type_constraints, "
"json_schema, or grammar can be provided."
)
return self.build_client()

def build_client(self) -> "Outlines":
try:
import outlines.models as models
except ImportError:
raise ImportError(
"Could not import the Outlines library. "
"Please install it with `pip install outlines`."
)

def check_packages_installed(
packages: List[Union[str, Tuple[str, str]]],
Expand All @@ -183,18 +182,18 @@ def check_packages_installed(
f" pip install {' '.join(missing_packages)}"
)

if backend == "llamacpp":
if ".gguf" in model:
creator, repo_name, file_name = model.split("/", 2)
if self.backend == "llamacpp":
if ".gguf" in self.model:
creator, repo_name, file_name = self.model.split("/", 2)
repo_id = f"{creator}/{repo_name}"
else: # todo add auto-file-selection if no file is given
raise ValueError("GGUF file_name must be provided for llama.cpp.")
check_packages_installed([("llama-cpp-python", "llama_cpp")])
self.client = models.llamacpp(repo_id, file_name, **self.model_kwargs)
elif backend == "transformers":
elif self.backend == "transformers":
check_packages_installed(["transformers", "torch", "datasets"])
self.client = models.transformers(model, **self.model_kwargs)
elif backend == "transformers_vision":
self.client = models.transformers(self.model, **self.model_kwargs)
elif self.backend == "transformers_vision":
check_packages_installed(
["transformers", "datasets", "torchvision", "PIL", "flash_attn"]
)
Expand All @@ -206,22 +205,20 @@ def check_packages_installed(
"please install the correct outlines version."
)
self.client = models.transformers_vision(
model,
self.model,
model_class=LlavaNextForConditionalGeneration,
**self.model_kwargs,
)
elif backend == "vllm":
elif self.backend == "vllm":
if platform.system() == "Darwin":
raise ValueError("vLLM backend is not supported on macOS.")
check_packages_installed(["vllm"])
self.client = models.vllm(model, **self.model_kwargs)
elif backend == "mlxlm":
self.client = models.vllm(self.model, **self.model_kwargs)
elif self.backend == "mlxlm":
check_packages_installed(["mlx"])
self.client = models.mlxlm(model, **self.model_kwargs)
self.client = models.mlxlm(self.model, **self.model_kwargs)
else:
raise ValueError(f"Unsupported backend: {backend}")

return self
raise ValueError(f"Unsupported backend: {self.backend}")

@property
def _llm_type(self) -> str:
Expand Down
90 changes: 90 additions & 0 deletions libs/community/tests/unit_tests/chat_models/test_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
from pydantic import BaseModel, Field

from langchain_community.chat_models.outlines import ChatOutlines


def test_chat_outlines_initialization(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)

chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=42,
stop=["\n"],
)
assert chat.model == "microsoft/Phi-3-mini-4k-instruct"
assert chat.max_tokens == 42
assert chat.backend == "transformers"
assert chat.stop == ["\n"]


def test_chat_outlines_backend_llamacpp(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(
model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf",
backend="llamacpp",
)
assert chat.backend == "llamacpp"


def test_chat_outlines_backend_vllm(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm")
assert chat.backend == "vllm"


def test_chat_outlines_backend_mlxlm(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm")
assert chat.backend == "mlxlm"


def test_chat_outlines_with_regex(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
regex = r"\d{3}-\d{3}-\d{4}"
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex)
assert chat.regex == regex


def test_chat_outlines_with_type_constraints(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int)
assert chat.type_constraints == int # noqa


def test_chat_outlines_with_json_schema(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)

class TestSchema(BaseModel):
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")

chat = ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema
)
assert chat.json_schema == TestSchema


def test_chat_outlines_with_grammar(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)

grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
"""
chat = ChatOutlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar)
assert chat.grammar == grammar


def test_raise_for_multiple_output_constraints(monkeypatch):
monkeypatch.setattr(ChatOutlines, "build_client", lambda self: self)

with pytest.raises(ValueError):
ChatOutlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)
91 changes: 91 additions & 0 deletions libs/community/tests/unit_tests/llms/test_outlines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pytest

from langchain_community.llms.outlines import Outlines


def test_outlines_initialization(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)

llm = Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
max_tokens=42,
stop=["\n"],
)
assert llm.model == "microsoft/Phi-3-mini-4k-instruct"
assert llm.max_tokens == 42
assert llm.backend == "transformers"
assert llm.stop == ["\n"]


def test_outlines_backend_llamacpp(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(
model="TheBloke/Llama-2-7B-Chat-GGUF/llama-2-7b-chat.Q4_K_M.gguf",
backend="llamacpp",
)
assert llm.backend == "llamacpp"


def test_outlines_backend_vllm(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="vllm")
assert llm.backend == "vllm"


def test_outlines_backend_mlxlm(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", backend="mlxlm")
assert llm.backend == "mlxlm"


def test_outlines_with_regex(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
regex = r"\d{3}-\d{3}-\d{4}"
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", regex=regex)
assert llm.regex == regex


def test_outlines_with_type_constraints(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", type_constraints=int)
assert llm.type_constraints == int # noqa


def test_outlines_with_json_schema(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
from pydantic import BaseModel, Field

class TestSchema(BaseModel):
name: str = Field(description="A person's name")
age: int = Field(description="A person's age")

llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", json_schema=TestSchema)
assert llm.json_schema == TestSchema


def test_outlines_with_grammar(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
grammar = """
?start: expression
?expression: term (("+" | "-") term)*
?term: factor (("*" | "/") factor)*
?factor: NUMBER | "-" factor | "(" expression ")"
%import common.NUMBER
"""
llm = Outlines(model="microsoft/Phi-3-mini-4k-instruct", grammar=grammar)
assert llm.grammar == grammar


def test_raise_for_multiple_output_constraints(monkeypatch):
monkeypatch.setattr(Outlines, "build_client", lambda self: self)
with pytest.raises(ValueError):
Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)
Outlines(
model="microsoft/Phi-3-mini-4k-instruct",
type_constraints=int,
regex=r"\d{3}-\d{3}-\d{4}",
)

0 comments on commit 3e45658

Please sign in to comment.