From 58d09e079713b61dc0c5650ad870755b92e4a0e1 Mon Sep 17 00:00:00 2001 From: Vince Jankovics Date: Thu, 17 Sep 2020 16:09:16 +0100 Subject: [PATCH] Target size adj fix --- hetsage/data.py | 6 +++++- hetsage/model.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/hetsage/data.py b/hetsage/data.py index 230d859..25ca188 100644 --- a/hetsage/data.py +++ b/hetsage/data.py @@ -228,6 +228,7 @@ def sample(self, batch): _, sorted_idx_inv = torch.sort(sorted_idx) n_id_map = n_id_map[sorted_idx] adjs = [] + target_size = len(batch) for i, size in enumerate(self.sizes): edge_index = torch.cat(edge_indeces[i], dim=-1) edge_index = reindex(sorted_idx, sorted_idx_inv, edge_index) @@ -235,7 +236,8 @@ def sample(self, batch): M = edge_index[0].max().item() + 1 N = edge_index[1].max().item() + 1 size = (M, N) - adjs.append(Adj(edge_index, e_feat, size)) + adjs.append(Adj(edge_index, e_feat, size, target_size)) + target_size = M return batch_size, n_id_map, adjs[::-1] @@ -420,10 +422,12 @@ class Adj(NamedTuple): edge_index: torch.Tensor edge_features: torch.Tensor size: Tuple[int, int] + target_size: int def to(self, *args, **kwargs): return Adj( self.edge_index.to(*args, **kwargs), self.edge_features.to(*args, **kwargs), self.size, + self.target_size ) diff --git a/hetsage/model.py b/hetsage/model.py index 92aa968..11f7d19 100644 --- a/hetsage/model.py +++ b/hetsage/model.py @@ -147,8 +147,8 @@ def forward(self, input_nodes, adjs): h[node_info.h_id] = self.embedders[node_type](node_info.x) # message passing - for i, (edge_index, e_feat, size) in enumerate(adjs): - h_target = h[:size[1]] + for i, (edge_index, e_feat, size, target_size) in enumerate(adjs): + h_target = h[:target_size] h = self.convs[i]((h, h_target), edge_index, e_feat) if self.bns: h = self.bns[i](h)