Skip to content

Commit

Permalink
chore!: Rename model_path to model in the Llama.cpp integration (#243)
Browse files Browse the repository at this point in the history
* rename model_path to model

* fix tests

* black
  • Loading branch information
ZanSara authored Jan 22, 2024
1 parent cd78080 commit ed92810
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class LlamaCppGenerator:
Usage example:
```python
from llama_cpp_haystack import LlamaCppGenerator
generator = LlamaCppGenerator(model_path="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)
generator = LlamaCppGenerator(model="zephyr-7b-beta.Q4_0.gguf", n_ctx=2048, n_batch=512)
print(generator.run("Who is the best American actor?", generation_kwargs={"max_tokens": 128}))
# {'replies': ['John Cusack'], 'meta': [{"object": "text_completion", ...}]}
Expand All @@ -26,23 +26,23 @@ class LlamaCppGenerator:

def __init__(
self,
model_path: str,
model: str,
n_ctx: Optional[int] = 0,
n_batch: Optional[int] = 512,
model_kwargs: Optional[Dict[str, Any]] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
):
"""
:param model_path: The path of a quantized model for text generation,
:param model: The path of a quantized model for text generation,
for example, "zephyr-7b-beta.Q4_0.gguf".
If the model_path is also specified in the `model_kwargs`, this parameter will be ignored.
If the model path is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_ctx: The number of tokens in the context. When set to 0, the context will be taken from the model.
If the n_ctx is also specified in the `model_kwargs`, this parameter will be ignored.
:param n_batch: Prompt processing maximum batch size. Defaults to 512.
If the n_batch is also specified in the `model_kwargs`, this parameter will be ignored.
:param model_kwargs: Dictionary containing keyword arguments used to initialize the LLM for text generation.
These keyword arguments provide fine-grained control over the model loading.
In case of duplication, these kwargs override `model_path`, `n_ctx`, and `n_batch` init parameters.
In case of duplication, these kwargs override `model`, `n_ctx`, and `n_batch` init parameters.
See Llama.cpp's [documentation](https://llama-cpp-python.readthedocs.io/en/latest/api-reference/#llama_cpp.Llama.__init__)
for more information on the available kwargs.
:param generation_kwargs: A dictionary containing keyword arguments to customize text generation.
Expand All @@ -56,11 +56,11 @@ def __init__(

# check if the huggingface_pipeline_kwargs contain the essential parameters
# otherwise, populate them with values from init parameters
model_kwargs.setdefault("model_path", model_path)
model_kwargs.setdefault("model_path", model)
model_kwargs.setdefault("n_ctx", n_ctx)
model_kwargs.setdefault("n_batch", n_batch)

self.model_path = model_path
self.model_path = model
self.n_ctx = n_ctx
self.n_batch = n_batch
self.model_kwargs = model_kwargs
Expand Down
20 changes: 8 additions & 12 deletions integrations/llama_cpp/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,22 @@ def generator(self, model_path, capsys):
download_file(ggml_model_path, str(model_path / filename), capsys)

model_path = str(model_path / filename)
generator = LlamaCppGenerator(model_path=model_path, n_ctx=128, n_batch=128)
generator = LlamaCppGenerator(model=model_path, n_ctx=128, n_batch=128)
generator.warm_up()
return generator

@pytest.fixture
def generator_mock(self):
mock_model = MagicMock()
generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=2048, n_batch=512)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=2048, n_batch=512)
generator.model = mock_model
return generator, mock_model

def test_default_init(self):
"""
Test default initialization parameters.
"""
generator = LlamaCppGenerator(model_path="test_model.gguf")
generator = LlamaCppGenerator(model="test_model.gguf")

assert generator.model_path == "test_model.gguf"
assert generator.n_ctx == 0
Expand All @@ -68,7 +68,7 @@ def test_custom_init(self):
Test custom initialization parameters.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf",
model="test_model.gguf",
n_ctx=2048,
n_batch=512,
)
Expand All @@ -84,7 +84,7 @@ def test_ignores_model_path_if_specified_in_model_kwargs(self):
Test that model_path is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf",
model="test_model.gguf",
n_ctx=512,
n_batch=512,
model_kwargs={"model_path": "other_model.gguf"},
Expand All @@ -95,25 +95,21 @@ def test_ignores_n_ctx_if_specified_in_model_kwargs(self):
"""
Test that n_ctx is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024}
)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_ctx": 1024})
assert generator.model_kwargs["n_ctx"] == 1024

def test_ignores_n_batch_if_specified_in_model_kwargs(self):
"""
Test that n_batch is ignored if already specified in model_kwargs.
"""
generator = LlamaCppGenerator(
model_path="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024}
)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512, model_kwargs={"n_batch": 1024})
assert generator.model_kwargs["n_batch"] == 1024

def test_raises_error_without_warm_up(self):
"""
Test that the generator raises an error if warm_up() is not called before running.
"""
generator = LlamaCppGenerator(model_path="test_model.gguf", n_ctx=512, n_batch=512)
generator = LlamaCppGenerator(model="test_model.gguf", n_ctx=512, n_batch=512)
with pytest.raises(RuntimeError):
generator.run("What is the capital of China?")

Expand Down

0 comments on commit ed92810

Please sign in to comment.