Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train test and validation for graph datasets. #168

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using MLUtils: getobs, numobs, AbstractDataContainer
using Glob
using DelimitedFiles: readdlm
using FileIO
import CSV
using LazyModules: @lazy

include("require.jl") # export @require
Expand All @@ -23,9 +24,8 @@ include("require.jl") # export @require
@require import DataFrames="a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@require import ImageShow="4e3cecfd-b093-5904-9786-8bbb286a6a31"
# @lazy import NPZ # lazy imported by FileIO
@lazy import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c"
@require import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c"
@lazy import MAT="23992714-dd62-5051-b70f-ba57cb901cac"
import CSV
@lazy import HDF5="f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
# @lazy import JLD2

Expand Down
37 changes: 23 additions & 14 deletions src/abstract_datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Super-type from which all datasets in MLDatasets.jl inherit.

Implements the following functionality:
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]`
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]`
- Pretty printing.
"""
abstract type AbstractDataset <: AbstractDataContainer end
Expand All @@ -19,9 +19,9 @@ end

function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset
recur_io = IOContext(io, :compact => false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be good to avoid polluting PRs with this formatting changes, there should be some option in your editor to avoid doing this


print(io, "dataset $(D.name.name):") # if the type is parameterized don't print the parameters

for f in fieldnames(D)
if !startswith(string(f), "_")
fstring = leftalign(string(f), 10)
Expand All @@ -34,7 +34,7 @@ function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset
end

function leftalign(s::AbstractString, n::Int)
m = length(s)
m = length(s)
if m > n
return s[1:n]
else
Expand All @@ -53,37 +53,35 @@ _summary(x::BitVector) = "$(count(x))-trues BitVector"
"""
SupervisedDataset <: AbstractDataset

An abstract dataset type for supervised learning tasks.
An abstract dataset type for supervised learning tasks.
Concrete dataset types inheriting from it must provide
a `features` and a `targets` fields.
"""
abstract type SupervisedDataset <: AbstractDataset end


Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) :
Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) :
numobs((d.features, d.targets))


# We return named tuples
Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ?
(features = d.features, targets=d.targets) :
getobs((; d.features, d.targets))

Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ?
Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ?
(features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) :
getobs((; d.features, d.targets), i)

"""
UnsupervisedDataset <: AbstractDataset

An abstract dataset type for unsupervised or self-supervised learning tasks.
An abstract dataset type for unsupervised or self-supervised learning tasks.
Concrete dataset types inheriting from it must provide a `features` field.
"""
abstract type UnsupervisedDataset <: AbstractDataset end


Base.length(d::UnsupervisedDataset) = numobs(d.features)

Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features)
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i)

Expand All @@ -99,13 +97,13 @@ const ARGUMENTS_SUPERVISED_TABLE = """

const FIELDS_SUPERVISED_TABLE = """
- `metadata`: A dictionary containing additional information on the dataset.
- `features`: The data features. An array if `as_df=true`, otherwise a dataframe.
- `features`: The data features. An array if `as_df=true`, otherwise a dataframe.
- `targets`: The targets for supervised learning. An array if `as_df=true`, otherwise a dataframe.
- `dataframe`: A dataframe containing both `features` and `targets`. It is `nothing` if `as_df=false`.
"""

const METHODS_SUPERVISED_TABLE = """
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
- `dataset[:]`: Return all observations as a named tuple of features and targets.
- `length(dataset)`: Number of observations.
"""
Expand All @@ -119,12 +117,23 @@ const ARGUMENTS_SUPERVISED_ARRAY = """

const FIELDS_SUPERVISED_ARRAY = """
- `metadata`: A dictionary containing additional information on the dataset.
- `features`: An array storing the data features.
- `features`: An array storing the data features.
- `targets`: An array storing the targets for supervised learning.
"""

const METHODS_SUPERVISED_ARRAY = """
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets.
- `dataset[:]`: Return all observations as a named tuple of features and targets.
- `length(dataset)`: Number of observations.
"""

"""
GraphDataset <: AbstractDataset

