Skip to content

Commit

Permalink
minor change
Browse files Browse the repository at this point in the history
1. Modified output shape of BiInteractionPooling, InnerProductLayer to make
sure that the dimensions of the output and input of a layer are same if possible

2. Minimize the nesting of other layers in the custom layer, because I found that the statistics of the parameters number of model.summary()  are incorrect when using other layers.
  • Loading branch information
shenweichen committed Dec 7, 2018
1 parent 34e8fa5 commit 82efc6f
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 43 deletions.
3 changes: 1 addition & 2 deletions deepctr/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ def build(self, input_shape):

def call(self, inputs, **kwargs):

inputs_normed = BatchNormalization(
axis=self.axis, epsilon=self.epsilon, center=False, scale=False)(inputs)
inputs_normed = tf.layers.batch_normalization(inputs,axis=self.axis, epsilon=self.epsilon, center=False, scale=False)
x_p = tf.sigmoid(inputs_normed)
return self.alphas * (1.0 - x_p) * inputs + x_p * inputs
def get_config(self,):
Expand Down
87 changes: 53 additions & 34 deletions deepctr/layers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from tensorflow.python.keras.layers import Layer,Dense,Activation,Dropout,BatchNormalization,concatenate
from tensorflow.python.keras.layers import Layer,Activation,BatchNormalization
from tensorflow.python.keras.regularizers import l2
from tensorflow.python.keras.initializers import RandomNormal,Zeros,glorot_normal,glorot_uniform
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.activations import softmax

import tensorflow as tf


Expand Down Expand Up @@ -101,12 +101,14 @@ def build(self, input_shape):

embedding_size = input_shape[0][-1]

#self.attention_W = self.add_weight(shape=(embedding_size, self.attention_factor), initializer=glorot_normal(seed=self.seed),
# name="attention_W")
#self.attention_b = self.add_weight(shape=(self.attention_factor,), initializer=Zeros(), name="attention_b")
self.attention_W = self.add_weight(shape=(embedding_size, self.attention_factor), initializer=glorot_normal(seed=self.seed),regularizer=l2(self.l2_reg_w),
name="attention_W")
self.attention_b = self.add_weight(shape=(self.attention_factor,), initializer=Zeros(), name="attention_b")
self.projection_h = self.add_weight(shape=(self.attention_factor, 1), initializer=glorot_normal(seed=self.seed),
name="projection_h")
self.projection_p = self.add_weight(shape=(embedding_size, 1), initializer=glorot_normal(seed=self.seed), name="projection_p")


super(AFMLayer, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs,**kwargs):
Expand All @@ -127,14 +129,14 @@ def call(self, inputs,**kwargs):
inner_product = p * q

bi_interaction = inner_product
attention_temp = tf.nn.relu(tf.nn.bias_add(tf.tensordot(bi_interaction,self.attention_W,axes=(-1,0)),self.attention_b))
# Dense(self.attention_factor,'relu',kernel_regularizer=l2(self.l2_reg_w))(bi_interaction)
attention_weight =tf.nn.softmax(tf.tensordot(attention_temp,self.projection_h,axes=(-1,0)),dim=1)
attention_output = tf.reduce_sum(attention_weight*bi_interaction,axis=1)

attention_temp = Dense(self.attention_factor,'relu',kernel_regularizer=l2(self.l2_reg_w))(bi_interaction)
attention_weight = softmax(K.dot(attention_temp, self.projection_h),axis=1)

attention_output = K.sum(attention_weight*bi_interaction,axis=1)
attention_output = tf.nn.dropout(attention_output,self.keep_prob,seed=1024)
# Dropout(1-self.keep_prob)(attention_output)
afm_out = K.dot(attention_output, self.projection_p)
# Dropout(1-self.keep_prob)(attention_output)
afm_out = tf.tensordot(attention_output,self.projection_p,axes=(-1,0))

return afm_out

Expand Down Expand Up @@ -169,13 +171,14 @@ def build(self, input_shape):
def call(self, inputs,**kwargs):
x = inputs
if self.use_bias:
x = K.bias_add(x, self.global_bias, data_format='channels_last')
x = tf.nn.bias_add(x,self.global_bias,data_format='NHWC')

if isinstance(self.activation,str):
output = Activation(self.activation)(x)
else:
output = self.activation(x)

output = K.reshape(output,(-1,1))
output = tf.reshape(output,(-1,1))

return output

