From 6a33e6565a3a0f63c35beaba6426d132dc63d89e Mon Sep 17 00:00:00 2001 From: tanishq-ids Date: Thu, 7 Nov 2024 21:09:26 +0530 Subject: [PATCH] corrected tests and linting Signed-off-by: tanishq-ids --- .../kpi_detection/train_kpi_detection.py | 1 - .../relevance_detector/test_inference.py | 91 +++++++++---------- 2 files changed, 45 insertions(+), 47 deletions(-) diff --git a/src/osc_transformer_based_extractor/kpi_detection/train_kpi_detection.py b/src/osc_transformer_based_extractor/kpi_detection/train_kpi_detection.py index 5060b5e..16c334c 100644 --- a/src/osc_transformer_based_extractor/kpi_detection/train_kpi_detection.py +++ b/src/osc_transformer_based_extractor/kpi_detection/train_kpi_detection.py @@ -34,7 +34,6 @@ ) import torch import numpy as np -from datetime import datetime from sklearn.model_selection import train_test_split diff --git a/tests/osc_transformer_based_extractor/relevance_detector/test_inference.py b/tests/osc_transformer_based_extractor/relevance_detector/test_inference.py index d32c90f..733e7bb 100644 --- a/tests/osc_transformer_based_extractor/relevance_detector/test_inference.py +++ b/tests/osc_transformer_based_extractor/relevance_detector/test_inference.py @@ -9,10 +9,11 @@ from unittest.mock import patch, mock_open, MagicMock from pathlib import Path import pandas as pd +import numpy as np import torch import pytest from osc_transformer_based_extractor.relevance_detector.inference import ( - get_inference, + get_batch_inference, run_full_inference, ) @@ -33,54 +34,43 @@ @patch( "osc_transformer_based_extractor.relevance_detector.inference.AutoTokenizer.from_pretrained" ) -def test_get_inference(mock_tokenizer, mock_model): - """Test the get_inference function for inference correctness.""" +def test_get_batch_inference(): + """Test the get_batch_inference function for inference correctness.""" + # Mock tokenizer and model tokenizer_mock = MagicMock() model_mock = MagicMock() - mock_tokenizer.return_value = tokenizer_mock - mock_model.return_value = model_mock - # Configure the tokenizer mock - tokenizer_mock.encode_plus.return_value = { - "input_ids": torch.tensor([[101, 102]]), - "attention_mask": torch.tensor([[1, 1]]), + # Configure the tokenizer mock for batch inputs + tokenizer_mock.return_value = { + "input_ids": torch.tensor([[101, 102], [101, 103]]), + "attention_mask": torch.tensor([[1, 1], [1, 1]]), } # Configure the model mock to return logits tensor model_output_mock = MagicMock() - model_output_mock.logits = torch.tensor([[0.1, 0.9]]) + model_output_mock.logits = torch.tensor([[0.1, 0.9], [0.7, 0.3]]) model_mock.return_value = model_output_mock - # Dummy question and context - question = "What is the capital of France?" - context = "Paris is the capital of France." + # Dummy questions and contexts + questions = ["What is the capital of France?", "What is the capital of Germany?"] + contexts = ["Paris is the capital of France.", "Berlin is the capital of Germany."] # Test inference - predicted_label_id, class_prob = get_inference( - question, context, model_path_valid, tokenizer_path_valid, threshold=0.7 + labels, positive_class_probs = get_batch_inference( + questions, contexts, model_mock, tokenizer_mock, device="cpu", threshold=0.7 ) # Assertions - assert isinstance(predicted_label_id, int) - assert isinstance(class_prob, float) - - # Test different inputs - tokenizer_mock.encode_plus.return_value = { - "input_ids": torch.tensor([[101, 103]]), - "attention_mask": torch.tensor([[1, 1]]), - } - model_output_mock.logits = torch.tensor([[0.7, 0.3]]) - predicted_label_id, class_prob = get_inference( - "What is the capital of Germany?", - "Berlin is the capital of Germany.", - model_path_valid, - tokenizer_path_valid, - threshold=0.7, + assert isinstance(labels, np.ndarray) and labels.dtype == int + assert ( + isinstance(positive_class_probs, np.ndarray) + and positive_class_probs.dtype == float ) - assert isinstance(predicted_label_id, int) - assert isinstance(class_prob, float) + # Check correct shapes and values + assert labels.shape[0] == len(questions) + assert positive_class_probs.shape[0] == len(questions) @pytest.fixture @@ -135,7 +125,7 @@ def sample_merged_dataframe(sample_dataframe, sample_kpi_mapping): @patch("pandas.DataFrame.to_excel") def test_run_full_inference( mock_to_excel, - mock_get_inference, + mock_get_batch_inference, mock_listdir, mock_read_csv, mock_open, @@ -149,32 +139,41 @@ def test_run_full_inference( output_path = "output_folder" model_path = "model_path" tokenizer_path = "tokenizer_path" + batch_size = 2 threshold = 0.5 + # Set up mock returns mock_read_csv.return_value = sample_kpi_mapping mock_listdir.return_value = ["test_file.json"] mock_open.return_value.read = json.dumps(sample_json_data) - mock_get_inference.return_value = (1, 0.95) + mock_get_batch_inference.return_value = ( + [1, 0], + [0.95, 0.3], + ) # mock labels and probabilities with patch("json.load", return_value=sample_json_data): with patch("pandas.DataFrame.merge", return_value=sample_merged_dataframe): - run_full_inference( - folder_path, - kpi_mapping_path, - output_path, - model_path, - tokenizer_path, - threshold, - ) + with patch("pandas.DataFrame.to_excel", mock_to_excel): + run_full_inference( + folder_path, + kpi_mapping_path, + output_path, + model_path, + tokenizer_path, + batch_size, + threshold, + ) + # Assertions mock_read_csv.assert_called_once_with(kpi_mapping_path) mock_listdir.assert_called_once_with(folder_path) mock_open.assert_called_once_with(Path(folder_path) / "test_file.json", "r") - assert mock_get_inference.call_count == len(sample_merged_dataframe) + # Check if batch inference was called with expected count + assert ( + mock_get_batch_inference.call_count + == (len(sample_merged_dataframe) // batch_size) + 1 + ) assert mock_to_excel.call_count == 1 output_file_path = Path(output_path) / "test_file.xlsx" mock_to_excel.assert_called_once_with(output_file_path, index=False) - - # Ensure no files were created - mock_open().write.assert_not_called()