diff --git a/src/MLDatasets.jl b/src/MLDatasets.jl index 0fec184c..ecbd4416 100644 --- a/src/MLDatasets.jl +++ b/src/MLDatasets.jl @@ -9,6 +9,7 @@ using MLUtils: getobs, numobs, AbstractDataContainer using Glob using DelimitedFiles: readdlm using FileIO +import CSV using LazyModules: @lazy include("require.jl") # export @require @@ -23,9 +24,8 @@ include("require.jl") # export @require @require import DataFrames="a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @require import ImageShow="4e3cecfd-b093-5904-9786-8bbb286a6a31" # @lazy import NPZ # lazy imported by FileIO -@lazy import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c" +@require import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c" @lazy import MAT="23992714-dd62-5051-b70f-ba57cb901cac" -import CSV @lazy import HDF5="f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f" # @lazy import JLD2 diff --git a/src/abstract_datasets.jl b/src/abstract_datasets.jl index 5b4e07af..b3602fad 100644 --- a/src/abstract_datasets.jl +++ b/src/abstract_datasets.jl @@ -4,7 +4,7 @@ Super-type from which all datasets in MLDatasets.jl inherit. Implements the following functionality: -- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]` +- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]` - Pretty printing. """ abstract type AbstractDataset <: AbstractDataContainer end @@ -19,9 +19,9 @@ end function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset recur_io = IOContext(io, :compact => false) - + print(io, "dataset $(D.name.name):") # if the type is parameterized don't print the parameters - + for f in fieldnames(D) if !startswith(string(f), "_") fstring = leftalign(string(f), 10) @@ -34,7 +34,7 @@ function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset end function leftalign(s::AbstractString, n::Int) - m = length(s) + m = length(s) if m > n return s[1:n] else @@ -53,37 +53,35 @@ _summary(x::BitVector) = "$(count(x))-trues BitVector" """ SupervisedDataset <: AbstractDataset -An abstract dataset type for supervised learning tasks. +An abstract dataset type for supervised learning tasks. Concrete dataset types inheriting from it must provide a `features` and a `targets` fields. """ abstract type SupervisedDataset <: AbstractDataset end -Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) : +Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) : numobs((d.features, d.targets)) - # We return named tuples Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ? (features = d.features, targets=d.targets) : getobs((; d.features, d.targets)) -Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ? +Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ? (features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) : getobs((; d.features, d.targets), i) """ UnsupervisedDataset <: AbstractDataset -An abstract dataset type for unsupervised or self-supervised learning tasks. +An abstract dataset type for unsupervised or self-supervised learning tasks. Concrete dataset types inheriting from it must provide a `features` field. """ abstract type UnsupervisedDataset <: AbstractDataset end Base.length(d::UnsupervisedDataset) = numobs(d.features) - Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features) Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i) @@ -99,13 +97,13 @@ const ARGUMENTS_SUPERVISED_TABLE = """ const FIELDS_SUPERVISED_TABLE = """ - `metadata`: A dictionary containing additional information on the dataset. -- `features`: The data features. An array if `as_df=true`, otherwise a dataframe. +- `features`: The data features. An array if `as_df=true`, otherwise a dataframe. - `targets`: The targets for supervised learning. An array if `as_df=true`, otherwise a dataframe. - `dataframe`: A dataframe containing both `features` and `targets`. It is `nothing` if `as_df=false`. """ const METHODS_SUPERVISED_TABLE = """ -- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. +- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. - `dataset[:]`: Return all observations as a named tuple of features and targets. - `length(dataset)`: Number of observations. """ @@ -119,12 +117,23 @@ const ARGUMENTS_SUPERVISED_ARRAY = """ const FIELDS_SUPERVISED_ARRAY = """ - `metadata`: A dictionary containing additional information on the dataset. -- `features`: An array storing the data features. +- `features`: An array storing the data features. - `targets`: An array storing the targets for supervised learning. """ const METHODS_SUPERVISED_ARRAY = """ -- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. +- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. - `dataset[:]`: Return all observations as a named tuple of features and targets. - `length(dataset)`: Number of observations. """ + +""" + GraphDataset <: AbstractDataset + +An abstract dataset type for graph learning tasks. +""" +abstract type GraphDataset <: AbstractDataset end + +Base.length(data::GraphDataset) = length(data.graphs) +Base.getindex(data::GraphDataset, ::Colon) = length(data) == 1 ? data.graphs[1] : data.graphs +Base.getindex(data::GraphDataset, i) = data.graphs[i] diff --git a/src/datasets/graphs/citeseer.jl b/src/datasets/graphs/citeseer.jl index 057e4b46..cfb5c123 100644 --- a/src/datasets/graphs/citeseer.jl +++ b/src/datasets/graphs/citeseer.jl @@ -21,17 +21,17 @@ end The CiteSeer citation network dataset from Ref. [1]. Nodes represent documents and edges represent citation links. -The dataset is designed for the node classification task. +The dataset is designed for the node classification task. The task is to predict the category of certain paper. The dataset is retrieved from Ref. [2]. # References -[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815) - +[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815) + [2]: [Planetoid](https://github.com/kimiyoung/planetoid) """ -struct CiteSeer <: AbstractDataset +struct CiteSeer <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end @@ -41,10 +41,6 @@ function CiteSeer(; dir=nothing, reverse_edges=true) return CiteSeer(metadata, [g]) end -Base.length(d::CiteSeer) = length(d.graphs) -Base.getindex(d::CiteSeer, ::Colon) = d.graphs[1] -Base.getindex(d::CiteSeer, i) = d.graphs[i] - # DEPRECATED in v0.6.0 function Base.getproperty(::Type{CiteSeer}, s::Symbol) @@ -67,5 +63,3 @@ function Base.getproperty(::Type{CiteSeer}, s::Symbol) return getfield(CiteSeer, s) end end - - diff --git a/src/datasets/graphs/cora.jl b/src/datasets/graphs/cora.jl index a1cf33cc..878ea7b7 100644 --- a/src/datasets/graphs/cora.jl +++ b/src/datasets/graphs/cora.jl @@ -18,18 +18,17 @@ function __init__cora() )) end - """ Cora() The Cora citation network dataset from Ref. [1]. Nodes represent documents and edges represent citation links. -Each node has a predefined feature with 1433 dimensions. -The dataset is designed for the node classification task. +Each node has a predefined feature with 1433 dimensions. +The dataset is designed for the node classification task. The task is to predict the category of certain paper. The dataset is retrieved from Ref. [2]. -# Statistics +# Statistics - Nodes: 2708 - Edges: 10556 @@ -39,17 +38,16 @@ The dataset is retrieved from Ref. [2]. - Val: 500 - Test: 1000 -The split is the one used in the original paper [1] and +The split is the one used in the original paper [1] and doesn't consider all nodes. - # References [1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815) -[2]: [Planetoid](https://github.com/kimiyoung/planetoid +[2]: [Planetoid](https://github.com/kimiyoung/planetoid) """ -struct Cora <: AbstractDataset +struct Cora <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end @@ -59,10 +57,6 @@ function Cora(; dir=nothing, reverse_edges=true) return Cora(metadata, [g]) end -Base.length(d::Cora) = length(d.graphs) -Base.getindex(d::Cora, ::Colon) = d.graphs[1] -Base.getindex(d::Cora, i) = getindex(d.graphs, i) - # DEPRECATED in v0.6.0 function Base.getproperty(::Type{Cora}, s::Symbol) diff --git a/src/datasets/graphs/karateclub.jl b/src/datasets/graphs/karateclub.jl index 2c31487e..1aa77127 100644 --- a/src/datasets/graphs/karateclub.jl +++ b/src/datasets/graphs/karateclub.jl @@ -3,16 +3,16 @@ export KarateClub """ KarateClub() -The Zachary's karate club dataset originally appeared in Ref [1]. +The Zachary's karate club dataset originally appeared in Ref. [1]. The network contains 34 nodes (members of the karate club). The nodes are connected by 78 undirected and unweighted edges. The edges indicate if the two members interacted outside the club. The node labels indicate which community or the karate club the member belongs to. -The club based labels are as per the original dataset in Ref [1]. -The community labels are obtained by modularity-based clustering following Ref [2]. -The data is retrieved from Ref [3] and [4]. +The club based labels are as per the original dataset in Ref. [1]. +The community labels are obtained by modularity-based clustering following Ref. [2]. +The data is retrieved from Ref. [3] and [4]. One node per unique label is used as training data. # References @@ -25,7 +25,7 @@ One node per unique label is used as training data. [4]: [NetworkX Zachary's Karate Club Dataset](https://networkx.org/documentation/stable/_modules/networkx/generators/social.html#karate_club_graph) """ -struct KarateClub <: AbstractDataset +struct KarateClub <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end @@ -60,14 +60,10 @@ function KarateClub() labels_comm = [ 1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0] - - node_data = (; labels_clubs, labels_comm) + + node_data = (; labels_clubs, labels_comm) g = Graph(; num_nodes=34, edge_index=(src, target), node_data) metadata = Dict{String, Any}() return KarateClub(metadata, [g]) end - -Base.length(d::KarateClub) = length(d.graphs) -Base.getindex(d::KarateClub, ::Colon) = d.graphs[1] -Base.getindex(d::KarateClub, i) = d.graphs[i] diff --git a/src/datasets/graphs/movielens.jl b/src/datasets/graphs/movielens.jl index 00684300..2adc3b84 100644 --- a/src/datasets/graphs/movielens.jl +++ b/src/datasets/graphs/movielens.jl @@ -14,21 +14,21 @@ end """ MovieLens(name; dir=nothing) -Datasets from the [MovieLens website](https://movielens.org) collected and maintained by [GroupLens](https://grouplens.org/datasets/movielens/). -The MovieLens datasets are presented in a Graph format. +Datasets from the [MovieLens website](https://movielens.org) collected and maintained by [GroupLens](https://grouplens.org/datasets/movielens/). +The MovieLens datasets are presented in a Graph format. For license and usage resitrictions please refer to the Readme.md of the datasets. -There are 6 versions of movielens datasets currently supported: "100k", "1m", "10m", "20m", "25m", "latest-small". +There are 6 versions of movielens datasets currently supported: "100k", "1m", "10m", "20m", "25m", "latest-small". The 100k and 1k datasets contain movie data and rating data along with demographic data. -Starting from the 10m dataset, Movielens datasets no longer contain the demographic data. -These datasets contain movie data, rating data, and tag information. +Starting from the 10m dataset, Movielens datasets no longer contain the demographic data. +These datasets contain movie data, rating data, and tag information. -The 20m and 25m datasets additionally contain [genome tag scores](http://files.grouplens.org/papers/tag_genome.pdf). +The 20m and 25m datasets additionally contain [genome tag scores](http://files.grouplens.org/papers/tag_genome.pdf). Each movie in these datasets contains tag relevance scores for every tag. -Each dataset contains an heterogeneous graph, with two kinds of nodes, -`movie` and `user`. The rating is represented by an edge between them: `(user, rating, movie)`. -20m, 25m, and latest-small datasets also contain `tag` nodes and edges of type `(user, tag, movie)` and +Each dataset contains an heterogeneous graph, with two kinds of nodes, +`movie` and `user`. The rating is represented by an edge between them: `(user, rating, movie)`. +20m, 25m, and latest-small datasets also contain `tag` nodes and edges of type `(user, tag, movie)` and optionally `(movie, score, tag)`. # Examples @@ -56,7 +56,7 @@ julia> g = data[:] node_data => Dict{String, Dict} with 2 entries edge_data => Dict{Tuple{String, String, String}, Dict} with 1 entry -# Acess the user information +### Acess the user information julia> user_data = g.node_data["user"] Dict{Symbol, AbstractVector} with 4 entries: :age => [24, 53, 23, 24, 33, 42, 57, 36, 29, 53 … 61, 42, 24, 48, 38, 26, 32, 20, 48, 22] @@ -64,7 +64,7 @@ Dict{Symbol, AbstractVector} with 4 entries: :zipcode => ["85711", "94043", "32067", "43537", "15213", "98101", "91344", "05201", "01002", "90703" … "22902", "66221", "3… :gender => Bool[1, 0, 1, 1, 0, 1, 1, 1, 1, 1 … 1, 1, 1, 1, 0, 0, 1, 1, 0, 1] -# Access rating information +### Access rating information julia> g.edge_data[("user", "rating", "movie")] Dict{Symbol, Vector} with 2 entries: :timestamp => [881250949, 891717742, 878887116, 880606923, 886397596, 884182806, 881171488, 891628467, 886324817, 883603013 … 8… @@ -79,7 +79,7 @@ MovieLens 20m: metadata => Dict{String, Any} with 4 entries graphs => 1-element Vector{MLDatasets.HeteroGraph} -# There is only 1 graph in MovieLens dataset +### There is only 1 graph in MovieLens dataset julia> g = data[1] Heterogeneous Graph: node_types => 3-element Vector{String} @@ -90,20 +90,20 @@ Heterogeneous Graph: node_data => Dict{String, Dict} with 0 entries edge_data => Dict{Tuple{String, String, String}, Dict} with 3 entries -# Apart from user rating a movie, a user assigns tag to movies and there are genome-scores for movie-tag pairs +### Apart from user rating a movie, a user assigns tag to movies and there are genome-scores for movie-tag pairs julia> g.edge_indices Dict{Tuple{String, String, String}, Tuple{Vector{Int64}, Vector{Int64}}} with 3 entries: ("movie", "score", "tag") => ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 131170, 131170, 131170, 131170, 131170, 131170, 131170, 131170,… ("user", "tag", "movie") => ([18, 65, 65, 65, 65, 65, 65, 65, 65, 65 … 3489, 7045, 7045, 7164, 7164, 55999, 55999, 55999, 55… ("user", "rating", "movie") => ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 60816, 61160, 65682, 66762, 68319, 68954, 69526, 69644, 70286, … -# Access the rating +### Access the rating julia> g.edge_data[("user", "rating", "movie")] Dict{Symbol, Vector} with 2 entries: :timestamp => [1112486027, 1112484676, 1112484819, 1112484727, 1112484580, 1094785740, 1094785734, 1112485573, 1112484940, 111248… :rating => Float16[3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 4.0, 4.0, 4.0, 4.0 … 4.5, 4.0, 4.5, 4.5, 4.5, 4.5, 4.5, 3.0, 5.0, 2.5] -# Access the movie-tag scores +### Access the movie-tag scores score = g.edge_data[("movie", "score", "tag")][:score] 23419536-element Vector{Float64}: 0.025000000000000022 @@ -112,17 +112,17 @@ score = g.edge_data[("movie", "score", "tag")][:score] ⋮ ``` -## References +# References [1] [GroupLens Website](https://grouplens.org/datasets/movielens/) -[2] [TensorFlow MovieLens Implementation](https://www.tensorflow.org/datasets/catalog/movielens) +[2] [TensorFlow MovieLens Implementation](https://www.tensorflow.org/datasets/catalog/movielens) -[3] Jesse Vig, Shilad Sen, and John Riedl. 2012. The Tag Genome: Encoding Community Knowledge to Support Novel Interaction. ACM Trans. Interact. Intell. Syst. 2, 3, Article 13 (September 2012), 44 pages. https://doi.org/10.1145/2362394.2362395. +[3] Jesse Vig, Shilad Sen, and John Riedl. 2012. The Tag Genome: Encoding Community Knowledge to Support Novel Interaction. ACM Trans. Interact. Intell. Syst. 2, 3, Article 13 (September 2012), 44 pages. https://doi.org/10.1145/2362394.2362395. -[4] F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Trans. Interact. Intell. Syst. 5, 4, Article 19 (January 2016), 19 pages. https://doi.org/10.1145/2827872 +[4] F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Trans. Interact. Intell. Syst. 5, 4, Article 19 (January 2016), 19 pages. https://doi.org/10.1145/2827872 """ -struct MovieLens +struct MovieLens <: GraphDataset name::String metadata::Dict{String, Any} graphs::Vector{HeteroGraph} @@ -526,7 +526,3 @@ function Base.show(io::IO, ::MIME"text/plain", d::MovieLens) end end end - -Base.length(data::MovieLens) = length(data.graphs) -Base.getindex(data::MovieLens, ::Colon) = length(data.graphs) == 1 ? data.graphs[1] : data.graphs -Base.getindex(data::MovieLens, i) = getobs(data.graphs, i) diff --git a/src/datasets/graphs/ogbdataset.jl b/src/datasets/graphs/ogbdataset.jl index 290d1e8e..1420b873 100644 --- a/src/datasets/graphs/ogbdataset.jl +++ b/src/datasets/graphs/ogbdataset.jl @@ -75,7 +75,7 @@ Dict{String, Any} with 17 entries: "is hetero" => false "task level" => "node" ⋮ => ⋮ -``` + julia> data = OGBDataset("ogbn-mag") OGBDataset ogbn-mag: metadata => Dict{String, Any} with 17 entries @@ -89,6 +89,7 @@ Heterogeneous Graph: edge_indices => Dict{Tuple{String, String, String}, Tuple{Vector{Int64}, Vector{Int64}}} with 4 entries node_data => (year = "Dict{String, Vector{Float32}} with 1 entry", features = "Dict{String, Matrix{Float32}} with 1 entry", label = "Dict{String, Vector{Int64}} with 1 entry") edge_data => (reltype = "Dict{Tuple{String, String, String}, Vector{Float32}} with 4 entries",) +``` ## Edge prediction task @@ -120,8 +121,12 @@ OGBDataset ogbg-molhiv: julia> data[1] (graphs = Graph(19, 40), labels = 0) ``` + +# References + +[1] [Open Graph Benchmark: Datasets for Machine Learning on Graphs](https://arxiv.org/pdf/2005.00687.pdf) """ -struct OGBDataset{GD} <: AbstractDataset +struct OGBDataset{GD} <: GraphDataset name::String metadata::Dict{String, Any} graphs::Vector{<:AbstractGraph} @@ -139,6 +144,14 @@ function OGBDataset(fullname; dir = nothing) graph_dicts, graph_data = read_ogb_graph(path, metadata) graphs = ogbdict2graph.(graph_dicts) end + split = read_ogb_split(path, graphs, metadata) + for key in keys(split) + if get(metadata, key, nothing) != nothing + metadata[key] = merge(split[key], metadata[key]) + else + metadata[key] = split[key] + end + end return OGBDataset(fullname, metadata, graphs, graph_data) end @@ -297,64 +310,59 @@ function read_ogb_graph(path, metadata) end end labels = isempty(dlabels) ? nothing : - length(dlabels) == 1 ? first(dlabels)[2] : dlabels + length(dlabels) == 1 ? first(dlabels)[2] : dlabels + graph_data = nothing + + return graphs, graph_data +end + +function read_ogb_split(path::String, graphs, metadata::Dict) splits = readdir(joinpath(path, "split")) @assert length(splits) == 1 # TODO check if datasets with multiple splits existin in OGB # TODO sometimes splits are given in .pt format # Use read_pytorch in src/io.jl to load them. - split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true), - val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true), - test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true)) - - if split_idx.train !== nothing - split_idx.train .+= 1 - end - if split_idx.val !== nothing - split_idx.val .+= 1 - end - if split_idx.test !== nothing - split_idx.test .+= 1 + if metadata["task level"] in ["node", "graph"] + split_idx = (train = read_ogb_file(joinpath(path, "split", splits[1], "train.csv"), Int; tovec=true), + val = read_ogb_file(joinpath(path, "split", splits[1], "valid.csv"), Int; tovec=true), + test = read_ogb_file(joinpath(path, "split", splits[1], "test.csv"), Int; tovec=true)) + else + split_idx = (train = read_pytorch(joinpath(path, "split", splits[1], "train.pt")), + val = read_pytorch(joinpath(path, "split", splits[1], "train.pt")), + test = read_pytorch(joinpath(path, "split", splits[1], "train.pt"))) end + split_dict = Dict() - graph_data = nothing if metadata["task level"] == "node" @assert length(graphs) == 1 - g = graphs[1] - if split_idx.train !== nothing - g["node_train_mask"] = indexes2mask(split_idx.train, g["num_nodes"]) - end - if split_idx.val !== nothing - g["node_val_mask"] = indexes2mask(split_idx.val, g["num_nodes"]) - end - if split_idx.test !== nothing - g["node_test_mask"] = indexes2mask(split_idx.test, g["num_nodes"]) - end + num_nodes = graphs[1].num_nodes - end - if metadata["task level"] == "link" + train_mask = split_idx.train == nothing ? nothing : indexes2mask(split_idx.train .+ 1, num_nodes) + val_mask = split_idx.val == nothing ? nothing : indexes2mask(split_idx.val .+ 1, num_nodes) + test_mask = split_idx.test == nothing ? nothing : indexes2mask(split_idx.test .+ 1, num_nodes) + + split_dict["node"] = (; split = Dict(splits[1] => [(; train=train_mask, val=val_mask, test=test_mask)])) + + elseif metadata["task level"] == "link" @assert length(graphs) == 1 - g = graphs[1] - if split_idx.train !== nothing - g["edge_train_mask"] = indexes2mask(split_idx.train, g["num_edges"]) - end - if split_idx.val !== nothing - g["edge_val_mask"] = indexes2mask(split_idx.val, g["num_edges"]) - end - if split_idx.test !== nothing - g["edge_test_mask"] = indexes2mask(split_idx.test, g["num_edges"]) + + for key in keys(split_idx) + split_idx[key]["edge"] .+= 1 end - end - if metadata["task level"] == "graph" - train_mask = split_idx.train !== nothing ? indexes2mask(split_idx.train, num_graphs) : nothing - val_mask = split_idx.val !== nothing ? indexes2mask(split_idx.val, num_graphs) : nothing - test_mask = split_idx.test !== nothing ? indexes2mask(split_idx.test, num_graphs) : nothing - graph_data = clean_nt((; labels=maybesqueeze(labels), train_mask, val_mask, test_mask)) + split_dict["edge"] = (; split = Dict(splits[1] => [(; train=split_idx.train, val=split_idx.val, test=split_idx.test)])) + + elseif metadata["task level"] == "graph" + num_graphs = length(graphs) + train_mask = split_idx.train == nothing ? nothing : indexes2mask(split_idx.train .+ 1, num_graphs) + val_mask = split_idx.val == nothing ? nothing : indexes2mask(split_idx.val .+ 1, num_graphs) + test_mask = split_idx.test == nothing ? nothing : indexes2mask(split_idx.test .+ 1, num_graphs) + + split_dict["graph"] = (; split = Dict(splits[1] => (; train=train_mask, val=val_mask, test=test_mask))) end - return graphs, graph_data + return split_dict end function read_ogb_hetero_graph(path, metadata) @@ -496,6 +504,7 @@ function read_ogb_hetero_graph(path, metadata) push!(graphs, graph) end + graph_data = nothing return graphs, graph_data end @@ -561,7 +570,6 @@ function ogbdict2heterograph(d::Dict) return HeteroGraph(;num_nodes, edge_indices, edge_data, node_data) end -Base.length(data::OGBDataset) = length(data.graphs) Base.getindex(data::OGBDataset{Nothing}, ::Colon) = length(data.graphs) == 1 ? data.graphs[1] : data.graphs Base.getindex(data::OGBDataset, ::Colon) = (; data.graphs, data.graph_data.labels) Base.getindex(data::OGBDataset{Nothing}, i) = getobs(data.graphs, i) diff --git a/src/datasets/graphs/planetoid.jl b/src/datasets/graphs/planetoid.jl index 845ef0b3..ac10e935 100644 --- a/src/datasets/graphs/planetoid.jl +++ b/src/datasets/graphs/planetoid.jl @@ -8,7 +8,7 @@ https://github.com/kimiyoung/planetoid/raw/master/data """ function read_planetoid_data(DEPNAME; dir=nothing, reverse_edges=true) name = lowercase(DEPNAME) - + x = read_planetoid_file(DEPNAME, "ind.$(name).x", dir) y = read_planetoid_file(DEPNAME, "ind.$(name).y", dir) allx = read_planetoid_file(DEPNAME, "ind.$(name).allx", dir) @@ -57,27 +57,22 @@ function read_planetoid_data(DEPNAME; dir=nothing, reverse_edges=true) end end - node_data = (features=x, targets=y, - train_indices, - val_indices, - test_indices) - - - node_data = (features=x, targets=y, - train_mask = indexes2mask(train_indices, num_nodes), - val_mask = indexes2mask(val_indices, num_nodes), - test_mask = indexes2mask(test_indices, num_nodes)) + node_data = (features=x, targets=y) + split = ( train = indexes2mask(train_indices, num_nodes), + val = indexes2mask(val_indices, num_nodes), + test = indexes2mask(test_indices, num_nodes)) metadata = Dict{String,Any}( "name" => name, "num_classes" => length(unique(y)), - "classes" => sort(unique(y)) + "classes" => sort(unique(y)), + "node" => (; split=[split]), ) edge_index = adjlist2edgeindex(adj_list) - - g = Graph(; num_nodes, - edge_index, + + g = Graph(; num_nodes, + edge_index, node_data) return metadata, g diff --git a/src/datasets/graphs/polblogs.jl b/src/datasets/graphs/polblogs.jl index 5cc1adf6..78b8ffda 100644 --- a/src/datasets/graphs/polblogs.jl +++ b/src/datasets/graphs/polblogs.jl @@ -1,7 +1,6 @@ function __init__polblogs() LINK = "https://netset.telecom-paris.fr/datasets/polblogs.tar.gz" DEPNAME = "PolBlogs" - register(DataDep(DEPNAME, """ @@ -16,17 +15,20 @@ end """ PolBlogs(; dir=nothing) - -The Political Blogs dataset from the [The Political Blogosphere and -the 2004 US Election: Divided they Blog](https://dl.acm.org/doi/10.1145/1134271.1134277) paper. + +The Political Blogs dataset from Ref. [1]. `PolBlogs` is a graph with 1,490 vertices (representing political blogs) and 19,025 edges (links between blogs). -The links are automatically extracted from a crawl of the front page of the blog. +The links are automatically extracted from a crawl of the front page of the blog. Each vertex receives a label indicating the political leaning of the blog: liberal or conservative. + +# References + +[1] [The Political Blogosphere and the 2004 US Election: Divided they Blog](https://dl.acm.org/doi/10.1145/1134271.1134277) paper. """ -struct PolBlogs <: AbstractDataset +struct PolBlogs <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end @@ -44,12 +46,8 @@ function PolBlogs(; dir=nothing) metadata = Dict{String, Any}() g = Graph(; num_nodes = 1490, - edge_index = (s, t), + edge_index = (s, t), node_data = (; labels)) return PolBlogs(metadata, [g]) end - -Base.length(d::PolBlogs) = length(d.graphs) -Base.getindex(d::PolBlogs, ::Colon) = d.graphs[1] -Base.getindex(d::PolBlogs, i) = getindex(d.graphs, i) diff --git a/src/datasets/graphs/pubmed.jl b/src/datasets/graphs/pubmed.jl index c8d57470..22c7c4c8 100644 --- a/src/datasets/graphs/pubmed.jl +++ b/src/datasets/graphs/pubmed.jl @@ -21,7 +21,7 @@ end The PubMed citation network dataset from Ref. [1]. Nodes represent documents and edges represent citation links. -The dataset is designed for the node classification task. +The dataset is designed for the node classification task. The task is to predict the category of certain paper. The dataset is retrieved from Ref. [2]. @@ -31,7 +31,7 @@ The dataset is retrieved from Ref. [2]. [2]: [Planetoid](https://github.com/kimiyoung/planetoid) """ -struct PubMed <: AbstractDataset +struct PubMed <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end @@ -41,10 +41,6 @@ function PubMed(; dir=nothing, reverse_edges=true) return PubMed(metadata, [g]) end -Base.length(d::PubMed) = length(d.graphs) -Base.getindex(d::PubMed, ::Colon) = d.graphs[1] -Base.getindex(d::PubMed, i) = getindex(d.graphs, i) - # DEPRECATED in v0.6.0 function Base.getproperty(::Type{PubMed}, s::Symbol) diff --git a/src/datasets/graphs/reddit.jl b/src/datasets/graphs/reddit.jl index c6b405cb..d9b7d1b8 100644 --- a/src/datasets/graphs/reddit.jl +++ b/src/datasets/graphs/reddit.jl @@ -19,7 +19,7 @@ end """ Reddit(; full=true, dir=nothing) -The Reddit dataset was introduced in Ref [1]. +The Reddit dataset was introduced in Ref. [1]. It is a graph dataset of Reddit posts made in the month of September, 2014. The dataset contains a single post-to-post graph, connecting posts if the same user comments on both. The node label in this case is one of the 41 communities, or “subreddit”s, that a post belongs to. @@ -31,16 +31,16 @@ Use `full=false` to load only a subsample of the dataset. # References + [1]: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) [2]: [Benchmarks on the Reddit Dataset](https://paperswithcode.com/dataset/reddit) """ -struct Reddit <: AbstractDataset +struct Reddit <: GraphDataset metadata::Dict{String, Any} graphs::Vector{Graph} end - function Reddit(; full=true, dir=nothing) DATAFILES = [ "reddit-G.json", "reddit-G_full.json", "reddit-adjlist.txt", @@ -49,7 +49,6 @@ function Reddit(; full=true, dir=nothing) DATA = joinpath.("reddit", DATAFILES) DEPNAME = "Reddit" - if full graph_json = datafile(DEPNAME, DATA[2], dir) else @@ -67,12 +66,11 @@ function Reddit(; full=true, dir=nothing) # Metadata directed = graph["directed"] - multigraph = graph["multigraph"] links = graph["links"] nodes = graph["nodes"] num_edges = directed ? length(links) : length(links) * 2 num_nodes = length(nodes) - @assert length(graph["graph"]) == 0 # should be zero + @assert length(graph["graph"]) == 0 # edges s = get.(links, "source", nothing) .+ 1 @@ -101,15 +99,15 @@ function Reddit(; full=true, dir=nothing) @assert sum(val_mask .& test_mask) == 0 train_mask = nor.(test_mask, val_mask) - metadata = Dict{String, Any}("directed" => directed, "multigraph" => multigraph, - "num_edges" => num_edges, "num_nodes" => num_nodes) - g = Graph(; num_nodes, - edge_index=(s, t), - node_data= (; labels, features, train_mask, val_mask, test_mask) + split = ( train = train_mask, + val = val_mask, + test = test_mask) + + metadata = Dict{String, Any}("directed" => directed, + "num_edges" => num_edges, "num_nodes" => num_nodes, "node" => (; split=[split])) + g = Graph(; num_nodes, + edge_index=(s, t), + node_data= (; labels, features) ) return Reddit(metadata, [g]) end - -Base.length(d::Reddit) = length(d.graphs) -Base.getindex(d::Reddit, ::Colon) = d.graphs -Base.getindex(d::Reddit, i) = getindex(d.graphs, i) diff --git a/src/datasets/graphs/tudataset.jl b/src/datasets/graphs/tudataset.jl index fcc9677f..012d47bb 100644 --- a/src/datasets/graphs/tudataset.jl +++ b/src/datasets/graphs/tudataset.jl @@ -17,12 +17,12 @@ end A variety of graph benchmark datasets, *.e.g.* "QM9", "IMDB-BINARY", "REDDIT-BINARY" or "PROTEINS", collected from the [TU Dortmund University](https://chrsmrrs.github.io/datasets/). Retrieve from the TUDataset collection the dataset `name`, where `name` -is any of the datasets available [here](https://chrsmrrs.github.io/datasets/docs/datasets/). +is any of the datasets available [here](https://chrsmrrs.github.io/datasets/docs/datasets/). A `TUDataset` object can be indexed to retrieve a specific graph or a subset of graphs. -See [here](https://chrsmrrs.github.io/datasets/docs/format/) for an in-depth -description of the format. +See [here](https://chrsmrrs.github.io/datasets/docs/format/) for an in-depth +description of the format. # Usage Example @@ -40,8 +40,13 @@ dataset TUDataset: julia> data[1] (graphs = Graph(42, 162), targets = 1) ``` + +# References + +[1] [TUDataset: A collection of benchmark datasets for learning with graphs](https://arxiv.org/pdf/2007.08663.pdf) + """ -struct TUDataset <: AbstractDataset +struct TUDataset <: GraphDataset name::String metadata::Dict{String, Any} graphs::Vector{Graph} @@ -55,14 +60,14 @@ function TUDataset(name; dir=nothing) create_default_dir("TUDataset") d = tudataset_datadir(name, dir) # See here for the file format: https://chrsmrrs.github.io/datasets/docs/format/ - + st = readdlm(joinpath(d, "$(name)_A.txt"), ',', Int) # Check that the first node is labeled 1. # TODO this will fail if the first node is isolated @assert minimum(st) == 1 source, target = st[:,1], st[:,2] - graph_indicator = readdlm(joinpath(d, "$(name)_graph_indicator.txt"), Int) |> vec + graph_indicator = readdlm(joinpath(d, "$(name)_graph_indicator.txt"), Int) |> vec @assert all(sort(unique(graph_indicator)) .== 1:length(unique(graph_indicator))) num_nodes = length(graph_indicator) @@ -70,7 +75,7 @@ function TUDataset(name; dir=nothing) num_graphs = length(unique(graph_indicator)) # LOAD OPTIONAL FILES IF EXIST - + node_labels = isfile(joinpath(d, "$(name)_node_labels.txt")) ? readdlm(joinpath(d, "$(name)_node_labels.txt"), Int) |> vec : nothing @@ -104,14 +109,13 @@ function TUDataset(name; dir=nothing) end end - full_dataset = (; num_nodes, num_edges, num_graphs, - source, target, + source, target, graph_indicator, node_labels, - edge_labels, + edge_labels, graph_labels, - node_attributes, + node_attributes, edge_attributes, graph_attributes) @@ -136,14 +140,13 @@ function tudataset_datadir(name, dir = nothing) return d end - function tudataset_getgraph(data::NamedTuple, i::Int) vmin = searchsortedfirst(data.graph_indicator, i) vmax = searchsortedlast(data.graph_indicator, i) nodes = vmin:vmax node_labels = isnothing(data.node_labels) ? nothing : getobs(data.node_labels, nodes) node_attributes = isnothing(data.node_attributes) ? nothing : getobs(data.node_attributes, nodes) - + emin = searchsortedfirst(data.source, vmin) emax = searchsortedlast(data.source, vmax) edges = emin:emax @@ -156,18 +159,16 @@ function tudataset_getgraph(data::NamedTuple, i::Int) num_edges = length(source) node_data = (features = node_attributes, targets = node_labels) edge_data = (features = edge_attributes, targets = edge_labels) - + return Graph(; num_nodes, - edge_index = (source, target), + edge_index = (source, target), node_data = node_data |> clean_nt, edge_data = edge_data |> clean_nt, ) end -Base.length(data::TUDataset) = length(data.graphs) - -function Base.getindex(data::TUDataset, ::Colon) +function Base.getindex(data::TUDataset, ::Colon) if data.graph_data === nothing return data.graphs else @@ -175,10 +176,10 @@ function Base.getindex(data::TUDataset, ::Colon) end end -function Base.getindex(data::TUDataset, i) +function Base.getindex(data::TUDataset, i) if data.graph_data === nothing return getobs(data.graphs, i) else - return getobs((; data.graphs, data.graph_data...), i) + return getobs((; data.graphs, data.graph_data...), i) end end diff --git a/src/graph.jl b/src/graph.jl index 9f13d67a..f04c0633 100644 --- a/src/graph.jl +++ b/src/graph.jl @@ -121,13 +121,13 @@ end HeteroGraph is used for HeteroGeneous Graphs. -`HeteroGraph` unlike `Graph` can have different types of nodes. Each node pertains to different types of information. +`HeteroGraph` unlike `Graph` can have different types of nodes. Each node pertains to different types of information. -Edges in `HeteroGraph` is defined by relations. A relation is a tuple of +Edges in `HeteroGraph` is defined by relations. A relation is a tuple of (`src_node_type`, `edge_type`, `target_node_type`) where `edge_type` represents the relation -between the src and target nodes. Edges between same node types are possible. +between the src and target nodes. Edges between same node types are possible. -A `HeteroGraph` can be directed or undirected. It doesn't distinguish between directed +A `HeteroGraph` can be directed or undirected. It doesn't distinguish between directed or undirected graphs. Therefore, for undirected graphs, it will store edges in both directions. Nodes are indexed in `1:num_nodes`. @@ -137,10 +137,10 @@ Nodes are indexed in `1:num_nodes`. - `num_edges`: Dictionary containing the number of edges for each relation. - `edge_indices`: Dictionary containing the `edge_index` for each edge relation. An `edge_index` is a tuple containing two vectors with length equal to the number of edges for the relation. The first vector contains the list of the source nodes of each edge, the second contains the target nodes. -- `node_data`: node-related data. Can be `nothing`, Dictionary of a dictionary of arrays. Data of a speific type of node can be accessed +- `node_data`: node-related data. Can be `nothing`, Dictionary of a dictionary of arrays. Data of a speific type of node can be accessed using node_data[node_type].The array's last dimension size should be equal to the number of nodes. Default `nothing`. -- `edge_data`: Can be `nothing`, Dictionary of a dictionary of arrays. Data of a speific type of edge can be accessed +- `edge_data`: Can be `nothing`, Dictionary of a dictionary of arrays. Data of a speific type of edge can be accessed using edge_data[edge_type].The array's last dimension size should be equal to the number of nodes. Default `nothing`. """ @@ -227,3 +227,5 @@ function edgeindex2adjlist(s, t, num_nodes; inneigs=false) end return adj end + + diff --git a/test/datasets/graphs.jl b/test/datasets/graphs.jl index 85b3a6a3..fd21c80a 100644 --- a/test/datasets/graphs.jl +++ b/test/datasets/graphs.jl @@ -11,9 +11,9 @@ @test g.num_edges == 9104 @test size(g.node_data.features) == (3703, g.num_nodes) @test size(g.node_data.targets) == (g.num_nodes,) - @test sum(g.node_data.train_mask) == 120 - @test sum(g.node_data.val_mask) == 500 - @test sum(g.node_data.test_mask) == 1015 + @test sum(data.metadata["node"].split[1].train) == 120 + @test sum(data.metadata["node"].split[1].val) == 500 + @test sum(data.metadata["node"].split[1].test) == 1015 @test g.edge_index isa Tuple{Vector{Int}, Vector{Int}} s, t = g.edge_index for a in (s, t) @@ -36,9 +36,9 @@ end @test g.num_edges == 10556 @test size(g.node_data.features) == (1433, g.num_nodes) @test size(g.node_data.targets) == (g.num_nodes,) - @test sum(g.node_data.train_mask) == 140 - @test sum(g.node_data.val_mask) == 500 - @test sum(g.node_data.test_mask) == 1000 + @test sum(data.metadata["node"].split[1].train) == 140 + @test sum(data.metadata["node"].split[1].val) == 500 + @test sum(data.metadata["node"].split[1].test) == 1000 @test g.edge_index isa Tuple{Vector{Int}, Vector{Int}} s, t = g.edge_index for a in (s, t) @@ -109,9 +109,9 @@ end @test g.num_edges == 88648 @test size(g.node_data.features) == (500, g.num_nodes) @test size(g.node_data.targets) == (g.num_nodes,) - @test sum(g.node_data.train_mask) == 60 - @test sum(g.node_data.val_mask) == 500 - @test sum(g.node_data.test_mask) == 1000 + @test sum(data.metadata["node"].split[1].train) == 60 + @test sum(data.metadata["node"].split[1].val) == 500 + @test sum(data.metadata["node"].split[1].test) == 1000 @test g.edge_index isa Tuple{Vector{Int}, Vector{Int}} s, t = g.edge_index for a in (s, t) diff --git a/test/datasets/graphs_no_ci.jl b/test/datasets/graphs_no_ci.jl index 675e02bc..1e5e19ad 100644 --- a/test/datasets/graphs_no_ci.jl +++ b/test/datasets/graphs_no_ci.jl @@ -247,13 +247,21 @@ end @test g.num_nodes == 169343 @test g.num_edges == 1166243 - @test sum(count.([g.node_data.train_mask, g.node_data.test_mask, g.node_data.val_mask])) == g.num_nodes + train_mask = d.metadata["node"].split["time"][1].train + test_mask = d.metadata["node"].split["time"][1].test + val_mask = d.metadata["node"].split["time"][1].val + + @test sum(count.([train_mask, test_mask, val_mask])) == g.num_nodes end @testset "OGBDataset - ogbg-molhiv" begin d = OGBDataset("ogbg-molhiv") - @test sum(count.([d.graph_data.train_mask, d.graph_data.test_mask, d.graph_data.val_mask])) == length(d) + train_mask = d.metadata["graph"].split["scaffold"].train + test_mask = d.metadata["graph"].split["scaffold"].test + val_mask = d.metadata["graph"].split["scaffold"].val + + @test sum(count.([train_mask, test_mask, val_mask])) == length(d) end @testset "Reddit_full" begin @@ -264,9 +272,10 @@ end @test g.num_edges == 114615892 @test size(g.node_data.features) == (602, g.num_nodes) @test size(g.node_data.labels) == (g.num_nodes,) - @test count(g.node_data.train_mask) == 153431 - @test count(g.node_data.val_mask) == 23831 - @test count(g.node_data.test_mask) == 55703 + split = data.metadata["node"].split[1] + @test count(split.train) == 153431 + @test count(split.val) == 23831 + @test count(split.test) == 55703 s, t = g.edge_index @test length(s) == length(t) == g.num_edges @test minimum(s) == minimum(t) == 1 @@ -281,9 +290,10 @@ end @test g.num_edges == 23213838 @test size(g.node_data.features) == (602, g.num_nodes) @test size(g.node_data.labels) == (g.num_nodes,) - @test count(g.node_data.train_mask) == 152410 - @test count(g.node_data.val_mask) == 23699 - @test count(g.node_data.test_mask) == 55334 + split = data.metadata["node"].split[1] + @test count(split.train) == 152410 + @test count(split.val) == 23699 + @test count(split.test) == 55334 s, t = g.edge_index @test length(s) == length(t) == g.num_edges @test minimum(s) == minimum(t) == 1 diff --git a/test/runtests.jl b/test/runtests.jl index 7ed8e5f4..dcb64db9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using ImageShow using ColorTypes using FixedPointNumbers using JSON3 +using Pickle ENV["DATADEPS_ALWAYS_ACCEPT"] = true @@ -22,7 +23,6 @@ dataset_tests = [ no_ci_dataset_tests = [ "datasets/graphs_no_ci.jl", - "datasets/text_no_ci.jl", "datasets/vision/cifar10.jl", "datasets/vision/cifar100.jl", "datasets/vision/emnist.jl",