Skip to content

🔥 Machine Learning framework allowing plug-and-play training for pytorch models

License

Notifications You must be signed in to change notification settings

peacefulotter/pyroml

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

94 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

🔥 pyro

Lightweight Machine Learning framework allowing plug-and-play training for Pytorch models

  • âš¡ Lightning inspired
  • 💾 Support for wandb and checkpoints out-of-the-box
  • 📊 Pretty logs, plots and support for metrics
  • ✨ Fully type-safe
  • 🪶 Lightweight and easy to use

Examples

See 📓 notebooks for examples using pyro. In particular, you can find:

  • Iris : Simplest example training a small MLP on the Iris dataset.
  • SmolVLM on Flowers102 : Features from SmolVLM vision model are extracted and used to train a linear classifier on the Flowers102 dataset, reaching a test accuracy of 98.6%.

Usage

You can use 🔥 pyro with minimal code changes and forever forget about writing training loops. Here is an example of a pyro model and training script to get you started.

1. Define your Model

import torch
import pyroml as p

class MySOTAModel(p.PyroModel):
    def __init__(self):
        super().__init__()
        self.loss_fn = torch.nn.MyLossFunction()

    # Optionally, configure your own optimizer and scheduler, see more in the docs
    def configure_optimizers(self, _):
        self.optimizer = torch.optim.AdamW(self.parameters(), lr=tr.lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=1, gamma=0.99
        )

    def step(self, batch, stage: p.Stage):
        # Extract data from your dataset batch
        # Batches and model are moved to the appropriate device automatically
        x, y = batch
        # Forward the model
        preds = self(x)
        # Compute the loss
        loss = self.loss_fn(preds, y)
        # Optionally, register some metrics
        self.log(loss=loss.item(), accuracy=compute_accuracy(preds, y))
        # Return loss when training, otherwise return predictions
        if stage == p.Stage.TRAIN:
            return loss
        return preds    

2. Instantiate a Trainer

trainer = p.Trainer(
    lr=0.01,
    max_epochs=32,
    batch_size=16,
    # And many other options such as device, precision, callbacks, ...
)

3. Run training, validation and testing

# Fit the model on given training set and evaluate the model during training  
train_tracker = trainer.fit(model, training_dataset, validation_dataset)
print(train_tracker.records)

# Plot metric curves registered during training 
train_tracker.plot(epoch=True)

# Evaluate your model after training
validation_tracker = trainer.evaluate(model, validation_dataset)
print(validation_tracker.records)

# Test your model on some testing set
_, test_preds = trainer.predict(model, test_dataset)
print("Test Predictions", test_preds)

Requirements

  • Python ^3.10 | ^3.11 | ^3.12
  • Recommended: Poetry v2 (docs)

Installation

pip

# CPU only version
pip install pyroml 
# OR with CUDA-enabled PyTorch and torchvision
pip install pyroml[cuda]
# Additional dependencies that you might require
pip install pyroml[extra] 

poetry

# CPU only version
poetry add pyroml 
# OR with CUDA-enabled PyTorch and torchvision
poetry add pyroml[cuda] --source pytorch-cu124 
# Additional dependencies that you might require
poetry add [...] --extras extra

Locally

# Clone the repo
git clone https://github.com/peacefulotter/pyroml.git
cd pyroml

# Install dependencies
poetry config virtualenvs.in-project true  
poetry install --with dev

Tests

Running tests has been made easy using pytest. First install the package and run the script:

poetry install --with test
./run_tests.sh

About

🔥 Machine Learning framework allowing plug-and-play training for pytorch models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published