From 759bedae0165f55a30c6d23318420f6f0b576fef Mon Sep 17 00:00:00 2001 From: Vince Jankovics Date: Tue, 8 Sep 2020 16:10:03 +0100 Subject: [PATCH] Initial version --- exp.py | 154 ++++++++++++++++++++++++ hetsage/__init__.py | 0 hetsage/data.py | 277 ++++++++++++++++++++++++++++++++++++++++++++ hetsage/model.py | 252 ++++++++++++++++++++++++++++++++++++++++ hetsage/utils.py | 12 ++ 5 files changed, 695 insertions(+) create mode 100644 exp.py create mode 100644 hetsage/__init__.py create mode 100644 hetsage/data.py create mode 100644 hetsage/model.py create mode 100644 hetsage/utils.py diff --git a/exp.py b/exp.py new file mode 100644 index 0000000..3b8948c --- /dev/null +++ b/exp.py @@ -0,0 +1,154 @@ +import argparse +import json +import random +from os import path as osp + +import ipdb +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np +import pandas as pd +import torch +import torch.nn.functional as F +import yaml +from tqdm import tqdm + +import torch_geometric +from hetsage.model import Model +from hetsage.utils import init_random + + +def zero_grad(model): + for p in model.parameters(): + p.grad = None + + +def run_iter(model, optimizer, device): + metrics = {} + loss, acc = _run_iter(model, model.tng_loader, optimizer, device=device) + metrics['tng-loss'] = loss + metrics['tng-acc'] = acc + + with torch.no_grad(): + # loss, acc = _run_iter(model, model.tng_loader, device=device) + # metrics['tng2-loss'] = loss + # metrics['tng2-acc'] = acc + loss, acc = _run_iter(model, model.val_loader, device=device) + metrics['val-loss'] = loss + metrics['val-acc'] = acc + + return metrics + + +def _run_iter(model, data_loader, optimizer=None, device='cpu'): + if optimizer is not None: + model.train() + else: + model.eval() + + total_loss = 0 + total_correct = 0 + total_nodes = 0 + i = 0 + for batch_size, n_id, adjs in tqdm(data_loader): + # import ipdb; ipdb.set_trace() + # `adjs` holds a list of `(edge_index, e_id, size)` tuples. + if isinstance(adjs, torch_geometric.data.sampler.Adj): + adjs = [adjs] + adjs = [adj.to(device) for adj in adjs] + n_id = n_id.to(device) + targets = model.get_targets(n_id[:batch_size, 0]) + # targets = model.get_targets(n_id[torch.nonzero(n_id[:, 1] == 1, + # as_tuple=False).squeeze()][:, 0]) + if optimizer is not None: + # zero_grad(model) + optimizer.zero_grad() + out = model(n_id, adjs) + loss = F.nll_loss(F.log_softmax(out, dim=-1), targets) + if optimizer is not None: + loss.backward() + optimizer.step() + + total_loss += float(loss.detach()) * batch_size + + # import ipdb; ipdb.set_trace() + # out_np = out.cpu().to_numpy() + y_pred = torch.argmax(out.detach(), dim=-1) + # if i >= 0: + # print(i) + # print(n_id[:batch_size, 0]) + # print(adjs[0].edge_index.t()) + # print(out) + # print(targets) + # print(y_pred) + # i += 1 + total_correct += float((y_pred == targets).sum()) + total_nodes += batch_size + + loss = total_loss / total_nodes + approx_acc = total_correct / total_nodes + + return loss, approx_acc + + +# @torch.no_grad() +# def test(): +# model.eval() + +# out = model.inference(x) + +# y_true = y.cpu().unsqueeze(-1) +# y_pred = out.argmax(dim=-1, keepdim=True) + +# results = [] +# for mask in [data.train_mask, data.val_mask, data.test_mask]: +# results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())] + +# return results + + +def main(args): + init_random() + # load graph + g = nx.readwrite.gml.read_gml(args.gml) + + if args.use_gpu: + device = torch.device('cuda:0') + else: + device = torch.device('cpu') + + model = Model(g, args.target, device=device) + model = model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=0.005) + + model.train() + for epoch in range(1, 1 + args.max_epochs): + metrics = run_iter(model, optimizer, device=device) + tng_loss = metrics['tng-loss'] + tng_acc = metrics['tng-acc'] + # tng2_loss = metrics['tng2-loss'] + # tng2_acc = metrics['tng2-acc'] + val_loss = metrics['val-loss'] + val_acc = metrics['val-acc'] + msg = '' + msg += f'Epoch {epoch:02d}, ' + msg += f'Tng loss: {tng_loss:.4f}, ' + msg += f'Tng acc: {100*tng_acc:.2f}, ' + # msg += f'Tng2 loss: {tng2_loss:.4f}, ' + # msg += f'Tng2 acc: {100*tng2_acc:.2f}, ' + msg += f'Val loss: {val_loss:.4f}, ' + msg += f'Val acc: {100*val_acc:.2f}' + print(msg) + # train_acc, val_acc, test_acc = test() + # print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, ' f'Test: {test_acc:.4f}') + + +if __name__ == '__main__': + PARSER = argparse.ArgumentParser() + PARSER.add_argument('--gml') + PARSER.add_argument('--target') + PARSER.add_argument('--use-gpu', action='store_true') + PARSER.add_argument('--max-epochs', type=int) + + ARGS = PARSER.parse_args() + main(ARGS) diff --git a/hetsage/__init__.py b/hetsage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hetsage/data.py b/hetsage/data.py new file mode 100644 index 0000000..2c51762 --- /dev/null +++ b/hetsage/data.py @@ -0,0 +1,277 @@ +import collections +import copy +import logging +import re +from typing import List, NamedTuple, Optional, Tuple + +import networkx as nx +import numpy as np +import torch +from torch_sparse import SparseTensor +from tqdm import tqdm + +import torch_geometric +from torch_geometric.data.sampler import Adj +from torch_geometric.utils import (contains_isolated_nodes, + contains_self_loops, is_undirected) + +tensor = torch.FloatTensor + + +def get_props(g): + return { + data['nodetype']: list(data.get('properties', {}).keys()) + for n, data in g.nodes(data=True) + } + + +def featurize(g, target_node, target_prop): + g = nx.convert_node_labels_to_integers(g) + edge_cats = set() + for n1, n2, data in tqdm(g.edges(data=True)): + edge_cats = edge_cats.union({data['label']}) + edge_cats = sorted(list(edge_cats)) + edge_feats = [] + for n1, n2, k, data in tqdm(g.edges(data=True, keys=True)): + # g.edges[n1, n2, k]['x'] = (data['label'] == np.array(edge_cats)).astype(np.int) + edge_feats.append((data['label'] == np.array(edge_cats)).astype(np.int)) + edge_feats = tensor(np.stack(edge_feats)) + + features = {} + for n, data in tqdm(g.nodes(data=True)): + nt = data['nodetype'] + if nt in features: + if data.get('prop', {}).keys() != features[nt]['prop'].keys(): + raise ValueError('Inconsistent prop keys') + if data.get('single_cat', {}).keys() != features[nt]['single_cat'].keys(): + raise ValueError('Inconsistent prop keys') + if data.get('multi_cat', {}).keys() != features[nt]['multi_cat'].keys(): + raise ValueError('Inconsistent prop keys') + + for p, v in data.get('prop', {}).items(): + if v < features[nt]['prop'][p]['min']: + features[nt]['prop'][p]['min'] = v + if v > features[nt]['prop'][p]['max']: + features[nt]['prop'][p]['max'] = v + + for p, v in data.get('single_cat', {}).items(): + features[nt]['single_cat'][p].add(v) + for p, v in data.get('multi_cat', {}).items(): + features[nt]['multi_cat'][p].update(v) + else: + prop = {p: {'min': v, 'max': v} for p, v in data.get('prop', {}).items()} + single_cat = {p: {v} for p, v in data.get('single_cat', {}).items()} + multi_cat = {p: set(v) for p, v in data.get('multi_cat', {}).items()} + + features[nt] = { + 'prop': prop, + 'single_cat': single_cat, + 'multi_cat': multi_cat, + } + + # feature_mats = {} + for nt in features: + for pc, cats in features[nt]['single_cat'].items(): + features[nt]['single_cat'][pc] = sorted(list(cats)) + for pc, cats in features[nt]['multi_cat'].items(): + features[nt]['multi_cat'][pc] = sorted(list(cats)) + features[nt].update({'x_in': [], 'x_out': [], 'y': [], 'n_ids': []}) + + target_nodes = [] + targets = [] + for n, data in tqdm(g.nodes(data=True)): + nt = data['nodetype'] + prop_keys = set(features[nt]['prop'].keys()) + single_cat_keys = set(features[nt]['single_cat'].keys()) + multi_cat_keys = set(features[nt]['multi_cat'].keys()) + if nt == target_node: + prop_keys = prop_keys - {target_prop} + single_cat_keys = single_cat_keys - {target_prop} + multi_cat_keys = multi_cat_keys - {target_prop} + + nd = data['nodetype'] + prop = data.get('prop', {}) + single_cats = data.get('single_cat', {}) + multi_cats = data.get('multi_cat', {}) + x_p = get_prop(prop, prop_keys, features[nt]['prop']) + x_sc = get_cat(single_cats, single_cat_keys, features[nt]['single_cat']) + x_mc = get_cat(multi_cats, multi_cat_keys, features[nt]['multi_cat']) + if nd == target_node: + if target_prop in features[nt]['prop']: + y = get_prop(prop, [target_prop], features[nt]['prop']) + elif target_prop in features[nt]['single_cat']: + y = get_cat(single_cats, [target_prop], features[nt]['single_cat']) + elif target_prop in features[nt]['multi_cat']: + y = get_cat(multi_cats, [target_prop], features[nt]['multi_cat']) + else: + raise ValueError(f'{target_prop} is not a property') + target_nodes.append(n) + # targets.append(torch.tensor(y)) + targets.append(torch.nonzero(torch.tensor(y) == 1, as_tuple=False).squeeze()) + else: + y = [] + # g.nodes[n]['x_in'] = torch.tensor(np.concatenate([x_p, x_c, y])) + # g.nodes[n]['x_out'] = torch.tensor(np.concatenate([x_p, x_c])) + # g.nodes[n]['x_in'] = get_tensor(np.concatenate([x_p, x_c])) + g.nodes[n]['x_in'] = get_tensor(np.concatenate([x_p, x_sc, x_mc, y])) + # g.nodes[n]['x_in'] = get_tensor(np.concatenate([x_p, x_sc, x_mc])) + # g.nodes[n]['x_in'] = get_tensor(np.concatenate([y])) + g.nodes[n]['x_out'] = get_tensor(np.concatenate([x_p, x_sc, x_mc])) + g.nodes[n]['y'] = torch.tensor(y) + # features[nt]['x_in'].append(torch.tensor([n]).float()) + features[nt]['x_in'].append(g.nodes[n]['x_in']) + # features[nt]['x_in'].append(get_tensor(np.concatenate([x_p, x_c, y]))) + features[nt]['x_out'].append(g.nodes[n]['x_out']) + features[nt]['y'].append(g.nodes[n]['y']) + features[nt]['n_ids'].append(torch.tensor(n)) + # features[nt]['x_in_size'] = g.nodes[n]['x_in'].shape[0] + # features[nt]['x_out_size'] = g.nodes[n]['x_out'].shape[0] + # features[nt]['y_size'] = g.nodes[n]['y'].shape[0] + + for nt in features: + features[nt]['x_in'] = torch.stack(features[nt]['x_in']) + features[nt]['x_out'] = torch.stack(features[nt]['x_out']) + features[nt]['y'] = torch.stack(features[nt]['y']) + features[nt]['n_ids'] = torch.stack(features[nt]['n_ids']) + + target_nodes = torch.LongTensor(target_nodes) + targets = torch.stack(targets) + return g, target_nodes, targets, features, edge_feats + + +def get_tensor(np_arr): + if np_arr.shape[0] > 0: + return tensor(np_arr) + return tensor([0]) + + +def get_cat(cats, cat_keys, all_cats): + x_c = [] + # x_c = [(cats[k] == np.array(all_cats[k])).astype(np.int) for k in cat_keys] + for k in cat_keys: + if not isinstance(cats[k], list): + cs = [cats[k]] + else: + cs = cats[k] + c_ = [] + for i, c in enumerate(all_cats[k]): + if c in cs: + c_.append(1) + else: + c_.append(0) + # (cats[k] == np.array(all_cats[k])).astype(np.int) + x_c.append(c_) + if x_c: + return np.concatenate(x_c) + return np.array(x_c) + + +def get_prop(props, prop_keys, all_props, normalize=True): + x_p = np.array([ + norm(props[k], all_props[k]['min'], all_props[k]['max']) if normalize else props[k] + for k in prop_keys + ]) + return x_p + + +def norm(v, v_min, v_max): + return (v - v_min) / (v_max - v_min) + + +class NeighborSampler(torch.utils.data.DataLoader): + def __init__(self, + edge_index: torch.Tensor, + sizes: List[int], + node_idx: Optional[torch.Tensor] = None, + num_nodes: Optional[int] = None, + flow: str = "source_to_target", + **kwargs): + + N = int(edge_index.max() + 1) if num_nodes is None else num_nodes + edge_attr = torch.arange(edge_index.size(1)) + adj = SparseTensor(row=edge_index[0], + col=edge_index[1], + value=edge_attr, + sparse_sizes=(N, N), + is_sorted=False) + adj = adj.t() if flow == 'source_to_target' else adj + self.adj = adj.to('cpu') + + if node_idx is None: + node_idx = torch.arange(N) + elif node_idx.dtype == torch.bool: + node_idx = node_idx.nonzero(as_tuple=False).view(-1) + + self.sizes = sizes + self.flow = flow + assert self.flow in ['source_to_target', 'target_to_source'] + + super(NeighborSampler, self).__init__(node_idx.tolist(), collate_fn=self.sample, **kwargs) + + def sample(self, batch): + if not isinstance(batch, torch.Tensor): + batch = torch.tensor(batch) + + batch_size: int = len(batch) + + n_id_offset = 0 + n_id_map = [] + edge_indeces = [[] for _ in self.sizes] + e_ids = [[] for _ in self.sizes] + + for target_id in batch: + n_id = target_id.unsqueeze(dim=0) + n_id_targets = [] + for i, size in enumerate(self.sizes): + n_id_targets.append(n_id) + adj, n_id = self.adj.sample_adj(n_id, size, replace=False) + if self.flow == 'source_to_target': + adj = adj.t() + row, col, e_id = adj.coo() + row += n_id_offset + col += n_id_offset + size = adj.sparse_sizes() + edge_index = torch.stack([row, col], dim=0) + edge_indeces[i].append(edge_index) + e_ids[i].append(e_id) + + is_target = n_id == target_id + n_id_layers = torch.zeros_like(n_id) # * len(n_id_targets) + for i, targets in enumerate(reversed(n_id_targets)): + id_in_layer = [idx for idx, n in enumerate(n_id) if n in targets] + n_id_layers[id_in_layer] = i + 1 + n_id_map.append(torch.stack([n_id, is_target, n_id_layers], dim=1)) + n_id_offset += len(n_id) + + n_id_map = torch.cat(n_id_map) + _, sorted_idx = torch.sort(n_id_map[:, 2], descending=True) + n_id_map = n_id_map[sorted_idx] + adjs = [] + for i, size in enumerate(self.sizes): + edge_index = torch.cat(edge_indeces[i], dim=-1) + edge_index = reindex(sorted_idx, edge_index) + e_id = torch.cat(e_ids[i], dim=-1) + M = edge_index[0].max().item() + 1 + N = edge_index[1].max().item() + 1 + # M = (n_id_map[:, 2] >= i).sum().item() + # N = (n_id_map[:, 2] > i).sum().item() + size = (M, N) + adjs.append(Adj(edge_index, e_id, size)) + + # import ipdb + # ipdb.set_trace() + if adjs[0].size[-1] != len(batch): + import ipdb + ipdb.set_trace() + return batch_size, n_id_map, adjs[::-1] + + def __repr__(self): + return '{}(sizes={})'.format(self.__class__.__name__, self.sizes) + + +def reindex(idx_map, edge_index): + edge_reindex = torch.ones_like(edge_index) * -1 + for new_idx, old_idx in enumerate(idx_map): + edge_reindex[edge_index == old_idx] = new_idx + + return edge_reindex diff --git a/hetsage/model.py b/hetsage/model.py new file mode 100644 index 0000000..12ec158 --- /dev/null +++ b/hetsage/model.py @@ -0,0 +1,252 @@ +import ipdb +import networkx as nx +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +# from torch_geometric.data import NeighborSampler +from torch_geometric.nn import NNConv + +from .data import NeighborSampler, featurize + + +class MLP(nn.Module): + def __init__(self, input_size, hidden_sizes, output_size, activation='ReLU'): + super().__init__() + + if not isinstance(hidden_sizes, list): + hidden_sizes = [hidden_sizes] + + layers = [] + layers.append(nn.Linear(input_size, hidden_sizes[0])) + act = getattr(nn, activation) + layers.append(act()) + + for i, s in enumerate(hidden_sizes): + if i < len(hidden_sizes) - 1: + output_feats = hidden_sizes[i + 1] + else: + output_feats = output_size + layers.append(nn.Linear(hidden_sizes[i], output_feats)) + if i < len(hidden_sizes) - 1: + layers.append(act()) + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + y = self.layers(x) + return y + + def is_cuda(self): + return next(self.parameters()).is_cuda + + +class Model(nn.Module): + def __init__(self, + nx_graph, + target, + embed_size=256, + emb_hidden=[256, 256], + hidden_size=256, + device='cpu'): + super().__init__() + + self.device = device + self.embed_size = embed_size + # featurize + target_node, target_prop = target.split(':') + self.target_node = target_node + self.g, self.target_nodes, self.targets, self.node_features, self.edge_feats = featurize( + nx_graph, target_node, target_prop) + # self.target_nodes = self.target_nodes[:1000] + # k = 1500 + k = self.target_nodes.size(0) + 1 + perm = torch.randperm(self.target_nodes.size(0)) + subset_idx = perm[:k] + last_tng_id = int(0.8 * subset_idx.size(0)) + tng_idx, _ = torch.sort(subset_idx[:last_tng_id]) + val_idx, _ = torch.sort(subset_idx[last_tng_id:]) + + # import ipdb; ipdb.set_trace() + unique_targets, target_counts = torch.unique(self.targets, return_counts=True) + print('Data stats', len(self.targets), 100 * target_counts / float(len(self.targets))) + print('Tng len', len(tng_idx)) + print('Val len', len(val_idx)) + self.tng_target_nodes = self.target_nodes[tng_idx] + self.tng_targets = self.targets[tng_idx] + # FIXME: this is wrong, it includes edges to val nodes in the val set + self.val_target_nodes = self.target_nodes[val_idx] + self.val_targets = self.targets[val_idx] + edge_idx = torch.tensor(list(self.g.edges)).t().contiguous() + tng_edge_idx = self.filter_edge_index(edge_idx, self.val_target_nodes) + + # import ipdb + # ipdb.set_trace() + # neigbor_sizes = [3, 2] + # neigbor_sizes = [10, 10, 10, 10] + # neigbor_sizes = [-1, -1] + neigbor_sizes = [10, 10, 10] + # neigbor_sizes = [25] + batch_size = min(50, len(self.tng_targets) // 4) + # batch_size = 5 + workers = 1 + self.tng_loader = NeighborSampler( + edge_idx, + node_idx=self.tng_target_nodes, + sizes=neigbor_sizes, + batch_size=batch_size, + # batch_size=1024, + # shuffle=False, + shuffle=True, + num_workers=workers, + pin_memory=True, + drop_last=True) + + self.val_loader = NeighborSampler( + # tng_edge_idx, + edge_idx, + node_idx=self.val_target_nodes, + # node_idx=self.tng_target_nodes, + sizes=neigbor_sizes, + batch_size=batch_size, + # batch_size=1024, + shuffle=False, + # shuffle=True, + num_workers=workers, + pin_memory=True, + drop_last=True) + + self.edge_feats = self.edge_feats.to(device) + self.targets = self.targets.to(device) + self.target_nodes = self.target_nodes.to(device) + embedders = {} + for node_type, node_props in self.node_features.items(): + self.node_features[node_type]['x_in'] = node_props['x_in'].to(device) + self.node_features[node_type]['n_ids'] = node_props['n_ids'].to(device) + embedders[node_type] = MLP(input_size=node_props['x_in'].shape[1], + hidden_sizes=emb_hidden, + output_size=embed_size) + if node_type == target_node: + self.node_features[node_type]['x_out'] = node_props['x_out'].to(device) + self.embedders_out = MLP(input_size=node_props['x_out'].shape[1], + hidden_sizes=emb_hidden, + output_size=embed_size) + self.embedders = nn.ModuleDict(embedders) + + output_size = self.node_features[target_node]['y'].shape[1] + + # import ipdb + # ipdb.set_trace() + root_weight = True + self.convs = nn.ModuleList([ + NNConv(embed_size, + hidden_size, + MLP(input_size=self.edge_feats.shape[1], + hidden_sizes=[hidden_size, hidden_size], + output_size=hidden_size * embed_size), + aggr='mean', + root_weight=root_weight, + bias=True), + NNConv(hidden_size, + hidden_size, + MLP(input_size=self.edge_feats.shape[1], + hidden_sizes=[hidden_size, hidden_size], + output_size=hidden_size * hidden_size), + aggr='mean', + root_weight=root_weight, + bias=True), + NNConv(hidden_size, + hidden_size, + MLP(input_size=self.edge_feats.shape[1], + hidden_sizes=[hidden_size, hidden_size], + output_size=hidden_size * hidden_size), + aggr='mean', + root_weight=root_weight, + bias=True), + NNConv(hidden_size, + hidden_size, + MLP(input_size=self.edge_feats.shape[1], + hidden_sizes=[hidden_size, hidden_size], + output_size=hidden_size * hidden_size), + aggr='mean', + root_weight=root_weight, + bias=True), + ]) + self.bns = nn.ModuleList([ + torch.nn.BatchNorm1d(hidden_size), + torch.nn.BatchNorm1d(hidden_size), + torch.nn.BatchNorm1d(hidden_size), + torch.nn.BatchNorm1d(hidden_size), + ]) + self.lin1 = torch.nn.Linear(hidden_size, output_size) + + def filter_edge_index(self, edge_idx, node_idx): + mask = [ + i for i, edge in enumerate(edge_idx.t()) + if not (edge[0] in node_idx or edge[1] in node_idx) + ] + mask = torch.LongTensor(mask) + return edge_idx[:, mask] + + def get_targets(self, n_id): + ind_map = self.get_ind_map(n_id, self.target_nodes) + return self.targets[ind_map[:, 2]] + + def get_ind_map(self, n_id1, n_id2, ignore1=[]): + # ind_map = [ + # torch.tensor([i, id, torch.nonzero(n_id2 == id, as_tuple=False).squeeze()]) + # for i, id in enumerate(n_id1) if id in n_id2 + # ] + ind_map = [] + for i, id in enumerate(n_id1): + if id in n_id2 and i not in ignore1: + t = torch.tensor([i, id, torch.nonzero(n_id2 == id, as_tuple=False).squeeze()]) + ind_map.append(t) + if len(ind_map) == 0: + return None + return torch.stack(ind_map) + + def forward(self, n_id, adjs): + # embed + h = torch.zeros((n_id.shape[0], self.embed_size)).to(self.device) + h += np.nan + for node_type, node_props in self.node_features.items(): + np_n_ids = node_props['n_ids'] + ind_map = self.get_ind_map(n_id[:, 0], np_n_ids, + torch.nonzero(n_id[:, 1] == 1, as_tuple=False).squeeze()) + if ind_map is None: + continue + + h[ind_map[:, 0]] = self.embedders[node_type](node_props['x_in'][ind_map[:, 2]]) + + node_props = self.node_features[self.target_node] + np_n_ids = node_props['n_ids'] + ind_map = self.get_ind_map(n_id[:, 0], np_n_ids, + torch.nonzero(n_id[:, 1] == 0, as_tuple=False).squeeze()) + if ind_map is None: + print('WTF???') + + h[ind_map[:, 0]] = self.embedders_out(node_props['x_out'][ind_map[:, 2]]) + # import ipdb + # ipdb.set_trace() + + # import ipdb; ipdb.set_trace() + # message passing + for i, (edge_index, e_id, size) in enumerate(adjs): + h_target = h[:size[1]] + # h_target = h[n_id[:, 2] > i] + h = self.convs[i]((h, h_target), edge_index, self.edge_feats[e_id]) + h = F.relu(h) + # h = self.bns[i](h) + # h = self.convs[i](h, edge_index, self.edge_feats[e_id]) + # target_mask = n_id[:, 2] > i + # h = h[target_mask] + # out, h = self.gru(m.unsqueeze(0), h) + # out = out.squeeze(0) + + # h = F.relu(h) + # h = self.lin1(h) + # h = F.relu(h) + # return F.log_softmax(h, dim=-1) + return h diff --git a/hetsage/utils.py b/hetsage/utils.py new file mode 100644 index 0000000..4ad5c10 --- /dev/null +++ b/hetsage/utils.py @@ -0,0 +1,12 @@ +import random + +import numpy as np +import torch + + +def init_random(seed=0): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed)