Skip to content

Commit

Permalink
refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
lthoang committed Jan 18, 2024
1 parent 3d4ac7d commit f625f24
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 35 deletions.
36 changes: 5 additions & 31 deletions cornac/models/dnntsp/dnntsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,9 @@ def forward(
def get_edges_weight(history_baskets):
edges_weight_dict = defaultdict(float)
for basket_items in history_baskets:
for i in range(len(basket_items)):
for j in range(i + 1, len(basket_items)):
edges_weight_dict[(basket_items[i], basket_items[j])] += 1
edges_weight_dict[(basket_items[j], basket_items[i])] += 1
for (i,j) in itertools.combinations(range(len(basket_items)), 2):
edges_weight_dict[(basket_items[i], basket_items[j])] += 1
edges_weight_dict[(basket_items[j], basket_items[i])] += 1
return edges_weight_dict


Expand Down Expand Up @@ -373,19 +372,11 @@ def transform_data(
torch.tensor(list(range(nodes.shape[0]))) for nodes in batch_nodes
]
batch_src = [
(
torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=1)
.flatten()
.tolist()
)
project_nodes.repeat((project_nodes.shape[0], 1)).T.flatten().tolist()
for project_nodes in batch_project_nodes
]
batch_dst = [
(
torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=0)
.flatten()
.tolist()
)
project_nodes.repeat((project_nodes.shape[0],)).flatten().tolist()
for project_nodes in batch_project_nodes
]
batch_g = [
Expand Down Expand Up @@ -487,11 +478,8 @@ def forward(self, predict, truth):
Returns:
output: tensor
"""
# predict = torch.softmax(predict, dim=-1)
predict = torch.sigmoid(predict)
truth = truth.float()
# print(predict.device)
# print(truth.device)
if self.weights is not None:
self.weights = self.weights.to(truth.device)
predict = predict * self.weights
Expand Down Expand Up @@ -623,17 +611,3 @@ def learn(

# Note that step should be called after validate
scheduler.step(total_val_loss)


def score(model: TemporalSetPrediction, history_baskets, total_items, device="cpu"):
model = model.to(device)
model.eval()
(g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data(
[history_baskets],
item_embedding=model.embedding_matrix,
total_items=total_items,
device=device,
is_test=True,
)
preds = model(g, nodes_feature, edges_weight, lengths, nodes)
return preds.cpu().detach().numpy()
14 changes: 10 additions & 4 deletions cornac/models/dnntsp/recom_dnntsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,15 @@ def fit(self, train_set, val_set=None):
return self

def score(self, user_idx, history_baskets, **kwargs):
from .dnntsp import score
from .dnntsp import transform_data

item_scores = score(
self.model, history_baskets, self.total_items, device=self.device
self.model.eval()
(g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data(
[history_baskets],
item_embedding=self.model.embedding_matrix,
total_items=self.total_items,
device=self.device,
is_test=True,
)
return item_scores
preds = self.model(g, nodes_feature, edges_weight, lengths, nodes)
return preds.cpu().detach().numpy()

0 comments on commit f625f24

Please sign in to comment.