Skip to content

Commit

Permalink
format print
Browse files Browse the repository at this point in the history
  • Loading branch information
XU-YaoKun committed Jan 22, 2020
1 parent 55782cd commit 3b41fac
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 41 deletions.
11 changes: 11 additions & 0 deletions common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,14 @@ def unfreeze(model):
for param in model.parameters():
param.requires_grad = True
return model


def print_dict(dic):
"""print dictionary using specified format
example: {"a": 1, "b": 2}
output:
"a": 1
"b": 2
"""
print('\n'.join('{:10s}: {}'.format(key, values) for key, values in dic.items()))
47 changes: 6 additions & 41 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

from tqdm import tqdm
from copy import deepcopy
from prettytable import PrettyTable

from common.test import test_v2
from common.utils import early_stopping
from common.utils import early_stopping, print_dict
from common.config.parser import parse_args

from common.dataset.build import build_loader
Expand Down Expand Up @@ -101,13 +102,9 @@ def train_one_epoch(
reg_loss += reg_loss_batch

avg_reward = epoch_reward / num_batch
print(
" Epoch {0:4d}: \
\n Training loss: [{1:4f} = {2:4f} + {3:4f}] \
\n Reward: {4:4f}".format(
cur_epoch, loss, base_loss, reg_loss, avg_reward
)
)
train_res = PrettyTable()
train_res.field_names = ["Epoch", "Loss", "BPR-Loss", "Regulation", "AVG-Reward"]
train_res.add_row([cur_epoch, loss, base_loss, reg_loss, avg_reward])

return loss, base_loss, reg_loss, avg_reward

Expand Down Expand Up @@ -228,47 +225,15 @@ def train(train_loader, test_loader, graph, data_config, args_config):
"""Test"""
if cur_epoch % args_config.show_step == 0:
with torch.no_grad():
t2 = time()
ret = test_v2(recommender, args_config.Ks, graph)

t3 = time()
loss_loger.append(loss)
rec_loger.append(ret["recall"])
pre_loger.append(ret["precision"])
ndcg_loger.append(ret["ndcg"])
hit_loger.append(ret["hit_ratio"])

perf_str = (
"Evaluate[%.1fs]: \
\n recall=[%.5f, %.5f, %.5f, %.5f, %.5f], \
\n precision=[%.5f, %.5f, %.5f, %.5f, %.5f], \
\n hit=[%.5f, %.5f, %.5f, %.5f, %.5f], \
\n ndcg=[%.5f, %.5f, %.5f, %.5f, %.5f] "
% (
t3 - t2,
ret["recall"][0],
ret["recall"][1],
ret["recall"][2],
ret["recall"][3],
ret["recall"][4],
ret["precision"][0],
ret["precision"][1],
ret["precision"][2],
ret["precision"][3],
ret["precision"][4],
ret["hit_ratio"][0],
ret["hit_ratio"][1],
ret["hit_ratio"][2],
ret["hit_ratio"][3],
ret["hit_ratio"][4],
ret["ndcg"][0],
ret["ndcg"][1],
ret["ndcg"][2],
ret["ndcg"][3],
ret["ndcg"][4],
)
)
print(perf_str)
print_dict(ret)

cur_best_pre_0, stopping_step, should_stop = early_stopping(
ret["recall"][0],
Expand Down

0 comments on commit 3b41fac

Please sign in to comment.