Skip to content

Commit

Permalink
add docstrings for FineTuner class
Browse files Browse the repository at this point in the history
  • Loading branch information
borauyar committed Jun 27, 2024
1 parent 2af6142 commit 4fd9742
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,33 @@ def load_and_convert_config(self, config_path):
import random, copy, logging

class FineTuner(pl.LightningModule):
"""
FineTuner class is designed for fine-tuning trained flexynesis models with flexible control over parameters such as
learning rates and component freezing, utilizing cross-validation to optimize generalization.
This class allows the application of different configuration strategies to either freeze or unfreeze specific
model components, while also exploring different learning rates to find the optimal setting.
It carries out cross-validation to find the best combination of parameter freezing strategies and learning rates.
Attributes:
model (pl.LightningModule): The model instance to be fine-tuned.
dataset (Dataset): The dataset used for training and validation.
n_splits (int): Number of cross-validation splits.
batch_size (int): Batch size for training and validation.
learning_rates (list): List of learning rates to try during fine-tuning.
max_epoch (int): Maximum number of epochs for training.
freeze_configs (list of dicts): Configurations specifying which components of the model to freeze.
Methods:
apply_freeze_config(config): Apply a freezing configuration to the model components.
train_dataloader(): Returns a DataLoader for the training data of the current fold.
val_dataloader(): Returns a DataLoader for the validation data of the current fold.
training_step(batch, batch_idx): Executes a training step using the model's internal training logic.
validation_step(batch, batch_idx): Executes a validation step using the model's internal validation logic.
configure_optimizers(): Sets up the optimizer with the current learning rate and filtered trainable parameters.
run_experiments(): Executes the finetuning process across all configurations and learning rates, evaluates
using cross-validation, and selects the best configuration based on validation loss.
"""
def __init__(self, model, dataset, n_splits=5, batch_size=32, learning_rates=None, max_epoch = 50, freeze_configs = None):
super().__init__()
logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)
Expand Down

0 comments on commit 4fd9742

Please sign in to comment.