Skip to content

Commit

Permalink
Implement ModelTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
ti1uan committed Feb 17, 2024
1 parent 71a50c1 commit 261492b
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/model_trainer/model_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from snorkel.labeling.model import LabelModel
from snorkel.labeling import PandasLFApplier
from typing import Dict, Any
import logging
from labeling_function_lib.labeling_function_lib import LabelingFunctionLib
from pandas import DataFrame


class ModelTrainer:
DEFAULT_N_EPOCHS = 500

def __init__(self, config: Dict[str, Any] | None = None) -> None:
self.config = config or {}
self.logger = logging.getLogger(self.__class__.__name__)

def train(
self,
data: DataFrame,
lf_lib: LabelingFunctionLib,
cardinality: int
) -> LabelModel:
"""
Trains a LabelModel using the provided data and labeling function library.
Args:
data (DataFrame): The input data for training.
lf_lib (LabelingFunctionLib): The labeling function library.
cardinality (int): The cardinality of the label space.
Returns:
LabelModel: The trained LabelModel.
"""
self.logger.info("Applying labeling functions...")

applier = PandasLFApplier(lfs=lf_lib.get_all())
L_train = applier.apply(data)
model = LabelModel(cardinality=cardinality)
n_epochs = self.config.get('n_epochs', self.DEFAULT_N_EPOCHS)

self.logger.info(f"Training label model for {n_epochs}...")

model.fit(L_train=L_train,n_epochs=n_epochs)

self.logger.info("Training completed.")

return model
Empty file removed tests/model_trainer/.gitkeep
Empty file.
49 changes: 49 additions & 0 deletions tests/model_trainer/test_model_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest
from model_trainer.model_trainer import ModelTrainer
from pandas import DataFrame
from snorkel.labeling import labeling_function
from snorkel.labeling.model import LabelModel
from unittest.mock import patch, MagicMock


# Sample labeling functions

@labeling_function("func_1")
def labeling_function_1(x):
return 1

@labeling_function("func_2")
def labeling_function_2(x):
return 0

@labeling_function("func_3")
def labeling_function_3(x):
return 1

@pytest.fixture
def lf_lib_mock(mocker):
# Mock labeling function library
mock = MagicMock()
mock.get_all.return_value = [labeling_function_1, labeling_function_2, labeling_function_3]
return mock

@pytest.fixture
def data_mock():
# Mock a DataFrame as data
return DataFrame({"text": ["example1", "example2"]})

@pytest.fixture
def config_mock():
return {"n_epochs": 100}

def test_train_fits_label_model(lf_lib_mock, data_mock, config_mock):
# Initialize
trainer = ModelTrainer(config=config_mock)

try:
trained_model = trainer.train(data=data_mock, lf_lib=lf_lib_mock, cardinality=2)
except Exception as e:
pytest.fail(f"test_train_fits_label_model() raised an exception {e}")

# Verify return instance
assert isinstance(trained_model, LabelModel), "train method should return a LabelModel instance"

0 comments on commit 261492b

Please sign in to comment.