Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split for OGBDatasets #172

Merged
merged 7 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 129 additions & 52 deletions src/datasets/graphs/ogbdataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,11 @@ function read_ogb_metadata(fullname, dir = nothing)
@assert fullname ∈ names(df)
metadata = Dict{String, Any}(String(r[1]) => parse_pystring(r[2]) for r in eachrow(df[!,[names(df)[1], fullname]]))
# edge cases for additional node and edge files
for additional_keys in ["additional edge files", "additional node files"]
if !isnothing(metadata[additional_keys])
metadata[additional_keys] = Vector{String}(split(metadata[additional_keys], ","))
for additional_key in ["additional edge files", "additional node files"]
if !isnothing(metadata[additional_key])
metadata[additional_key] = Vector{String}(split(metadata[additional_key], ","))
else
metadata[additional_key] = Vector{String}()
end
end
if prefix == "ogbn"
Expand Down Expand Up @@ -300,60 +302,70 @@ function read_ogb_graph(path, metadata)
length(dlabels) == 1 ? first(dlabels)[2] : dlabels

splits = readdir(joinpath(path, "split"))
@assert length(splits) == 1 # TODO check if datasets with multiple splits existin in OGB

# TODO sometimes splits are given in .pt format
# Use read_pytorch in src/io.jl to load them.
split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true),
val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true),
test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true))

if split_idx.train !== nothing
split_idx.train .+= 1
end
if split_idx.val !== nothing
split_idx.val .+= 1
end
if split_idx.test !== nothing
split_idx.test .+= 1
end

@assert length(splits) == 1 "Current implementation supports only 1 split"

graph_data = nothing
if metadata["task level"] == "node"
if metadata["task level"] in ["node", "graph"]
split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true),
val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true),
test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true))

if metadata["task level"] == "node"
# During the time of writing this piece of code,
# node level OGBDataset had only 1 graph.
# We need to implement splits for multiple graphs if that changes
@assert length(graphs) == 1
g = graphs[1]

# TODO: Implement for more than one split
for key in keys(split_idx)
if !isnothing(split_idx[key])
g["node_$(key)_mask"] = indexes2mask(split_idx[key] .+ 1, g["num_nodes"])
end
end
else
split_mask = Dict()
for key in keys(split_idx)
if !isnothing(split_idx[key])
split_mask[Symbol("$(key)_mask")] = indexes2mask(split_idx[key] .+ 1, num_graphs)
end
end
graph_data = clean_nt((; labels=maybesqueeze(labels), split_mask...))
end
elseif metadata["task level"] == "link"
# During the time of writing this piece of code,
# link level OGBDataset had only 1 graph.
# We need to implement splits for multiple graphs if that changes
@assert length(graphs) == 1
g = graphs[1]
if split_idx.train !== nothing
g["node_train_mask"] = indexes2mask(split_idx.train, g["num_nodes"])
end
if split_idx.val !== nothing
g["node_val_mask"] = indexes2mask(split_idx.val, g["num_nodes"])
end
if split_idx.test !== nothing
g["node_test_mask"] = indexes2mask(split_idx.test, g["num_nodes"])
end

end
if metadata["task level"] == "link"
@assert length(graphs) == 1
g = graphs[1]
if split_idx.train !== nothing
g["edge_train_mask"] = indexes2mask(split_idx.train, g["num_edges"])
end
if split_idx.val !== nothing
g["edge_val_mask"] = indexes2mask(split_idx.val, g["num_edges"])
end
if split_idx.test !== nothing
g["edge_test_mask"] = indexes2mask(split_idx.test, g["num_edges"])
split_dict = (train = read_pytorch(joinpath(path, "split", splits[1], "train.pt")),
val = read_pytorch(joinpath(path, "split", splits[1], "valid.pt")),
test = read_pytorch(joinpath(path, "split", splits[1], "test.pt")))

for key in keys(split_dict)
if !isnothing(split_dict[key])
for k in keys(split_dict[key])
if k in ["edge", "edge_neg"]
split_dict[key][k] .+= 1
ei = split_dict[key][k]
s, t = ei[:,1], ei[:,2]
if metadata["add_inverse_edge"]
split_dict[key][k] = ([s;t], [t;s])
else
split_dict[key][k] = (s, t)
end
else
if metadata["add_inverse_edge"]
v = split_dict[key][k]
split_dict[key][k] = [v; v]
end
end
end
g["edge_$(key)_dict"] = split_dict[key]
end
end
end
if metadata["task level"] == "graph"
train_mask = split_idx.train !== nothing ? indexes2mask(split_idx.train, num_graphs) : nothing
val_mask = split_idx.val !== nothing ? indexes2mask(split_idx.val, num_graphs) : nothing
test_mask = split_idx.test !== nothing ? indexes2mask(split_idx.test, num_graphs) : nothing

graph_data = clean_nt((; labels=maybesqueeze(labels), train_mask, val_mask, test_mask))
end
return graphs, graph_data
end

Expand Down Expand Up @@ -495,7 +507,72 @@ function read_ogb_hetero_graph(path, metadata)
end

