-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsfedu_conv_model.py
69 lines (45 loc) · 1.76 KB
/
sfedu_conv_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import logging
from tensorflow.keras.layers import Dense # noqa
from tensorflow.keras.models import Model # noqa
from gns.layer.global_average_pool import global_average_pool_layer_fabric
from gns.layer.gcs_convolution import gsn_convolutional_general_layer_fabric
from gns.model.model_folder import MODEL_FOLDER
from gns.config.settings import settings_fabric
settings = settings_fabric()
logger = logging.getLogger(__name__)
class SfeduModel(Model):
"""
Custom model for Industry example - find candidates vacancies
"""
def __init__(self, data):
super().__init__()
logger.info("Mount data...")
self.data = data
logger.info("Create convolutional layers...")
self.conv1 = gsn_convolutional_general_layer_fabric(32, activation=settings.activations.relu)
self.conv2 = gsn_convolutional_general_layer_fabric(32, activation=settings.activations.relu)
self.conv3 = gsn_convolutional_general_layer_fabric(32, activation=settings.activations.relu)
logger.info("Create pooling...")
self.global_pool = global_average_pool_layer_fabric()
logger.info("Create Dense layer...")
self.dense = Dense(data.n_labels, activation=settings.activations.softmax)
def call(self, inputs):
"""
Call layer.
Args:
inputs: inputs
mask: mask
Returns:
"""
x, a, i = inputs
x = self.conv1([x, a])
x = self.conv2([x, a])
x = self.conv3([x, a])
output = self.global_pool([x, i])
output = self.dense(output)
return output
def sfedu_model_fabric(data, **kwargs):
return SfeduModel(data, **kwargs) # noqa
def path():
return os.path.join(MODEL_FOLDER, 'SfeduModel')