forked from puf3zin/deep_dehazing
-
Notifications
You must be signed in to change notification settings - Fork 0
/
architecture.py
71 lines (56 loc) · 2.18 KB
/
architecture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#import numpy as np
#import tensorflow as tf
import json
import sys
import abc
import base
class Architecture(base.Base):
@abc.abstractmethod
def prediction(self, sample, training=False):
"""This is a abstract method for architectures prediction.
Each architecture must implement this method. Depending on
each diferent implementation the output shape varies. So
the loss must be chosen acoording with the achitecture
implementation.
In a similar way the architecture implementation depends on
the dataset shape.
Args:
sample: networks input tensor
training: boolean value indication if this prediction is
being used on training or not
Returns:
achitecture output: networks output tensor
"""
pass
@abc.abstractmethod
def get_validation_period(self):
pass
@abc.abstractmethod
def get_model_saving_period(self):
pass
@abc.abstractmethod
def get_summary_writing_period(self):
pass
def get_layer(self, layer_name):
"""This method returns a reference to a layer in the architecture.
It must be overridden if the user wishes to visualize the hidden
layers of the network, but doesn't need to be implemented otherwise.
Args:
layer_name: The name of the desired layer
Returns:
layer: a reference to the layer's tensor
"""
layer=None
return layer
# def verify_config(self, parameters_list, config_dict):
# for parameter in parameters_list:
# if parameter not in config_dict:
# raise Exception('Config: ' + parameter + ' is necessary for ' +
# self.__class__.__name__ + ' execution.')
# def open_config(self, parameters_list=[], config_filename=None):
# if config_filename is None:
# config_filename = sys.modules[self.__module__].__file__[:-3]+'.json'
# with open(config_filename) as config_file:
# config_dict = json.load(config_file)
# self.verify_config(parameters_list, config_dict)
# return config_dict