Expand Down Expand Up @@ -282,17 +285,28 @@ def __init__(self, hidden_size, activation,l2_reg, keep_prob, use_bn,seed,**kwa
super(MLP, self).__init__(**kwargs)

def build(self, input_shape):
input_size = input_shape[-1]
hidden_units = [int(input_size)] + self.hidden_size
self.kernels = [self.add_weight(name='kernel' + str(i),
shape=(hidden_units[i], hidden_units[i+1]),
initializer=glorot_normal(seed=self.seed),
regularizer=l2(self.l2_reg),
trainable=True) for i in range(len(self.hidden_size))]
self.bias = [self.add_weight(name='bias' + str(i),
shape=(self.hidden_size[i],),
initializer=Zeros(),
trainable=True) for i in range(len(self.hidden_size))]

super(MLP, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs,**kwargs):
deep_input = inputs
#deep_input = Dropout(1 - self.keep_prob)(deep_input)

for l in range(len(self.hidden_size)):
fc = Dense(self.hidden_size[l], activation=None, \
kernel_initializer=glorot_normal(seed=self.seed), \
kernel_regularizer=l2(self.l2_reg))(deep_input)
for i in range(len(self.hidden_size)):
fc = tf.nn.bias_add(tf.tensordot(deep_input,self.kernels[i],axes=(-1,0)),self.bias[i])
#fc = Dense(self.hidden_size[i], activation=None, \
# kernel_initializer=glorot_normal(seed=self.seed), \
# kernel_regularizer=l2(self.l2_reg))(deep_input)
if self.use_bn:
fc = BatchNormalization()(fc)

Expand All @@ -302,7 +316,7 @@ def call(self, inputs,**kwargs):
fc = self.activation()(fc)
else:
raise ValueError("Invalid activation of MLP,found %s.You should use a str or a Activation Layer Class."%(self.activation))
fc = Dropout(1 - self.keep_prob)(fc)
fc = tf.nn.dropout(fc,self.keep_prob)

deep_input = fc

Expand All @@ -327,7 +341,7 @@ class BiInteractionPooling(Layer):
- A list of 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
Output shape
- 2D tensor with shape: ``(batch_size, embedding_size)``.
- 3D tensor with shape: ``(batch_size,1,embedding_size)``.
References
- [Neural Factorization Machines for Sparse Predictive Analytics](http://arxiv.org/abs/1708.05027)
Expand All @@ -350,14 +364,14 @@ def call(self, inputs,**kwargs):
raise ValueError("Unexpected inputs dimensions %d, expect to be 3 dimensions"% (K.ndim(inputs)))

concated_embeds_value = inputs
square_of_sum = K.square(K.sum(concated_embeds_value, axis=1, keepdims=True))
sum_of_square = K.sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=True)
square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keep_dims=True))
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keep_dims=True)
cross_term = 0.5*(square_of_sum - sum_of_square)
cross_term = K.reshape(cross_term,(-1,inputs.get_shape()[-1]))

return cross_term

def compute_output_shape(self, input_shape):
return (None, input_shape[-1])
return (None, 1, input_shape[-1])

