diff --git a/Project.toml b/Project.toml index 66fc0e3..a2a12ae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBalancing" uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" authors = ["Essam Wisam ", "Anthony Blaom and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" diff --git a/src/balanced_bagging.jl b/src/balanced_bagging.jl index 8b67b6e..eaa5af6 100644 --- a/src/balanced_bagging.jl +++ b/src/balanced_bagging.jl @@ -144,17 +144,16 @@ MMI.metadata_pkg( package_uuid = "45f359ea-796d-4f51-95a5-deb1a414c586", package_url = "https://github.com/JuliaAI/MLJBalancing.jl", is_pure_julia = true, + is_wrapper = true, ) MMI.metadata_model( BalancedBaggingClassifier, - input_scitype = Union{Union{Infinite,Finite}}, - output_scitype = Union{Union{Infinite,Finite}}, - target_scitype = AbstractVector, + target_scitype = AbstractVector{<:Finite}, load_path = "MLJBalancing." * string(BalancedBaggingClassifier), ) -MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{P}}) where {P} = +MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{<:Any,<:Any,P}}) where P = MLJBase.prepend(:model, iteration_parameter(P)) for trait in [ :input_scitype, @@ -173,7 +172,8 @@ for trait in [ :prediction_type, ] quote - MMI.$trait(::Type{<:BalancedBaggingClassifier{P}}) where {P} = MMI.$trait(P) + MMI.$trait(::Type{<:BalancedBaggingClassifier{<:Any,<:Any, P}}) where {P} = + MMI.$trait(P) end |> eval end diff --git a/test/balanced_bagging.jl b/test/balanced_bagging.jl index 62eb823..e5178a4 100644 --- a/test/balanced_bagging.jl +++ b/test/balanced_bagging.jl @@ -120,6 +120,10 @@ end mach = machine(modelo, X, y) fit!(mach) @test report(mach) == (chosen_T = 9,) + + ## traits + @test fit_data_scitype(modelo) == fit_data_scitype(model) + @test is_wrapper(modelo) end @@ -137,4 +141,4 @@ end @test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin BalancedBaggingClassifier(model; model=model, T=T, rng=R) end -end \ No newline at end of file +end diff --git a/test/balanced_model.jl b/test/balanced_model.jl index 2a77318..7ffbf3e 100644 --- a/test/balanced_model.jl +++ b/test/balanced_model.jl @@ -52,7 +52,12 @@ fit!(mach) y_pred2 = MLJBase.predict(mach, X_test) - @test y_pred ≈ y_pred2 + @test y_pred ≈ y_pred2 + + # traits: + @test fit_data_scitype(balanced_model) == fit_data_scitype(model_prob) + @test is_wrapper(balanced_model) + ### 2. Make a pipeline of the three balancers and a deterministic model ## ordinary way @@ -99,4 +104,4 @@ end @test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin BalancedModel(model; model=model, balancer1=balancer1) end -end \ No newline at end of file +end