Skip to content

Commit

Permalink
refactor code, add train cmd line args
Browse files Browse the repository at this point in the history
  • Loading branch information
cwmeijer committed Oct 30, 2024
1 parent de7bb56 commit e38cf24
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 0 deletions.
36 changes: 36 additions & 0 deletions src/segmentmytiff/utils/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pathlib import Path

from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class MonochromeFlairDataset(Dataset):
def __init__(self, root_path, limit=None, split="train"):
self.root_path = root_path
self.limit = limit
self.images = sorted([str(p) for p in (Path(root_path) / split / "input").glob("*.tif")])[:self.limit]

def image_path_to_mask_path(image_path: Path) -> Path:
return image_path.parent.parent / "labels" / f"MSK{image_path.stem[3:-2]}_0{image_path.suffix}" # -2 for "_b" where b is band#

self.masks = [str(image_path_to_mask_path(Path(p))) for p in self.images][:self.limit]
non_existing_masks = [p for p in self.masks if Path(p).exists() == False]
if non_existing_masks:
print(f"{len(non_existing_masks)} of a total of {len(self.masks)} masks not found.")

self.transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()])

if self.limit is None:
self.limit = len(self.images)

def __getitem__(self, index):
img = Image.open(self.images[index]).convert("L")
mask = Image.open(self.masks[index]).convert("L")

return self.transform(img), self.transform(mask)

def __len__(self):
return min(len(self.images), self.limit)
75 changes: 75 additions & 0 deletions src/segmentmytiff/utils/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from torch import nn


class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)

def forward(self, x):
return self.conv_op(x)


class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

def forward(self, x):
down = self.conv(x)
p = self.pool(down)

return down, p


class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], 1)
return self.conv(x)


class UNet(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.down_convolution_1 = DownSample(in_channels, 64)
self.down_convolution_2 = DownSample(64, 128)
self.down_convolution_3 = DownSample(128, 256)
self.down_convolution_4 = DownSample(256, 512)

self.bottle_neck = DoubleConv(512, 1024)

self.up_convolution_1 = UpSample(1024, 512)
self.up_convolution_2 = UpSample(512, 256)
self.up_convolution_3 = UpSample(256, 128)
self.up_convolution_4 = UpSample(128, 64)

self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)

def forward(self, x):
down_1, p1 = self.down_convolution_1(x)
down_2, p2 = self.down_convolution_2(p1)
down_3, p3 = self.down_convolution_3(p2)
down_4, p4 = self.down_convolution_4(p3)

b = self.bottle_neck(p4)

up_1 = self.up_convolution_1(b, down_4)
up_2 = self.up_convolution_2(up_1, down_3)
up_3 = self.up_convolution_3(up_2, down_2)
up_4 = self.up_convolution_4(up_3, down_1)

out = self.out(up_4)
return out
14 changes: 14 additions & 0 deletions src/segmentmytiff/utils/performance_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch


def dice_coefficient(prediction, target, epsilon=1e-07):
prediction_copy = prediction.clone()

prediction_copy[prediction_copy < 0] = 0
prediction_copy[prediction_copy > 0] = 1

intersection = abs(torch.sum(prediction_copy * target))
union = abs(torch.sum(prediction_copy) + torch.sum(target))
dice = (2. * intersection + epsilon) / (union + epsilon)

return dice
161 changes: 161 additions & 0 deletions src/segmentmytiff/utils/train_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Based on, and parts copied from https://medium.com/@fernandopalominocobo/mastering-u-net-a-step-by-step-guide-to-segmentation-from-scratch-with-pytorch-6a17c5916114"""
import argparse
import time
from pathlib import Path

import torch
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

try:
import mlflow
mlflow_installed = True
except ImportError:
mlflow_installed = False

from segmentmytiff.utils.datasets import MonochromeFlairDataset
from segmentmytiff.utils.models import UNet
from segmentmytiff.utils.performance_metrics import dice_coefficient


def main(root_path, use_mlflow=True, train_set_limit=None):
if use_mlflow and not mlflow_installed:
raise Exception("Please install mlflow first or specify to run without mlflow.")

if use_mlflow:
mlflow.set_tracking_uri(uri="http://127.0.0.1:80")
mlflow.set_experiment("MLflow Quickstart")

if use_mlflow:
mlflow.start_run()
mlflow.set_tag("Training Info", "Segmentation model for ortho photos.")
train_val_dataset = MonochromeFlairDataset(root_path, limit=train_set_limit)
test_dataset = MonochromeFlairDataset(root_path, split="test")
generator = torch.Generator().manual_seed(0)
train_dataset, validation_dataset = random_split(train_val_dataset, [0.8, 0.2], generator=generator)

device = "cuda" if torch.cuda.is_available() else "cpu"

if device == "cuda":
num_workers = torch.cuda.device_count() * 4
else:
num_workers = 1

LEARNING_RATE = 3e-4
BATCH_SIZE = 8

if use_mlflow:
mlflow.log_param("LEARNING_RATE", LEARNING_RATE)
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
mlflow.log_param("device", device)
mlflow.log_param("num_workers", num_workers)

train_dataloader = DataLoader(dataset=train_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)
validation_dataloader = DataLoader(dataset=validation_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)

_test_dataloader = DataLoader(dataset=test_dataset,
num_workers=num_workers, pin_memory=False,
batch_size=BATCH_SIZE,
shuffle=True)

model = UNet(in_channels=1, num_classes=1).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()

torch.cuda.empty_cache()

train(model, train_dataloader, validation_dataloader, criterion, optimizer, device)

if use_mlflow:
mlflow.end_run()


def train(model, train_dataloader, validation_dataloader, criterion, optimizer, device, epochs = 10):
for epoch in tqdm(range(epochs)):
train_loss, train_dice = train_one_step(model, train_dataloader, optimizer, criterion, device)
val_loss, val_dice = validate(model, validation_dataloader, criterion, device)

print("-" * 30)
print(f"Training Loss EPOCH {epoch}: {train_loss:.4f}")
print(f"Training DICE EPOCH {epoch}: {train_dice:.4f}")
print("\n")
print(f"Validation Loss EPOCH {epoch}: {val_loss:.4f}")
print(f"Validation DICE EPOCH {epoch}: {val_dice:.4f}")
print("-" * 30)

metrics = {
"train_loss": train_loss,
"train_dc": train_dice,
"val_loss": val_loss,
"val_dc": val_dice,
}

mlflow.log_metrics(metrics, step=epoch, timestamp=int(round(time.time())))
# Saving the model
torch.save(model.state_dict(), 'my_checkpoint.pth')


def validate(model, validation_dataloader, criterion, device):
model.eval()
val_running_loss = 0
val_running_dc = 0
with torch.no_grad():
for idx, img_mask in enumerate(tqdm(validation_dataloader, position=0, leave=True)):
img = img_mask[0].float().to(device)
mask = img_mask[1].float().to(device)

y_pred = model(img)
loss = criterion(y_pred, mask)
dc = dice_coefficient(y_pred, mask)

val_running_loss += loss.item()
val_running_dc += dc.item()

val_loss = val_running_loss / (idx + 1)
val_dc = val_running_dc / (idx + 1)
return val_dc, val_loss


def train_one_step(model, train_dataloader, optimizer, criterion, device):
model.train()
train_running_loss = 0
train_running_dice = 0
for idx, img_mask in enumerate(tqdm(train_dataloader, position=0, leave=True)):
img = img_mask[0].float().to(device)
mask = img_mask[1].float().to(device)

y_pred = model(img)
optimizer.zero_grad()

dice = dice_coefficient(y_pred, mask)
loss = criterion(y_pred, mask)

train_running_loss += loss.item()
train_running_dice += dice.item()

loss.backward()
optimizer.step()
train_loss = train_running_loss / (idx + 1)
train_dice = train_running_dice / (idx + 1)
return train_loss, train_dice,


def parse_args():
parser = argparse.ArgumentParser(description="Process input and output TIFF files.")
parser.add_argument('-r', '--root', type=Path, required=True, help='Root to the dataset')
parser.add_argument('--no_mlflow', action='store_true', help='Flag for enabling or disabling MLflow')
parser.add_argument('--train_set_limit', type=int, default=None, help='Limit for the size of the train set (default: no limit)')
return parser.parse_args()


if __name__ == '__main__':
args = parse_args()
root_path = args.root
main(root_path, use_mlflow=not args.no_mlflow, train_set_limit=args.train_set_limit)

0 comments on commit e38cf24

Please sign in to comment.