From 77074f8ea974d471362a88c90d72798ccc6c24dc Mon Sep 17 00:00:00 2001 From: phborba Date: Mon, 21 Sep 2020 18:25:14 -0300 Subject: [PATCH 1/3] minor refactor and addiing callback, metric and loss to experiment --- .../callbacks_loader/callback.py | 44 ++++++++-- .../dataset_loader/dataset.py | 2 +- .../experiment_builder/experiment.py | 38 ++++----- .../experiment_builder/loss.py | 85 +++++++++++++++++++ .../experiment_builder/metric.py | 84 ++++++++++++++++++ tests/test_callbacks.py | 4 +- tests/test_experiment.py | 7 +- 7 files changed, 232 insertions(+), 32 deletions(-) create mode 100644 segmentation_models_trainer/experiment_builder/loss.py create mode 100644 segmentation_models_trainer/experiment_builder/metric.py diff --git a/segmentation_models_trainer/callbacks_loader/callback.py b/segmentation_models_trainer/callbacks_loader/callback.py index ee3be7a..cb2f8af 100644 --- a/segmentation_models_trainer/callbacks_loader/callback.py +++ b/segmentation_models_trainer/callbacks_loader/callback.py @@ -20,6 +20,7 @@ **** """ import tensorflow as tf +from typing import Any, List from segmentation_models_trainer.callbacks_loader.callback_factory import CallbackFactory from dataclasses import dataclass from dataclasses_jsonschema import JsonSchemaMixin @@ -27,9 +28,10 @@ @dataclass class Callback(JsonSchemaMixin): name: str - parameters: dict + config: dict + def __post_init__(self): - pass + self.callback_obj = self.get_callback() @staticmethod def validate_callback_name(name): @@ -42,15 +44,25 @@ def validate_callback_name(name): def get_callback(self): return CallbackFactory.get_callback( self.name, - self.parameters + self.config ) +@dataclass +class CallbackList(JsonSchemaMixin): + items: List[Callback] + + def get_tf_objects(self): + return [ + i.get_callback() for i in items + ] + if __name__ == '__main__': + import json x = Callback( name='ReduceLROnPlateau', - parameters= { + config= { 'monitor' : 'val_loss', 'factor' : 0.2, 'patience' : 5, @@ -58,5 +70,27 @@ def get_callback(self): } ) print(x.to_json()) + print(json.dumps([x.to_json()])) x.get_callback() - x \ No newline at end of file + y= [ + Callback.from_dict( + { + 'name' : 'ReduceLROnPlateau', + 'config' : { + 'monitor' : 'val_loss', + 'factor' : 0.2, + 'patience' : 5, + 'min_lr' : 0.001 + } + } + ), + Callback.from_dict( + { + 'name' : 'ModelCheckpoint', + 'config' : {'filepath' : '/data/teste'} + } + ) + ] + y + z = CallbackList(y) + print(z.to_json()) \ No newline at end of file diff --git a/segmentation_models_trainer/dataset_loader/dataset.py b/segmentation_models_trainer/dataset_loader/dataset.py index ecb349a..41607ea 100644 --- a/segmentation_models_trainer/dataset_loader/dataset.py +++ b/segmentation_models_trainer/dataset_loader/dataset.py @@ -20,7 +20,7 @@ **** """ -from dataclasses import dataclass +from dataclasses import dataclass, field from dataclasses_jsonschema import JsonSchemaMixin from typing import Any, List from collections import OrderedDict diff --git a/segmentation_models_trainer/experiment_builder/experiment.py b/segmentation_models_trainer/experiment_builder/experiment.py index e228a4b..3201086 100644 --- a/segmentation_models_trainer/experiment_builder/experiment.py +++ b/segmentation_models_trainer/experiment_builder/experiment.py @@ -22,11 +22,16 @@ import segmentation_models as sm import os import numpy as np +import importlib +from typing import Any, List from dataclasses import dataclass from dataclasses_jsonschema import JsonSchemaMixin from segmentation_models_trainer.model_builder.segmentation_model import SegmentationModel from segmentation_models_trainer.hyperparameter_builder.hyperparameters import Hyperparameters from segmentation_models_trainer.dataset_loader.dataset import Dataset +from segmentation_models_trainer.callbacks_loader.callback import Callback, CallbackList +from segmentation_models_trainer.experiment_builder.loss import Loss +from segmentation_models_trainer.experiment_builder.metric import Metric, MetricList @dataclass class Experiment(JsonSchemaMixin): @@ -42,6 +47,9 @@ class Experiment(JsonSchemaMixin): train_dataset: Dataset test_dataset: Dataset model: SegmentationModel + loss: Loss + callbacks: CallbackList + metrics: MetricList def train(self): gpu_devices = tf.config.experimental.list_physical_devices('GPU') @@ -65,6 +73,7 @@ def train(self): test_steps_per_epoch = int( np.ceil(self.test_dataset.dataset_size / BATCH_SIZE) ) def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None): + callback_list = self.callbacks.get_tf_objects() with strategy.scope(): model = self.model.get_model( n_classes, @@ -73,18 +82,11 @@ def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None): ) opt = self.hyperparameters.optimizer.tf_object #TODO metrics and loss fields into compile + metric_list = self.metrics.get_tf_objects() model.compile( opt, loss=sm.losses.bce_jaccard_loss, - metrics=[ - 'accuracy', - 'binary_crossentropy', - sm.metrics.iou_score, - sm.metrics.precision, - sm.metrics.recall, - sm.metrics.f1_score, - sm.metrics.f2_score - ] + metrics=metric_list ) model.fit( train_ds, @@ -93,7 +95,7 @@ def train_model(epochs, save_weights_path, encoder_freeze, load_weights=None): epochs=epochs, validation_data=test_ds, validation_steps=test_steps_per_epoch, - callbacks=[] + callbacks=callback_list ) model.save_weights( save_weights_path @@ -151,17 +153,6 @@ def test_and_create_folder(path): return path os.makedirs(path, exist_ok=True) return path - - - -@dataclass -class Callbacks(JsonSchemaMixin): - name: str - keras_callback: str - -@dataclass -class Metric(JsonSchemaMixin): - name: str if __name__ == "__main__": import json @@ -186,7 +177,10 @@ class Metric(JsonSchemaMixin): 'test_dataset' : json.loads('{"name": "test_ds", "file_path": "/data/test_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}'), - 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}') + 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'), + 'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'), + 'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}'), + 'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'), } ) diff --git a/segmentation_models_trainer/experiment_builder/loss.py b/segmentation_models_trainer/experiment_builder/loss.py new file mode 100644 index 0000000..5ec1c25 --- /dev/null +++ b/segmentation_models_trainer/experiment_builder/loss.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +""" +/*************************************************************************** + segmentation_models_trainer + ------------------- + begin : 2020-09-21 + git sha : $Format:%H$ + copyright : (C) 2020 by Philipe Borba - Cartographic Engineer @ Brazilian Army + email : philipeborba at gmail dot com + ***************************************************************************/ + +/*************************************************************************** + * * + * This program is free software; you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation; either version 2 of the License, or * + * (at your option) any later version. * + * * + **** +""" +import tensorflow as tf +import segmentation_models as sm +from dataclasses import dataclass +from dataclasses_jsonschema import JsonSchemaMixin + +@dataclass +class Loss(JsonSchemaMixin): + class_name: str + config: dict + framework: str + + def __post_init__(self): + if self.framework == 'sm': + self.loss_obj = self.get_sm_loss(self.class_name) + elif self.framework == 'tf.keras': + identifier = { + "class_name" : self.class_name, + "config" : self.config + } + self.loss_obj = tf.keras.losses.get(identifier) + else: + raise ValueError("Loss not implemented") + + + def get_sm_loss(self, name): + if self.class_name == 'jaccard_loss': + return sm.losses.JaccardLoss(**self.config) + elif self.class_name == 'dice_loss': + return sm.losses.DiceLoss(**self.config) + elif self.class_name == 'binary_focal_loss': + return sm.losses.BinaryFocalLoss(**self.config) + elif self.class_name == 'categorical_focal_loss': + return sm.losses.CategoricalFocalLoss(**self.config) + elif self.class_name == 'binary_crossentropy': + return sm.losses.BinaryCELoss(**self.config) + elif self.class_name == 'categorical_crossentropy': + return sm.losses.CategoricalCELoss(**self.config) + elif self.class_name == 'bce_dice_loss': + return sm.losses.BinaryCELoss(**self.config) + sm.losses.DiceLoss(**self.config) + elif self.class_name == 'bce_jaccard_loss': + return sm.losses.BinaryCELoss(**self.config) + sm.losses.JaccardLoss(**self.config) + elif self.class_name == 'cce_dice_loss': + return sm.losses.CategoricalCELoss(**self.config) + sm.losses.DiceLoss(**self.config) + elif self.class_name == 'cce_jaccard_loss': + return sm.losses.CategoricalCELoss(**self.config) + sm.losses.JaccardLoss(**self.config) + elif self.class_name == 'binary_focal_dice_loss': + return sm.losses.BinaryFocalLoss(**self.config) + sm.losses.DiceLoss(**self.config) + elif self.class_name == 'binary_focal_jaccard_loss': + return sm.losses.BinaryFocalLoss(**self.config) + sm.losses.JaccardLoss(**self.config) + elif self.class_name == 'categorical_focal_dice_loss': + return sm.losses.CategoricalFocalLoss(**self.config) + sm.losses.DiceLoss(**self.config) + elif self.class_name == 'categorical_focal_jaccard_loss': + return sm.losses.CategoricalFocalLoss(**self.config) + sm.losses.JaccardLoss(**self.config) + else: + raise ValueError("SM Loss not implemented") + +if __name__ == "__main__": + import json + x = Loss( + class_name='bce_dice_loss', + config={}, + framework='sm' + ) + print(x.to_json()) + x \ No newline at end of file diff --git a/segmentation_models_trainer/experiment_builder/metric.py b/segmentation_models_trainer/experiment_builder/metric.py new file mode 100644 index 0000000..d79a5a2 --- /dev/null +++ b/segmentation_models_trainer/experiment_builder/metric.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +""" +/*************************************************************************** + segmentation_models_trainer + ------------------- + begin : 2020-09-21 + git sha : $Format:%H$ + copyright : (C) 2020 by Philipe Borba - Cartographic Engineer @ Brazilian Army + email : philipeborba at gmail dot com + ***************************************************************************/ + +/*************************************************************************** + * * + * This program is free software; you can redistribute it and/or modify * + * it under the terms of the GNU General Public License as published by * + * the Free Software Foundation; either version 2 of the License, or * + * (at your option) any later version. * + * * + **** +""" +import tensorflow as tf +import segmentation_models as sm +from typing import Any, List +from dataclasses import dataclass +from dataclasses_jsonschema import JsonSchemaMixin + +@dataclass +class Metric(JsonSchemaMixin): + class_name: str + config: dict + framework: str + + def __post_init__(self): + if self.framework == 'sm': + self.metric_obj = self.get_sm_metric(self.class_name) + elif self.framework == 'tf.keras': + identifier = { + "class_name" : self.class_name, + "config" : self.config + } + self.metric_obj = tf.keras.metrics.get(identifier) + else: + raise ValueError("Metric not implemented") + + def get_sm_metric(self, name): + if self.class_name == 'iou_score': + return sm.metrics.iou_score + elif self.class_name == 'precision': + return sm.metrics.precision + elif self.class_name == 'recall': + return sm.metrics.recall + elif self.class_name == 'f1_score': + return sm.metrics.f1_score + elif self.class_name == 'f2_score': + return sm.metrics.f2_score + else: + raise ValueError("SM metric not implemented") + +@dataclass +class MetricList(JsonSchemaMixin): + items: List[Metric] + + def get_tf_objects(self): + return [ + i.metric_obj for i in items + ] + +if __name__ == "__main__": + import json + metric_list = [ + Metric( + class_name=i, + config={}, + framework='sm' + ) for i in ['iou_score', 'precision', 'recall', 'f1_score', 'f2_score'] + ] + [ + Metric( + class_name=i, + config={} if i != 'MeanIoU' else {'num_classes':2}, + framework='tf.keras' + ) for i in ['LogCoshError', 'KLDivergence', 'MeanIoU'] + ] + x=MetricList(metric_list) + print(x.to_json()) \ No newline at end of file diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 89b3357..7c9daff 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -31,14 +31,14 @@ class Test_TestCallbacks(unittest.TestCase): callback = Callback( name='ReduceLROnPlateau', - parameters= { + config= { 'monitor' : 'val_loss', 'factor' : 0.2, 'patience' : 5, 'min_lr' : 0.001 } ) - json_dict = json.loads('{"name": "ReduceLROnPlateau", "parameters": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}') + json_dict = json.loads('{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}') def test_create_instance(self): """[summary] diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 1fe580a..f3eb5c2 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -45,10 +45,13 @@ class Test_TestExperiment(unittest.TestCase): 'test_dataset' : json.loads('{"name": "test_ds", "file_path": "/data/test_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}'), - 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}') + 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'), + 'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'), + 'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}'), + 'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'), } ) - json_dict = json.loads('{"name": "test", "epochs": 2, "experiment_data_path": "/data/test", "checkpoint_frequency": 10, "warmup_epochs": 2, "use_multiple_gpus": false, "hyperparameters": {"batch_size": 16, "optimizer": {"name": "Adam", "config": {"learning_rate": 0.01}}}, "train_dataset": {"name": "train_ds", "file_path": "/data/train_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}, "test_dataset": {"name": "test_ds", "file_path": "/data/test_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}, "model": {"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}}') + json_dict = json.loads('{"name": "test", "epochs": 2, "experiment_data_path": "/data/test", "checkpoint_frequency": 10, "warmup_epochs": 2, "use_multiple_gpus": false, "hyperparameters": {"batch_size": 16, "optimizer": {"name": "Adam", "config": {"learning_rate": 0.01}}}, "train_dataset": {"name": "train_ds", "file_path": "/data/train_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}, "test_dataset": {"name": "test_ds", "file_path": "/data/test_ds.csv", "n_classes": 1, "dataset_size": 1000, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": true, "shuffle": true, "shuffle_buffer_size": 10000, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 4, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}, "model": {"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}, "loss": {"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}, "callbacks": {"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}, "metrics": {"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}}') def test_create_instance(self): """[summary] From 8ce888955bf228026459640ad8062bfc3758fd69 Mon Sep 17 00:00:00 2001 From: phborba Date: Mon, 21 Sep 2020 18:27:32 -0300 Subject: [PATCH 2/3] syntax fix --- segmentation_models_trainer/callbacks_loader/callback.py | 2 +- segmentation_models_trainer/experiment_builder/metric.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/segmentation_models_trainer/callbacks_loader/callback.py b/segmentation_models_trainer/callbacks_loader/callback.py index cb2f8af..bf8dc20 100644 --- a/segmentation_models_trainer/callbacks_loader/callback.py +++ b/segmentation_models_trainer/callbacks_loader/callback.py @@ -53,7 +53,7 @@ class CallbackList(JsonSchemaMixin): def get_tf_objects(self): return [ - i.get_callback() for i in items + i.get_callback() for i in self.items ] diff --git a/segmentation_models_trainer/experiment_builder/metric.py b/segmentation_models_trainer/experiment_builder/metric.py index d79a5a2..6a2ef93 100644 --- a/segmentation_models_trainer/experiment_builder/metric.py +++ b/segmentation_models_trainer/experiment_builder/metric.py @@ -62,7 +62,7 @@ class MetricList(JsonSchemaMixin): def get_tf_objects(self): return [ - i.metric_obj for i in items + i.metric_obj for i in self.items ] if __name__ == "__main__": From 69d38ffd312254deeaebff50cd7c5b709bd46e7c Mon Sep 17 00:00:00 2001 From: phborba Date: Mon, 21 Sep 2020 18:57:17 -0300 Subject: [PATCH 3/3] tests fix --- .../experiment_builder/experiment.py | 2 +- .../experiment_builder/metric.py | 10 +++++++--- tests/test_training_script.py | 5 ++++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/segmentation_models_trainer/experiment_builder/experiment.py b/segmentation_models_trainer/experiment_builder/experiment.py index 3201086..0de81c4 100644 --- a/segmentation_models_trainer/experiment_builder/experiment.py +++ b/segmentation_models_trainer/experiment_builder/experiment.py @@ -179,7 +179,7 @@ def test_and_create_folder(path): 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'), 'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'), - 'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}'), + 'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste/checkpoint.hdf5"}}]}'), 'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'), } ) diff --git a/segmentation_models_trainer/experiment_builder/metric.py b/segmentation_models_trainer/experiment_builder/metric.py index 6a2ef93..27f008a 100644 --- a/segmentation_models_trainer/experiment_builder/metric.py +++ b/segmentation_models_trainer/experiment_builder/metric.py @@ -31,14 +31,18 @@ class Metric(JsonSchemaMixin): framework: str def __post_init__(self): + if self.framework not in ['sm', 'tf.keras']: + raise ValueError("Metric not implemented") + + def get_metric(self): if self.framework == 'sm': - self.metric_obj = self.get_sm_metric(self.class_name) + return self.get_sm_metric(self.class_name) elif self.framework == 'tf.keras': identifier = { "class_name" : self.class_name, "config" : self.config } - self.metric_obj = tf.keras.metrics.get(identifier) + return tf.keras.metrics.get(identifier) else: raise ValueError("Metric not implemented") @@ -62,7 +66,7 @@ class MetricList(JsonSchemaMixin): def get_tf_objects(self): return [ - i.metric_obj for i in self.items + i.get_metric() for i in self.items ] if __name__ == "__main__": diff --git a/tests/test_training_script.py b/tests/test_training_script.py index bf75c53..8381259 100644 --- a/tests/test_training_script.py +++ b/tests/test_training_script.py @@ -98,7 +98,10 @@ def setUp(self): '''{"name": "test_ds", "file_path": "'''+self.csv_test_ds_file+'''", "n_classes": 1, "dataset_size": 1, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": false, "shuffle": false, "shuffle_buffer_size": 1, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 1, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}''' ), - 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}') + 'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'), + 'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'), + 'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste/checkpoint.hdf5"}}]}'), + 'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'), } self.settings_json = os.path.join( current_dir, 'testing_data', 'settings.json'