-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train test and validation for graph datasets. #168
Changes from all commits
3b3b2fa
362c986
e12c43f
ca318db
9361fa2
7387071
f0fba57
b63ddfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,7 +4,7 @@ | |||||
Super-type from which all datasets in MLDatasets.jl inherit. | ||||||
|
||||||
Implements the following functionality: | ||||||
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]` | ||||||
- `getobs(d)` and `getobs(d, i)` falling back to `d[:]` and `d[i]` | ||||||
- Pretty printing. | ||||||
""" | ||||||
abstract type AbstractDataset <: AbstractDataContainer end | ||||||
|
@@ -19,9 +19,9 @@ end | |||||
|
||||||
function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset | ||||||
recur_io = IOContext(io, :compact => false) | ||||||
|
||||||
print(io, "dataset $(D.name.name):") # if the type is parameterized don't print the parameters | ||||||
|
||||||
for f in fieldnames(D) | ||||||
if !startswith(string(f), "_") | ||||||
fstring = leftalign(string(f), 10) | ||||||
|
@@ -34,7 +34,7 @@ function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset | |||||
end | ||||||
|
||||||
function leftalign(s::AbstractString, n::Int) | ||||||
m = length(s) | ||||||
m = length(s) | ||||||
if m > n | ||||||
return s[1:n] | ||||||
else | ||||||
|
@@ -53,37 +53,35 @@ _summary(x::BitVector) = "$(count(x))-trues BitVector" | |||||
""" | ||||||
SupervisedDataset <: AbstractDataset | ||||||
|
||||||
An abstract dataset type for supervised learning tasks. | ||||||
An abstract dataset type for supervised learning tasks. | ||||||
Concrete dataset types inheriting from it must provide | ||||||
a `features` and a `targets` fields. | ||||||
""" | ||||||
abstract type SupervisedDataset <: AbstractDataset end | ||||||
|
||||||
|
||||||
Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) : | ||||||
Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) : | ||||||
numobs((d.features, d.targets)) | ||||||
|
||||||
|
||||||
# We return named tuples | ||||||
Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ? | ||||||
(features = d.features, targets=d.targets) : | ||||||
getobs((; d.features, d.targets)) | ||||||
|
||||||
Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ? | ||||||
Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ? | ||||||
(features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) : | ||||||
getobs((; d.features, d.targets), i) | ||||||
|
||||||
""" | ||||||
UnsupervisedDataset <: AbstractDataset | ||||||
|
||||||
An abstract dataset type for unsupervised or self-supervised learning tasks. | ||||||
An abstract dataset type for unsupervised or self-supervised learning tasks. | ||||||
Concrete dataset types inheriting from it must provide a `features` field. | ||||||
""" | ||||||
abstract type UnsupervisedDataset <: AbstractDataset end | ||||||
|
||||||
|
||||||
Base.length(d::UnsupervisedDataset) = numobs(d.features) | ||||||
|
||||||
Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features) | ||||||
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i) | ||||||
|
||||||
|
@@ -99,13 +97,13 @@ const ARGUMENTS_SUPERVISED_TABLE = """ | |||||
|
||||||
const FIELDS_SUPERVISED_TABLE = """ | ||||||
- `metadata`: A dictionary containing additional information on the dataset. | ||||||
- `features`: The data features. An array if `as_df=true`, otherwise a dataframe. | ||||||
- `features`: The data features. An array if `as_df=true`, otherwise a dataframe. | ||||||
- `targets`: The targets for supervised learning. An array if `as_df=true`, otherwise a dataframe. | ||||||
- `dataframe`: A dataframe containing both `features` and `targets`. It is `nothing` if `as_df=false`. | ||||||
""" | ||||||
|
||||||
const METHODS_SUPERVISED_TABLE = """ | ||||||
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. | ||||||
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. | ||||||
- `dataset[:]`: Return all observations as a named tuple of features and targets. | ||||||
- `length(dataset)`: Number of observations. | ||||||
""" | ||||||
|
@@ -119,12 +117,23 @@ const ARGUMENTS_SUPERVISED_ARRAY = """ | |||||
|
||||||
const FIELDS_SUPERVISED_ARRAY = """ | ||||||
- `metadata`: A dictionary containing additional information on the dataset. | ||||||
- `features`: An array storing the data features. | ||||||
- `features`: An array storing the data features. | ||||||
- `targets`: An array storing the targets for supervised learning. | ||||||
""" | ||||||
|
||||||
const METHODS_SUPERVISED_ARRAY = """ | ||||||
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. | ||||||
- `dataset[i]`: Return observation(s) `i` as a named tuple of features and targets. | ||||||
- `dataset[:]`: Return all observations as a named tuple of features and targets. | ||||||
- `length(dataset)`: Number of observations. | ||||||
""" | ||||||
|
||||||
""" | ||||||
GraphDataset <: AbstractDataset | ||||||
|
||||||
An abstract dataset type for graph learning tasks. | ||||||
""" | ||||||
abstract type GraphDataset <: AbstractDataset end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. better rename this as
Suggested change
especially if we go with #169 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, it is better to file this change as a separate PR. Generally a PR should target a single issue or a single feature addition |
||||||
|
||||||
Base.length(data::GraphDataset) = length(data.graphs) | ||||||
Base.getindex(data::GraphDataset, ::Colon) = length(data) == 1 ? data.graphs[1] : data.graphs | ||||||
Base.getindex(data::GraphDataset, i) = data.graphs[i] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be good to avoid polluting PRs with this formatting changes, there should be some option in your editor to avoid doing this