diff --git a/cornac/models/dnntsp/dnntsp.py b/cornac/models/dnntsp/dnntsp.py index 2e4a1ec6..82056e2a 100644 --- a/cornac/models/dnntsp/dnntsp.py +++ b/cornac/models/dnntsp/dnntsp.py @@ -328,10 +328,9 @@ def forward( def get_edges_weight(history_baskets): edges_weight_dict = defaultdict(float) for basket_items in history_baskets: - for i in range(len(basket_items)): - for j in range(i + 1, len(basket_items)): - edges_weight_dict[(basket_items[i], basket_items[j])] += 1 - edges_weight_dict[(basket_items[j], basket_items[i])] += 1 + for (i,j) in itertools.combinations(range(len(basket_items)), 2): + edges_weight_dict[(basket_items[i], basket_items[j])] += 1 + edges_weight_dict[(basket_items[j], basket_items[i])] += 1 return edges_weight_dict @@ -373,19 +372,11 @@ def transform_data( torch.tensor(list(range(nodes.shape[0]))) for nodes in batch_nodes ] batch_src = [ - ( - torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=1) - .flatten() - .tolist() - ) + project_nodes.repeat((project_nodes.shape[0], 1)).T.flatten().tolist() for project_nodes in batch_project_nodes ] batch_dst = [ - ( - torch.stack([project_nodes for _ in range(project_nodes.shape[0])], dim=0) - .flatten() - .tolist() - ) + project_nodes.repeat((project_nodes.shape[0],)).flatten().tolist() for project_nodes in batch_project_nodes ] batch_g = [ @@ -487,11 +478,8 @@ def forward(self, predict, truth): Returns: output: tensor """ - # predict = torch.softmax(predict, dim=-1) predict = torch.sigmoid(predict) truth = truth.float() - # print(predict.device) - # print(truth.device) if self.weights is not None: self.weights = self.weights.to(truth.device) predict = predict * self.weights @@ -623,17 +611,3 @@ def learn( # Note that step should be called after validate scheduler.step(total_val_loss) - - -def score(model: TemporalSetPrediction, history_baskets, total_items, device="cpu"): - model = model.to(device) - model.eval() - (g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data( - [history_baskets], - item_embedding=model.embedding_matrix, - total_items=total_items, - device=device, - is_test=True, - ) - preds = model(g, nodes_feature, edges_weight, lengths, nodes) - return preds.cpu().detach().numpy() diff --git a/cornac/models/dnntsp/recom_dnntsp.py b/cornac/models/dnntsp/recom_dnntsp.py index 9196d335..9ed5bc3c 100644 --- a/cornac/models/dnntsp/recom_dnntsp.py +++ b/cornac/models/dnntsp/recom_dnntsp.py @@ -122,9 +122,15 @@ def fit(self, train_set, val_set=None): return self def score(self, user_idx, history_baskets, **kwargs): - from .dnntsp import score + from .dnntsp import transform_data - item_scores = score( - self.model, history_baskets, self.total_items, device=self.device + self.model.eval() + (g, nodes_feature, edges_weight, lengths, nodes, _) = transform_data( + [history_baskets], + item_embedding=self.model.embedding_matrix, + total_items=self.total_items, + device=self.device, + is_test=True, ) - return item_scores + preds = self.model(g, nodes_feature, edges_weight, lengths, nodes) + return preds.cpu().detach().numpy()