An abstract dataset type for graph learning tasks.
"""
abstract type GraphDataset <: AbstractDataset end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better rename this as

Suggested change
abstract type GraphDataset <: AbstractDataset end
abstract type AbstractGraphDataset <: AbstractDataset end

especially if we go with #169

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it is better to file this change as a separate PR. Generally a PR should target a single issue or a single feature addition


Base.length(data::GraphDataset) = length(data.graphs)
Base.getindex(data::GraphDataset, ::Colon) = length(data) == 1 ? data.graphs[1] : data.graphs
Base.getindex(data::GraphDataset, i) = data.graphs[i]
14 changes: 4 additions & 10 deletions src/datasets/graphs/citeseer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ end

The CiteSeer citation network dataset from Ref. [1].
Nodes represent documents and edges represent citation links.
The dataset is designed for the node classification task.
The dataset is designed for the node classification task.
The task is to predict the category of certain paper.
The dataset is retrieved from Ref. [2].

# References

[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)
[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)

[2]: [Planetoid](https://github.com/kimiyoung/planetoid)
"""
struct CiteSeer <: AbstractDataset
struct CiteSeer <: GraphDataset
metadata::Dict{String, Any}
graphs::Vector{Graph}
end
Expand All @@ -41,10 +41,6 @@ function CiteSeer(; dir=nothing, reverse_edges=true)
return CiteSeer(metadata, [g])
end

Base.length(d::CiteSeer) = length(d.graphs)
Base.getindex(d::CiteSeer, ::Colon) = d.graphs[1]
Base.getindex(d::CiteSeer, i) = d.graphs[i]


# DEPRECATED in v0.6.0
function Base.getproperty(::Type{CiteSeer}, s::Symbol)
Expand All @@ -67,5 +63,3 @@ function Base.getproperty(::Type{CiteSeer}, s::Symbol)
return getfield(CiteSeer, s)
end
end


18 changes: 6 additions & 12 deletions src/datasets/graphs/cora.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,17 @@ function __init__cora()
))
end


"""
Cora()

The Cora citation network dataset from Ref. [1].
Nodes represent documents and edges represent citation links.
Each node has a predefined feature with 1433 dimensions.
The dataset is designed for the node classification task.
Each node has a predefined feature with 1433 dimensions.
The dataset is designed for the node classification task.
The task is to predict the category of certain paper.
The dataset is retrieved from Ref. [2].

# Statistics
# Statistics

- Nodes: 2708
- Edges: 10556
Expand All @@ -39,17 +38,16 @@ The dataset is retrieved from Ref. [2].
- Val: 500
- Test: 1000

The split is the one used in the original paper [1] and
The split is the one used in the original paper [1] and
doesn't consider all nodes.


# References

[1]: [Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via Ranking](https://arxiv.org/abs/1707.03815)

[2]: [Planetoid](https://github.com/kimiyoung/planetoid
[2]: [Planetoid](https://github.com/kimiyoung/planetoid)
"""
struct Cora <: AbstractDataset
struct Cora <: GraphDataset
metadata::Dict{String, Any}
graphs::Vector{Graph}
end
Expand All @@ -59,10 +57,6 @@ function Cora(; dir=nothing, reverse_edges=true)
return Cora(metadata, [g])
end

Base.length(d::Cora) = length(d.graphs)
Base.getindex(d::Cora, ::Colon) = d.graphs[1]
Base.getindex(d::Cora, i) = getindex(d.graphs, i)


# DEPRECATED in v0.6.0
function Base.getproperty(::Type{Cora}, s::Symbol)
Expand Down
18 changes: 7 additions & 11 deletions src/datasets/graphs/karateclub.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ export KarateClub
"""
KarateClub()

The Zachary's karate club dataset originally appeared in Ref [1].
The Zachary's karate club dataset originally appeared in Ref. [1].

The network contains 34 nodes (members of the karate club).
The nodes are connected by 78 undirected and unweighted edges.
The edges indicate if the two members interacted outside the club.

The node labels indicate which community or the karate club the member belongs to.
The club based labels are as per the original dataset in Ref [1].
The community labels are obtained by modularity-based clustering following Ref [2].
The data is retrieved from Ref [3] and [4].
The club based labels are as per the original dataset in Ref. [1].
The community labels are obtained by modularity-based clustering following Ref. [2].
The data is retrieved from Ref. [3] and [4].
One node per unique label is used as training data.

# References
Expand All @@ -25,7 +25,7 @@ One node per unique label is used as training data.

[4]: [NetworkX Zachary's Karate Club Dataset](https://networkx.org/documentation/stable/_modules/networkx/generators/social.html#karate_club_graph)
"""
struct KarateClub <: AbstractDataset
struct KarateClub <: GraphDataset
metadata::Dict{String, Any}
graphs::Vector{Graph}
end
Expand Down Expand Up @@ -60,14 +60,10 @@ function KarateClub()
labels_comm = [
1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1,
0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0]
node_data = (; labels_clubs, labels_comm)

