Skip to content

Commit

Permalink
chore: fixing pylint issues (#8610)
Browse files Browse the repository at this point in the history
* initial import

* fixing internal methods

* fixing some internal methods

* modify _preprocess

* fixed internal methods

---------

Co-authored-by: anakin87 <[email protected]>
  • Loading branch information
2 people authored and julian-risch committed Jan 9, 2025
1 parent d3b39b7 commit 30fc80a
Show file tree
Hide file tree
Showing 20 changed files with 58 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class TransformersZeroShotDocumentClassifier:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str,
labels: List[str],
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/azure_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AzureOpenAIDocumentEmbedder:
```
"""

def __init__( # noqa: PLR0913 (too-many-arguments)
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/azure_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class AzureOpenAITextEmbedder:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/embedders/openai_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class OpenAITextEmbedder:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/context_relevance.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class ContextRelevanceEvaluator(LLMEvaluator):
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
examples: Optional[List[Dict[str, Any]]] = None,
progress_bar: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/faithfulness.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class FaithfulnessEvaluator(LLMEvaluator):
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
examples: Optional[List[Dict[str, Any]]] = None,
progress_bar: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/evaluators/llm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class LLMEvaluator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
instructions: str,
inputs: List[Tuple[str, Type[List]]],
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class AzureOpenAIGenerator(OpenAIGenerator):
"""

# pylint: disable=super-init-not-called
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator):
"""

# pylint: disable=super-init-not-called
def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
azure_endpoint: Optional[str] = None,
api_version: Optional[str] = "2023-05-15",
Expand Down
4 changes: 2 additions & 2 deletions haystack/components/generators/chat/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class HuggingFaceLocalChatGenerator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str = "HuggingFaceH4/zephyr-7b-beta",
task: Optional[Literal["text-generation", "text2text-generation"]] = None,
Expand Down Expand Up @@ -295,7 +295,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
]
return {"replies": chat_messages}

def create_message(
def create_message( # pylint: disable=too-many-positional-arguments
self,
text: str,
index: int,
Expand Down
5 changes: 3 additions & 2 deletions haystack/components/rankers/meta_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MetaFieldRanker:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
meta_field: str,
weight: float = 1.0,
Expand Down Expand Up @@ -106,6 +106,7 @@ def __init__(

def _validate_params(
self,
*,
weight: float,
top_k: Optional[int],
ranking_mode: Literal["reciprocal_rank_fusion", "linear_score"],
Expand Down Expand Up @@ -156,7 +157,7 @@ def _validate_params(
)

@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
documents: List[Document],
top_k: Optional[int] = None,
Expand Down
34 changes: 24 additions & 10 deletions haystack/components/readers/extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class ExtractiveReader:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: Union[Path, str] = "deepset/roberta-base-squad2-distilled",
device: Optional[ComponentDevice] = None,
Expand Down Expand Up @@ -192,8 +192,9 @@ def warm_up(self):
)
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map))

@staticmethod
def _flatten_documents(
self, queries: List[str], documents: List[List[Document]]
queries: List[str], documents: List[List[Document]]
) -> Tuple[List[str], List[Document], List[int]]:
"""
Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
Expand All @@ -203,8 +204,8 @@ def _flatten_documents(
query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_]
return flattened_queries, flattened_documents, query_ids

def _preprocess(
self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
def _preprocess( # pylint: disable=too-many-positional-arguments
self, *, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", List["Encoding"], List[int], List[int]]:
"""
Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
Expand Down Expand Up @@ -256,6 +257,7 @@ def _preprocess(

def _postprocess(
self,
*,
start: "torch.Tensor",
end: "torch.Tensor",
sequence_ids: "torch.Tensor",
Expand Down Expand Up @@ -285,9 +287,9 @@ def _postprocess(
masked_logits = torch.where(mask, logits, -torch.inf)
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)

flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk
flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k

# topk can return invalid candidates as well if answers_per_seq > num_valid_candidates
# top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
# We only keep probability > 0 candidates later on
candidates = torch.topk(flat_probabilities, answers_per_seq)
seq_length = logits.shape[-1]
Expand Down Expand Up @@ -343,6 +345,7 @@ def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:

def _nest_answers(
self,
*,
start: List[List[int]],
end: List[List[int]],
probabilities: "torch.Tensor",
Expand Down Expand Up @@ -526,7 +529,7 @@ def deduplicate_by_overlap(
return deduplicated_answers

@component.output_types(answers=List[ExtractedAnswer])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
documents: List[Document],
Expand Down Expand Up @@ -594,9 +597,15 @@ def run(
no_answer = no_answer if no_answer is not None else self.no_answer
overlap_threshold = overlap_threshold or self.overlap_threshold

flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents)
flattened_queries, flattened_documents, query_ids = ExtractiveReader._flatten_documents(
queries, nested_documents
)
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
flattened_queries, flattened_documents, max_seq_length, query_ids, stride
queries=flattened_queries,
documents=flattened_documents,
max_seq_length=max_seq_length,
query_ids=query_ids,
stride=stride,
)

num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
Expand Down Expand Up @@ -625,7 +634,12 @@ def run(
end_logits = torch.cat(end_logits_list)

start, end, probabilities = self._postprocess(
start_logits, end_logits, sequence_ids, attention_mask, answers_per_seq, encodings
start=start_logits,
end=end_logits,
sequence_ids=sequence_ids,
attention_mask=attention_mask,
answers_per_seq=answers_per_seq,
encodings=encodings,
)

answers = self._nest_answers(
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/routers/transformers_text_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class TransformersTextRouter:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
model: str,
labels: Optional[List[str]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/routers/zero_shot_text_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class TransformersZeroShotTextRouter:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
labels: List[str],
multi_label: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/in_memory/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class InMemoryDocumentStore:
Stores data in-memory. It's ephemeral and cannot be saved to disk.
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
bm25_tokenization_regex: str = r"(?u)\b\w\w+\b",
bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L",
Expand Down Expand Up @@ -539,7 +539,7 @@ def bm25_retrieval(

return return_documents

def embedding_retrieval(
def embedding_retrieval( # pylint: disable=too-many-positional-arguments
self,
query_embedding: List[float],
filters: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def patch_make_records_to_use_kwarg_string_interpolation(original_make_records:
"""A decorator to ensure string interpolation is used."""

@functools.wraps(original_make_records)
def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any:
def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any: # pylint: disable=too-many-positional-arguments
safe_extra = extra or {}
try:
interpolated_msg = msg.format(**safe_extra)
Expand Down
2 changes: 1 addition & 1 deletion haystack/testing/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def to_dict(self) -> Dict[str, Any]:
return cls


def component_class(
def component_class( # pylint: disable=too-many-positional-arguments
name: str,
input_types: Optional[Dict[str, Any]] = None,
output_types: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio
return model_kwargs


def resolve_hf_pipeline_kwargs(
def resolve_hf_pipeline_kwargs( # pylint: disable=too-many-positional-arguments
huggingface_pipeline_kwargs: Dict[str, Any],
model: str,
task: Optional[str],
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ max-locals = 45 # Default is 15
max-module-lines = 2468 # Default is 1000
max-nested-blocks = 9 # Default is 5
max-statements = 206 # Default is 50

[tool.pylint.'SIMILARITIES']
min-similarity-lines = 6

Expand Down
15 changes: 12 additions & 3 deletions test/components/readers/test_extractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ def test_flatten_documents(mock_reader: ExtractiveReader):

def test_preprocess(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 3, example_documents[0], 384, [1, 1, 1], 0
queries=example_queries * 3, documents=example_documents[0], max_seq_length=384, query_ids=[1, 1, 1], stride=0
)
expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int)
expected_seq_ids[:, :16] = 0
Expand All @@ -333,7 +333,11 @@ def test_preprocess(mock_reader: ExtractiveReader):

def test_preprocess_splitting(mock_reader: ExtractiveReader):
_, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess(
example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0
queries=example_queries * 4,
documents=example_documents[0] + [Document(content="a" * 64)],
max_seq_length=96,
query_ids=[1, 1, 1, 1],
stride=0,
)
assert seq_ids.shape[0] == 5
assert query_ids == [1, 1, 1, 1, 1]
Expand Down Expand Up @@ -362,7 +366,12 @@ def test_postprocess(mock_reader: ExtractiveReader):
encoding.token_to_chars = lambda i: (int(i), int(i) + 1)

start_candidates, end_candidates, probs = mock_reader._postprocess(
start, end, sequence_ids, attention_mask, 3, [encoding, encoding]
start=start,
end=end,
sequence_ids=sequence_ids,
attention_mask=attention_mask,
answers_per_seq=3,
encodings=[encoding, encoding],
)

assert len(start_candidates) == len(end_candidates) == len(probs) == 2
Expand Down

0 comments on commit 30fc80a

Please sign in to comment.