Skip to content

Commit

Permalink
Merge pull request #383 from frankwhzhang/fix_esmm_0208
Browse files Browse the repository at this point in the history
fix esmm& multi-auc problem
  • Loading branch information
seemingwang authored Feb 9, 2021
2 parents 5aeafbc + a049613 commit 26dd5b9
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 8 deletions.
1 change: 1 addition & 0 deletions models/multitask/esmm/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runner:
train_reader_path: "esmm_reader" # importlib format
use_gpu: False
use_auc: True
auc_num: 2
train_batch_size: 2
epochs: 3
print_interval: 2
Expand Down
1 change: 1 addition & 0 deletions models/multitask/esmm/config_bigdata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ runner:
train_reader_path: "esmm_reader" # importlib format
use_gpu: True
use_auc: True
auc_num: 2
train_batch_size: 1024
epochs: 10
print_interval: 10
Expand Down
2 changes: 1 addition & 1 deletion models/multitask/esmm/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ ESMM是发表在 SIGIR’2018 的论文[《Entire Space Multi-Task Model: An E

### 效果复现
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
在全量数据下模型的指标如下
在全量数据下模型的训练指标如下
| 模型 | auc_ctcvr | batch_size | epoch_num | Time of each epoch |
| :------| :------ | :------ | :------| :------ |
| ESMM | 0.82 | 1024 | 10 | 约3分钟 |
Expand Down
3 changes: 2 additions & 1 deletion tools/static_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def main(args):

use_gpu = config.get("runner.use_gpu", True)
use_auc = config.get("runner.use_auc", False)
auc_num = config.get("runner.auc_num", 1)
test_data_dir = config.get("runner.test_data_dir", None)
print_interval = config.get("runner.print_interval", None)
model_load_path = config.get("runner.infer_load_path", "model_output")
Expand Down Expand Up @@ -92,7 +93,7 @@ def main(args):
epoch_begin = time.time()
interval_begin = time.time()
if use_auc:
reset_auc()
reset_auc(auc_num)
for batch_id, batch_data in enumerate(test_dataloader()):
fetch_batch_var = exe.run(
program=paddle.static.default_main_program(),
Expand Down
3 changes: 2 additions & 1 deletion tools/static_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(args):

use_gpu = config.get("runner.use_gpu", True)
use_auc = config.get("runner.use_auc", False)
auc_num = config.get("runner.auc_num", 1)
train_data_dir = config.get("runner.train_data_dir", None)
epochs = config.get("runner.epochs", None)
print_interval = config.get("runner.print_interval", None)
Expand Down Expand Up @@ -93,7 +94,7 @@ def main(args):

epoch_begin = time.time()
if use_auc:
reset_auc()
reset_auc(auc_num)
if reader_type == 'DataLoader':
fetch_batch_var = dataloader_train(epoch_id, train_dataloader,
input_data_names, fetch_vars,
Expand Down
11 changes: 6 additions & 5 deletions tools/utils/utils_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ def load_yaml(yaml_file, other_part=None):
return running_config


def reset_auc():
auc_var_name = [
"_generated_var_0", "_generated_var_1", "_generated_var_2",
"_generated_var_3"
]
def reset_auc(auc_num=1):
# for static clear auc
auc_var_name = []
for i in range(auc_num * 4):
auc_var_name.append("_generated_var_{}".format(i))

for name in auc_var_name:
param = paddle.fluid.global_scope().var(name)
if param == None:
Expand Down

0 comments on commit 26dd5b9

Please sign in to comment.