Skip to content

Commit

Permalink
Target size adj fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vakker committed Sep 17, 2020
1 parent a1da98e commit 58d09e0
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion hetsage/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,16 @@ def sample(self, batch):
_, sorted_idx_inv = torch.sort(sorted_idx)
n_id_map = n_id_map[sorted_idx]
adjs = []
target_size = len(batch)
for i, size in enumerate(self.sizes):
edge_index = torch.cat(edge_indeces[i], dim=-1)
edge_index = reindex(sorted_idx, sorted_idx_inv, edge_index)
e_feat = torch.cat(e_feats[i], dim=0)
M = edge_index[0].max().item() + 1
N = edge_index[1].max().item() + 1
size = (M, N)
adjs.append(Adj(edge_index, e_feat, size))
adjs.append(Adj(edge_index, e_feat, size, target_size))
target_size = M

return batch_size, n_id_map, adjs[::-1]

Expand Down Expand Up @@ -420,10 +422,12 @@ class Adj(NamedTuple):
edge_index: torch.Tensor
edge_features: torch.Tensor
size: Tuple[int, int]
target_size: int

def to(self, *args, **kwargs):
return Adj(
self.edge_index.to(*args, **kwargs),
self.edge_features.to(*args, **kwargs),
self.size,
self.target_size
)
4 changes: 2 additions & 2 deletions hetsage/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def forward(self, input_nodes, adjs):
h[node_info.h_id] = self.embedders[node_type](node_info.x)

# message passing
for i, (edge_index, e_feat, size) in enumerate(adjs):
h_target = h[:size[1]]
for i, (edge_index, e_feat, size, target_size) in enumerate(adjs):
h_target = h[:target_size]
h = self.convs[i]((h, h_target), edge_index, e_feat)
if self.bns:
h = self.bns[i](h)
Expand Down

0 comments on commit 58d09e0

Please sign in to comment.