class OutterProductLayer(Layer):
"""OutterProduct Layer used in PNN.This implemention is adapted from code that the author of the paper published on https://github.com/Atomu2014/product-nets.
Expand All @@ -366,7 +380,7 @@ class OutterProductLayer(Layer):
- A list of N 3D tensor with shape: ``(batch_size,1,embedding_size)``.
Output shape
- 2D tensor with shape:``(batch_size, N*(N-1)/2 )``.
- 2D tensor with shape:``(batch_size,N*(N-1)/2 )``.
Arguments
- **kernel_type**: str. The kernel weight matrix type to use,can be mat,vec or num
Expand Down Expand Up @@ -434,8 +448,8 @@ def call(self, inputs,**kwargs):
for j in range(i + 1, num_inputs):
row.append(i)
col.append(j)
p = K.concatenate([embed_list[idx] for idx in row],axis=1) # batch num_pairs k
q = K.concatenate([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
p = tf.concat([embed_list[idx] for idx in row],axis=1) # batch num_pairs k
q = tf.concat([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])

#-------------------------
if self.kernel_type == 'mat':
Expand Down Expand Up @@ -499,7 +513,7 @@ class InnerProductLayer(Layer):
- A list of N 3D tensor with shape: ``(batch_size,1,embedding_size)``.
Output shape
- 2D tensor with shape: ``(batch_size, N*(N-1)/2 )`` if use reduce_sum. or 3D tensor with shape: ``(batch_size, N*(N-1)/2, embedding_size )`` if not use reduce_sum.
- 3D tensor with shape: ``(batch_size, N*(N-1)/2 ,1)`` if use reduce_sum. or 3D tensor with shape: ``(batch_size, N*(N-1)/2, embedding_size )`` if not use reduce_sum.
Arguments
- **reduce_sum**: bool. Whether return inner product or element-wise product
Expand Down Expand Up @@ -550,11 +564,11 @@ def call(self, inputs,**kwargs):
for j in range(i + 1, num_inputs):
row.append(i)
col.append(j)
p = K.concatenate([embed_list[idx] for idx in row],axis=1)# batch num_pairs k
q = K.concatenate([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
p = tf.concat([embed_list[idx] for idx in row],axis=1)# batch num_pairs k
q = tf.concat([embed_list[idx] for idx in col],axis=1) # Reshape([num_pairs, self.embedding_size])
inner_product = p * q
if self.reduce_sum:
inner_product = K.sum(inner_product, axis=2, keepdims=False)
inner_product = tf.reduce_sum(inner_product, axis=2, keep_dims=True)
return inner_product


Expand All @@ -564,7 +578,7 @@ def compute_output_shape(self, input_shape):
input_shape = input_shape[0]
embed_size = input_shape[-1]
if self.reduce_sum:
return (input_shape[0],num_pairs)
return (input_shape[0],num_pairs,1)
else:
return (input_shape[0],num_pairs,embed_size)

Expand Down Expand Up @@ -623,6 +637,11 @@ def build(self, input_shape):
raise ValueError('A `LocalActivationUnit` layer requires '
'inputs of a two inputs with shape (None,1,embedding_size) and (None,T,embedding_size)'
'Got different shapes: %s,%s' % (input_shape))
size = 4*int(input_shape[0][-1]) if len(self.hidden_size) == 0 else self.hidden_size[-1]
self.kernel = self.add_weight(shape=(size, 1),
initializer=glorot_normal(seed=self.seed),
name="kernel")
self.bias = self.add_weight(shape=(1,), initializer=Zeros(), name="bias")
super(LocalActivationUnit, self).build(input_shape) # Be sure to call this somewhere!

def call(self, inputs,**kwargs):
Expand All @@ -634,9 +653,9 @@ def call(self, inputs,**kwargs):
queries = K.repeat_elements(query,keys_len,1)

att_input = tf.concat([queries, keys, queries - keys, queries * keys], axis=-1)
att_input = BatchNormalization()(att_input)
att_input = tf.layers.batch_normalization(att_input)
att_out = MLP(self.hidden_size, self.activation, self.l2_reg, self.keep_prob, self.use_bn, seed=self.seed)(att_input)
attention_score = Dense(1, 'linear')(att_out)
attention_score = tf.nn.bias_add(tf.tensordot(att_out,self.kernel,axes=(-1,0)),self.bias)

return attention_score

Expand Down
2 changes: 1 addition & 1 deletion deepctr/models/din.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def get_input(feature_dim_dict, seq_feature_list, seq_max_len):


def DIN(feature_dim_dict, seq_feature_list, embedding_size=8, hist_len_max=16,
use_din=True, use_bn=False, hidden_size=[200, 80], activation='relu', att_hidden_size=[80, 40], att_activation='sigmoid', att_weight_normalization=True,
use_din=True, use_bn=False, hidden_size=[200, 80], activation='relu', att_hidden_size=[80, 40], att_activation=Dice, att_weight_normalization=False,
l2_reg_deep=0, l2_reg_embedding=1e-5, final_activation='sigmoid', keep_prob=1, init_std=0.0001, seed=1024, ):
"""Instantiates the Deep Interest Network architecture.
Expand Down
4 changes: 2 additions & 2 deletions deepctr/models/pnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[1] Qu, Yanru, et al. "Product-based neural networks for user response prediction." Data Mining (ICDM), 2016 IEEE 16th International Conference on. IEEE, 2016.(https://arxiv.org/pdf/1611.00144.pdf)
"""

from tensorflow.python.keras.layers import Dense, Embedding, Concatenate, Reshape
from tensorflow.python.keras.layers import Dense, Embedding, Concatenate, Reshape,Flatten
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.initializers import RandomNormal
from tensorflow.python.keras.regularizers import l2
Expand Down Expand Up @@ -63,7 +63,7 @@ def PNN(feature_dim_dict, embedding_size=8, hidden_size=[128, 128], l2_reg_embed
map(Reshape((1, embedding_size)), continuous_embedding_list))
embed_list += continuous_embedding_list

inner_product = InnerProductLayer()(embed_list)
inner_product = Flatten()(InnerProductLayer()(embed_list))
outter_product = OutterProductLayer(kernel_type)(embed_list)

# ipnn deep input
Expand Down
4 changes: 2 additions & 2 deletions deepctr/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class SequencePoolingLayer(Layer):
- **mode**:str.Pooling operation to be used,can be sum,mean or max.
"""

def __init__(self, seq_len_max, mode='sum', **kwargs):
def __init__(self, seq_len_max, mode='mean', **kwargs):

if mode not in ['sum', 'mean', 'max']:
raise ValueError("mode must be sum or mean")
Expand Down Expand Up @@ -91,7 +91,7 @@ class AttentionSequencePoolingLayer(Layer):
- [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self, hidden_size=(80, 40), activation='sigmoid', weight_normalization=True, **kwargs):
def __init__(self, hidden_size=(80, 40), activation='sigmoid', weight_normalization=False, **kwargs):

self.hidden_size = hidden_size
self.activation = activation
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = '0.1.4'
release = '0.1.5'


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="deepctr",
version="0.1.4",
version="0.1.5",
author="Weichen Shen",
author_email="[email protected]",
description="DeepCTR is a Easy-to-use,Modular and Extendible package of deep-learning based CTR models ,including serval DNN-based CTR models and lots of core components layer of the models which can be used to build your own custom model.",
Expand Down

0 comments on commit 82efc6f

Please sign in to comment.