Skip to content

Commit

Permalink
Merge pull request #2 from Maazaowski/feature/UnitTesting
Browse files Browse the repository at this point in the history
Merging UnitTesting feature to main
  • Loading branch information
Maazaowski authored Dec 10, 2024
2 parents cd89cb4 + e8d50b4 commit 715f7f8
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 5 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Python Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8, 3.9, '3.10']

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
sudo apt-get install python3-tk
sudo apt-get install xvfb
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov
- name: Run tests with coverage
run: |
xvfb-run --auto-servernum pytest tests/ --cov=src/ --cov-report=xml
- name: Upload coverage reports
uses: codecov/codecov-action@v3
with:
file: ./coverage.xml
flags: unittests
9 changes: 9 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
transformers[torch]
accelerate>=0.26.0
datasets
ttkbootstrap
psutil
torch
pytest
pytest-cov
pytest-mock
15 changes: 10 additions & 5 deletions src/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ def __init__(self, window=None):
self.current_progress = 0

def update_progress(self, increment: float):
self.current_progress += increment
self.logger.update_progress(min(self.current_progress, 100))
self.current_progress = min(self.current_progress + increment, 100)
self.logger.update_progress(self.current_progress)

def load_dataset(self):
self.logger.log("INFO", "Loading and preprocessing dataset...")
self.dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')
self.logger.log("INFO", f"Dataset loaded. {len(self.dataset['train'])} training examples.")
try:
self.logger.log("INFO", "Loading and preprocessing dataset...")
self.dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')
self.logger.log("INFO", f"Dataset loaded. {len(self.dataset['train'])} training examples.")

except Exception as e:
self.logger.log("ERROR", "Error loading dataset.")
raise

def initialize_model(self):
self.logger.log("INFO", "Initializing model...")
Expand Down
7 changes: 7 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .training.test_trainer import TestModelTrainer
from .ui.test_training_window import TestTrainingWindow

__all__ = [
'TestModelTrainer',
'TestTrainingWindow'
]
19 changes: 19 additions & 0 deletions tests/run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest
import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def run_tests():
# Discover and run tests
loader = unittest.TestLoader()
start_dir = os.path.dirname(os.path.abspath(__file__))
suite = loader.discover(start_dir, pattern='test_*.py')

runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
return result.wasSuccessful()

if __name__ == '__main__':
success = run_tests()
sys.exit(0 if success else 1)
Empty file added tests/training/__init__.py
Empty file.
64 changes: 64 additions & 0 deletions tests/training/test_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import unittest
from unittest.mock import MagicMock, patch
from src.training.trainer import ModelTrainer
from transformers import TrainingArguments
import warnings

# Add this at the top of your test file
warnings.filterwarnings("ignore", category=DeprecationWarning)

class TestModelTrainer(unittest.TestCase):
def setUp(self):
self.mock_window = MagicMock()
self.trainer = ModelTrainer(self.mock_window)

def test_progress_tracking(self):
"""Test progress updates during training"""
initial_progress = self.trainer.current_progress
self.trainer.update_progress(25.0)
self.assertEqual(self.trainer.current_progress, initial_progress + 25.0)

@patch('src.training.trainer.load_dataset')
def test_dataset_loading(self, mock_load_dataset):
"""Test dataset loading functionality"""
mock_dataset = MagicMock()
mock_dataset.__getitem__.return_value = MagicMock()
mock_load_dataset.return_value = mock_dataset

self.trainer.load_dataset()
mock_load_dataset.assert_called_once()

@patch('src.training.trainer.GPT2LMHeadModel')
def test_model_initialization(self, mock_gpt2):
"""Test model initialization"""
self.trainer.initialize_model()
mock_gpt2.from_pretrained.assert_called_once_with("gpt2")

def test_training_args_integration(self):
"""Test training arguments integration"""
args = TrainingArguments(
output_dir="./test_output",
num_train_epochs=3,
per_device_train_batch_size=8,
per_device_eval_batch_size=8
)

with patch('src.training.trainer.Trainer') as mock_trainer:
self.trainer.train(args)
mock_trainer.assert_called()

def test_logger_initialization(self):
"""Test logger initialization"""
self.assertIsNotNone(self.trainer.logger)
self.assertEqual(self.trainer.current_progress, 0)

def test_tokenizer_initialization(self):
"""Test tokenizer initialization"""
self.assertIsNotNone(self.trainer.tokenizer)

def test_progress_limits(self):
"""Test progress update boundaries"""
self.trainer.current_progress = 0
self.trainer.update_progress(150)
self.assertEqual(self.trainer.current_progress, 100)

Empty file added tests/ui/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions tests/ui/test_training_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from src.ui.training_window import TrainingProgressWindow
import tkinter as tk

class TestTrainingWindow(unittest.TestCase):
def test_window_creation(self):
window = TrainingProgressWindow()
self.assertIsNotNone(window)
window.close()

@classmethod
def setUpClass(cls):
cls.window = TrainingProgressWindow()

@classmethod
def tearDownClass(cls):
try:
cls.window.close()
except:
pass # Window might already be destroyed

def test_metrics_update(self):
"""Test if training metrics update correctly"""
self.window.update_training_metrics(
epoch=1,
loss=0.5,
accuracy=75.5,
learning_rate=0.001,
total_epochs=3
)

self.assertEqual(self.window.epoch_label.cget("text"), "1.00/3")
self.assertEqual(self.window.training_loss_label.cget("text"), "0.5000")
self.assertEqual(self.window.accuracy_label.cget("text"), "75.50%")
self.assertEqual(self.window.learning_rate_label.cget("text"), "0.0010")

def test_progress_bar(self):
"""Test progress bar updates"""
test_values = [0, 25, 50, 75, 100]
for value in test_values:
self.window.update_progress(value)
self.assertEqual(self.window.progress_bar['value'], value)

def test_log_messages(self):
"""Test log message functionality"""
test_message = "Test training started"
self.window.update_log("TEST", test_message)
log_content = self.window.log_box.get("1.0", tk.END)
self.assertIn(test_message, log_content)

def test_system_metrics_initialization(self):
"""Test system metrics display initialization"""
self.assertIsNotNone(self.window.cpu_label)
self.assertIsNotNone(self.window.memory_label)
self.assertIsNotNone(self.window.disk_label)

def test_training_suspension(self):
"""Test training suspension functionality"""
self.window.suspend_training()
self.assertTrue(self.window.training_suspended)

def test_training_termination(self):
"""Test training termination functionality"""
self.window.terminate_training()
self.assertTrue(self.window.training_terminated)

0 comments on commit 715f7f8

Please sign in to comment.