Skip to content

Commit

Permalink
Merge pull request #10 from invenia/ox/use
Browse files Browse the repository at this point in the history
Updates because I tried to use this, also adding submodels
  • Loading branch information
oxinabox authored Apr 9, 2020
2 parents afcd659 + 1169b24 commit 5506930
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 38 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Models"
uuid = "e6388cff-ecff-480c-9b53-83211bf7812a"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.0"
version = "0.1.1"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -10,7 +10,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Distributions = "0.16, 0.22"
Distributions = "0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23"
NamedDims = "0.1, 0.2"
StatsBase = "0.32"
julia = "1"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Model
```@docs
fit
predict
submodels
estimate_type
output_type
```
Expand Down
26 changes: 24 additions & 2 deletions src/Models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Models
import StatsBase: fit, predict

export Model, Template
export fit, predict, estimate_type, output_type
export fit, predict, submodels, estimate_type, output_type
export EstimateTrait, PointEstimate, DistributionEstimate
export OutputTrait, SingleOutput, MultiOutput

Expand All @@ -28,10 +28,11 @@ Defined as well are the traits:
abstract type Model end

"""
fit(::Template, output, input) -> Model
fit(::Template, output, input, [weights]) -> Model
Fit the [`Template`](@ref) to the `output` and `input` data and return a trained
[`Model`](@ref).
Convention is that `weights` defaults to `StatsBase.uweights(Float32, size(outputs, 2))`
"""
function fit end

