Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultitargetNeuralNetworkRegressor doc example doesn't work as intended #268

Closed
EssamWisam opened this issue Aug 4, 2024 · 1 comment · Fixed by #270
Closed

MultitargetNeuralNetworkRegressor doc example doesn't work as intended #268

EssamWisam opened this issue Aug 4, 2024 · 1 comment · Fixed by #270

Comments

@EssamWisam
Copy link
Collaborator

I was trying to run the following example available in the documentation

import Pkg; Pkg.activate("./scratchpad")

using MLJ
import MLJFlux
using Flux
import Optimisers

X, y = make_regression(100, 9; n_targets = 2) # both tables
schema(y)
schema(X)

(X, Xtest), (y, ytest) = partition((X, y), 0.7, multi=true);

builder = MLJFlux.@builder begin
    init=Flux.glorot_uniform(rng)
    Chain(
        Dense(n_in, 64, relu, init=init),
        Dense(64, 32, relu, init=init),
        Dense(32, n_out, init=init),
    )
end

MultitargetNeuralNetworkRegressor = @load MultitargetNeuralNetworkRegressor pkg = MLJFlux
model = MultitargetNeuralNetworkRegressor(builder=builder, rng=123, epochs=20)

pipe = Standardizer |> TransformedTargetModel(model, transformer=Standardizer)

mach = machine(pipe, X, y)
fit!(mach, verbosity=2)

# first element initial loss, 2:end per epoch training losses
report(mach).transformed_target_model_deterministic.model.training_losses

# custom MLJ loss:
multi_loss(yhat, y) = l2(MLJ.matrix(yhat), MLJ.matrix(y))

# CV estimate, based on `(X, y)`:
evaluate!(mach, resampling=CV(nfolds=5), measure=multi_loss)

# loss for `(Xtest, test)`:
fit!(mach) # trains on all data `(X, y)`
yhat = predict(mach, Xtest)
multi_loss(yhat, ytest)

Error log is so long but here is a snippet:

ERROR: MethodError: no method matching abs(::Vector{Float64})

Closest candidates are:
  abs(::Missing)
   @ Base missing.jl:101
  abs(::Bool)
   @ Base bool.jl:153
  abs(::Pkg.Resolve.VersionWeight)
   @ Pkg ~/.julia/juliaup/julia-1.10.4+0.aarch64.apple.darwin14/share/julia/stdlib/v1.10/Pkg/src/Resolve/versionweights.jl:32
  ...

Stacktrace:
  [1] pth_power_of_absolute_difference(yhat::Vector{Float64}, y::Vector{Float64}, p::Int64)
    @ StatisticalMeasures.Functions ~/.julia/packages/StatisticalMeasures/hPDX2/src/functions.jl:12
  [2] (::StatisticalMeasures.LPLossOnScalars{Int64})(yhat::Vector{Float64}, y::Vector{Float64})
    @ StatisticalMeasures ~/.julia/packages/StatisticalMeasuresBase/7euvM/src/tools_for_implementers.jl:353
  [3] call(::StatisticalMeasuresBase.SupportsMissingsMeasure{…}, ::Vector{…}, ::Vector{…})
    @ StatisticalMeasuresBase ~/.julia/packages/StatisticalMeasuresBase/7euvM/src/supports_missings_measure.jl:21
  [4] (::StatisticalMeasuresBase.SupportsMissingsMeasure{…})(arg::Vector{…}, args::Vector{…})
    @ StatisticalMeasuresBase ~/.julia/packages/StatisticalMeasuresBase/7euvM/src/wrappers.jl:26
  [5] (::StatisticalMeasuresBase.var"#28#29"{StatisticalMeasuresBase.Multimeasure{…}, Nothing})(::Tuple{Vector{…}, Vector{…}})
    @ StatisticalMeasuresBase ~/.julia/packages/StatisticalMeasuresBase/7euvM/src/multimeasure.jl:158
  [6] _broadcast_getindex_evalf
    @ ./broadcast.jl:709 [inlined]
  [7] _broadcast_getindex
    @ ./broadcast.jl:682 [inlined]
  [8] _getindex

Which seems to stem from the following:
image

Possibly some missing dots?

@ablaom
Copy link
Collaborator

ablaom commented Aug 4, 2024

Thanks @EssamWisam for reporting this.

The docstring you quote is indeed incorrect. Since MLJBase 1.0, multitarget_l2 should be used here instead of l2. Also, there is no longer any need to convert y and yhat to matrices, as multi target measures in StatisticalMeasures support most tables.

A fix to the docstring is cross-referenced above if you wish to see the corrected example now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants