Skip to content

Commit

Permalink
minor refactor and addiing callback, metric and loss to experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
phborba committed Sep 21, 2020
1 parent 8a5f3ff commit 77074f8
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 32 deletions.
44 changes: 39 additions & 5 deletions segmentation_models_trainer/callbacks_loader/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
****
"""
import tensorflow as tf
from typing import Any, List
from segmentation_models_trainer.callbacks_loader.callback_factory import CallbackFactory
from dataclasses import dataclass
from dataclasses_jsonschema import JsonSchemaMixin

@dataclass
class Callback(JsonSchemaMixin):
name: str
parameters: dict
config: dict

def __post_init__(self):
pass
self.callback_obj = self.get_callback()

@staticmethod
def validate_callback_name(name):
Expand All @@ -42,21 +44,53 @@ def validate_callback_name(name):
def get_callback(self):
return CallbackFactory.get_callback(
self.name,
self.parameters
self.config
)

@dataclass
class CallbackList(JsonSchemaMixin):
items: List[Callback]

def get_tf_objects(self):
return [
i.get_callback() for i in items
]



if __name__ == '__main__':
import json
x = Callback(
name='ReduceLROnPlateau',
parameters= {
config= {
'monitor' : 'val_loss',
'factor' : 0.2,
'patience' : 5,
'min_lr' : 0.001
}
)
print(x.to_json())
print(json.dumps([x.to_json()]))
x.get_callback()
x
y= [
Callback.from_dict(
{
'name' : 'ReduceLROnPlateau',
'config' : {
'monitor' : 'val_loss',
'factor' : 0.2,
'patience' : 5,
'min_lr' : 0.001
}
}
),
Callback.from_dict(
{
'name' : 'ModelCheckpoint',
'config' : {'filepath' : '/data/teste'}
}
)
]
y
z = CallbackList(y)
print(z.to_json())
2 changes: 1 addition & 1 deletion segmentation_models_trainer/dataset_loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
****
"""

from dataclasses import dataclass
from dataclasses import dataclass, field
from dataclasses_jsonschema import JsonSchemaMixin
from typing import Any, List
from collections import OrderedDict
Expand Down
38 changes: 16 additions & 22 deletions segmentation_models_trainer/experiment_builder/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,16 @@
import segmentation_models as sm
import os
import numpy as np
import importlib
from typing import Any, List
from dataclasses import dataclass
from dataclasses_jsonschema import JsonSchemaMixin
from segmentation_models_trainer.model_builder.segmentation_model import SegmentationModel
from segmentation_models_trainer.hyperparameter_builder.hyperparameters import Hyperparameters
from segmentation_models_trainer.dataset_loader.dataset import Dataset
from segmentation_models_trainer.callbacks_loader.callback import Callback, CallbackList
from segmentation_models_trainer.experiment_builder.loss import Loss
from segmentation_models_trainer.experiment_builder.metric import Metric, MetricList

@dataclass
class Experiment(JsonSchemaMixin):
Expand All @@ -42,6 +47,9 @@ class Experiment(JsonSchemaMixin):
train_dataset: Dataset
test_dataset: Dataset
model: SegmentationModel
loss: Loss
callbacks: CallbackList
metrics: MetricList

def train(self):
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
Expand All @@ -65,6 +73,7 @@ def train(self):
test_steps_per_epoch = int( np.ceil(self.test_dataset.dataset_size / BATCH_SIZE) )

def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None):
callback_list = self.callbacks.get_tf_objects()
with strategy.scope():
model = self.model.get_model(
n_classes,
Expand All @@ -73,18 +82,11 @@ def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None):
)
opt = self.hyperparameters.optimizer.tf_object
#TODO metrics and loss fields into compile
metric_list = self.metrics.get_tf_objects()
model.compile(
opt,
loss=sm.losses.bce_jaccard_loss,
metrics=[
'accuracy',
'binary_crossentropy',
sm.metrics.iou_score,
sm.metrics.precision,
sm.metrics.recall,
sm.metrics.f1_score,
sm.metrics.f2_score
]
metrics=metric_list
)
model.fit(
train_ds,
Expand All @@ -93,7 +95,7 @@ def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None):
epochs=epochs,
validation_data=test_ds,
validation_steps=test_steps_per_epoch,
callbacks=[]
callbacks=callback_list
)
model.save_weights(
save_weights_path
Expand Down Expand Up @@ -151,17 +153,6 @@ def test_and_create_folder(path):
return path
os.makedirs(path, exist_ok=True)
return path



@dataclass
class Callbacks(JsonSchemaMixin):
name: str
keras_callback: str

@dataclass
class Metric(JsonSchemaMixin):
name: str

if __name__ == "__main__":
import json
Expand All @@ -186,7 +177,10 @@ class Metric(JsonSchemaMixin):

'test_dataset' : json.loads('{"name": "test_ds", "file_path": "/data/test_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}'),

'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}')
'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'),
'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'),
'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}'),
'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'),
}
)

Expand Down
85 changes: 85 additions & 0 deletions segmentation_models_trainer/experiment_builder/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
"""
/***************************************************************************
segmentation_models_trainer
-------------------
begin : 2020-09-21
git sha : $Format:%H$
copyright : (C) 2020 by Philipe Borba - Cartographic Engineer @ Brazilian Army
email : philipeborba at gmail dot com
***************************************************************************/
/***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
****
"""
import tensorflow as tf
import segmentation_models as sm
from dataclasses import dataclass
from dataclasses_jsonschema import JsonSchemaMixin

@dataclass
class Loss(JsonSchemaMixin):
class_name: str
config: dict
framework: str

def __post_init__(self):
if self.framework == 'sm':
self.loss_obj = self.get_sm_loss(self.class_name)
elif self.framework == 'tf.keras':
identifier = {
"class_name" : self.class_name,
"config" : self.config
}
self.loss_obj = tf.keras.losses.get(identifier)
else:
raise ValueError("Loss not implemented")


def get_sm_loss(self, name):
if self.class_name == 'jaccard_loss':
return sm.losses.JaccardLoss(**self.config)
elif self.class_name == 'dice_loss':
return sm.losses.DiceLoss(**self.config)
elif self.class_name == 'binary_focal_loss':
return sm.losses.BinaryFocalLoss(**self.config)
elif self.class_name == 'categorical_focal_loss':
return sm.losses.CategoricalFocalLoss(**self.config)
elif self.class_name == 'binary_crossentropy':
return sm.losses.BinaryCELoss(**self.config)
elif self.class_name == 'categorical_crossentropy':
return sm.losses.CategoricalCELoss(**self.config)
elif self.class_name == 'bce_dice_loss':
return sm.losses.BinaryCELoss(**self.config) + sm.losses.DiceLoss(**self.config)
elif self.class_name == 'bce_jaccard_loss':
return sm.losses.BinaryCELoss(**self.config) + sm.losses.JaccardLoss(**self.config)
elif self.class_name == 'cce_dice_loss':
return sm.losses.CategoricalCELoss(**self.config) + sm.losses.DiceLoss(**self.config)
elif self.class_name == 'cce_jaccard_loss':
return sm.losses.CategoricalCELoss(**self.config) + sm.losses.JaccardLoss(**self.config)
elif self.class_name == 'binary_focal_dice_loss':
return sm.losses.BinaryFocalLoss(**self.config) + sm.losses.DiceLoss(**self.config)
elif self.class_name == 'binary_focal_jaccard_loss':
return sm.losses.BinaryFocalLoss(**self.config) + sm.losses.JaccardLoss(**self.config)
elif self.class_name == 'categorical_focal_dice_loss':
return sm.losses.CategoricalFocalLoss(**self.config) + sm.losses.DiceLoss(**self.config)
elif self.class_name == 'categorical_focal_jaccard_loss':
return sm.losses.CategoricalFocalLoss(**self.config) + sm.losses.JaccardLoss(**self.config)
else:
raise ValueError("SM Loss not implemented")

if __name__ == "__main__":
import json
x = Loss(
class_name='bce_dice_loss',
config={},
framework='sm'
)
print(x.to_json())
x
84 changes: 84 additions & 0 deletions segmentation_models_trainer/experiment_builder/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
/***************************************************************************
segmentation_models_trainer
-------------------
begin : 2020-09-21
git sha : $Format:%H$
copyright : (C) 2020 by Philipe Borba - Cartographic Engineer @ Brazilian Army
email : philipeborba at gmail dot com
***************************************************************************/
/***************************************************************************
* *
* This program is free software; you can redistribute it and/or modify *
* it under the terms of the GNU General Public License as published by *
* the Free Software Foundation; either version 2 of the License, or *
* (at your option) any later version. *
* *
****
"""
import tensorflow as tf
import segmentation_models as sm
from typing import Any, List
from dataclasses import dataclass
from dataclasses_jsonschema import JsonSchemaMixin

@dataclass
class Metric(JsonSchemaMixin):
class_name: str
config: dict
framework: str

def __post_init__(self):
if self.framework == 'sm':
self.metric_obj = self.get_sm_metric(self.class_name)
elif self.framework == 'tf.keras':
identifier = {
"class_name" : self.class_name,
"config" : self.config
}
self.metric_obj = tf.keras.metrics.get(identifier)
else:
raise ValueError("Metric not implemented")

def get_sm_metric(self, name):
if self.class_name == 'iou_score':
return sm.metrics.iou_score
elif self.class_name == 'precision':
return sm.metrics.precision
elif self.class_name == 'recall':
return sm.metrics.recall
elif self.class_name == 'f1_score':
return sm.metrics.f1_score
elif self.class_name == 'f2_score':
return sm.metrics.f2_score
else:
raise ValueError("SM metric not implemented")

@dataclass
class MetricList(JsonSchemaMixin):
items: List[Metric]

def get_tf_objects(self):
return [
i.metric_obj for i in items
]

if __name__ == "__main__":
import json
metric_list = [
Metric(
class_name=i,
config={},
framework='sm'
) for i in ['iou_score', 'precision', 'recall', 'f1_score', 'f2_score']
] + [
Metric(
class_name=i,
config={} if i != 'MeanIoU' else {'num_classes':2},
framework='tf.keras'
) for i in ['LogCoshError', 'KLDivergence', 'MeanIoU']
]
x=MetricList(metric_list)
print(x.to_json())
4 changes: 2 additions & 2 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ class Test_TestCallbacks(unittest.TestCase):

callback = Callback(
name='ReduceLROnPlateau',
parameters= {
config= {
'monitor' : 'val_loss',
'factor' : 0.2,
'patience' : 5,
'min_lr' : 0.001
}
)
json_dict = json.loads('{"name": "ReduceLROnPlateau", "parameters": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}')
json_dict = json.loads('{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}')

def test_create_instance(self):
"""[summary]
Expand Down
Loading

0 comments on commit 77074f8

Please sign in to comment.