From 30fc80a13d23bd25257a523bab070af19cda816b Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 9 Dec 2024 17:53:37 +0100 Subject: [PATCH] chore: fixing `pylint` issues (#8610) * initial import * fixing internal methods * fixing some internal methods * modify _preprocess * fixed internal methods --------- Co-authored-by: anakin87 --- .../zero_shot_document_classifier.py | 2 +- .../embedders/azure_document_embedder.py | 2 +- .../embedders/azure_text_embedder.py | 2 +- .../embedders/openai_text_embedder.py | 2 +- .../evaluators/context_relevance.py | 2 +- .../components/evaluators/faithfulness.py | 2 +- .../components/evaluators/llm_evaluator.py | 2 +- haystack/components/generators/azure.py | 2 +- haystack/components/generators/chat/azure.py | 2 +- .../generators/chat/hugging_face_local.py | 4 +-- haystack/components/rankers/meta_field.py | 5 +-- haystack/components/readers/extractive.py | 34 +++++++++++++------ .../routers/transformers_text_router.py | 2 +- .../routers/zero_shot_text_router.py | 2 +- .../in_memory/document_store.py | 4 +-- haystack/logging.py | 2 +- haystack/testing/factory.py | 2 +- haystack/utils/hf.py | 2 +- pyproject.toml | 1 + test/components/readers/test_extractive.py | 15 ++++++-- 20 files changed, 58 insertions(+), 33 deletions(-) diff --git a/haystack/components/classifiers/zero_shot_document_classifier.py b/haystack/components/classifiers/zero_shot_document_classifier.py index 5aa52fde80..4be0a66d44 100644 --- a/haystack/components/classifiers/zero_shot_document_classifier.py +++ b/haystack/components/classifiers/zero_shot_document_classifier.py @@ -73,7 +73,7 @@ class TransformersZeroShotDocumentClassifier: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, model: str, labels: List[str], diff --git a/haystack/components/embedders/azure_document_embedder.py b/haystack/components/embedders/azure_document_embedder.py index e60c8781b6..b28fc1fda1 100644 --- a/haystack/components/embedders/azure_document_embedder.py +++ b/haystack/components/embedders/azure_document_embedder.py @@ -34,7 +34,7 @@ class AzureOpenAIDocumentEmbedder: ``` """ - def __init__( # noqa: PLR0913 (too-many-arguments) + def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments self, azure_endpoint: Optional[str] = None, api_version: Optional[str] = "2023-05-15", diff --git a/haystack/components/embedders/azure_text_embedder.py b/haystack/components/embedders/azure_text_embedder.py index 961cd910ad..bef34d6c3f 100644 --- a/haystack/components/embedders/azure_text_embedder.py +++ b/haystack/components/embedders/azure_text_embedder.py @@ -33,7 +33,7 @@ class AzureOpenAITextEmbedder: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, azure_endpoint: Optional[str] = None, api_version: Optional[str] = "2023-05-15", diff --git a/haystack/components/embedders/openai_text_embedder.py b/haystack/components/embedders/openai_text_embedder.py index 4a2d9d3bee..ba73be2122 100644 --- a/haystack/components/embedders/openai_text_embedder.py +++ b/haystack/components/embedders/openai_text_embedder.py @@ -38,7 +38,7 @@ class OpenAITextEmbedder: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"), model: str = "text-embedding-ada-002", diff --git a/haystack/components/evaluators/context_relevance.py b/haystack/components/evaluators/context_relevance.py index c60bd0bd53..b91b8272a1 100644 --- a/haystack/components/evaluators/context_relevance.py +++ b/haystack/components/evaluators/context_relevance.py @@ -95,7 +95,7 @@ class ContextRelevanceEvaluator(LLMEvaluator): ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, examples: Optional[List[Dict[str, Any]]] = None, progress_bar: bool = True, diff --git a/haystack/components/evaluators/faithfulness.py b/haystack/components/evaluators/faithfulness.py index 8daf9dc0e0..fa58aa060f 100644 --- a/haystack/components/evaluators/faithfulness.py +++ b/haystack/components/evaluators/faithfulness.py @@ -82,7 +82,7 @@ class FaithfulnessEvaluator(LLMEvaluator): ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, examples: Optional[List[Dict[str, Any]]] = None, progress_bar: bool = True, diff --git a/haystack/components/evaluators/llm_evaluator.py b/haystack/components/evaluators/llm_evaluator.py index 39a80b93a5..458b0b5452 100644 --- a/haystack/components/evaluators/llm_evaluator.py +++ b/haystack/components/evaluators/llm_evaluator.py @@ -47,7 +47,7 @@ class LLMEvaluator: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, instructions: str, inputs: List[Tuple[str, Type[List]]], diff --git a/haystack/components/generators/azure.py b/haystack/components/generators/azure.py index 20bb2cda8e..f51a089895 100644 --- a/haystack/components/generators/azure.py +++ b/haystack/components/generators/azure.py @@ -55,7 +55,7 @@ class AzureOpenAIGenerator(OpenAIGenerator): """ # pylint: disable=super-init-not-called - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, azure_endpoint: Optional[str] = None, api_version: Optional[str] = "2023-05-15", diff --git a/haystack/components/generators/chat/azure.py b/haystack/components/generators/chat/azure.py index 445e580402..b74be533dc 100644 --- a/haystack/components/generators/chat/azure.py +++ b/haystack/components/generators/chat/azure.py @@ -62,7 +62,7 @@ class AzureOpenAIChatGenerator(OpenAIChatGenerator): """ # pylint: disable=super-init-not-called - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, azure_endpoint: Optional[str] = None, api_version: Optional[str] = "2023-05-15", diff --git a/haystack/components/generators/chat/hugging_face_local.py b/haystack/components/generators/chat/hugging_face_local.py index 1304ebd408..1244e3b954 100644 --- a/haystack/components/generators/chat/hugging_face_local.py +++ b/haystack/components/generators/chat/hugging_face_local.py @@ -71,7 +71,7 @@ class HuggingFaceLocalChatGenerator: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, model: str = "HuggingFaceH4/zephyr-7b-beta", task: Optional[Literal["text-generation", "text2text-generation"]] = None, @@ -295,7 +295,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str, ] return {"replies": chat_messages} - def create_message( + def create_message( # pylint: disable=too-many-positional-arguments self, text: str, index: int, diff --git a/haystack/components/rankers/meta_field.py b/haystack/components/rankers/meta_field.py index 3b28742a05..b5c46b9ec3 100644 --- a/haystack/components/rankers/meta_field.py +++ b/haystack/components/rankers/meta_field.py @@ -38,7 +38,7 @@ class MetaFieldRanker: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, meta_field: str, weight: float = 1.0, @@ -106,6 +106,7 @@ def __init__( def _validate_params( self, + *, weight: float, top_k: Optional[int], ranking_mode: Literal["reciprocal_rank_fusion", "linear_score"], @@ -156,7 +157,7 @@ def _validate_params( ) @component.output_types(documents=List[Document]) - def run( + def run( # pylint: disable=too-many-positional-arguments self, documents: List[Document], top_k: Optional[int] = None, diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 85d90a74cf..e9a376436e 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -51,7 +51,7 @@ class ExtractiveReader: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, model: Union[Path, str] = "deepset/roberta-base-squad2-distilled", device: Optional[ComponentDevice] = None, @@ -192,8 +192,9 @@ def warm_up(self): ) self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(self.model.hf_device_map)) + @staticmethod def _flatten_documents( - self, queries: List[str], documents: List[List[Document]] + queries: List[str], documents: List[List[Document]] ) -> Tuple[List[str], List[Document], List[int]]: """ Flattens queries and Documents so all query-document pairs are arranged along one batch axis. @@ -203,8 +204,8 @@ def _flatten_documents( query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_] return flattened_queries, flattened_documents, query_ids - def _preprocess( - self, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int + def _preprocess( # pylint: disable=too-many-positional-arguments + self, *, queries: List[str], documents: List[Document], max_seq_length: int, query_ids: List[int], stride: int ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", List["Encoding"], List[int], List[int]]: """ Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs. @@ -256,6 +257,7 @@ def _preprocess( def _postprocess( self, + *, start: "torch.Tensor", end: "torch.Tensor", sequence_ids: "torch.Tensor", @@ -285,9 +287,9 @@ def _postprocess( masked_logits = torch.where(mask, logits, -torch.inf) probabilities = torch.sigmoid(masked_logits * self.calibration_factor) - flat_probabilities = probabilities.flatten(-2, -1) # necessary for topk + flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k - # topk can return invalid candidates as well if answers_per_seq > num_valid_candidates + # top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates # We only keep probability > 0 candidates later on candidates = torch.topk(flat_probabilities, answers_per_seq) seq_length = logits.shape[-1] @@ -343,6 +345,7 @@ def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer: def _nest_answers( self, + *, start: List[List[int]], end: List[List[int]], probabilities: "torch.Tensor", @@ -526,7 +529,7 @@ def deduplicate_by_overlap( return deduplicated_answers @component.output_types(answers=List[ExtractedAnswer]) - def run( + def run( # pylint: disable=too-many-positional-arguments self, query: str, documents: List[Document], @@ -594,9 +597,15 @@ def run( no_answer = no_answer if no_answer is not None else self.no_answer overlap_threshold = overlap_threshold or self.overlap_threshold - flattened_queries, flattened_documents, query_ids = self._flatten_documents(queries, nested_documents) + flattened_queries, flattened_documents, query_ids = ExtractiveReader._flatten_documents( + queries, nested_documents + ) input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess( - flattened_queries, flattened_documents, max_seq_length, query_ids, stride + queries=flattened_queries, + documents=flattened_documents, + max_seq_length=max_seq_length, + query_ids=query_ids, + stride=stride, ) num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1 @@ -625,7 +634,12 @@ def run( end_logits = torch.cat(end_logits_list) start, end, probabilities = self._postprocess( - start_logits, end_logits, sequence_ids, attention_mask, answers_per_seq, encodings + start=start_logits, + end=end_logits, + sequence_ids=sequence_ids, + attention_mask=attention_mask, + answers_per_seq=answers_per_seq, + encodings=encodings, ) answers = self._nest_answers( diff --git a/haystack/components/routers/transformers_text_router.py b/haystack/components/routers/transformers_text_router.py index f5a9f99612..61b3dc9ce1 100644 --- a/haystack/components/routers/transformers_text_router.py +++ b/haystack/components/routers/transformers_text_router.py @@ -72,7 +72,7 @@ class TransformersTextRouter: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, model: str, labels: Optional[List[str]] = None, diff --git a/haystack/components/routers/zero_shot_text_router.py b/haystack/components/routers/zero_shot_text_router.py index 619862f2f2..7c551b5dc3 100644 --- a/haystack/components/routers/zero_shot_text_router.py +++ b/haystack/components/routers/zero_shot_text_router.py @@ -95,7 +95,7 @@ class TransformersZeroShotTextRouter: ``` """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, labels: List[str], multi_label: bool = False, diff --git a/haystack/document_stores/in_memory/document_store.py b/haystack/document_stores/in_memory/document_store.py index 31199ee9d4..ad469a002c 100644 --- a/haystack/document_stores/in_memory/document_store.py +++ b/haystack/document_stores/in_memory/document_store.py @@ -58,7 +58,7 @@ class InMemoryDocumentStore: Stores data in-memory. It's ephemeral and cannot be saved to disk. """ - def __init__( + def __init__( # pylint: disable=too-many-positional-arguments self, bm25_tokenization_regex: str = r"(?u)\b\w\w+\b", bm25_algorithm: Literal["BM25Okapi", "BM25L", "BM25Plus"] = "BM25L", @@ -539,7 +539,7 @@ def bm25_retrieval( return return_documents - def embedding_retrieval( + def embedding_retrieval( # pylint: disable=too-many-positional-arguments self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, diff --git a/haystack/logging.py b/haystack/logging.py index fb0149a032..53387af72c 100644 --- a/haystack/logging.py +++ b/haystack/logging.py @@ -188,7 +188,7 @@ def patch_make_records_to_use_kwarg_string_interpolation(original_make_records: """A decorator to ensure string interpolation is used.""" @functools.wraps(original_make_records) - def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any: + def _wrapper(name, level, fn, lno, msg, args, exc_info, func=None, extra=None, sinfo=None) -> Any: # pylint: disable=too-many-positional-arguments safe_extra = extra or {} try: interpolated_msg = msg.format(**safe_extra) diff --git a/haystack/testing/factory.py b/haystack/testing/factory.py index 29d8fa7738..94ee5da85f 100644 --- a/haystack/testing/factory.py +++ b/haystack/testing/factory.py @@ -124,7 +124,7 @@ def to_dict(self) -> Dict[str, Any]: return cls -def component_class( +def component_class( # pylint: disable=too-many-positional-arguments name: str, input_types: Optional[Dict[str, Any]] = None, output_types: Optional[Dict[str, Any]] = None, diff --git a/haystack/utils/hf.py b/haystack/utils/hf.py index de815245d3..8ef68065f5 100644 --- a/haystack/utils/hf.py +++ b/haystack/utils/hf.py @@ -166,7 +166,7 @@ def resolve_hf_device_map(device: Optional[ComponentDevice], model_kwargs: Optio return model_kwargs -def resolve_hf_pipeline_kwargs( +def resolve_hf_pipeline_kwargs( # pylint: disable=too-many-positional-arguments huggingface_pipeline_kwargs: Dict[str, Any], model: str, task: Optional[str], diff --git a/pyproject.toml b/pyproject.toml index f909ed3ae9..7c64f8a6f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -246,6 +246,7 @@ max-locals = 45 # Default is 15 max-module-lines = 2468 # Default is 1000 max-nested-blocks = 9 # Default is 5 max-statements = 206 # Default is 50 + [tool.pylint.'SIMILARITIES'] min-similarity-lines = 6 diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index 9c42c44254..aedfaa13bc 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -321,7 +321,7 @@ def test_flatten_documents(mock_reader: ExtractiveReader): def test_preprocess(mock_reader: ExtractiveReader): _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 3, example_documents[0], 384, [1, 1, 1], 0 + queries=example_queries * 3, documents=example_documents[0], max_seq_length=384, query_ids=[1, 1, 1], stride=0 ) expected_seq_ids = torch.full((3, 384), -1, dtype=torch.int) expected_seq_ids[:, :16] = 0 @@ -333,7 +333,11 @@ def test_preprocess(mock_reader: ExtractiveReader): def test_preprocess_splitting(mock_reader: ExtractiveReader): _, _, seq_ids, _, query_ids, doc_ids = mock_reader._preprocess( - example_queries * 4, example_documents[0] + [Document(content="a" * 64)], 96, [1, 1, 1, 1], 0 + queries=example_queries * 4, + documents=example_documents[0] + [Document(content="a" * 64)], + max_seq_length=96, + query_ids=[1, 1, 1, 1], + stride=0, ) assert seq_ids.shape[0] == 5 assert query_ids == [1, 1, 1, 1, 1] @@ -362,7 +366,12 @@ def test_postprocess(mock_reader: ExtractiveReader): encoding.token_to_chars = lambda i: (int(i), int(i) + 1) start_candidates, end_candidates, probs = mock_reader._postprocess( - start, end, sequence_ids, attention_mask, 3, [encoding, encoding] + start=start, + end=end, + sequence_ids=sequence_ids, + attention_mask=attention_mask, + answers_per_seq=3, + encodings=[encoding, encoding], ) assert len(start_candidates) == len(end_candidates) == len(probs) == 2