forked from ZiyaoGeng/RecLearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules.py
29 lines (23 loc) · 799 Bytes
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""
Created on May 19, 2021
modules of NFM: DNN
@author: Ziyao Geng([email protected])
"""
import tensorflow as tf
from tensorflow.keras.layers import Dense, Dropout, Layer
class DNN(Layer):
def __init__(self, hidden_units, activation='relu', dropout=0.):
"""Deep Neural Network
:param hidden_units: A list. Neural network hidden units.
:param activation: A string. Activation function of dnn.
:param dropout: A scalar. Dropout number.
"""
super(DNN, self).__init__()
self.dnn_network = [Dense(units=unit, activation=activation) for unit in hidden_units]
self.dropout = Dropout(dropout)
def call(self, inputs, **kwargs):
x = inputs
for dnn in self.dnn_network:
x = dnn(x)
x = self.dropout(x)
return x