Skip to content

Commit

Permalink
Merge from dev (#15)
Browse files Browse the repository at this point in the history
* Fix bugs when using tensorflow version higher than 1.6.0.
* Now support tf version from 1.4.0 - 1.12.0 except for 1.7.* and 1.8.*
* Update docs
  • Loading branch information
Weichen Shen authored Dec 19, 2018
1 parent c641875 commit adabf33
Show file tree
Hide file tree
Showing 16 changed files with 152 additions and 64 deletions.
25 changes: 18 additions & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sudo: required
dist: trusty
#sudo: required
#dist: trusty xenial
language: python

python:
Expand All @@ -9,25 +9,34 @@ python:

env:
- TF_VERSION=1.4.0
- TF_VERSION=1.5.0
- TF_VERSION=1.5.1
- TF_VERSION=1.5.0 #- TF_VERSION=1.5.1
- TF_VERSION=1.6.0
#Not Support- TF_VERSION=1.7.0
#Not Support- TF_VERSION=1.7.1
#Not Support- TF_VERSION=1.8.0
- TF_VERSION=1.9.0
- TF_VERSION=1.10.0 #- TF_VERSION=1.10.1
- TF_VERSION=1.11.0
- TF_VERSION=1.12.0

matrix:
allow_failures:
- python: "3.4"
- python: "3.5"
#- env: TF_VERSION=1.5.0 #local is ok
- env: TF_VERSION=1.7.0
- env: TF_VERSION=1.5.0
- env: TF_VERSION=1.7.1
- env: TF_VERSION=1.8.0
fast_finish: true


cache: pip
# command to install dependencies
install:
- pip install -q pytest-cov==2.4.0
#>=2.4.0,<2.6
#>=2.4.0,<2.6
- pip install -q python-coveralls
- pip install -q codacy-coverage
- pip install -q h5py
- pip install -q tensorflow==$TF_VERSION
- pip install -e .
Expand All @@ -43,4 +52,6 @@ notifications:
on_failure: always

after_success:
coveralls
- coveralls
- coverage xml
- python-codacy-coverage -r coverage.xml
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# DeepCTR
![dep1](https://img.shields.io/badge/Tensorflow-1.4/1.5/1.6-blue.svg
)

[![Python Versions](https://img.shields.io/pypi/pyversions/deepctr.svg)](https://pypi.org/project/deepctr)
[![Downloads](https://pepy.tech/badge/deepctr)](https://pepy.tech/project/deepctr)
[![PyPI Version](https://img.shields.io/pypi/v/deepctr.svg)](https://pypi.org/project/deepctr)
[![GitHub Issues](https://img.shields.io/github/issues/shenweichen/deepctr.svg
)](https://github.com/shenweichen/deepctr/issues)
[![License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://github.com/shenweichen/deepctr/blob/master/LICENSE)
[![Activity](https://img.shields.io/github/last-commit/shenweichen/deepctr.svg)](https://github.com/shenweichen/DeepCTR/commits/master)


[![Documentation Status](https://readthedocs.org/projects/deepctr-doc/badge/?version=latest)](https://deepctr-doc.readthedocs.io/)
[![Build Status](https://travis-ci.com/shenweichen/DeepCTR.svg?branch=master)](https://travis-ci.com/shenweichen/DeepCTR)
[![Coverage Status](https://coveralls.io/repos/github/shenweichen/DeepCTR/badge.svg?branch=master)](https://coveralls.io/github/shenweichen/DeepCTR?branch=master)
[![Codacy Badge](https://api.codacy.com/project/badge/Grade/d4099734dc0e4bab91d332ead8c0bdd0)](https://www.codacy.com/app/wcshen1994/DeepCTR?utm_source=github.com&amp;utm_medium=referral&amp;utm_content=shenweichen/DeepCTR&amp;utm_campaign=Badge_Grade)
[![License](https://img.shields.io/github/license/shenweichen/deepctr.svg)](https://github.com/shenweichen/deepctr/blob/master/LICENSE)

DeepCTR is a **Easy-to-use**,**Modular** and **Extendible** package of deep-learning based CTR models along with lots of core components layer which can be used to build your own custom model easily.You can use any complex model with `model.fit()`and`model.predict()` just like any other keras model.And the layers are compatible with tensorflow.

DeepCTR is a **Easy-to-use**,**Modular** and **Extendible** package of deep-learning based CTR models along with lots of core components layer which can be used to build your own custom model easily.You can use any complex model with `model.fit()`and`model.predict()` just like any other keras model.And the layers are compatible with tensorflow.Through `pip install deepctr` get the package and [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html)
Through `pip install deepctr` get the package and [**Get Started!**](https://deepctr-doc.readthedocs.io/en/latest/Quick-Start.html)


## Models List
Expand Down
5 changes: 4 additions & 1 deletion deepctr/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Dice(Layer):
- **epsilon** : Small float added to variance to avoid dividing by zero.
References
- [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self, axis=-1, epsilon=1e-9, **kwargs):
Expand All @@ -43,3 +43,6 @@ def get_config(self,):
config = {'axis': self.axis, 'epsilon': self.epsilon}
base_config = super(Dice, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape
48 changes: 28 additions & 20 deletions deepctr/layers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
from tensorflow.python.keras.layers import Layer,Activation,BatchNormalization
from tensorflow.python.keras.regularizers import l2
from tensorflow.python.keras.initializers import Zeros,glorot_normal,glorot_uniform
Expand Down Expand Up @@ -35,10 +36,10 @@ def call(self, inputs,**kwargs):

concated_embeds_value = inputs

square_of_sum = K.square(K.sum(concated_embeds_value, axis=1, keepdims=True))
sum_of_square = K.sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=True)
square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keep_dims=True))
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keep_dims=True)
cross_term = square_of_sum - sum_of_square
cross_term = 0.5 * K.sum(cross_term, axis=2, keepdims=False)
cross_term = 0.5 * tf.reduce_sum(cross_term, axis=2, keep_dims=False)

return cross_term

Expand All @@ -56,7 +57,7 @@ class AFMLayer(Layer):
- 2D tensor with shape: ``(batch_size, 1)``.
Arguments
- **attention_factor** : Positive integer, dimensionality of the attention network output space.
- **l2_reg_w** : float between 0 and 1. L2 regularizer strength applied to attention network.
Expand Down Expand Up @@ -99,7 +100,7 @@ def build(self, input_shape):



embedding_size = input_shape[0][-1]
embedding_size = input_shape[0][-1].value

self.attention_W = self.add_weight(shape=(embedding_size, self.attention_factor), initializer=glorot_normal(seed=self.seed),regularizer=l2(self.l2_reg_w),
name="attention_W")
Expand All @@ -119,13 +120,18 @@ def call(self, inputs,**kwargs):
embeds_vec_list = inputs
row = []
col = []
num_inputs = len(embeds_vec_list)
for i in range(num_inputs - 1):
for j in range(i + 1, num_inputs):
row.append(i)
col.append(j)
p = tf.concat([embeds_vec_list[idx] for idx in row],axis=1)
q = tf.concat([embeds_vec_list[idx] for idx in col],axis=1)
# num_inputs = len(embeds_vec_list)
# for i in range(num_inputs - 1):
# for j in range(i + 1, num_inputs):
# row.append(i)
# col.append(j)
for r, c in itertools.combinations(embeds_vec_list, 2):
row.append(r)
col.append(c)
#p = tf.concat([embeds_vec_list[idx] for idx in row],axis=1)
#q = tf.concat([embeds_vec_list[idx] for idx in col], axis=1)
p = tf.concat(row,axis=1)
q = tf.concat(col,axis=1)
inner_product = p * q

bi_interaction = inner_product
Expand Down Expand Up @@ -155,7 +161,6 @@ def get_config(self,):

class PredictionLayer(Layer):


def __init__(self, activation='sigmoid',use_bias=True, **kwargs):
self.activation = activation
self.use_bias = use_bias
Expand Down Expand Up @@ -208,7 +213,7 @@ class CrossNet(Layer):
- **seed**: A Python integer to use as random seed.
References
- [Deep & Cross Network for Ad Click Predictions](https://arxiv.org/abs/1708.05123)
- [Wang R, Fu B, Fu G, et al. Deep & cross network for ad click predictions[C]//Proceedings of the ADKDD'17. ACM, 2017: 12.](https://arxiv.org/abs/1708.05123)
"""
def __init__(self, layer_num=2,l2_reg=0,seed=1024, **kwargs):
self.layer_num = layer_num
Expand All @@ -221,7 +226,7 @@ def build(self, input_shape):
if len(input_shape) != 2:
raise ValueError("Unexpected inputs dimensions %d, expect to be 2 dimensions" % (len(input_shape),))

dim = input_shape[-1]
dim = input_shape[-1].value
self.kernels = [self.add_weight(name='kernel'+str(i),
shape=(dim, 1),
initializer=glorot_normal(seed=self.seed),
Expand Down Expand Up @@ -252,6 +257,9 @@ def get_config(self,):
base_config = super(CrossNet, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def compute_output_shape(self, input_shape):
return input_shape


class MLP(Layer):
"""The Multi Layer Percetron
Expand Down Expand Up @@ -345,7 +353,7 @@ class BiInteractionPooling(Layer):
- 3D tensor with shape: ``(batch_size,1,embedding_size)``.
References
- [Neural Factorization Machines for Sparse Predictive Analytics](http://arxiv.org/abs/1708.05027)
- [He X, Chua T S. Neural factorization machines for sparse predictive analytics[C]//Proceedings of the 40th International ACM SIGIR conference on Research and Development in Information Retrieval. ACM, 2017: 355-364.](http://arxiv.org/abs/1708.05027)
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -389,7 +397,7 @@ class OutterProductLayer(Layer):
- **seed**: A Python integer to use as random seed.
References
- [Product-based Neural Networks for User Response Prediction](https://arxiv.org/pdf/1611.00144.pdf)
- [Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response prediction[C]//Data Mining (ICDM), 2016 IEEE 16th International Conference on. IEEE, 2016: 1149-1154.](https://arxiv.org/pdf/1611.00144.pdf)
"""

def __init__(self, kernel_type='mat', seed=1024, **kwargs):
Expand Down Expand Up @@ -423,7 +431,7 @@ def build(self, input_shape):
num_inputs = len(input_shape)
num_pairs = int(num_inputs * (num_inputs - 1) / 2)
input_shape = input_shape[0]
embed_size = input_shape[-1]
embed_size = input_shape[-1].value
if self.kernel_type == 'mat':

self.kernel = self.add_weight(shape=(embed_size,num_pairs,embed_size), initializer=glorot_uniform(seed=self.seed),
Expand Down Expand Up @@ -520,7 +528,7 @@ class InnerProductLayer(Layer):
- **reduce_sum**: bool. Whether return inner product or element-wise product
References
- [Product-based Neural Networks for User Response Prediction](https://arxiv.org/pdf/1611.00144.pdf)
- [Qu Y, Cai H, Ren K, et al. Product-based neural networks for user response prediction[C]//Data Mining (ICDM), 2016 IEEE 16th International Conference on. IEEE, 2016: 1149-1154.](https://arxiv.org/pdf/1611.00144.pdf)
"""

def __init__(self,reduce_sum=True,**kwargs):
Expand Down Expand Up @@ -612,7 +620,7 @@ class LocalActivationUnit(Layer):
- **seed**: A Python integer to use as random seed.
References
- [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self,hidden_size=(64,32), activation='sigmoid',l2_reg=0, keep_prob=1, use_bn=False,seed=1024,**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion deepctr/models/afm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Weichen Shen,[email protected]
Reference:
[1] Attentional Factorization Machines: Learning the Weight of Feature Interactions via Attention Networks
[1] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of feature interactions via attention networks[J]. arXiv preprint arXiv:1708.04617, 2017.
(https://arxiv.org/abs/1708.04617)
"""
Expand Down
8 changes: 4 additions & 4 deletions deepctr/models/mlr.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ def MLR(region_feature_dim_dict,base_feature_dim_dict={"sparse":{},"dense":[]},r
add([region_embeddings[j][i](region_sparse_input[i]) for i in range(region_sparse_feature_num)])
for j in range(region_num)]
region_logits =Concatenate()([add([region_sparse_logits[i],region_dense_logits_[i]]) for i in range(region_num)])

if base_dense_feature_num > 0 and base_sparse_feature_num == 0:
base_logits = base_dense_logits
elif base_dense_feature_num == 0 and base_sparse_feature_num > 0:
base_sparse_logits = [add(
[base_embeddings[j][i](base_sparse_input_[i]) for i in range(base_sparse_feature_num)]) if base_sparse_feature_num > 1 else base_embeddings[j][0](base_sparse_input_[0])
[base_embeddings[j][i](base_sparse_input_[i]) for i in range(base_sparse_feature_num)]) if base_sparse_feature_num > 1 else base_embeddings[j][0](base_sparse_input_[0])
for j in range(region_num)]
base_logits = base_sparse_logits
else:
base_sparse_logits = [add(
[base_embeddings[j][i](base_sparse_input_[i]) for i in range(base_sparse_feature_num)]) if base_sparse_feature_num > 1 else base_embeddings[j][0](base_sparse_input_[0])
[base_embeddings[j][i](base_sparse_input_[i]) for i in range(base_sparse_feature_num)]) if base_sparse_feature_num > 1 else base_embeddings[j][0](base_sparse_input_[0])
for j in range(region_num)]
base_logits = [add([base_sparse_logits[i], base_dense_logits[i]]) for i in range(region_num)]

Expand All @@ -128,7 +128,7 @@ def MLR(region_feature_dim_dict,base_feature_dim_dict={"sparse":{},"dense":[]},r
bias_cate_logits = bias_embedding[0](bias_sparse_input[0])
else:
pass

if bias_dense_feature_num >0 and bias_sparse_feature_num > 0:
bias_logits = add([bias_dense_logits, bias_cate_logits])
elif bias_dense_feature_num > 0:
Expand Down
2 changes: 1 addition & 1 deletion deepctr/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class AttentionSequencePoolingLayer(Layer):
- **weight_normalization**: bool.Whether normalize the attention score of local activation unit.
References
- [Deep Interest Network for Click-Through Rate Prediction](https://arxiv.org/pdf/1706.06978.pdf)
- [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf)
"""

def __init__(self, hidden_size=(80, 40), activation='sigmoid', weight_normalization=False, **kwargs):
Expand Down
26 changes: 25 additions & 1 deletion docs/source/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,31 @@ To save/load models,just a little different.
from deepctr.utils import custom_objects
model = load_model('DeepFM.h5',custom_objects)# load_model,just add a parameter
2. Does the models support multi-value input?
2. How can I get the attentional weights of feature interactions in AFM?

First,make sure that you have install the latest version of deepctr.

Then,use the following code,the ``attentional_weights[:,i,0]`` is the ``feature_interactions[i]``'s attentional weight of all samples.

.. code-block:: python
import itertools
import deepctr
from tensorflow.python.keras.models import Model
from tensorflow.python.keras.layers import Lambda
feature_dim_dict = {"sparse": sparse_feature_dict, "dense": dense_feature_list}
model = deepctr.models.AFM(feature_dim_dict)
model.fit(model_input,target)
afmlayer = model.layers[-3]
afm_weight_model = Model(model.input,outputs=Lambda(lambda x:afmlayer.normalized_att_score)(model.input))
attentional_weights = afm_weight_model.predict(model_input,batch_size=4096)
feature_interactions = list(itertools.combinations(list(feature_dim_dict['sparse'].keys()) + feature_dim_dict['dense'] ,2))
3. Does the models support multi-value input?

Now only the `DIN <Features.html#din-deep-interest-network>`_ model support multi-value input,you can use layers in `sequence <deepctr.sequence.html>`_ to build your own models!
And I will add the feature soon~
Loading

0 comments on commit adabf33

Please sign in to comment.