Expand All @@ -44,6 +45,27 @@ Returns a predictive distribution or point estimates depending on the [`Model`](
"""
function predict end

"""
submodels(::Union{Template, Model})
Return all submodels within a multistage model/template.
Submodels are models within a model that have their own inputs (which may or may not be
combined with outputs of _earlier_ submodels, before actually being passed as input to the submodel).
Such multistage models take a tuple of inputs (which may be nested if the submodel itself
has submodels).
The order of submodels returned by `submodels` is as per the order of the inputs in the
tuple.
For single-stage models, (i.e. ones that simply take a matrix as input), this returns an
empty tuple.
Wrapper models which do not expose their inner models to seperate inputs, including ones
that only wrap a single model, should **not** define `submodels` as they are
(from the outside API perspective) single-stage models.
"""
submodels(::Union{Template, Model}) = ()


include("traits.jl")
include("test_utils.jl")

Expand Down
78 changes: 44 additions & 34 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ mutable struct FakeModel{E<:EstimateTrait, O<:OutputTrait} <: Model
num_variates::Int
end

estimate_type(::FakeModel{E, O}) where {E, O} = E
output_type(::FakeModel{E, O}) where {E, O} = O
Models.estimate_type(::Type{<:FakeModel{E, O}}) where {E, O} = E
Models.output_type(::Type{<:FakeModel{E, O}}) where {E, O} = O

estimate_type(::FakeTemplate{E, O}) where {E, O} = E
output_type(::FakeTemplate{E, O}) where {E, O} = O
Models.estimate_type(::Type{<:FakeTemplate{E, O}}) where {E, O} = E
Models.output_type(::Type{<:FakeTemplate{E, O}}) where {E, O} = O

function StatsBase.fit(
template::FakeTemplate{E, O},
Expand All @@ -122,94 +122,104 @@ end
StatsBase.predict(m::FakeModel, inputs) = m.predictor(m.num_variates, inputs)

"""
test_interface(temp::Template; inputs=rand(5, 5), outputs=rand(5, 5))
test_interface(template::Template; inputs=rand(5, 5), outputs=rand(5, 5))
Test that subtypes of [`Template`](@ref) and [`Model`](@ref) implement the expected API.
Can be used as an initial test to verify the API has been correctly implemented.
Returns the predictions of the `Model`.
"""
function test_interface(temp::Template; kwargs...)
return test_interface(temp, estimate_type(temp), output_type(temp); kwargs...)
function test_interface(template::Template; kwargs...)
@testset "Models API Interface Test: $(nameof(typeof(template)))" begin
return test_interface(template, estimate_type(template), output_type(template); kwargs...)
end
end

function test_interface(
temp::Template, ::Type{PointEstimate}, ::Type{SingleOutput};
template::Template, ::Type{PointEstimate}, ::Type{SingleOutput};
inputs=rand(5, 5), outputs=rand(1, 5),
)
predictions = test_common(temp, inputs, outputs)
predictions = test_common(template, inputs, outputs)

@test predictions isa NamedDimsArray{(:variates, :observations), <:Real, 2}
@test size(predictions) == size(outputs)
@test size(predictions, 1) == 1
end

function test_interface(
temp::Template, ::Type{PointEstimate}, ::Type{MultiOutput};
template::Template, ::Type{PointEstimate}, ::Type{MultiOutput};
inputs=rand(5, 5), outputs=rand(2, 5),
)
predictions = test_common(temp, inputs, outputs)
predictions = test_common(template, inputs, outputs)
@test predictions isa NamedDimsArray{(:variates, :observations), <:Real, 2}
@test size(predictions) == size(outputs)
end

function test_interface(
temp::Template, ::Type{DistributionEstimate}, ::Type{SingleOutput};
template::Template, ::Type{DistributionEstimate}, ::Type{SingleOutput};
inputs=rand(5, 5), outputs=rand(1, 5),
)
predictions = test_common(temp, inputs, outputs)
predictions = test_common(template, inputs, outputs)
@test predictions isa Vector{<:Normal{<:Real}}
@test length(predictions) == size(outputs, 2)
@test all(length.(predictions) .== size(outputs, 1))
end

function test_interface(
temp::Template, ::Type{DistributionEstimate}, ::Type{MultiOutput};
template::Template, ::Type{DistributionEstimate}, ::Type{MultiOutput};
inputs=rand(5, 5), outputs=rand(3, 5)
)
predictions = test_common(temp, inputs, outputs)
predictions = test_common(template, inputs, outputs)
@test predictions isa Vector{<:MultivariateNormal{<:Real}}
@test length(predictions) == size(outputs, 2)
@test all(length.(predictions) .== size(outputs, 1))
end

function test_common(temp, inputs, outputs)
function test_names(template, model)
template_type_name = string(nameof(typeof(template)))
template_base_name_match = match(r"(.*)Template", template_type_name)
@test template_base_name_match !== nothing # must have Template suffix

model = fit(temp, outputs, inputs)
model_type_name = string(nameof(typeof(model)))
model_base_name_match = match(r"(.*)Model", model_type_name)
@test model_base_name_match !== nothing # must have Model suffix

@test temp isa Template
@test model isa Model
# base_name must agreee
@test model_base_name_match[1] == template_base_name_match[1]
end

@testset "type names" begin
template_type_name = string(nameof(typeof(temp)))
template_base_name_match = match(r"(.*)Template", template_type_name)
@test template_base_name_match !== nothing # must have Template suffix
function test_common(template, inputs, outputs)

model = fit(template, outputs, inputs)

model_type_name = string(nameof(typeof(model)))
model_base_name_match = match(r"(.*)Model", model_type_name)
@test model_base_name_match !== nothing # must have Model suffix
@test template isa Template
@test model isa Model

# base_name must agreee
@test model_base_name_match[1] == template_base_name_match[1]
@testset "type names" begin
test_names(template, model)
end

@testset "test fit/predict errors" begin
@test_throws MethodError predict(temp, inputs)
@test_throws MethodError fit(model, outputs, inputs)
@test_throws MethodError predict(template, inputs) # can only predict on a model
@test_throws MethodError fit(model, outputs, inputs) # can only fit a template
end

@testset "test weights can also be passed" begin
weights = uweights(Float32, size(outputs, 2))
model_weights = fit(temp, outputs, inputs, weights)
model_weights = fit(template, outputs, inputs, weights)
end

@testset "traits" begin
@test estimate_type(temp) == estimate_type(model)
@test output_type(temp) == output_type(model)
@test estimate_type(template) == estimate_type(model)
@test output_type(template) == output_type(model)
end

predictions = predict(model, inputs)
@testset "submodels" begin
@test length(submodels(template)) == length(submodels(model))
foreach(test_names, submodels(template), submodels(model))
end

predictions = predict(model, inputs)
return predictions
end

Expand Down

2 comments on commit 5506930

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/12633

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.1 -m "<description of version>" 5506930d9a587d1b2ab04553fa8443af4387a9e6
git push origin v0.1.1

Please sign in to comment.