Skip to content

Commit

Permalink
Merge branch 'bedrock-embedders' into br-text-embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Feb 22, 2024
2 parents 80173dd + 274767d commit bed1996
Show file tree
Hide file tree
Showing 23 changed files with 1,662 additions and 89 deletions.
5 changes: 5 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ integration:opensearch:
- any-glob-to-any-file: "integrations/opensearch/**/*"
- any-glob-to-any-file: ".github/workflows/opensearch.yml"

integration:optimum:
- changed-files:
- any-glob-to-any-file: "integrations/optimum/**/*"
- any-glob-to-any-file: ".github/workflows/optimum.yml"

integration:pgvector:
- changed-files:
- any-glob-to-any-file: "integrations/pgvector/**/*"
Expand Down
60 changes: 60 additions & 0 deletions .github/workflows/optimum.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This workflow comes from https://github.com/ofek/hatch-mypyc
# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml
name: Test / optimum

on:
schedule:
- cron: "0 0 * * *"
pull_request:
paths:
- "integrations/optimum/**"
- ".github/workflows/optimum.yml"

defaults:
run:
working-directory: integrations/optimum

concurrency:
group: optimum-${{ github.head_ref }}
cancel-in-progress: true

env:
PYTHONUNBUFFERED: "1"
FORCE_COLOR: "1"

jobs:
run:
name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ["3.9", "3.10"]

steps:
- name: Support longpaths
if: matrix.os == 'windows-latest'
working-directory: .
run: git config --system core.longpaths true

- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install Hatch
run: pip install --upgrade hatch

- name: Lint
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run lint:all

- name: Generate docs
if: matrix.python-version == '3.9' && runner.os == 'Linux'
run: hatch run docs

- name: Run tests
run: hatch run cov
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from botocore.exceptions import ClientError
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils.auth import Secret, deserialize_secrets_inplace

Expand Down Expand Up @@ -192,7 +192,7 @@ def to_dict(self) -> Dict[str, Any]:
model=self.model,
stop_words=self.stop_words,
generation_kwargs=self.model_adapter.generation_kwargs,
streaming_callback=self.streaming_callback,
streaming_callback=serialize_callback_handler(self.streaming_callback),
)

@classmethod
Expand Down
10 changes: 1 addition & 9 deletions integrations/amazon_bedrock/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch

import pytest

Expand All @@ -12,14 +12,6 @@ def set_env_variables(monkeypatch):
monkeypatch.setenv("AWS_PROFILE", "some_fake_profile")


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
Expand Down
42 changes: 8 additions & 34 deletions integrations/amazon_bedrock/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, Type
from unittest.mock import MagicMock, patch

import pytest
from haystack.components.generators.utils import print_streaming_chunk
Expand All @@ -15,30 +14,7 @@
clazz = "haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator.AmazonBedrockChatGenerator"


@pytest.fixture
def mock_auto_tokenizer():
with patch("transformers.AutoTokenizer.from_pretrained", autospec=True) as mock_from_pretrained:
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
yield mock_tokenizer


# create a fixture with mocked boto3 client and session
@pytest.fixture
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler


def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
def test_to_dict(mock_boto3_session):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
Expand All @@ -58,14 +34,14 @@ def test_to_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
"model": "anthropic.claude-v2",
"generation_kwargs": {"temperature": 0.7},
"stop_words": [],
"streaming_callback": print_streaming_chunk,
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
},
}

assert generator.to_dict() == expected_dict


def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
def test_from_dict(mock_boto3_session):
"""
Test that the from_dict method returns the correct object
"""
Expand All @@ -89,7 +65,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session):
assert generator.streaming_callback == print_streaming_chunk


def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
def test_default_constructor(mock_boto3_session, set_env_variables):
"""
Test that the default constructor sets the correct values
"""
Expand All @@ -116,7 +92,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_va
)


def test_constructor_with_generation_kwargs(mock_auto_tokenizer, mock_boto3_session):
def test_constructor_with_generation_kwargs(mock_boto3_session):
"""
Test that model_kwargs are correctly set in the constructor
"""
Expand All @@ -135,8 +111,7 @@ def test_constructor_with_empty_model():
AmazonBedrockChatGenerator(model="")


@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
def test_invoke_with_no_kwargs(mock_boto3_session):
"""
Test invoke raises an error if no messages are provided
"""
Expand All @@ -145,7 +120,6 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
layer.invoke()


@pytest.mark.unit
@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand All @@ -168,7 +142,7 @@ def test_get_model_adapter(model: str, expected_model_adapter: Optional[Type[Bed


class TestAnthropicClaudeAdapter:
def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None:
def test_prepare_body_with_default_params(self) -> None:
layer = AnthropicClaudeChatAdapter(generation_kwargs={})
prompt = "Hello, how are you?"
expected_body = {
Expand All @@ -181,7 +155,7 @@ def test_prepare_body_with_default_params(self, mock_auto_tokenizer) -> None:

assert body == expected_body

def test_prepare_body_with_custom_inference_params(self, mock_auto_tokenizer) -> None:
def test_prepare_body_with_custom_inference_params(self) -> None:
layer = AnthropicClaudeChatAdapter(generation_kwargs={"temperature": 0.7, "top_p": 0.8, "top_k": 4})
prompt = "Hello, how are you?"
expected_body = {
Expand Down
22 changes: 6 additions & 16 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
)


@pytest.mark.unit
def test_to_dict(mock_boto3_session, set_env_variables):
def test_to_dict(mock_boto3_session):
"""
Test that the to_dict method returns the correct dictionary without aws credentials
"""
Expand All @@ -40,8 +39,7 @@ def test_to_dict(mock_boto3_session, set_env_variables):
assert generator.to_dict() == expected_dict


@pytest.mark.unit
def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
def test_from_dict(mock_boto3_session):
"""
Test that the from_dict method returns the correct object
"""
Expand All @@ -64,8 +62,7 @@ def test_from_dict(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
assert generator.model == "anthropic.claude-v2"


@pytest.mark.unit
def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_variables):
def test_default_constructor(mock_boto3_session, set_env_variables):
"""
Test that the default constructor sets the correct values
"""
Expand Down Expand Up @@ -94,8 +91,7 @@ def test_default_constructor(mock_auto_tokenizer, mock_boto3_session, set_env_va
)


@pytest.mark.unit
def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_session, mock_prompt_handler):
def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_handler):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
Expand All @@ -104,8 +100,7 @@ def test_constructor_prompt_handler_initialized(mock_auto_tokenizer, mock_boto3_
assert layer.prompt_handler.model_max_length == 4096


@pytest.mark.unit
def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
def test_constructor_with_model_kwargs(mock_boto3_session):
"""
Test that model_kwargs are correctly set in the constructor
"""
Expand All @@ -116,7 +111,6 @@ def test_constructor_with_model_kwargs(mock_auto_tokenizer, mock_boto3_session):
assert layer.model_adapter.model_kwargs["temperature"] == 0.7


@pytest.mark.unit
def test_constructor_with_empty_model():
"""
Test that the constructor raises an error when the model is empty
Expand All @@ -125,8 +119,7 @@ def test_constructor_with_empty_model():
AmazonBedrockGenerator(model="")


@pytest.mark.unit
def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
def test_invoke_with_no_kwargs(mock_boto3_session):
"""
Test invoke raises an error if no prompt is provided
"""
Expand All @@ -135,7 +128,6 @@ def test_invoke_with_no_kwargs(mock_auto_tokenizer, mock_boto3_session):
layer.invoke()


@pytest.mark.unit
def test_short_prompt_is_not_truncated(mock_boto3_session):
"""
Test that a short prompt is not truncated
Expand Down Expand Up @@ -166,7 +158,6 @@ def test_short_prompt_is_not_truncated(mock_boto3_session):
assert prompt_after_resize == mock_prompt_text


@pytest.mark.unit
def test_long_prompt_is_truncated(mock_boto3_session):
"""
Test that a long prompt is truncated
Expand Down Expand Up @@ -201,7 +192,6 @@ def test_long_prompt_is_truncated(mock_boto3_session):
assert prompt_after_resize == truncated_prompt_text


@pytest.mark.unit
@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
import sys
from typing import Any, Callable, Dict, List, Optional, cast

from haystack import DeserializationError, component, default_from_dict, default_to_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler
from haystack.dataclasses import StreamingChunk
from haystack.utils import Secret, deserialize_secrets_inplace

Expand Down Expand Up @@ -89,19 +89,10 @@ def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
"""
if self.streaming_callback:
module = self.streaming_callback.__module__
if module == "builtins":
callback_name = self.streaming_callback.__name__
else:
callback_name = f"{module}.{self.streaming_callback.__name__}"
else:
callback_name = None

return default_to_dict(
self,
model=self.model,
streaming_callback=callback_name,
streaming_callback=serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None,
api_base_url=self.api_base_url,
api_key=self.api_key.to_dict(),
**self.model_parameters,
Expand All @@ -114,20 +105,10 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator":
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
streaming_callback = None
if "streaming_callback" in init_params and init_params["streaming_callback"] is not None:
parts = init_params["streaming_callback"].split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
msg = f"Could not locate the module of the streaming callback: {module_name}"
raise DeserializationError(msg)
streaming_callback = getattr(module, function_name, None)
if not streaming_callback:
msg = f"Could not locate the streaming callback: {function_name}"
raise DeserializationError(msg)
data["init_parameters"]["streaming_callback"] = streaming_callback
data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(
init_params["streaming_callback"]
)
return default_from_dict(cls, data)

@component.output_types(replies=List[str], meta=List[Dict[str, Any]])
Expand Down
Loading

0 comments on commit bed1996

Please sign in to comment.