Skip to content

Commit

Permalink
tests fix
Browse files Browse the repository at this point in the history
  • Loading branch information
phborba committed Sep 21, 2020
1 parent 8ce8889 commit 69d38ff
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_and_create_folder(path):

'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'),
'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'),
'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste"}}]}'),
'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste/checkpoint.hdf5"}}]}'),
'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'),
}
)
Expand Down
10 changes: 7 additions & 3 deletions segmentation_models_trainer/experiment_builder/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,18 @@ class Metric(JsonSchemaMixin):
framework: str

def __post_init__(self):
if self.framework not in ['sm', 'tf.keras']:
raise ValueError("Metric not implemented")

def get_metric(self):
if self.framework == 'sm':
self.metric_obj = self.get_sm_metric(self.class_name)
return self.get_sm_metric(self.class_name)
elif self.framework == 'tf.keras':
identifier = {
"class_name" : self.class_name,
"config" : self.config
}
self.metric_obj = tf.keras.metrics.get(identifier)
return tf.keras.metrics.get(identifier)
else:
raise ValueError("Metric not implemented")

Expand All @@ -62,7 +66,7 @@ class MetricList(JsonSchemaMixin):

def get_tf_objects(self):
return [
i.metric_obj for i in self.items
i.get_metric() for i in self.items
]

if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion tests/test_training_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ def setUp(self):
'''{"name": "test_ds", "file_path": "'''+self.csv_test_ds_file+'''", "n_classes": 1, "dataset_size": 1, "augmentation_list": [{"name": "random_crop", "parameters": {"crop_width": 256, "crop_height": 256}}, {"name": "per_image_standardization", "parameters": {}}], "cache": false, "shuffle": false, "shuffle_buffer_size": 1, "shuffle_csv": true, "ignore_errors": true, "num_paralel_reads": 1, "img_dtype": "float32", "img_format": "png", "img_width": 256, "img_length": 256, "img_bands": 3, "mask_bands": 1, "use_ds_width_len": false, "autotune": -1, "distributed_training": false}'''
),

'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}')
'model' : json.loads('{"description": "test case", "backbone": "resnet18", "architecture": "Unet", "activation": "sigmoid", "use_imagenet_weights": true}'),
'loss' : json.loads('{"class_name": "bce_dice_loss", "config": {}, "framework": "sm"}'),
'callbacks' : json.loads('{"items": [{"name": "ReduceLROnPlateau", "config": {"monitor": "val_loss", "factor": 0.2, "patience": 5, "min_lr": 0.001}}, {"name": "ModelCheckpoint", "config": {"filepath": "/data/teste/checkpoint.hdf5"}}]}'),
'metrics' : json.loads('{"items": [{"class_name": "iou_score", "config": {}, "framework": "sm"}, {"class_name": "precision", "config": {}, "framework": "sm"}, {"class_name": "recall", "config": {}, "framework": "sm"}, {"class_name": "f1_score", "config": {}, "framework": "sm"}, {"class_name": "f2_score", "config": {}, "framework": "sm"}, {"class_name": "LogCoshError", "config": {}, "framework": "tf.keras"}, {"class_name": "KLDivergence", "config": {}, "framework": "tf.keras"}, {"class_name": "MeanIoU", "config": {"num_classes": 2}, "framework": "tf.keras"}]}'),
}
self.settings_json = os.path.join(
current_dir, 'testing_data', 'settings.json'
Expand Down

0 comments on commit 69d38ff

Please sign in to comment.