-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from Maazaowski/feature/UnitTesting
Merging UnitTesting feature to main
- Loading branch information
Showing
9 changed files
with
214 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
name: Python Tests | ||
|
||
on: | ||
push: | ||
branches: [ main ] | ||
pull_request: | ||
branches: [ main ] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: [3.8, 3.9, '3.10'] | ||
|
||
steps: | ||
- uses: actions/checkout@v2 | ||
|
||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v2 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install dependencies | ||
run: | | ||
sudo apt-get install python3-tk | ||
sudo apt-get install xvfb | ||
python -m pip install --upgrade pip | ||
pip install -r requirements.txt | ||
pip install pytest pytest-cov | ||
- name: Run tests with coverage | ||
run: | | ||
xvfb-run --auto-servernum pytest tests/ --cov=src/ --cov-report=xml | ||
- name: Upload coverage reports | ||
uses: codecov/codecov-action@v3 | ||
with: | ||
file: ./coverage.xml | ||
flags: unittests |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
transformers[torch] | ||
accelerate>=0.26.0 | ||
datasets | ||
ttkbootstrap | ||
psutil | ||
torch | ||
pytest | ||
pytest-cov | ||
pytest-mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .training.test_trainer import TestModelTrainer | ||
from .ui.test_training_window import TestTrainingWindow | ||
|
||
__all__ = [ | ||
'TestModelTrainer', | ||
'TestTrainingWindow' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import unittest | ||
import sys | ||
import os | ||
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
|
||
def run_tests(): | ||
# Discover and run tests | ||
loader = unittest.TestLoader() | ||
start_dir = os.path.dirname(os.path.abspath(__file__)) | ||
suite = loader.discover(start_dir, pattern='test_*.py') | ||
|
||
runner = unittest.TextTestRunner(verbosity=2) | ||
result = runner.run(suite) | ||
return result.wasSuccessful() | ||
|
||
if __name__ == '__main__': | ||
success = run_tests() | ||
sys.exit(0 if success else 1) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
from src.training.trainer import ModelTrainer | ||
from transformers import TrainingArguments | ||
import warnings | ||
|
||
# Add this at the top of your test file | ||
warnings.filterwarnings("ignore", category=DeprecationWarning) | ||
|
||
class TestModelTrainer(unittest.TestCase): | ||
def setUp(self): | ||
self.mock_window = MagicMock() | ||
self.trainer = ModelTrainer(self.mock_window) | ||
|
||
def test_progress_tracking(self): | ||
"""Test progress updates during training""" | ||
initial_progress = self.trainer.current_progress | ||
self.trainer.update_progress(25.0) | ||
self.assertEqual(self.trainer.current_progress, initial_progress + 25.0) | ||
|
||
@patch('src.training.trainer.load_dataset') | ||
def test_dataset_loading(self, mock_load_dataset): | ||
"""Test dataset loading functionality""" | ||
mock_dataset = MagicMock() | ||
mock_dataset.__getitem__.return_value = MagicMock() | ||
mock_load_dataset.return_value = mock_dataset | ||
|
||
self.trainer.load_dataset() | ||
mock_load_dataset.assert_called_once() | ||
|
||
@patch('src.training.trainer.GPT2LMHeadModel') | ||
def test_model_initialization(self, mock_gpt2): | ||
"""Test model initialization""" | ||
self.trainer.initialize_model() | ||
mock_gpt2.from_pretrained.assert_called_once_with("gpt2") | ||
|
||
def test_training_args_integration(self): | ||
"""Test training arguments integration""" | ||
args = TrainingArguments( | ||
output_dir="./test_output", | ||
num_train_epochs=3, | ||
per_device_train_batch_size=8, | ||
per_device_eval_batch_size=8 | ||
) | ||
|
||
with patch('src.training.trainer.Trainer') as mock_trainer: | ||
self.trainer.train(args) | ||
mock_trainer.assert_called() | ||
|
||
def test_logger_initialization(self): | ||
"""Test logger initialization""" | ||
self.assertIsNotNone(self.trainer.logger) | ||
self.assertEqual(self.trainer.current_progress, 0) | ||
|
||
def test_tokenizer_initialization(self): | ||
"""Test tokenizer initialization""" | ||
self.assertIsNotNone(self.trainer.tokenizer) | ||
|
||
def test_progress_limits(self): | ||
"""Test progress update boundaries""" | ||
self.trainer.current_progress = 0 | ||
self.trainer.update_progress(150) | ||
self.assertEqual(self.trainer.current_progress, 100) | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import unittest | ||
from src.ui.training_window import TrainingProgressWindow | ||
import tkinter as tk | ||
|
||
class TestTrainingWindow(unittest.TestCase): | ||
def test_window_creation(self): | ||
window = TrainingProgressWindow() | ||
self.assertIsNotNone(window) | ||
window.close() | ||
|
||
@classmethod | ||
def setUpClass(cls): | ||
cls.window = TrainingProgressWindow() | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
try: | ||
cls.window.close() | ||
except: | ||
pass # Window might already be destroyed | ||
|
||
def test_metrics_update(self): | ||
"""Test if training metrics update correctly""" | ||
self.window.update_training_metrics( | ||
epoch=1, | ||
loss=0.5, | ||
accuracy=75.5, | ||
learning_rate=0.001, | ||
total_epochs=3 | ||
) | ||
|
||
self.assertEqual(self.window.epoch_label.cget("text"), "1.00/3") | ||
self.assertEqual(self.window.training_loss_label.cget("text"), "0.5000") | ||
self.assertEqual(self.window.accuracy_label.cget("text"), "75.50%") | ||
self.assertEqual(self.window.learning_rate_label.cget("text"), "0.0010") | ||
|
||
def test_progress_bar(self): | ||
"""Test progress bar updates""" | ||
test_values = [0, 25, 50, 75, 100] | ||
for value in test_values: | ||
self.window.update_progress(value) | ||
self.assertEqual(self.window.progress_bar['value'], value) | ||
|
||
def test_log_messages(self): | ||
"""Test log message functionality""" | ||
test_message = "Test training started" | ||
self.window.update_log("TEST", test_message) | ||
log_content = self.window.log_box.get("1.0", tk.END) | ||
self.assertIn(test_message, log_content) | ||
|
||
def test_system_metrics_initialization(self): | ||
"""Test system metrics display initialization""" | ||
self.assertIsNotNone(self.window.cpu_label) | ||
self.assertIsNotNone(self.window.memory_label) | ||
self.assertIsNotNone(self.window.disk_label) | ||
|
||
def test_training_suspension(self): | ||
"""Test training suspension functionality""" | ||
self.window.suspend_training() | ||
self.assertTrue(self.window.training_suspended) | ||
|
||
def test_training_termination(self): | ||
"""Test training termination functionality""" | ||
self.window.terminate_training() | ||
self.assertTrue(self.window.training_terminated) |