node_data = (; labels_clubs, labels_comm)
g = Graph(; num_nodes=34, edge_index=(src, target), node_data)

metadata = Dict{String, Any}()
return KarateClub(metadata, [g])
end

Base.length(d::KarateClub) = length(d.graphs)
Base.getindex(d::KarateClub, ::Colon) = d.graphs[1]
Base.getindex(d::KarateClub, i) = d.graphs[i]
44 changes: 20 additions & 24 deletions src/datasets/graphs/movielens.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,21 @@ end
"""
MovieLens(name; dir=nothing)

Datasets from the [MovieLens website](https://movielens.org) collected and maintained by [GroupLens](https://grouplens.org/datasets/movielens/).
The MovieLens datasets are presented in a Graph format.
Datasets from the [MovieLens website](https://movielens.org) collected and maintained by [GroupLens](https://grouplens.org/datasets/movielens/).
The MovieLens datasets are presented in a Graph format.
For license and usage resitrictions please refer to the Readme.md of the datasets.

There are 6 versions of movielens datasets currently supported: "100k", "1m", "10m", "20m", "25m", "latest-small".
There are 6 versions of movielens datasets currently supported: "100k", "1m", "10m", "20m", "25m", "latest-small".
The 100k and 1k datasets contain movie data and rating data along with demographic data.
Starting from the 10m dataset, Movielens datasets no longer contain the demographic data.
These datasets contain movie data, rating data, and tag information.
Starting from the 10m dataset, Movielens datasets no longer contain the demographic data.
These datasets contain movie data, rating data, and tag information.

The 20m and 25m datasets additionally contain [genome tag scores](http://files.grouplens.org/papers/tag_genome.pdf).
The 20m and 25m datasets additionally contain [genome tag scores](http://files.grouplens.org/papers/tag_genome.pdf).
Each movie in these datasets contains tag relevance scores for every tag.

Each dataset contains an heterogeneous graph, with two kinds of nodes,
`movie` and `user`. The rating is represented by an edge between them: `(user, rating, movie)`.
20m, 25m, and latest-small datasets also contain `tag` nodes and edges of type `(user, tag, movie)` and
Each dataset contains an heterogeneous graph, with two kinds of nodes,
`movie` and `user`. The rating is represented by an edge between them: `(user, rating, movie)`.
20m, 25m, and latest-small datasets also contain `tag` nodes and edges of type `(user, tag, movie)` and
optionally `(movie, score, tag)`.

# Examples
Expand Down Expand Up @@ -56,15 +56,15 @@ julia> g = data[:]
node_data => Dict{String, Dict} with 2 entries
edge_data => Dict{Tuple{String, String, String}, Dict} with 1 entry

# Acess the user information
### Acess the user information
julia> user_data = g.node_data["user"]
Dict{Symbol, AbstractVector} with 4 entries:
:age => [24, 53, 23, 24, 33, 42, 57, 36, 29, 53 … 61, 42, 24, 48, 38, 26, 32, 20, 48, 22]
:occupation => ["technician", "other", "writer", "technician", "other", "executive", "administrator", "administrator", "student",…
:zipcode => ["85711", "94043", "32067", "43537", "15213", "98101", "91344", "05201", "01002", "90703" … "22902", "66221", "3…
:gender => Bool[1, 0, 1, 1, 0, 1, 1, 1, 1, 1 … 1, 1, 1, 1, 0, 0, 1, 1, 0, 1]

# Access rating information
### Access rating information
julia> g.edge_data[("user", "rating", "movie")]
Dict{Symbol, Vector} with 2 entries:
:timestamp => [881250949, 891717742, 878887116, 880606923, 886397596, 884182806, 881171488, 891628467, 886324817, 883603013 … 8…
Expand All @@ -79,7 +79,7 @@ MovieLens 20m:
metadata => Dict{String, Any} with 4 entries
graphs => 1-element Vector{MLDatasets.HeteroGraph}

# There is only 1 graph in MovieLens dataset
### There is only 1 graph in MovieLens dataset
julia> g = data[1]
Heterogeneous Graph:
node_types => 3-element Vector{String}
Expand All @@ -90,20 +90,20 @@ Heterogeneous Graph:
node_data => Dict{String, Dict} with 0 entries
edge_data => Dict{Tuple{String, String, String}, Dict} with 3 entries

# Apart from user rating a movie, a user assigns tag to movies and there are genome-scores for movie-tag pairs
### Apart from user rating a movie, a user assigns tag to movies and there are genome-scores for movie-tag pairs
julia> g.edge_indices
Dict{Tuple{String, String, String}, Tuple{Vector{Int64}, Vector{Int64}}} with 3 entries:
("movie", "score", "tag") => ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 131170, 131170, 131170, 131170, 131170, 131170, 131170, 131170,…
("user", "tag", "movie") => ([18, 65, 65, 65, 65, 65, 65, 65, 65, 65 … 3489, 7045, 7045, 7164, 7164, 55999, 55999, 55999, 55…
("user", "rating", "movie") => ([1, 1, 1, 1, 1, 1, 1, 1, 1, 1 … 60816, 61160, 65682, 66762, 68319, 68954, 69526, 69644, 70286, …

# Access the rating
### Access the rating
julia> g.edge_data[("user", "rating", "movie")]
Dict{Symbol, Vector} with 2 entries:
:timestamp => [1112486027, 1112484676, 1112484819, 1112484727, 1112484580, 1094785740, 1094785734, 1112485573, 1112484940, 111248…
:rating => Float16[3.5, 3.5, 3.5, 3.5, 3.5, 3.5, 4.0, 4.0, 4.0, 4.0 … 4.5, 4.0, 4.5, 4.5, 4.5, 4.5, 4.5, 3.0, 5.0, 2.5]

# Access the movie-tag scores
### Access the movie-tag scores
score = g.edge_data[("movie", "score", "tag")][:score]
23419536-element Vector{Float64}:
0.025000000000000022
Expand All @@ -112,17 +112,17 @@ score = g.edge_data[("movie", "score", "tag")][:score]
```

## References
# References

[1] [GroupLens Website](https://grouplens.org/datasets/movielens/)

[2] [TensorFlow MovieLens Implementation](https://www.tensorflow.org/datasets/catalog/movielens)
[2] [TensorFlow MovieLens Implementation](https://www.tensorflow.org/datasets/catalog/movielens)

[3] Jesse Vig, Shilad Sen, and John Riedl. 2012. The Tag Genome: Encoding Community Knowledge to Support Novel Interaction. ACM Trans. Interact. Intell. Syst. 2, 3, Article 13 (September 2012), 44 pages. https://doi.org/10.1145/2362394.2362395.
[3] Jesse Vig, Shilad Sen, and John Riedl. 2012. The Tag Genome: Encoding Community Knowledge to Support Novel Interaction. ACM Trans. Interact. Intell. Syst. 2, 3, Article 13 (September 2012), 44 pages. https://doi.org/10.1145/2362394.2362395.

[4] F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Trans. Interact. Intell. Syst. 5, 4, Article 19 (January 2016), 19 pages. https://doi.org/10.1145/2827872
[4] F. Maxwell Harper and Joseph A. Konstan. 2015. The MovieLens Datasets: History and Context. ACM Trans. Interact. Intell. Syst. 5, 4, Article 19 (January 2016), 19 pages. https://doi.org/10.1145/2827872
"""
struct MovieLens
struct MovieLens <: GraphDataset
name::String
metadata::Dict{String, Any}
graphs::Vector{HeteroGraph}
Expand Down Expand Up @@ -526,7 +526,3 @@ function Base.show(io::IO, ::MIME"text/plain", d::MovieLens)
end
end
end

Base.length(data::MovieLens) = length(data.graphs)
Base.getindex(data::MovieLens, ::Colon) = length(data.graphs) == 1 ? data.graphs[1] : data.graphs
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
Loading