From a8424295f6ad0a1b8acbbfa207e78c82c4a6870b Mon Sep 17 00:00:00 2001 From: yanbing-j Date: Sun, 20 Nov 2022 17:10:46 +0800 Subject: [PATCH] reverse edge types and convert y to long --- examples/lsc/mag240m/gnn.py | 59 ++++--------------------------------- ogb/lsc/mag240m.py | 24 ++------------- 2 files changed, 8 insertions(+), 75 deletions(-) diff --git a/examples/lsc/mag240m/gnn.py b/examples/lsc/mag240m/gnn.py index 85f1b40f..7f55532f 100644 --- a/examples/lsc/mag240m/gnn.py +++ b/examples/lsc/mag240m/gnn.py @@ -48,7 +48,6 @@ def num_classes(self) -> int: def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]: node_types = ['paper', 'author', 'institution'] - # node_types = ['paper'] edge_types = [ ('author', 'affiliated_with', 'institution'), ('institution', 'rev_affiliated_with', 'author'), @@ -67,57 +66,18 @@ def __init__(self, model: str, in_channels: int, out_channels: int, self.dropout = dropout self.num_layers = num_layers - # self.convs = ModuleList() - # self.norms = ModuleList() - # self.skips = ModuleList() - - # if self.model == 'gat': - # self.convs.append( - # GATConv(in_channels, hidden_channels // heads, heads)) - # self.skips.append(Linear(in_channels, hidden_channels)) - # for _ in range(num_layers - 1): - # self.convs.append( - # GATConv(hidden_channels, hidden_channels // heads, heads)) - # self.skips.append(Linear(hidden_channels, hidden_channels)) - - # elif self.model == 'graphsage': - # self.convs.append(SAGEConv(in_channels, hidden_channels)) - # for _ in range(num_layers - 1): - # self.convs.append(SAGEConv(hidden_channels, hidden_channels)) - - # for _ in range(num_layers): - # self.norms.append(BatchNorm1d(hidden_channels)) - - # self.mlp = Sequential( - # Linear(hidden_channels, hidden_channels), - # BatchNorm1d(hidden_channels), - # ReLU(inplace=True), - # Dropout(p=self.dropout), - # Linear(hidden_channels, out_channels), - # ) - self.conv1 = SAGEConv(in_channels, hidden_channels) self.conv2 = SAGEConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels, out_channels) # self.relu = ReLU(inplace=True) def forward(self, x: Tensor, edge_index: Adj) -> Tensor: - # for i in range(self.num_layers): - # x = self.convs[i](x, edge_index) - # if self.model == 'gat': - # x = x + self.skips[i](x) - # x = F.elu(self.norms[i](x)) - # elif self.model == 'graphsage': - # x = F.relu(self.norms[i](x)) - # x = F.dropout(x, p=self.dropout, training=self.training) x = x.to(torch.float) - x = self.conv1(x, edge_index) - x = F.relu(x, inplace=True) - x = self.conv2(x, edge_index) - x = F.relu(x, inplace=True) + x = self.conv1(x, edge_index).relu() x = F.dropout(x, p=self.dropout, training=self.training) - return x - # return self.mlp(x) + x = self.conv2(x, edge_index).relu() + x = F.dropout(x, p=self.dropout, training=self.training) + return self.lin(x) class HeteroGNN(LightningModule): def __init__(self, model_name: str, metadata: Tuple[List[NodeType], List[EdgeType]], in_channels: int, out_channels: int, @@ -138,19 +98,10 @@ def forward( ) -> Dict[NodeType, Tensor]: return self.model(x_dict, edge_index_dict) - # @torch.no_grad() - # def setup(self, stage: Optional[str] = None): # Initialize parameters. - # data = self.trainer.datamodule - # loader = data.dataloader(torch.arange(1), shuffle=False, num_workers=0) - # batch = next(iter(loader)) - # self(batch.x_dict, batch.edge_index_dict) - def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]: batch_size = batch['paper'].batch_size - print(batch.x_dict) - print(batch.edge_index_dict) y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size] - y = batch['paper'].y[:batch_size] + y = batch['paper'].y[:batch_size].to(torch.long) return y_hat, y def training_step(self, batch: Batch, batch_idx: int) -> Tensor: diff --git a/ogb/lsc/mag240m.py b/ogb/lsc/mag240m.py index 349bfbdd..9ddf7fe3 100644 --- a/ogb/lsc/mag240m.py +++ b/ogb/lsc/mag240m.py @@ -66,28 +66,19 @@ def to_pyg_hetero_data(self): data['author'].num_nodes = self.__meta__['author'] path = osp.join(self.dir, 'processed', 'author', 'author.npy') - # data['author'].x = np.memmap(path, mode='w+', shape=(data['author'].num_nodes, self.num_paper_features)) data['author'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['author'].num_nodes, self.num_paper_features)) data['institution'].num_nodes = self.__meta__['institution'] path = osp.join(self.dir, 'processed', 'institution', 'inst.npy') - # data['institution'].x = np.memmap(path, mode='w+', shape=(data['institution'].num_nodes, self.num_paper_features)) data['institution'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['institution'].num_nodes, self.num_paper_features)) - # data['author'].num_nodes = self.__meta__['author'] - # data['institution'].num_nodes = self.__meta__['institution'] - # path = osp.join(self.dir, 'processed', 'author', 'author.npy') - # data['author'].x = np.load(path, mmap_mode='r', encoding='bytes', allow_pickle=True) - # path = osp.join(self.dir, 'processed', 'institution', 'inst.npy') - # data['institution'].x = np.load(path, mmap_mode='r', allow_pickle=True) - - print("node done") for edge_type in [('author', 'affiliated_with', 'institution'), ('author', 'writes', 'paper'), ('paper', 'cites', 'paper')]: name = '___'.join(edge_type) path = osp.join(self.dir, 'processed', name, 'edge_index.npy') edge_index = torch.from_numpy(np.load(path)) - data[edge_type].edge_index = edge_index.flip([0]) + data[edge_type].edge_index = edge_index + data[edge_type[2], f'rev_{edge_type[1]}', edge_type[0]].edge_index = edge_index.flip([0]) for f, v in [('train', 'train'), ('valid', 'val'), ('test-dev', 'test')]: idx = self.get_idx_split(f) @@ -152,15 +143,6 @@ def all_paper_year(self) -> np.ndarray: path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy') return np.load(path) - # def edge_index(self, id1: str, id2: str, - # id3: Optional[str] = None) -> np.ndarray: - # src = id1 - # rel, dst = (id3, id2) if id3 is None else (id2, id3) - # rel = self.__rels__[(src, dst)] if rel is None else rel - # name = f'{src}___{rel}___{dst}' - # path = osp.join(self.dir, 'processed', name, 'edge_index.npy') - # return np.load(path) - def __repr__(self) -> str: return f'{self.__class__.__name__}()' @@ -207,7 +189,7 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str): if __name__ == '__main__': - dataset = MAG240MDataset('/home/user/yanbing/pyg/ogb/ogb/lsc/dataset') + dataset = MAG240MDataset() data = dataset.to_pyg_hetero_data() print(dataset) print(dataset.num_papers)