Skip to content

Commit

Permalink
corrected tests and linting
Browse files Browse the repository at this point in the history
Signed-off-by: tanishq-ids <[email protected]>
  • Loading branch information
tanishq-ids committed Nov 7, 2024
1 parent 4666dab commit 6a33e65
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
)
import torch
import numpy as np
from datetime import datetime
from sklearn.model_selection import train_test_split


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
from unittest.mock import patch, mock_open, MagicMock
from pathlib import Path
import pandas as pd
import numpy as np
import torch
import pytest
from osc_transformer_based_extractor.relevance_detector.inference import (
get_inference,
get_batch_inference,
run_full_inference,
)

Expand All @@ -33,54 +34,43 @@
@patch(
"osc_transformer_based_extractor.relevance_detector.inference.AutoTokenizer.from_pretrained"
)
def test_get_inference(mock_tokenizer, mock_model):
"""Test the get_inference function for inference correctness."""
def test_get_batch_inference():
"""Test the get_batch_inference function for inference correctness."""

# Mock tokenizer and model
tokenizer_mock = MagicMock()
model_mock = MagicMock()
mock_tokenizer.return_value = tokenizer_mock
mock_model.return_value = model_mock

# Configure the tokenizer mock
tokenizer_mock.encode_plus.return_value = {
"input_ids": torch.tensor([[101, 102]]),
"attention_mask": torch.tensor([[1, 1]]),
# Configure the tokenizer mock for batch inputs
tokenizer_mock.return_value = {
"input_ids": torch.tensor([[101, 102], [101, 103]]),
"attention_mask": torch.tensor([[1, 1], [1, 1]]),
}

# Configure the model mock to return logits tensor
model_output_mock = MagicMock()
model_output_mock.logits = torch.tensor([[0.1, 0.9]])
model_output_mock.logits = torch.tensor([[0.1, 0.9], [0.7, 0.3]])
model_mock.return_value = model_output_mock

# Dummy question and context
question = "What is the capital of France?"
context = "Paris is the capital of France."
# Dummy questions and contexts
questions = ["What is the capital of France?", "What is the capital of Germany?"]
contexts = ["Paris is the capital of France.", "Berlin is the capital of Germany."]

# Test inference
predicted_label_id, class_prob = get_inference(
question, context, model_path_valid, tokenizer_path_valid, threshold=0.7
labels, positive_class_probs = get_batch_inference(
questions, contexts, model_mock, tokenizer_mock, device="cpu", threshold=0.7
)

# Assertions
assert isinstance(predicted_label_id, int)
assert isinstance(class_prob, float)

# Test different inputs
tokenizer_mock.encode_plus.return_value = {
"input_ids": torch.tensor([[101, 103]]),
"attention_mask": torch.tensor([[1, 1]]),
}
model_output_mock.logits = torch.tensor([[0.7, 0.3]])
predicted_label_id, class_prob = get_inference(
"What is the capital of Germany?",
"Berlin is the capital of Germany.",
model_path_valid,
tokenizer_path_valid,
threshold=0.7,
assert isinstance(labels, np.ndarray) and labels.dtype == int
assert (
isinstance(positive_class_probs, np.ndarray)
and positive_class_probs.dtype == float
)

assert isinstance(predicted_label_id, int)
assert isinstance(class_prob, float)
# Check correct shapes and values
assert labels.shape[0] == len(questions)
assert positive_class_probs.shape[0] == len(questions)


@pytest.fixture
Expand Down Expand Up @@ -135,7 +125,7 @@ def sample_merged_dataframe(sample_dataframe, sample_kpi_mapping):
@patch("pandas.DataFrame.to_excel")
def test_run_full_inference(
mock_to_excel,
mock_get_inference,
mock_get_batch_inference,
mock_listdir,
mock_read_csv,
mock_open,
Expand All @@ -149,32 +139,41 @@ def test_run_full_inference(
output_path = "output_folder"
model_path = "model_path"
tokenizer_path = "tokenizer_path"
batch_size = 2
threshold = 0.5

# Set up mock returns
mock_read_csv.return_value = sample_kpi_mapping
mock_listdir.return_value = ["test_file.json"]
mock_open.return_value.read = json.dumps(sample_json_data)
mock_get_inference.return_value = (1, 0.95)
mock_get_batch_inference.return_value = (
[1, 0],
[0.95, 0.3],
) # mock labels and probabilities

with patch("json.load", return_value=sample_json_data):
with patch("pandas.DataFrame.merge", return_value=sample_merged_dataframe):
run_full_inference(
folder_path,
kpi_mapping_path,
output_path,
model_path,
tokenizer_path,
threshold,
)
with patch("pandas.DataFrame.to_excel", mock_to_excel):
run_full_inference(
folder_path,
kpi_mapping_path,
output_path,
model_path,
tokenizer_path,
batch_size,
threshold,
)

# Assertions
mock_read_csv.assert_called_once_with(kpi_mapping_path)
mock_listdir.assert_called_once_with(folder_path)
mock_open.assert_called_once_with(Path(folder_path) / "test_file.json", "r")

assert mock_get_inference.call_count == len(sample_merged_dataframe)
# Check if batch inference was called with expected count
assert (
mock_get_batch_inference.call_count
== (len(sample_merged_dataframe) // batch_size) + 1
)
assert mock_to_excel.call_count == 1
output_file_path = Path(output_path) / "test_file.xlsx"
mock_to_excel.assert_called_once_with(output_file_path, index=False)

# Ensure no files were created
mock_open().write.assert_not_called()

0 comments on commit 6a33e65

Please sign in to comment.