Skip to content

Commit

Permalink
Refactor & Add sequence input support
Browse files Browse the repository at this point in the history
* Refactor Input&Embedding

* Support sequence(multi-value) input for AFM,AutoInt,DCN,DeepFM,FNN,NFM,PNN,xDeepFM models
  • Loading branch information
Weichen Shen authored Jan 1, 2019
1 parent cc844f3 commit d524c86
Show file tree
Hide file tree
Showing 33 changed files with 722 additions and 767 deletions.
2 changes: 1 addition & 1 deletion deepctr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
from .import sequence
from . import models
from .utils import check_version
__version__ = '0.2.1'
__version__ = '0.2.2'
check_version(__version__)
164 changes: 164 additions & 0 deletions deepctr/input_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from itertools import chain

from tensorflow.python.keras import Input
from tensorflow.python.keras.initializers import RandomNormal
from tensorflow.python.keras.layers import Embedding, Dense, Reshape, Concatenate
from tensorflow.python.keras.regularizers import l2
from .sequence import SequencePoolingLayer
from .utils import get_linear_logit


def create_input_dict(feature_dim_dict, prefix=''):
sparse_input = {feat: Input(shape=(1,), name=prefix+'sparse_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["sparse"])}
dense_input = {feat: Input(shape=(1,), name=prefix+'dense_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["dense"])}
return sparse_input, dense_input


def create_sequence_input_dict(feature_dim_dict):

sequence_dim_dict = feature_dim_dict.get('sequence', [])
sequence_input_dict = {feat.name: Input(shape=(feat.maxlen,), name='seq_' + str(
i) + '-' + feat.name) for i, feat in enumerate(sequence_dim_dict)}
sequence_pooling_dict = {feat.name: feat.combiner
for i, feat in enumerate(sequence_dim_dict)}
sequence_len_dict = {feat.name: Input(shape=(
1,), name='seq_length'+str(i)+'-'+feat.name) for i, feat in enumerate(sequence_dim_dict)}
sequence_max_len_dict = {feat.name: feat.maxlen
for i, feat in enumerate(sequence_dim_dict)}
return sequence_input_dict, sequence_pooling_dict, sequence_len_dict, sequence_max_len_dict


def create_embedding_dict(feature_dim_dict, embedding_size, init_std, seed, l2_reg, prefix='sparse'):
if embedding_size == 'auto':

sparse_embedding = {feat: Embedding(feature_dim_dict["sparse"][feat], 6 * int(pow(feature_dim_dict["sparse"][feat], 0.25)),
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(l2_reg),
name=prefix+'_emb_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["sparse"])}
else:

sparse_embedding = {feat: Embedding(feature_dim_dict["sparse"][feat], embedding_size,
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(l2_reg),
name=prefix+'_emb_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["sparse"])}

if 'sequence' in feature_dim_dict:
count = len(sparse_embedding)
sequence_dim_list = feature_dim_dict['sequence']
for feat in sequence_dim_list:
if feat.name not in sparse_embedding:
if embedding_size == "auto":
sparse_embedding[feat.name] = Embedding(feat.dimension, 6 * int(pow(feat.dimension, 0.25)),
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(
l2_reg),
name=prefix + '_emb_' + str(count) + '-' + feat.name)

else:
sparse_embedding[feat.name] = Embedding(feat.dimension, embedding_size,
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(
l2_reg),
name=prefix+'_emb_' + str(count) + '-' + feat.name)

count += 1

return sparse_embedding


def merge_dense_input(dense_input_, embed_list, embedding_size, l2_reg):
dense_input = list(dense_input_.values())
if len(dense_input) > 0:
if embedding_size == "auto":
if len(dense_input) == 1:
continuous_embedding_list = dense_input[0]
else:
continuous_embedding_list = Concatenate()(dense_input)
continuous_embedding_list = Reshape(
[1, len(dense_input)])(continuous_embedding_list)
embed_list.append(continuous_embedding_list)

else:
continuous_embedding_list = list(
map(Dense(embedding_size, use_bias=False, kernel_regularizer=l2(l2_reg), ),
dense_input))
continuous_embedding_list = list(
map(Reshape((1, embedding_size)), continuous_embedding_list))
embed_list += continuous_embedding_list

return embed_list


def merge_sequence_input(embedding_dict, embed_list, sequence_input_dict, sequence_len_dict, sequence_max_len_dict, sequence_pooling_dict):
if len(sequence_input_dict) > 0:
sequence_embed_dict = get_varlen_embedding_vec_dict(
embedding_dict, sequence_input_dict)
sequence_embed_list = get_pooling_vec_list(
sequence_embed_dict, sequence_len_dict, sequence_max_len_dict, sequence_pooling_dict)
embed_list += sequence_embed_list

return embed_list


def get_embedding_vec_list(embedding_dict, input_dict):

return [embedding_dict[feat](v)
for feat, v in input_dict.items()]


def get_varlen_embedding_vec_dict(embedding_dict, input_dict):

return {feat: embedding_dict[feat](v)
for feat, v in input_dict.items()}


def get_pooling_vec_list(sequence_embed_dict, sequence_len_dict, sequence_max_len_dict, sequence_pooling_dict):
return [SequencePoolingLayer(sequence_max_len_dict[feat], sequence_pooling_dict[feat])(
[v, sequence_len_dict[feat]]) for feat, v in sequence_embed_dict.items()]


def get_inputs_list(inputs):
return list(chain(*list(map(lambda x: x.values(), inputs))))


def get_inputs_embedding(feature_dim_dict, embedding_size, l2_reg_embedding, l2_reg_linear, init_std, seed, include_linear=True):
sparse_input_dict, dense_input_dict = create_input_dict(feature_dim_dict)
sequence_input_dict, sequence_pooling_dict, sequence_input_len_dict, sequence_max_len_dict = create_sequence_input_dict(
feature_dim_dict)

deep_sparse_emb_dict = create_embedding_dict(
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding)

deep_emb_list = get_embedding_vec_list(
deep_sparse_emb_dict, sparse_input_dict)

deep_emb_list = merge_sequence_input(deep_sparse_emb_dict, deep_emb_list, sequence_input_dict,
sequence_input_len_dict, sequence_max_len_dict, sequence_pooling_dict)

deep_emb_list = merge_dense_input(
dense_input_dict, deep_emb_list, embedding_size, l2_reg_embedding)
if include_linear:
linear_sparse_emb_dict = create_embedding_dict(
feature_dim_dict, 1, init_std, seed, l2_reg_linear, 'linear')
linear_emb_list = get_embedding_vec_list(
linear_sparse_emb_dict, sparse_input_dict)
linear_emb_list = merge_sequence_input(linear_sparse_emb_dict, linear_emb_list, sequence_input_dict,
sequence_input_len_dict,
sequence_max_len_dict, sequence_pooling_dict)

linear_logit = get_linear_logit(
linear_emb_list, dense_input_dict, l2_reg_linear)
else:
linear_logit = None

inputs_list = get_inputs_list(
[sparse_input_dict, dense_input_dict, sequence_input_dict, sequence_input_len_dict])
return deep_emb_list, linear_logit, inputs_list
50 changes: 12 additions & 38 deletions deepctr/models/afm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@
(https://arxiv.org/abs/1708.04617)
"""

from tensorflow.python.keras.layers import Dense, Concatenate, Reshape, add
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.regularizers import l2

from ..utils import get_input, get_share_embeddings
import tensorflow as tf
from ..input_embedding import get_inputs_embedding
from ..layers import PredictionLayer, AFMLayer, FM
from ..utils import concat_fun


def AFM(feature_dim_dict, embedding_size=8, use_attention=True, attention_factor=8,
Expand Down Expand Up @@ -48,41 +45,18 @@ def AFM(feature_dim_dict, embedding_size=8, use_attention=True, attention_factor
raise ValueError("feature_dim_dict['dense'] must be a list,cur is", type(
feature_dim_dict['dense']))

sparse_input, dense_input = get_input(feature_dim_dict, None)
sparse_embedding, linear_embedding, = get_share_embeddings(
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding, l2_reg_linear)

embed_list = [sparse_embedding[i](sparse_input[i])
for i in range(len(sparse_input))]
linear_term = [linear_embedding[i](sparse_input[i])
for i in range(len(sparse_input))]
if len(linear_term) > 1:
linear_term = add(linear_term)
elif len(linear_term) == 1:
linear_term = linear_term[0]
deep_emb_list, linear_logit, inputs_list = get_inputs_embedding(
feature_dim_dict, embedding_size, l2_reg_embedding, l2_reg_linear, init_std, seed)

if len(dense_input) > 0:
continuous_embedding_list = list(
map(Dense(embedding_size, use_bias=False, kernel_regularizer=l2(l2_reg_embedding), ),
dense_input))
continuous_embedding_list = list(
map(Reshape((1, embedding_size)), continuous_embedding_list))
embed_list += continuous_embedding_list

dense_input_ = dense_input[0] if len(
dense_input) == 1 else Concatenate()(dense_input)
linear_dense_logit = Dense(
1, activation=None, use_bias=False, kernel_regularizer=l2(l2_reg_linear))(dense_input_)
linear_term = add([linear_dense_logit, linear_term])

fm_input = Concatenate(axis=1)(embed_list)
fm_input = concat_fun(deep_emb_list,axis=1)
if use_attention:
fm_out = AFMLayer(attention_factor, l2_reg_att,
keep_prob, seed)(embed_list)
fm_logit = AFMLayer(attention_factor, l2_reg_att,
keep_prob, seed)(deep_emb_list)
else:
fm_out = FM()(fm_input)
fm_logit = FM()(fm_input)

final_logit = add([linear_term, fm_out])
final_logit = tf.keras.layers.add([linear_logit, fm_logit])
output = PredictionLayer(final_activation)(final_logit)
model = Model(inputs=sparse_input + dense_input, outputs=output)

model = tf.keras.models.Model(inputs=inputs_list, outputs=output)
return model
55 changes: 16 additions & 39 deletions deepctr/models/autoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
"""

from tensorflow.python.keras.layers import Dense, Embedding, Concatenate
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.initializers import RandomNormal
from tensorflow.python.keras.regularizers import l2
import tensorflow as tf

from ..utils import get_input
from ..input_embedding import get_inputs_embedding
from ..layers import PredictionLayer, MLP, InteractingLayer
from ..utils import concat_fun


def AutoInt(feature_dim_dict, embedding_size=8, att_layer_num=3, att_embedding_size=8, att_head_num=2, att_res=True, hidden_size=(256, 256), activation='relu',
Expand Down Expand Up @@ -48,56 +44,37 @@ def AutoInt(feature_dim_dict, embedding_size=8, att_layer_num=3, att_embedding_s
raise ValueError(
"feature_dim must be a dict like {'sparse':{'field_1':4,'field_2':3,'field_3':2},'dense':['field_5',]}")

sparse_input, dense_input = get_input(feature_dim_dict, None,)
sparse_embedding = get_embeddings(
feature_dim_dict, embedding_size, init_std, seed, l2_reg_embedding)
embed_list = [sparse_embedding[i](sparse_input[i])
for i in range(len(sparse_input))]
deep_emb_list, _, inputs_list = get_inputs_embedding(
feature_dim_dict, embedding_size, l2_reg_embedding, 0, init_std, seed, False)

att_input = Concatenate(axis=1)(embed_list) if len(
embed_list) > 1 else embed_list[0]
att_input = concat_fun(deep_emb_list, axis=1)

for i in range(att_layer_num):
for _ in range(att_layer_num):
att_input = InteractingLayer(
att_embedding_size, att_head_num, att_res)(att_input)
att_output = tf.keras.layers.Flatten()(att_input)

deep_input = tf.keras.layers.Flatten()(Concatenate()(embed_list)
if len(embed_list) > 1 else embed_list[0])
if len(dense_input) > 0:
if len(dense_input) == 1:
continuous_list = dense_input[0]
else:
continuous_list = Concatenate()(dense_input)

deep_input = Concatenate()([deep_input, continuous_list])
deep_input = tf.keras.layers.Flatten()(concat_fun(deep_emb_list))

if len(hidden_size) > 0 and att_layer_num > 0: # Deep & Interacting Layer
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
use_bn, seed)(deep_input)
stack_out = Concatenate()([att_output, deep_out])
final_logit = Dense(1, use_bias=False, activation=None)(stack_out)
stack_out = tf.keras.layers.Concatenate()([att_output, deep_out])
final_logit = tf.keras.layers.Dense(
1, use_bias=False, activation=None)(stack_out)
elif len(hidden_size) > 0: # Only Deep
deep_out = MLP(hidden_size, activation, l2_reg_deep, keep_prob,
use_bn, seed)(deep_input)
final_logit = Dense(1, use_bias=False, activation=None)(deep_out)
final_logit = tf.keras.layers.Dense(
1, use_bias=False, activation=None)(deep_out)
elif att_layer_num > 0: # Only Interacting Layer
final_logit = Dense(1, use_bias=False, activation=None)(att_output)
final_logit = tf.keras.layers.Dense(
1, use_bias=False, activation=None)(att_output)
else: # Error
raise NotImplementedError

output = PredictionLayer(final_activation)(final_logit)
model = Model(inputs=sparse_input + dense_input, outputs=output)

return model


def get_embeddings(feature_dim_dict, embedding_size, init_std, seed, l2_rev_V):
sparse_embedding = [Embedding(feature_dim_dict["sparse"][feat], embedding_size,
embeddings_initializer=RandomNormal(
mean=0.0, stddev=init_std, seed=seed),
embeddings_regularizer=l2(l2_rev_V),
name='sparse_emb_' + str(i) + '-' + feat) for i, feat in
enumerate(feature_dim_dict["sparse"])]
model = tf.keras.models.Model(inputs=inputs_list, outputs=output)

return sparse_embedding
return model
Loading

0 comments on commit d524c86

Please sign in to comment.