From 0d91bfe2c3b96ad6b0daba51da0bdc68db230535 Mon Sep 17 00:00:00 2001 From: Deeptendu Santra <55111154+Dsantra92@users.noreply.github.com> Date: Tue, 4 Oct 2022 04:50:03 -0400 Subject: [PATCH] Split for OGBDatasets (#172) * Tensor split for Simple Graphs * Split for HeteroGraphs * Fix the node dataset * Pass the tests * Keep Hetero API as it is * Add backlinks for directed graphs + structerize edge syntax * Add tests --- src/datasets/graphs/ogbdataset.jl | 181 +++++++++++++++++++++--------- test/datasets/graphs.jl | 40 ++++++- test/runtests.jl | 1 + 3 files changed, 169 insertions(+), 53 deletions(-) diff --git a/src/datasets/graphs/ogbdataset.jl b/src/datasets/graphs/ogbdataset.jl index 290d1e8e..56a8f125 100644 --- a/src/datasets/graphs/ogbdataset.jl +++ b/src/datasets/graphs/ogbdataset.jl @@ -159,9 +159,11 @@ function read_ogb_metadata(fullname, dir = nothing) @assert fullname ∈ names(df) metadata = Dict{String, Any}(String(r[1]) => parse_pystring(r[2]) for r in eachrow(df[!,[names(df)[1], fullname]])) # edge cases for additional node and edge files - for additional_keys in ["additional edge files", "additional node files"] - if !isnothing(metadata[additional_keys]) - metadata[additional_keys] = Vector{String}(split(metadata[additional_keys], ",")) + for additional_key in ["additional edge files", "additional node files"] + if !isnothing(metadata[additional_key]) + metadata[additional_key] = Vector{String}(split(metadata[additional_key], ",")) + else + metadata[additional_key] = Vector{String}() end end if prefix == "ogbn" @@ -300,60 +302,70 @@ function read_ogb_graph(path, metadata) length(dlabels) == 1 ? first(dlabels)[2] : dlabels splits = readdir(joinpath(path, "split")) - @assert length(splits) == 1 # TODO check if datasets with multiple splits existin in OGB - - # TODO sometimes splits are given in .pt format - # Use read_pytorch in src/io.jl to load them. - split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true), - val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true), - test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true)) - - if split_idx.train !== nothing - split_idx.train .+= 1 - end - if split_idx.val !== nothing - split_idx.val .+= 1 - end - if split_idx.test !== nothing - split_idx.test .+= 1 - end - + @assert length(splits) == 1 "Current implementation supports only 1 split" graph_data = nothing - if metadata["task level"] == "node" + if metadata["task level"] in ["node", "graph"] + split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true), + val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true), + test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true)) + + if metadata["task level"] == "node" + # During the time of writing this piece of code, + # node level OGBDataset had only 1 graph. + # We need to implement splits for multiple graphs if that changes + @assert length(graphs) == 1 + g = graphs[1] + + # TODO: Implement for more than one split + for key in keys(split_idx) + if !isnothing(split_idx[key]) + g["node_$(key)_mask"] = indexes2mask(split_idx[key] .+ 1, g["num_nodes"]) + end + end + else + split_mask = Dict() + for key in keys(split_idx) + if !isnothing(split_idx[key]) + split_mask[Symbol("$(key)_mask")] = indexes2mask(split_idx[key] .+ 1, num_graphs) + end + end + graph_data = clean_nt((; labels=maybesqueeze(labels), split_mask...)) + end + elseif metadata["task level"] == "link" + # During the time of writing this piece of code, + # link level OGBDataset had only 1 graph. + # We need to implement splits for multiple graphs if that changes @assert length(graphs) == 1 g = graphs[1] - if split_idx.train !== nothing - g["node_train_mask"] = indexes2mask(split_idx.train, g["num_nodes"]) - end - if split_idx.val !== nothing - g["node_val_mask"] = indexes2mask(split_idx.val, g["num_nodes"]) - end - if split_idx.test !== nothing - g["node_test_mask"] = indexes2mask(split_idx.test, g["num_nodes"]) - end - end - if metadata["task level"] == "link" - @assert length(graphs) == 1 - g = graphs[1] - if split_idx.train !== nothing - g["edge_train_mask"] = indexes2mask(split_idx.train, g["num_edges"]) - end - if split_idx.val !== nothing - g["edge_val_mask"] = indexes2mask(split_idx.val, g["num_edges"]) - end - if split_idx.test !== nothing - g["edge_test_mask"] = indexes2mask(split_idx.test, g["num_edges"]) + split_dict = (train = read_pytorch(joinpath(path, "split", splits[1], "train.pt")), + val = read_pytorch(joinpath(path, "split", splits[1], "valid.pt")), + test = read_pytorch(joinpath(path, "split", splits[1], "test.pt"))) + + for key in keys(split_dict) + if !isnothing(split_dict[key]) + for k in keys(split_dict[key]) + if k in ["edge", "edge_neg"] + split_dict[key][k] .+= 1 + ei = split_dict[key][k] + s, t = ei[:,1], ei[:,2] + if metadata["add_inverse_edge"] + split_dict[key][k] = ([s;t], [t;s]) + else + split_dict[key][k] = (s, t) + end + else + if metadata["add_inverse_edge"] + v = split_dict[key][k] + split_dict[key][k] = [v; v] + end + end + end + g["edge_$(key)_dict"] = split_dict[key] + end end end - if metadata["task level"] == "graph" - train_mask = split_idx.train !== nothing ? indexes2mask(split_idx.train, num_graphs) : nothing - val_mask = split_idx.val !== nothing ? indexes2mask(split_idx.val, num_graphs) : nothing - test_mask = split_idx.test !== nothing ? indexes2mask(split_idx.test, num_graphs) : nothing - - graph_data = clean_nt((; labels=maybesqueeze(labels), train_mask, val_mask, test_mask)) - end return graphs, graph_data end @@ -495,7 +507,72 @@ function read_ogb_hetero_graph(path, metadata) end push!(graphs, graph) - end + end + + dlabels = Dict{String, Any}() + for k in keys(dict) + if contains(k, "label") + if k ∉ [node_keys; edge_keys] + dlabels[k] = dict[k] + end + end + end + labels = isempty(dlabels) ? nothing : + length(dlabels) == 1 ? first(dlabels)[2] : dlabels + + # Similar to OGB Graphs + # Also see split implementation for normal ogb graphs + # for any possible issues + splits = readdir(joinpath(path, "split")) + @assert length(splits) == 1 "Current implementation supports only 1 split" + + graph_data = nothing + split_dir = joinpath(path, "split", splits[1]) + if metadata["task level"] == "node" + @assert length(graphs) == 1 + g = graphs[1] + split_idx_dict = Dict{String, Dict{String, Vector{Int}}}() + split_idx_dict["train"] = Dict() + split_idx_dict["test"] = Dict() + split_idx_dict["val"] = Dict() + + nodetype_has_label_df = read_csv_asdf(joinpath(path, "raw", "nodetype-has-label.csv")) + nodetype_has_label = Dict(String(node) => num[1] for (node, num) in pairs(eachcol(nodetype_has_label_df))) + for (node_type, has_label) in nodetype_has_label + @assert node_type ∈ node_types + if has_label + split_idx_dict["train"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "train.csv"), Int; tovec=true) + split_idx_dict["test"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "test.csv"), Int; tovec=true) + split_idx_dict["val"][node_type] = read_ogb_file(joinpath(split_dir, node_type, "valid.csv"), Int; tovec=true) + end + end + + for key in keys(split_idx_dict) + g["node_$(key)_mask"] = Dict() + for node_type in keys(split_idx_dict[key]) + num_nodes = dict["num_nodes"][node_type][1] + g["node_$(key)_mask"][node_type] = indexes2mask(split_idx_dict[key][node_type] .+ 1, num_nodes) + end + end + elseif metadata["task level"] == "graph" + split_mask = Dict() + for key in keys(split_idx) + if !isnothing(split_idx[key]) + split_mask[Symbol("$(key)_mask")] = indexes2mask(split_idx[key] .+ 1, num_graphs) + end + end + graph_data = clean_nt((; labels=maybesqueeze(labels), split_mask...)) + elseif metadata["task level"] == "link" + @assert length(graphs) == 1 + @warn "Link split for HeteroData has not been implemented yet." + + # g = graphs[1] + + # split_dict = (train = read_pytorch(joinpath(split_dir, "train.pt")), + # val = read_pytorch(joinpath(split_dir, "valid.pt")), + # test = read_pytorch(joinpath(split_dir, "test.pt"))) + + end return graphs, graph_data end @@ -530,7 +607,7 @@ function ogbdict2heterograph(d::Dict) edge_indices = Dict(triplet => (ei[:, 1], ei[:, 2]) for (triplet, ei) in d["edge_indices"]) - edge_data = Dict(k => Dict{Symbol, Any}() for k in keys(edge_indices)) + edge_data = Dict{Tuple{String, String, String}, Dict}(k => Dict{Symbol, Any}() for k in keys(edge_indices)) for (feature_name, v) in d # v is a dict # the number of nothing values should not be equal to total number of values diff --git a/test/datasets/graphs.jl b/test/datasets/graphs.jl index 85b3a6a3..dcd84440 100644 --- a/test/datasets/graphs.jl +++ b/test/datasets/graphs.jl @@ -150,7 +150,7 @@ Sys.iswindows() || @testset "OGBn-mag" begin @test g.num_nodes[type] == num_nodes[type] node_data = get(g.node_data, type, nothing) isnothing(node_data) || for (key, val) in node_data - @test key ∈ [:year, :features, :label] + @test key ∈ [:year, :features, :label, :train_mask, :test_mask, :val_mask] if key == :features @test size(val)[1] == 128 end @@ -172,3 +172,41 @@ Sys.iswindows() || @testset "OGBn-mag" begin end end +@testset "OGBl-ddi" begin + data = OGBDataset("ogbl-ddi") + # @test data isa AbstractDataset + @test length(data) == 1 + + g = data[1] + @test g == data[:] + @test g isa MLDatasets.Graph + + @test g.num_nodes == 4267 + @test g.num_edges == 2135822 + @test g.edge_index isa Tuple{Vector{Int}, Vector{Int}} + s, t = g.edge_index + for a in (s, t) + @test a isa Vector{Int} + @test length(a) == g.num_edges + @test minimum(a) == 1 + @test maximum(a) == g.num_nodes + end + for key in keys(g.edge_data) + @assert key ∈ [:train_dict, :val_dict, :test_dict] + s, t = g.edge_data[key]["edge"] + for a in (s, t) + @test a isa Vector{Int} + @test minimum(a) >= 1 + @test maximum(a) <= g.num_nodes + end + if haskey(g.edge_data[key], "edge_neg") + s, t = g.edge_data[key]["edge_neg"] + for a in (s, t) + @test a isa Vector{Int} + @test minimum(a) >= 1 + @test maximum(a) <= g.num_nodes + end + end + end +end + diff --git a/test/runtests.jl b/test/runtests.jl index 30385eb4..396e0c44 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using ImageShow using ColorTypes using FixedPointNumbers using JSON3 +using Pickle ENV["DATADEPS_ALWAYS_ACCEPT"] = true