Skip to content

Commit

Permalink
feat: Add model_kwargs to ExtractiveReader to impact model loading (#…
Browse files Browse the repository at this point in the history
…6257)

* Add ability to pass model_kwargs to AutoModelForQuestionAnswering

* Add testing for new model_kwargs

* Add spacing

* Add release notes

* Update haystack/preview/components/readers/extractive.py

Co-authored-by: Stefano Fiorucci <[email protected]>

* Make changes suggested by Stefano

---------

Co-authored-by: Stefano Fiorucci <[email protected]>
  • Loading branch information
sjrl and anakin87 authored Nov 9, 2023
1 parent cd429a7 commit 71d0d92
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 4 deletions.
14 changes: 10 additions & 4 deletions haystack/preview/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ def __init__(
answers_per_seq: Optional[int] = None,
no_answer: bool = True,
calibration_factor: float = 0.1,
model_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""
Creates an ExtractiveReader
:param model: A HuggingFace transformers question answering model.
:param model_name_or_path: A HuggingFace transformers question answering model.
Can either be a path to a folder containing the model files or an identifier for the HF hub
Default: `'deepset/roberta-base-squad2-distilled'`
:param device: Pytorch device string. Uses GPU by default if available
Expand All @@ -68,6 +69,8 @@ def __init__(
This is relevant when a document has been split into multiple sequence due to max_seq_length.
:param no_answer: Whether to return no answer scores
:param calibration_factor: Factor used for calibrating confidence scores
:param model_kwargs: Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
when loading the model specified in `model_name_or_path`.
"""
torch_and_transformers_import.check()
self.model_name_or_path = str(model_name_or_path)
Expand All @@ -82,6 +85,7 @@ def __init__(
self.answers_per_seq = answers_per_seq
self.no_answer = no_answer
self.calibration_factor = calibration_factor
self.model_kwargs = model_kwargs or {}

def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Expand All @@ -106,6 +110,7 @@ def to_dict(self) -> Dict[str, Any]:
answers_per_seq=self.answers_per_seq,
no_answer=self.no_answer,
calibration_factor=self.calibration_factor,
model_kwargs=self.model_kwargs,
)

def warm_up(self):
Expand All @@ -120,9 +125,10 @@ def warm_up(self):
self.device = self.device or "mps:0"
else:
self.device = self.device or "cpu:0"
self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_name_or_path, token=self.token).to(
self.device
)

self.model = AutoModelForQuestionAnswering.from_pretrained(
self.model_name_or_path, token=self.token, **self.model_kwargs
).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=self.token)

def _flatten_documents(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add new variable model_kwargs to the ExtractiveReader so we can pass different loading options supported by
HuggingFace.
25 changes: 25 additions & 0 deletions test/preview/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ def forward(self, input_ids, attention_mask, *args, **kwargs):

@pytest.mark.unit
def test_to_dict():
component = ExtractiveReader("my-model", token="secret-token", model_kwargs={"torch_dtype": "auto"})
data = component.to_dict()

assert data == {
"type": "ExtractiveReader",
"init_parameters": {
"model_name_or_path": "my-model",
"device": None,
"token": None, # don't serialize valid tokens
"top_k": 20,
"confidence_threshold": None,
"max_seq_length": 384,
"stride": 128,
"max_batch_size": None,
"answers_per_seq": None,
"no_answer": True,
"calibration_factor": 0.1,
"model_kwargs": {"torch_dtype": "auto"},
},
}


@pytest.mark.unit
def test_to_dict_empty_model_kwargs():
component = ExtractiveReader("my-model", token="secret-token")
data = component.to_dict()

Expand All @@ -106,6 +130,7 @@ def test_to_dict():
"answers_per_seq": None,
"no_answer": True,
"calibration_factor": 0.1,
"model_kwargs": {},
},
}

Expand Down

0 comments on commit 71d0d92

Please sign in to comment.