Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: Rename model_path to model in the Llama.cpp integration #243

Merged
merged 3 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlamaCppGenerator:
Usage example:
```python
from llama_cpp_haystack import LlamaCppGenerator
generator = LlamaCppGenerator(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)
generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)

print(generator.run("Who is the best American actor?", generation_kwargs={"max_tokens": 128}))
# {'replies': ['John Cusack'], 'meta': [{"object": "text_completion", ...}]}
Expand All @@ -26,23 +26,23 @@ class LlamaCppGenerator:

def __init__(
self,
model_path: str,
model: str,
n_ctx: Optional[int] = 0,
n_batch: Optional[int] = 512,
model_kwargs: Optional[Dict[str, Any]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param model_path: The path of a quantized model for text generation,
:param model: The path of a quantized model for text generation,
for example, "zephyr-7b-beta.Q4_0.gguf".
If the model_path is also specified in the `model_kwargs`, this parameter will be ignored.
If the model path is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model.
If the n_ctx is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_batch: Prompt processing maximum batch size. Defaults to 512.
If the n_batch is also specified in the `model_kwargs`, this parameter will be ignored.
:param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation.
These keyword arguments provide fine-grained control over the model loading.
In case of duplication, these kwargs override `model_path`, `n_ctx`, and `n_batch` init parameters.
In case of duplication, these kwargs override `model`, `n_ctx`, and `n_batch` init parameters.
See Llama.cpp's [documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__)
for more information on the available kwargs.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Expand All @@ -56,11 +56,11 @@ def __init__(

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_path", model_path)
model_kwargs.setdefault("model_path", model)
model_kwargs.setdefault("n_ctx", n_ctx)
model_kwargs.setdefault("n_batch", n_batch)

self.model_path = model_path
self.model_path = model
self.n_ctx = n_ctx
self.n_batch = n_batch
self.model_kwargs = model_kwargs
Expand Down
20 changes: 8 additions & 12 deletions integrations/llama_cpp/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def generator(self, model_path, capsys):
download_file(ggml_model_path, str(model_path / filename), capsys)

model_path = str(model_path / filename)
generator = LlamaCppGenerator(model_path=model_path, n_ctx=128, n_batch=128)
generator = LlamaCppGenerator(model=model_path, n_ctx=128, n_batch=128)
generator.warm_up()
return generator

@pytest.fixture
def generator_mock(self):
mock_model = MagicMock()
generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=2048, n_batch=512)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512)
generator.model = mock_model
return generator, mock_model

def test_default_init(self):
"""
Test default initialization parameters.
"""
generator = LlamaCppGenerator(model_path="test_model.gguf")
generator = LlamaCppGenerator(model="test_model.gguf")

assert generator.model_path == "test_model.gguf"
assert generator.n_ctx == 0
Expand All @@ -68,7 +68,7 @@ def test_custom_init(self):
Test custom initialization parameters.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf",
model="test_model.gguf",
n_ctx=2048,
n_batch=512,
)
Expand All @@ -84,7 +84,7 @@ def test_ignores_model_path_if_specified_in_model_kwargs(self):
Test that model_path is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf",
model="test_model.gguf",
n_ctx=512,
n_batch=512,
model_kwargs={"model_path": "other_model.gguf"},
Expand All @@ -95,25 +95,21 @@ def test_ignores_n_ctx_if_specified_in_model_kwargs(self):
"""
Test that n_ctx is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024}
)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024})
assert generator.model_kwargs["n_ctx"] == 1024

def test_ignores_n_batch_if_specified_in_model_kwargs(self):
"""
Test that n_batch is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024}
)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024})
assert generator.model_kwargs["n_batch"] == 1024

def test_raises_error_without_warm_up(self):
"""
Test that the generator raises an error if warm_up() is not called before running.
"""
generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=512, n_batch=512)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512)
with pytest.raises(RuntimeError):
generator.run("What is the capital of China?")

Expand Down