Skip to content

Commit

Permalink
reverse edge types and convert y to long
Browse files Browse the repository at this point in the history
  • Loading branch information
yanbing-j committed Nov 20, 2022
1 parent 2dfa0fb commit a842429
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 75 deletions.
59 changes: 5 additions & 54 deletions examples/lsc/mag240m/gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def num_classes(self) -> int:

def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]:
node_types = ['paper', 'author', 'institution']
# node_types = ['paper']
edge_types = [
('author', 'affiliated_with', 'institution'),
('institution', 'rev_affiliated_with', 'author'),
Expand All @@ -67,57 +66,18 @@ def __init__(self, model: str, in_channels: int, out_channels: int,
self.dropout = dropout
self.num_layers = num_layers

# self.convs = ModuleList()
# self.norms = ModuleList()
# self.skips = ModuleList()

# if self.model == 'gat':
# self.convs.append(
# GATConv(in_channels, hidden_channels // heads, heads))
# self.skips.append(Linear(in_channels, hidden_channels))
# for _ in range(num_layers - 1):
# self.convs.append(
# GATConv(hidden_channels, hidden_channels // heads, heads))
# self.skips.append(Linear(hidden_channels, hidden_channels))

# elif self.model == 'graphsage':
# self.convs.append(SAGEConv(in_channels, hidden_channels))
# for _ in range(num_layers - 1):
# self.convs.append(SAGEConv(hidden_channels, hidden_channels))

# for _ in range(num_layers):
# self.norms.append(BatchNorm1d(hidden_channels))

# self.mlp = Sequential(
# Linear(hidden_channels, hidden_channels),
# BatchNorm1d(hidden_channels),
# ReLU(inplace=True),
# Dropout(p=self.dropout),
# Linear(hidden_channels, out_channels),
# )

self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, hidden_channels)
self.lin = Linear(hidden_channels, out_channels)
# self.relu = ReLU(inplace=True)

def forward(self, x: Tensor, edge_index: Adj) -> Tensor:
# for i in range(self.num_layers):
# x = self.convs[i](x, edge_index)
# if self.model == 'gat':
# x = x + self.skips[i](x)
# x = F.elu(self.norms[i](x))
# elif self.model == 'graphsage':
# x = F.relu(self.norms[i](x))
# x = F.dropout(x, p=self.dropout, training=self.training)
x = x.to(torch.float)
x = self.conv1(x, edge_index)
x = F.relu(x, inplace=True)
x = self.conv2(x, edge_index)
x = F.relu(x, inplace=True)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
return x
# return self.mlp(x)
x = self.conv2(x, edge_index).relu()
x = F.dropout(x, p=self.dropout, training=self.training)
return self.lin(x)

class HeteroGNN(LightningModule):
def __init__(self, model_name: str, metadata: Tuple[List[NodeType], List[EdgeType]], in_channels: int, out_channels: int,
Expand All @@ -138,19 +98,10 @@ def forward(
) -> Dict[NodeType, Tensor]:
return self.model(x_dict, edge_index_dict)

# @torch.no_grad()
# def setup(self, stage: Optional[str] = None): # Initialize parameters.
# data = self.trainer.datamodule
# loader = data.dataloader(torch.arange(1), shuffle=False, num_workers=0)
# batch = next(iter(loader))
# self(batch.x_dict, batch.edge_index_dict)

def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]:
batch_size = batch['paper'].batch_size
print(batch.x_dict)
print(batch.edge_index_dict)
y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]
y = batch['paper'].y[:batch_size]
y = batch['paper'].y[:batch_size].to(torch.long)
return y_hat, y

def training_step(self, batch: Batch, batch_idx: int) -> Tensor:
Expand Down
24 changes: 3 additions & 21 deletions ogb/lsc/mag240m.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,19 @@ def to_pyg_hetero_data(self):

data['author'].num_nodes = self.__meta__['author']
path = osp.join(self.dir, 'processed', 'author', 'author.npy')
# data['author'].x = np.memmap(path, mode='w+', shape=(data['author'].num_nodes, self.num_paper_features))
data['author'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['author'].num_nodes, self.num_paper_features))
data['institution'].num_nodes = self.__meta__['institution']
path = osp.join(self.dir, 'processed', 'institution', 'inst.npy')
# data['institution'].x = np.memmap(path, mode='w+', shape=(data['institution'].num_nodes, self.num_paper_features))
data['institution'].x = np.memmap(path, mode='r', dtype="float16", shape=(data['institution'].num_nodes, self.num_paper_features))

# data['author'].num_nodes = self.__meta__['author']
# data['institution'].num_nodes = self.__meta__['institution']
# path = osp.join(self.dir, 'processed', 'author', 'author.npy')
# data['author'].x = np.load(path, mmap_mode='r', encoding='bytes', allow_pickle=True)
# path = osp.join(self.dir, 'processed', 'institution', 'inst.npy')
# data['institution'].x = np.load(path, mmap_mode='r', allow_pickle=True)

print("node done")
for edge_type in [('author', 'affiliated_with', 'institution'),
('author', 'writes', 'paper'),
('paper', 'cites', 'paper')]:
name = '___'.join(edge_type)
path = osp.join(self.dir, 'processed', name, 'edge_index.npy')
edge_index = torch.from_numpy(np.load(path))
data[edge_type].edge_index = edge_index.flip([0])
data[edge_type].edge_index = edge_index
data[edge_type[2], f'rev_{edge_type[1]}', edge_type[0]].edge_index = edge_index.flip([0])

for f, v in [('train', 'train'), ('valid', 'val'), ('test-dev', 'test')]:
idx = self.get_idx_split(f)
Expand Down Expand Up @@ -152,15 +143,6 @@ def all_paper_year(self) -> np.ndarray:
path = osp.join(self.dir, 'processed', 'paper', 'node_year.npy')
return np.load(path)

# def edge_index(self, id1: str, id2: str,
# id3: Optional[str] = None) -> np.ndarray:
# src = id1
# rel, dst = (id3, id2) if id3 is None else (id2, id3)
# rel = self.__rels__[(src, dst)] if rel is None else rel
# name = f'{src}___{rel}___{dst}'
# path = osp.join(self.dir, 'processed', name, 'edge_index.npy')
# return np.load(path)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

Expand Down Expand Up @@ -207,7 +189,7 @@ def save_test_submission(self, input_dict: Dict, dir_path: str, mode: str):


if __name__ == '__main__':
dataset = MAG240MDataset('/home/user/yanbing/pyg/ogb/ogb/lsc/dataset')
dataset = MAG240MDataset()
data = dataset.to_pyg_hetero_data()
print(dataset)
print(dataset.num_papers)
Expand Down

0 comments on commit a842429

Please sign in to comment.