Skip to content

Commit

Permalink
feat: Add model kwargs to SentenceTransformersRanker (#6627)
Browse files Browse the repository at this point in the history
* Add model_kwargs to SentenceTransformersRanker

* Add unit test

* Add unit tests for the torch dtype extraction

* Add release notes

* Fix formatting

* Fix patch

* Make function more explicit
  • Loading branch information
sjrl authored Dec 22, 2023
1 parent 28c0c01 commit 5ff81c2
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 2 deletions.
3 changes: 2 additions & 1 deletion haystack/nodes/prompt/invocation_layer/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.nodes.prompt.invocation_layer.handlers import HFTokenStreamingHandler
from haystack.utils.torch_utils import resolve_torch_dtype

class StopWordsCriteria(StoppingCriteria):
"""
Expand Down Expand Up @@ -177,7 +178,7 @@ def _prepare_pipeline_kwargs(self, **kwargs) -> Dict[str, Any]:
device_map = kwargs.get("device_map", None)
device = kwargs.get("device") if device_map is None else None
# prepare torch_dtype for pipeline invocation
torch_dtype = self._extract_torch_dtype(**kwargs)
torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype"))
# and the model (prefer model instance over model_name_or_path str identifier)
model = kwargs.get("model") or kwargs.get("model_name_or_path")
trust_remote_code = kwargs.get("trust_remote_code", False)
Expand Down
12 changes: 11 additions & 1 deletion haystack/nodes/ranker/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn import DataParallel
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
from haystack.utils.torch_utils import resolve_torch_dtype


class SentenceTransformersRanker(BaseRanker):
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
progress_bar: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
embed_meta_fields: Optional[List[str]] = None,
model_kwargs: Optional[dict] = None,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
Expand Down Expand Up @@ -90,8 +92,16 @@ def __init__(
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)

self.progress_bar = progress_bar
self.model_kwargs = model_kwargs
kwargs = model_kwargs if model_kwargs else {}
torch_dtype = resolve_torch_dtype(kwargs.get("torch_dtype"))
if torch_dtype:
kwargs["torch_dtype"] = torch_dtype
self.transformer_model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name_or_path=model_name_or_path, revision=model_version, use_auth_token=use_auth_token
pretrained_model_name_or_path=model_name_or_path,
revision=model_version,
use_auth_token=use_auth_token,
**kwargs,
)
self.transformer_model.to(str(self.devices[0]))
self.transformer_tokenizer = AutoTokenizer.from_pretrained(
Expand Down
22 changes: 22 additions & 0 deletions haystack/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,25 @@ def get_devices(devices: Optional[List[Union[str, torch.device]]]) -> List[torch
):
return [torch.device("mps")]
return [torch.device("cpu")]


def resolve_torch_dtype(torch_dtype: Optional[Union[str, "torch.dtype"]]) -> Optional["torch.dtype"]:
"""
Extract the torch dtype specified in kwargs. This function ensures the returned dtype is of a `torch.dtype` type.
"""
torch_dtype_resolved = None
if torch_dtype is not None:
if isinstance(torch_dtype, str):
if "torch." in torch_dtype:
torch_dtype_resolved = getattr(torch, torch_dtype.strip("torch."))
elif torch_dtype == "auto":
torch_dtype_resolved = torch_dtype
else:
raise ValueError(
f"torch_dtype should be a torch.dtype, a string with 'torch.' prefix or the string 'auto', got {torch_dtype}"
)
elif isinstance(torch_dtype, torch.dtype):
torch_dtype_resolved = torch_dtype
else:
raise ValueError(f"Invalid torch_dtype value {torch_dtype}")
return torch_dtype_resolved
5 changes: 5 additions & 0 deletions releasenotes/notes/ranker-model-kwargs-0f60508b69d7d46e.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add model_kwargs argument to SentenceTransformersRanker to be able to pass through HF transformers loading options
12 changes: 12 additions & 0 deletions test/nodes/test_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,18 @@ def test_ranker(docs, mock_transformer_model, mock_transformer_tokenizer):
assert results[0] == docs[4]


@pytest.mark.unit
def test_init_called_with():
with patch("haystack.nodes.SentenceTransformersRanker.__init__") as mock_ranker_init:
mock_ranker_init.return_value = None
_ = SentenceTransformersRanker(
model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16}
)
mock_ranker_init.assert_called_once_with(
model_name_or_path="fake_model", use_gpu=False, model_kwargs={"torch_dtype": torch.float16}
)


@pytest.mark.unit
def test_ranker_run(docs, mock_transformer_model, mock_transformer_tokenizer):
with patch("torch.nn.DataParallel"):
Expand Down
29 changes: 29 additions & 0 deletions test/utils/test_torch_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pytest
import torch

from haystack.utils.torch_utils import resolve_torch_dtype


def test_extract_torch_dtype() -> None:
torch_dtype = resolve_torch_dtype(**{"torch_dtype": torch.float16})
assert torch_dtype == torch.float16


def test_extract_torch_dtype_none() -> None:
torch_dtype = resolve_torch_dtype(**{})
assert torch_dtype is None


def test_extract_torch_dtype_str() -> None:
torch_dtype = resolve_torch_dtype(**{"torch_dtype": "torch.float16"})
assert torch_dtype == torch.float16


def test_extract_torch_dtype_auto() -> None:
torch_dtype = resolve_torch_dtype(**{"torch_dtype": "auto"})
assert torch_dtype == "auto"


def test_extract_torch_dtype_invalid() -> None:
with pytest.raises(ValueError):
_ = resolve_torch_dtype(**{"torch_dtype": "random string"})

0 comments on commit 5ff81c2

Please sign in to comment.