diff --git a/.readthedocs.yaml b/.readthedocs.yaml index a023e0bfa..41a9f4e98 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,4 +1,3 @@ - version: 2 build: diff --git a/docs/source/component/backbone.md b/docs/source/component/backbone.md index 5e05ec589..de77f85ec 100644 --- a/docs/source/component/backbone.md +++ b/docs/source/component/backbone.md @@ -131,6 +131,7 @@ MovieLens-1M数据集效果对比: - 还有一些特殊的`block`关联了一个特殊的模块,包括`lambda layer`、`sequential layers`、`repeated layer`和`recurrent layer`。这些特殊layer分别实现了自定义表达式、顺序执行多个layer、重复执行某个layer、循环执行某个layer的功能。 - DAG的输出节点名由`concat_blocks`配置项指定,配置了多个输出节点时自动执行tensor的concat操作。 - 如果不配置`concat_blocks`,框架会自动拼接DAG的所有叶子节点并输出。 +- 如果多个`block`的输出不需要 concat 在一起,而是作为一个list类型(下游对接多目标学习的tower)可以用`output_blocks`代替`concat_blocks` - 可以为主干网络配置一个可选的`MLP`模块。 ![](../../images/component/wide_deep.png) @@ -1275,6 +1276,8 @@ message InputLayer { optional bool only_output_3d_tensor = 6; optional bool output_2d_tensor_and_feature_list = 7; optional bool output_seq_and_normal_feature = 8; + optional uint32 wide_output_dim = 9; + optional bool concat_seq_feature = 10 [default = true]; } ``` @@ -1288,6 +1291,8 @@ message InputLayer { - `only_output_3d_tensor` 输出`feature group`对应的一个3d tensor,在`embedding_dim`相同时可配置该项 - `output_2d_tensor_and_feature_list` 是否同时输出2d tensor与特征list - `output_seq_and_normal_feature` 是否输出(sequence特征, 常规特征)元组 +- `wide_output_dim` wide模型每个特征的参数权重维度,一般设定为1 +- `concat_seq_feature` 是否需要把序列特征的embedding拼接在一起 ## 3. Lambda组件块 @@ -1437,6 +1442,12 @@ blocks { } ``` +## 8. 输出组件 + +- 使用`concat_blocks`或者`output_blocks`配置主干网络的输出 +- 两种的区别是前者会对多个输出组件块的结果按照最后一个axis拼接在一起;后者不会拼接,而是以list类型输出 +- 如果不配置上述两个选项,框架会自动拼接DAG的所有叶子节点并输出。 + ## 通过`组件包`实现参数共享的子网络 `组件包`封装了由多个`组件块`搭建的一个子网络DAG,作为整体可以被以参数共享的方式多次调用,通常用在 *自监督学习* 模型中。 diff --git a/docs/source/component/component.md b/docs/source/component/component.md index 49a18662a..8ef90b79e 100644 --- a/docs/source/component/component.md +++ b/docs/source/component/component.md @@ -4,10 +4,10 @@ | 类名 | 功能 | 说明 | 示例 | | ----------------- | ------ | ------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- | -| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.md#wide-deep) | +| MLP | 多层感知机 | 可定制激活函数、initializer、Dropout、BN等 | [案例1](backbone.html#wide-deep) | | Highway | 类似残差链接 | 可用来对预训练embedding做增量微调 | [highway network](../models/highway.html) | | Gate | 门控 | 多个输入的加权求和 | [Cross Decoupling Network](../models/cdn.html#id2) | -| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.md#dlrm-embedding) | +| PeriodicEmbedding | 周期激活函数 | 数值特征Embedding | [案例5](backbone.html#dlrm-embedding) | | AutoDisEmbedding | 自动离散化 | 数值特征Embedding | [dlrm_on_criteo_with_autodis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_autodis.config) | | NaryDisEmbedding | N进制编码 | 数值特征Embedding | [dlrm_on_criteo_with_narydis.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/dlrm_on_criteo_with_narydis.config) | | TextCNN | 文本卷积 | 提取文本序列的特征 | [text_cnn_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/text_cnn_on_movielens.config) | @@ -18,9 +18,9 @@ | 类名 | 功能 | 说明 | 示例 | | -------------- | ---------------- | ------------ | -------------------------------------------------------------------------------------------------------------------------- | -| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.md#deepfm) | -| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.md#dlrm) | -| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.md#dcn) | +| FM | 二阶交叉 | DeepFM模型的组件 | [案例2](backbone.html#deepfm) | +| DotInteraction | 二阶内积交叉 | DLRM模型的组件 | [案例4](backbone.html#dlrm) | +| Cross | bit-wise交叉 | DCN v2模型的组件 | [案例3](backbone.html#dcn) | | BiLinear | 双线性 | FiBiNet模型的组件 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) | | FiBiNet | SENet & BiLinear | FiBiNet模型 | [fibinet_on_movielens.config](https://github.com/alibaba/EasyRec/tree/master/examples/configs/fibinet_on_movielens.config) | @@ -50,14 +50,14 @@ | 类名 | 功能 | 说明 | 示例 | | --------- | --------------------------- | --------- | ----------------------------- | -| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.md#mmoe) | +| MMoE | Multiple Mixture of Experts | MMoE模型的组件 | [案例8](backbone.html#mmoe) | | AITMTower | AITM模型的一个tower | AITM模型的组件 | [AITM](../models/aitm.md#id2) | ## 6. 辅助损失函数组件 -| 类名 | 功能 | 说明 | 示例 | -| ------------- | ---------- | --------- | ---------------------- | -| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.md#id7) | +| 类名 | 功能 | 说明 | 示例 | +| ------------- | ---------- | --------- | ------------------------ | +| AuxiliaryLoss | 用来计算辅助损失函数 | 常用在自监督学习中 | [案例7](backbone.html#id7) | # 组件详细参数 @@ -138,6 +138,31 @@ ## 2.特征交叉组件 +- FM + +| 参数 | 类型 | 默认值 | 说明 | +| ----------- | ---- | ----- | -------------------------- | +| use_variant | bool | false | 是否使用FM的变体:所有二阶交叉项直接输出,而不求和 | + +- DotInteraction + +| 参数 | 类型 | 默认值 | 说明 | +| ---------------- | ---- | ----- | ------------------------------------ | +| self_interaction | bool | false | 是否运行特征自己与自己交叉 | +| skip_gather | bool | false | 一个优化开关,设置为true,可以提高运行速度,但需要占用更多的内存空间 | + +- Cross + +| 参数 | 类型 | 默认值 | 说明 | +| ------------------ | ------ | ---------------- | ------------------------------------------------------------------------------------------------------------------------- | +| projection_dim | uint32 | None | 使用矩阵分解降低计算开销,把大的权重矩阵分解为两个小的矩阵相乘,projection_dim是第一个小矩阵的列数,也是第二个小矩阵的行数 | +| diag_scale | float | 0 | used to increase the diagonal of the kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an identity matrix | +| use_bias | bool | true | whether to add a bias term for this layer. | +| kernel_initializer | string | truncated_normal | Initializer to use on the kernel matrix | +| bias_initializer | string | zeros | Initializer to use on the bias vector | +| kernel_regularizer | string | None | Regularizer to use on the kernel matrix | +| bias_regularizer | string | None | Regularizer to use on bias vector | + - Bilinear | 参数 | 类型 | 默认值 | 说明 | diff --git a/docs/source/component/custom_loss.md b/docs/source/component/custom_loss.md new file mode 100644 index 000000000..5e2c2a1dc --- /dev/null +++ b/docs/source/component/custom_loss.md @@ -0,0 +1,36 @@ +# 自定义辅助损失函数组件 + +可以使用如下方法添加多个辅助损失函数。 + +在`easy_rec/python/layers/keras/auxiliary_loss.py`里添加一个新的loss函数。 +如果计算逻辑比较复杂,建议在一个单独的python文件中实现,然后在`auxiliary_loss.py`里import并使用。 + +注意:用来标记损失函数类型的`loss_type`参数需要全局唯一。 + +## 配置方法 + +```protobuf +blocks { + name: 'custom_loss' + inputs { + block_name: 'pred' + } + inputs { + block_name: 'logit' + } + merge_inputs_into_list: true + keras_layer { + class_name: 'AuxiliaryLoss' + st_params { + fields { + key: "loss_type" + value { string_value: "my_custom_loss" } + } + } + } +} +``` + +st_params 参数列表下可以追加自定义参数。 + +记得使用`concat_blocks`或者`output_blocks`配置输出的block列表(不包括当前`custom_loss`节点)。 diff --git a/docs/source/index.rst b/docs/source/index.rst index 7eeebba67..10ed89920 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -31,6 +31,7 @@ Welcome to easy_rec's documentation! component/backbone component/component component/sequence + component/custom_loss component/custom_op .. toctree:: diff --git a/docs/source/models/loss.md b/docs/source/models/loss.md index d0c028d5d..e098aa0a6 100644 --- a/docs/source/models/loss.md +++ b/docs/source/models/loss.md @@ -25,7 +25,12 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 | ORDER_CALIBRATE_LOSS | 使用目标依赖关系校正预测结果的辅助损失函数,详见[AITM](aitm.md)模型 | | LISTWISE_RANK_LOSS | listwise的排序损失 | | LISTWISE_DISTILL_LOSS | 用来蒸馏给定list排序的损失函数,与listwise rank loss 比较类似 | +| ZILN_LOSS | LTV预测任务的损失函数(num_class必须设置为3) | +- ZILN_LOSS:使用时模型有3个可选的输出(在多目标任务重,输出名有一个目标相关的后缀) + - probs: 预估的转化概率 + - y: 预估的LTV值 + - logits: Shape为`[batch_size, 3]`的tensor,第一列是`probs`,第二列和第三列是学习到的LogNormal分布的均值与方差 - 说明:SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING - 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317) , - 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。 @@ -184,3 +189,4 @@ EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2 - [Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning](https://arxiv.org/abs/2111.10603) - [AITM: Modeling the Sequential Dependence among Audience Multi-step Conversions with Multi-task Learning in Targeted Display Advertising](https://arxiv.org/pdf/2105.08489.pdf) - [Pairwise Ranking Distillation for Deep Face Recognition](https://ceur-ws.org/Vol-2744/paper30.pdf) +- [A DEEP PROBABILISTIC MODEL FOR CUSTOMER LIFETIME VALUE PREDICTION](https://arxiv.org/pdf/1912.07753) diff --git a/easy_rec/python/builders/loss_builder.py b/easy_rec/python/builders/loss_builder.py index 36cdd95b4..720dfdd9e 100644 --- a/easy_rec/python/builders/loss_builder.py +++ b/easy_rec/python/builders/loss_builder.py @@ -2,6 +2,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import logging +import numpy as np import tensorflow as tf from easy_rec.python.loss.focal_loss import sigmoid_focal_loss_with_logits @@ -14,6 +15,8 @@ from easy_rec.python.loss.pairwise_loss import pairwise_loss from easy_rec.python.protos.loss_pb2 import LossType +from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA + from easy_rec.python.loss.f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA if tf.__version__ >= '2.0': @@ -46,6 +49,11 @@ def build(loss_type, logging.info('%s is used' % LossType.Name(loss_type)) return tf.losses.mean_squared_error( labels=label, predictions=pred, weights=loss_weight, **kwargs) + elif loss_type == LossType.ZILN_LOSS: + loss = zero_inflated_lognormal_loss(label, pred) + if np.isscalar(loss_weight) and loss_weight != 1.0: + return loss * loss_weight + return loss elif loss_type == LossType.JRC_LOSS: session = kwargs.get('session_ids', None) if loss_param is None: diff --git a/easy_rec/python/compat/early_stopping.py b/easy_rec/python/compat/early_stopping.py index fc850fb62..fe4c12132 100644 --- a/easy_rec/python/compat/early_stopping.py +++ b/easy_rec/python/compat/early_stopping.py @@ -21,9 +21,9 @@ import os import threading import time +from distutils.version import LooseVersion import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import init_ops diff --git a/easy_rec/python/core/sampler.py b/easy_rec/python/core/sampler.py index a3f8bf7fd..cb6d04e8c 100644 --- a/easy_rec/python/core/sampler.py +++ b/easy_rec/python/core/sampler.py @@ -79,7 +79,10 @@ def _init_graph(self): if 'ps' in tf_config['cluster']: # ps mode tf_config = json.loads(os.environ['TF_CONFIG']) - task_count = len(tf_config['cluster']['worker']) + 2 + if 'worker' in tf_config['cluster']: + task_count = len(tf_config['cluster']['worker']) + 2 + else: + task_count = 2 if self._is_on_ds: gl.set_tracker_mode(0) server_hosts = [ diff --git a/easy_rec/python/loss/jrc_loss.py b/easy_rec/python/loss/jrc_loss.py index 9ffe5b518..b5165d3c2 100644 --- a/easy_rec/python/loss/jrc_loss.py +++ b/easy_rec/python/loss/jrc_loss.py @@ -1,6 +1,7 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging + import numpy as np import tensorflow as tf @@ -122,6 +123,6 @@ def jrc_loss(labels, else: raise ValueError('Unsupported loss weight strategy `%s` for jrc loss' % loss_weight_strategy) - if np.isscalar(sample_weights): + if np.isscalar(sample_weights) and sample_weights != 1.0: return loss * sample_weights return loss diff --git a/easy_rec/python/loss/zero_inflated_lognormal.py b/easy_rec/python/loss/zero_inflated_lognormal.py new file mode 100644 index 000000000..e3ae3110e --- /dev/null +++ b/easy_rec/python/loss/zero_inflated_lognormal.py @@ -0,0 +1,76 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Zero-inflated lognormal loss for lifetime value prediction.""" +import tensorflow as tf +import tensorflow_probability as tfp + +tfd = tfp.distributions + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + + +def zero_inflated_lognormal_pred(logits): + """Calculates predicted mean of zero inflated lognormal logits. + + Arguments: + logits: [batch_size, 3] tensor of logits. + + Returns: + positive_probs: [batch_size, 1] tensor of positive probability. + preds: [batch_size, 1] tensor of predicted mean. + """ + logits = tf.convert_to_tensor(logits, dtype=tf.float32) + positive_probs = tf.keras.backend.sigmoid(logits[..., :1]) + loc = logits[..., 1:2] + scale = tf.keras.backend.softplus(logits[..., 2:]) + preds = ( + positive_probs * + tf.keras.backend.exp(loc + 0.5 * tf.keras.backend.square(scale))) + return positive_probs, preds + + +def zero_inflated_lognormal_loss(labels, logits, name=''): + """Computes the zero inflated lognormal loss. + + Usage with tf.keras API: + + ```python + model = tf.keras.Model(inputs, outputs) + model.compile('sgd', loss=zero_inflated_lognormal) + ``` + + Arguments: + labels: True targets, tensor of shape [batch_size, 1]. + logits: Logits of output layer, tensor of shape [batch_size, 3]. + name: the name of loss + + Returns: + Zero inflated lognormal loss value. + """ + loss_name = name if name else 'ziln_loss' + labels = tf.cast(labels, dtype=tf.float32) + if labels.shape.ndims == 1: + labels = tf.expand_dims(labels, 1) # [B, 1] + positive = tf.cast(labels > 0, tf.float32) + + logits = tf.convert_to_tensor(logits, dtype=tf.float32) + logits.shape.assert_is_compatible_with( + tf.TensorShape(labels.shape[:-1].as_list() + [3])) + + positive_logits = logits[..., :1] + classification_loss = tf.keras.backend.binary_crossentropy( + positive, positive_logits, from_logits=True) + classification_loss = tf.keras.backend.mean(classification_loss) + tf.summary.scalar('loss/%s_classify' % loss_name, classification_loss) + + loc = logits[..., 1:2] + scale = tf.math.maximum( + tf.keras.backend.softplus(logits[..., 2:]), + tf.math.sqrt(tf.keras.backend.epsilon())) + safe_labels = positive * labels + ( + 1 - positive) * tf.keras.backend.ones_like(labels) + regression_loss = -tf.keras.backend.mean( + positive * tfd.LogNormal(loc=loc, scale=scale).log_prob(safe_labels)) + tf.summary.scalar('loss/%s_regression' % loss_name, regression_loss) + return classification_loss + regression_loss diff --git a/easy_rec/python/model/rank_model.py b/easy_rec/python/model/rank_model.py index a144b999a..640f52502 100644 --- a/easy_rec/python/model/rank_model.py +++ b/easy_rec/python/model/rank_model.py @@ -9,6 +9,8 @@ from easy_rec.python.model.easy_rec_model import EasyRecModel from easy_rec.python.protos.loss_pb2 import LossType +from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_pred # NOQA + if tf.__version__ >= '2.0': tf = tf.compat.v1 @@ -78,6 +80,14 @@ def _output_to_prediction_impl(self, prediction_dict['logits' + suffix] = output prediction_dict['pos_logits' + suffix] = output[:, 1] prediction_dict['probs' + suffix] = probs[:, 1] + elif loss_type == LossType.ZILN_LOSS: + assert num_class == 3, 'num_class must be 3 when loss type is ZILN_LOSS' + probs, preds = zero_inflated_lognormal_pred(output) + tf.summary.scalar('prediction/probs', tf.reduce_mean(probs)) + tf.summary.scalar('prediction/y', tf.reduce_mean(preds)) + prediction_dict['logits' + suffix] = output + prediction_dict['probs' + suffix] = probs + prediction_dict['y' + suffix] = preds elif loss_type == LossType.CLASSIFICATION: if num_class == 1: output = tf.squeeze(output, axis=1) @@ -148,7 +158,7 @@ def build_rtp_output_dict(self): 'failed to build RTP rank_predict output: classification model ' + "expect 'probs' prediction, which is not found. Please check if" + ' build_predict_graph() is called.') - elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + elif loss_types & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: if 'y' in self._prediction_dict: forwarded = self._prediction_dict['y'] else: @@ -181,7 +191,7 @@ def _build_loss_impl(self, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.LISTWISE_RANK_LOSS, LossType.PAIRWISE_HINGE_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS, - LossType.LISTWISE_DISTILL_LOSS + LossType.LISTWISE_DISTILL_LOSS, LossType.ZILN_LOSS } if loss_type in { LossType.CLASSIFICATION, LossType.BINARY_CROSS_ENTROPY_LOSS @@ -288,12 +298,12 @@ def _build_metric_impl(self, LossType.PAIR_WISE_LOSS, LossType.BINARY_FOCAL_LOSS, LossType.PAIRWISE_FOCAL_LOSS, LossType.PAIRWISE_LOGISTIC_LOSS, LossType.JRC_LOSS, LossType.LISTWISE_DISTILL_LOSS, - LossType.LISTWISE_RANK_LOSS + LossType.LISTWISE_RANK_LOSS, LossType.ZILN_LOSS } metric_dict = {} if metric.WhichOneof('metric') == 'auc': assert loss_type & binary_loss_set - if num_class == 1 or loss_type & {LossType.JRC_LOSS}: + if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}: label = tf.to_int64(self._labels[label_name]) metric_dict['auc' + suffix] = metrics_tf.auc( label, @@ -309,7 +319,7 @@ def _build_metric_impl(self, raise ValueError('Wrong class number') elif metric.WhichOneof('metric') == 'gauc': assert loss_type & binary_loss_set - if num_class == 1 or loss_type & {LossType.JRC_LOSS}: + if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}: label = tf.to_int64(self._labels[label_name]) uids = self._feature_dict[metric.gauc.uid_field] if isinstance(uids, tf.sparse.SparseTensor): @@ -332,7 +342,7 @@ def _build_metric_impl(self, raise ValueError('Wrong class number') elif metric.WhichOneof('metric') == 'session_auc': assert loss_type & binary_loss_set - if num_class == 1 or loss_type & {LossType.JRC_LOSS}: + if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}: label = tf.to_int64(self._labels[label_name]) metric_dict['session_auc' + suffix] = metrics_lib.session_auc( label, @@ -350,7 +360,7 @@ def _build_metric_impl(self, raise ValueError('Wrong class number') elif metric.WhichOneof('metric') == 'max_f1': assert loss_type & binary_loss_set - if num_class == 1 or loss_type & {LossType.JRC_LOSS}: + if num_class == 1 or loss_type & {LossType.JRC_LOSS, LossType.ZILN_LOSS}: label = tf.to_int64(self._labels[label_name]) metric_dict['max_f1' + suffix] = metrics_lib.max_f1( label, self._prediction_dict['logits' + suffix]) @@ -369,7 +379,7 @@ def _build_metric_impl(self, metric.recall_at_topk.topk) elif metric.WhichOneof('metric') == 'mean_absolute_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['mean_absolute_error' + suffix] = metrics_tf.mean_absolute_error( label, self._prediction_dict['y' + suffix]) @@ -381,7 +391,7 @@ def _build_metric_impl(self, assert False, 'mean_absolute_error is not supported for this model' elif metric.WhichOneof('metric') == 'mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['mean_squared_error' + suffix] = metrics_tf.mean_squared_error( label, self._prediction_dict['y' + suffix]) @@ -393,7 +403,7 @@ def _build_metric_impl(self, assert False, 'mean_squared_error is not supported for this model' elif metric.WhichOneof('metric') == 'root_mean_squared_error': label = tf.to_float(self._labels[label_name]) - if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS}: + if loss_type & {LossType.L2_LOSS, LossType.SIGMOID_L2_LOSS, LossType.ZILN_LOSS}: metric_dict['root_mean_squared_error' + suffix] = metrics_tf.root_mean_squared_error( label, self._prediction_dict['y' + suffix]) @@ -435,6 +445,8 @@ def _get_outputs_impl(self, loss_type, num_class=1, suffix=''): return ['probs' + suffix, 'logits' + suffix] if loss_type == LossType.JRC_LOSS: return ['probs' + suffix, 'pos_logits' + suffix] + if loss_type == LossType.ZILN_LOSS: + return ['probs' + suffix, 'y' + suffix, 'logits' + suffix] if loss_type == LossType.CLASSIFICATION: if num_class == 1: return ['probs' + suffix, 'logits' + suffix] diff --git a/easy_rec/python/protos/loss.proto b/easy_rec/python/protos/loss.proto index b377cd75c..4416111a8 100644 --- a/easy_rec/python/protos/loss.proto +++ b/easy_rec/python/protos/loss.proto @@ -23,6 +23,7 @@ enum LossType { KL_DIVERGENCE_LOSS = 16; LISTWISE_RANK_LOSS = 18; LISTWISE_DISTILL_LOSS = 19; + ZILN_LOSS = 20; } message Loss { diff --git a/easy_rec/python/test/train_eval_test.py b/easy_rec/python/test/train_eval_test.py index 72eee9667..83656f2a0 100644 --- a/easy_rec/python/test/train_eval_test.py +++ b/easy_rec/python/test/train_eval_test.py @@ -7,11 +7,11 @@ import threading import time import unittest +from distutils.version import LooseVersion import numpy as np import six import tensorflow as tf -from distutils.version import LooseVersion from tensorflow.python.platform import gfile from easy_rec.python.main import predict @@ -374,6 +374,12 @@ def test_dcn(self): 'samples/model_config/dcn_on_taobao.config', self._test_dir) self.assertTrue(self._success) + def test_ziln_loss(self): + self._success = test_utils.test_single_train_eval( + 'samples/model_config/mlp_on_taobao_with_ziln_loss.config', + self._test_dir) + self.assertTrue(self._success) + def test_fibinet(self): self._success = test_utils.test_single_train_eval( 'samples/model_config/fibinet_on_taobao.config', self._test_dir) diff --git a/easy_rec/python/test/zero_inflated_lognormal_test.py b/easy_rec/python/test/zero_inflated_lognormal_test.py new file mode 100644 index 000000000..f512e48e8 --- /dev/null +++ b/easy_rec/python/test/zero_inflated_lognormal_test.py @@ -0,0 +1,53 @@ +# -*- encoding:utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import tensorflow as tf +from scipy import stats + +from easy_rec.python.loss.zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA + +if tf.__version__ >= '2.0': + tf = tf.compat.v1 + +# Absolute error tolerance in asserting array near. +_ERR_TOL = 1e-6 + + +# softplus function that calculates log(1+exp(x)) +def _softplus(x): + return np.log(1.0 + np.exp(x)) + + +# sigmoid function that calculates 1/(1+exp(-x)) +def _sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +class ZeroInflatedLognormalLossTest(tf.test.TestCase): + + def setUp(self): + super(ZeroInflatedLognormalLossTest, self).setUp() + self.logits = np.array([[.1, .2, .3], [.4, .5, .6]]) + self.labels = np.array([[0.], [1.5]]) + + def zero_inflated_lognormal(self, labels, logits): + positive_logits = logits[..., :1] + loss_zero = _softplus(positive_logits) + loc = logits[..., 1:2] + scale = np.maximum( + _softplus(logits[..., 2:]), np.sqrt(tf.keras.backend.epsilon())) + log_prob_non_zero = stats.lognorm.logpdf( + x=labels, s=scale, loc=0, scale=np.exp(loc)) + loss_non_zero = _softplus(-positive_logits) - log_prob_non_zero + return np.mean(np.where(labels == 0., loss_zero, loss_non_zero), axis=-1) + + def test_loss_value(self): + expected_loss = self.zero_inflated_lognormal(self.labels, self.logits) + expected_loss = np.average(expected_loss) + loss = zero_inflated_lognormal_loss(self.labels, self.logits) + self.assertNear(self.evaluate(loss), expected_loss, _ERR_TOL) + + +if __name__ == '__main__': + tf.enable_eager_execution() + tf.test.main() diff --git a/easy_rec/python/utils/estimator_utils.py b/easy_rec/python/utils/estimator_utils.py index a90f0b0f0..ea15063d1 100644 --- a/easy_rec/python/utils/estimator_utils.py +++ b/easy_rec/python/utils/estimator_utils.py @@ -885,8 +885,11 @@ def get_latest_checkpoint_from_checkpoint_path(checkpoint_path, ignore_ckpt_error): ckpt_path = None if checkpoint_path.endswith('/') or gfile.IsDirectory(checkpoint_path + '/'): - if gfile.Exists(checkpoint_path): - ckpt_path = latest_checkpoint(checkpoint_path) + checkpoint_dir = checkpoint_path + if not checkpoint_dir.endswith('/'): + checkpoint_dir = checkpoint_dir + '/' + if gfile.Exists(checkpoint_dir): + ckpt_path = latest_checkpoint(checkpoint_dir) if ckpt_path: logging.info( 'fine_tune_checkpoint is directory, will use the latest checkpoint: %s' diff --git a/easy_rec/version.py b/easy_rec/version.py index 7da645311..759f7a8b3 100644 --- a/easy_rec/version.py +++ b/easy_rec/version.py @@ -1,4 +1,4 @@ # -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -__version__ = '0.8.4' +__version__ = '0.8.5' diff --git a/pai_jobs/deploy.sh b/pai_jobs/deploy.sh index 4d325ecb1..77b1065b6 100755 --- a/pai_jobs/deploy.sh +++ b/pai_jobs/deploy.sh @@ -92,6 +92,7 @@ fi cp easy_rec/__init__.py easy_rec/__init__.py.bak sed -i -e "s/\[VERSION\]/$VERSION/g" easy_rec/__init__.py find -L easy_rec -name "*.pyc" | xargs rm -rf +echo "tensorflow-probability==0.5.0" > requirements.txt if [ ! -d "datahub" ] then @@ -102,7 +103,7 @@ then fi tar -zvxf pydatahub.tar.gz fi -tar -cvzhf $RES_PATH easy_rec run.py +tar -cvzhf $RES_PATH easy_rec run.py requirements.txt mv easy_rec/__init__.py.bak easy_rec/__init__.py # 2 means generate only diff --git a/pai_jobs/deploy_ext.sh b/pai_jobs/deploy_ext.sh index 26a1dd091..3c8383439 100755 --- a/pai_jobs/deploy_ext.sh +++ b/pai_jobs/deploy_ext.sh @@ -143,7 +143,28 @@ then rm -rf faiss.tar.gz fi -tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka faiss run.py +if [ ! -d "tensorflow_probability" ] +then + if [ $is_tf15 -gt 0 ]; then + tfp_version='0.8.0' + else + tfp_version='0.5.0' + fi + if [ ! -e "tensorflow_probability" ] + then + wget http://easyrec.oss-cn-beijing.aliyuncs.com/3rdparty/probability-${tfp_version}.tar.gz + if [ $? -ne 0 ] + then + echo "tensorflow_probability download failed." + fi + fi + tar -xzvf probability-${tfp_version}.tar.gz --strip-components=1 probability-${tfp_version}/tensorflow_probability + rm -rf tensorflow_probability/examples + rm -rf tensorflow_probability/g3doc + rm -rf probability-${tfp_version}.tar.gz +fi + +tar -cvzhf $RES_PATH easy_rec datahub lz4 cprotobuf kafka faiss tensorflow_probability run.py # 2 means generate only if [ $mode -ne 2 ] diff --git a/requirements/docs.txt b/requirements/docs.txt index 2ee199bb6..596bd527b 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -5,3 +5,4 @@ recommonmark==0.6.0 sphinx==5.1.1 sphinx_markdown_tables==0.0.17 sphinx_rtd_theme +tensorflow-probability==0.11.0 \ No newline at end of file diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 37302e7d5..8e6fa5616 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,7 +2,6 @@ eas_prediction future matplotlib numpy <= 1.23 -numpy <= 1.23 oss2 pandas psutil diff --git a/samples/model_config/mlp_on_taobao_with_ziln_loss.config b/samples/model_config/mlp_on_taobao_with_ziln_loss.config new file mode 100644 index 000000000..1f05afa91 --- /dev/null +++ b/samples/model_config/mlp_on_taobao_with_ziln_loss.config @@ -0,0 +1,279 @@ +train_input_path: "data/test/tb_data/taobao_train_data" +eval_input_path: "data/test/tb_data/taobao_test_data" +model_dir: "experiments/mlp_ziln_taobao_ckpt" + +train_config { + log_step_count_steps: 100 + optimizer_config: { + adam_optimizer: { + learning_rate: { + exponential_decay_learning_rate { + initial_learning_rate: 0.001 + decay_steps: 1000 + decay_factor: 0.5 + min_learning_rate: 0.00001 + } + } + } + use_moving_average: false + } + save_checkpoints_steps: 100 + sync_replicas: True + num_steps: 100 +} + +eval_config { + metrics_set: { + auc {} + } +} + +data_config { + input_fields { + input_name:'clk' + input_type: INT32 + } + input_fields { + input_name:'buy' + input_type: INT32 + } + input_fields { + input_name: 'pid' + input_type: STRING + } + input_fields { + input_name: 'adgroup_id' + input_type: STRING + } + input_fields { + input_name: 'cate_id' + input_type: STRING + } + input_fields { + input_name: 'campaign_id' + input_type: STRING + } + input_fields { + input_name: 'customer' + input_type: STRING + } + input_fields { + input_name: 'brand' + input_type: STRING + } + input_fields { + input_name: 'user_id' + input_type: STRING + } + input_fields { + input_name: 'cms_segid' + input_type: STRING + } + input_fields { + input_name: 'cms_group_id' + input_type: STRING + } + input_fields { + input_name: 'final_gender_code' + input_type: STRING + } + input_fields { + input_name: 'age_level' + input_type: STRING + } + input_fields { + input_name: 'pvalue_level' + input_type: STRING + } + input_fields { + input_name: 'shopping_level' + input_type: STRING + } + input_fields { + input_name: 'occupation' + input_type: STRING + } + input_fields { + input_name: 'new_user_class_level' + input_type: STRING + } + input_fields { + input_name: 'tag_category_list' + input_type: STRING + } + input_fields { + input_name: 'tag_brand_list' + input_type: STRING + } + input_fields { + input_name: 'price' + input_type: INT32 + } + + label_fields: 'clk' + batch_size: 4096 + num_epochs: 10000 + prefetch_size: 32 + input_type: CSVInput +} + +feature_config: { + features: { + input_names: 'pid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'adgroup_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cate_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10000 + } + features: { + input_names: 'campaign_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'customer' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'brand' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'user_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100000 + } + features: { + input_names: 'cms_segid' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'cms_group_id' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 100 + } + features: { + input_names: 'final_gender_code' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'age_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'pvalue_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'shopping_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'occupation' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'new_user_class_level' + feature_type: IdFeature + embedding_dim: 16 + hash_bucket_size: 10 + } + features: { + input_names: 'tag_category_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'tag_brand_list' + feature_type: TagFeature + separator: '|' + hash_bucket_size: 100000 + embedding_dim: 16 + } + features: { + input_names: 'price' + feature_type: IdFeature + embedding_dim: 16 + num_buckets: 50 + } +} +model_config: { + model_class: 'RankModel' + feature_groups: { + group_name: 'all' + feature_names: 'user_id' + feature_names: 'cms_segid' + feature_names: 'cms_group_id' + feature_names: 'age_level' + feature_names: 'pvalue_level' + feature_names: 'shopping_level' + feature_names: 'occupation' + feature_names: 'new_user_class_level' + feature_names: 'adgroup_id' + feature_names: 'cate_id' + feature_names: 'campaign_id' + feature_names: 'customer' + feature_names: 'brand' + feature_names: 'price' + feature_names: 'pid' + feature_names: 'tag_category_list' + feature_names: 'tag_brand_list' + wide_deep: DEEP + } + backbone { + blocks { + name: "deep" + inputs { + feature_group_name: "all" + } + keras_layer { + class_name: "MLP" + mlp { + hidden_units: [256, 128, 64] + } + } + } + } + model_params { + l2_regularization: 1e-6 + } + num_class: 3 + losses { + loss_type: ZILN_LOSS + weight: 1.0 + loss_name: 'LTV' + } + embedding_regularization: 1e-4 +} diff --git a/setup.cfg b/setup.cfg index b43211827..d8ed85f21 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ multi_line_output = 7 force_single_line = true known_standard_library = setuptools known_first_party = easy_rec -known_third_party = absl,common_io,distutils,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,yaml +known_third_party = absl,common_io,docutils,eas_prediction,faiss,future,google,graphlearn,kafka,matplotlib,numpy,oss2,pai,pandas,psutil,scipy,six,sklearn,sparse_operation_kit,sphinx_markdown_tables,sphinx_rtd_theme,tensorflow,tensorflow_probability,yaml no_lines_before = LOCALFOLDER default_section = THIRDPARTY skip = easy_rec/python/protos