push!(graphs, graph)
end
end

dlabels = Dict{String, Any}()
for k in keys(dict)
if contains(k, "label")
if k ∉ [node_keys; edge_keys]
dlabels[k] = dict[k]
end
end
end
labels = isempty(dlabels) ? nothing :
length(dlabels) == 1 ? first(dlabels)[2] : dlabels

# Similar to OGB Graphs
# Also see split implementation for normal ogb graphs
# for any possible issues
splits = readdir(joinpath(path, "split"))
@assert length(splits) == 1 "Current implementation supports only 1 split"

graph_data = nothing
split_dir = joinpath(path, "split", splits[1])
if metadata["task level"] == "node"
@assert length(graphs) == 1
g = graphs[1]
split_idx_dict = Dict{String, Dict{String, Vector{Int}}}()
split_idx_dict["train"] = Dict()
split_idx_dict["test"] = Dict()
split_idx_dict["val"] = Dict()

nodetype_has_label_df = read_csv_asdf(joinpath(path, "raw", "nodetype-has-label.csv"))
nodetype_has_label = Dict(String(node) => num[1] for (node, num) in pairs(eachcol(nodetype_has_label_df)))
for (node_type, has_label) in nodetype_has_label
@assert node_type ∈ node_types
if has_label
split_idx_dict["train"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "train.csv"), Int; tovec=true)
split_idx_dict["test"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "test.csv"), Int; tovec=true)
split_idx_dict["val"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "valid.csv"), Int; tovec=true)
end
end

for key in keys(split_idx_dict)
g["node_$(key)_mask"] = Dict()
for node_type in keys(split_idx_dict[key])
num_nodes = dict["num_nodes"][node_type][1]
g["node_$(key)_mask"][node_type] = indexes2mask(split_idx_dict[key][node_type] .+ 1, num_nodes)
end
end
elseif metadata["task level"] == "graph"
split_mask = Dict()
for key in keys(split_idx)
if !isnothing(split_idx[key])
split_mask[Symbol("$(key)_mask")] = indexes2mask(split_idx[key] .+ 1, num_graphs)
end
end
graph_data = clean_nt((; labels=maybesqueeze(labels), split_mask...))
elseif metadata["task level"] == "link"
@assert length(graphs) == 1
@warn "Link split for HeteroData has not been implemented yet."

# g = graphs[1]

# split_dict = (train = read_pytorch(joinpath(split_dir, "train.pt")),
# val = read_pytorch(joinpath(split_dir, "valid.pt")),
# test = read_pytorch(joinpath(split_dir, "test.pt")))

end
return graphs, graph_data
end

Expand Down Expand Up @@ -530,7 +607,7 @@ function ogbdict2heterograph(d::Dict)
edge_indices = Dict(triplet => (ei[:, 1], ei[:, 2])
for (triplet, ei) in d["edge_indices"])

edge_data = Dict(k => Dict{Symbol, Any}() for k in keys(edge_indices))
edge_data = Dict{Tuple{String, String, String}, Dict}(k => Dict{Symbol, Any}() for k in keys(edge_indices))
for (feature_name, v) in d
# v is a dict
# the number of nothing values should not be equal to total number of values
Expand Down
40 changes: 39 additions & 1 deletion test/datasets/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ Sys.iswindows() || @testset "OGBn-mag" begin
@test g.num_nodes[type] == num_nodes[type]
node_data = get(g.node_data, type, nothing)
isnothing(node_data) || for (key, val) in node_data
@test key ∈ [:year, :features, :label]
@test key ∈ [:year, :features, :label, :train_mask, :test_mask, :val_mask]
if key == :features
@test size(val)[1] == 128
end
Expand All @@ -172,3 +172,41 @@ Sys.iswindows() || @testset "OGBn-mag" begin
end
end

@testset "OGBl-ddi" begin
data = OGBDataset("ogbl-ddi")
# @test data isa AbstractDataset
@test length(data) == 1

g = data[1]
@test g == data[:]
@test g isa MLDatasets.Graph

@test g.num_nodes == 4267
@test g.num_edges == 2135822
@test g.edge_index isa Tuple{Vector{Int}, Vector{Int}}
s, t = g.edge_index
for a in (s, t)
@test a isa Vector{Int}
@test length(a) == g.num_edges
@test minimum(a) == 1
@test maximum(a) == g.num_nodes
end
for key in keys(g.edge_data)
@assert key ∈ [:train_dict, :val_dict, :test_dict]
s, t = g.edge_data[key]["edge"]
for a in (s, t)
@test a isa Vector{Int}
@test minimum(a) >= 1
@test maximum(a) <= g.num_nodes
end
if haskey(g.edge_data[key], "edge_neg")
s, t = g.edge_data[key]["edge_neg"]
for a in (s, t)
@test a isa Vector{Int}
@test minimum(a) >= 1
@test maximum(a) <= g.num_nodes
end
end
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ImageShow
using ColorTypes
using FixedPointNumbers
using JSON3
using Pickle

ENV["DATADEPS_ALWAYS_ACCEPT"] = true

Expand Down