From b08b796d704f252172bb73bcfe009b03e39c12e4 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 9 Jan 2024 09:06:31 +1300 Subject: [PATCH 1/4] fix traits for BalancedBaggingClassifier --- src/balanced_bagging.jl | 12 ++++++------ test/balanced_bagging.jl | 6 +++++- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/balanced_bagging.jl b/src/balanced_bagging.jl index 8b67b6e..f0b4600 100644 --- a/src/balanced_bagging.jl +++ b/src/balanced_bagging.jl @@ -133,7 +133,7 @@ function MLJBase.prefit(composite_model::BalancedBaggingClassifier, verbosity, X machines = (machine(:model, Xsub, ysub) for (Xsub, ysub) in X_y_list_s) # Average the predictions from nodes all_preds = [MLJBase.predict(mach, Xs) for (mach, (X, _)) in zip(machines, X_y_list_s)] - yhat = mean(all_preds) +; yhat = mean(all_preds) return (; predict=yhat, report=(;chosen_T=node(()->T))) end @@ -144,17 +144,16 @@ MMI.metadata_pkg( package_uuid = "45f359ea-796d-4f51-95a5-deb1a414c586", package_url = "https://github.com/JuliaAI/MLJBalancing.jl", is_pure_julia = true, + is_wrapper = true, ) MMI.metadata_model( BalancedBaggingClassifier, - input_scitype = Union{Union{Infinite,Finite}}, - output_scitype = Union{Union{Infinite,Finite}}, - target_scitype = AbstractVector, + target_scitype = AbstractVector{<:Finite}, load_path = "MLJBalancing." * string(BalancedBaggingClassifier), ) -MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{P}}) where {P} = +MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{<:Any,<:Any,P}}) where P = MLJBase.prepend(:model, iteration_parameter(P)) for trait in [ :input_scitype, @@ -173,7 +172,8 @@ for trait in [ :prediction_type, ] quote - MMI.$trait(::Type{<:BalancedBaggingClassifier{P}}) where {P} = MMI.$trait(P) + MMI.$trait(::Type{<:BalancedBaggingClassifier{<:Any,<:Any, P}}) where {P} = + MMI.$trait(P) end |> eval end diff --git a/test/balanced_bagging.jl b/test/balanced_bagging.jl index 62eb823..e5178a4 100644 --- a/test/balanced_bagging.jl +++ b/test/balanced_bagging.jl @@ -120,6 +120,10 @@ end mach = machine(modelo, X, y) fit!(mach) @test report(mach) == (chosen_T = 9,) + + ## traits + @test fit_data_scitype(modelo) == fit_data_scitype(model) + @test is_wrapper(modelo) end @@ -137,4 +141,4 @@ end @test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin BalancedBaggingClassifier(model; model=model, T=T, rng=R) end -end \ No newline at end of file +end From 83f804195afb7506049237fe5a1937c6af525eee Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 9 Jan 2024 09:14:55 +1300 Subject: [PATCH 2/4] add test of traits for BalancedModel --- test/balanced_model.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/balanced_model.jl b/test/balanced_model.jl index 2a77318..7ffbf3e 100644 --- a/test/balanced_model.jl +++ b/test/balanced_model.jl @@ -52,7 +52,12 @@ fit!(mach) y_pred2 = MLJBase.predict(mach, X_test) - @test y_pred ≈ y_pred2 + @test y_pred ≈ y_pred2 + + # traits: + @test fit_data_scitype(balanced_model) == fit_data_scitype(model_prob) + @test is_wrapper(balanced_model) + ### 2. Make a pipeline of the three balancers and a deterministic model ## ordinary way @@ -99,4 +104,4 @@ end @test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin BalancedModel(model; model=model, balancer1=balancer1) end -end \ No newline at end of file +end From 5dcd1aac1fec6744996bd06ff3bc82840c2a0bb9 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 9 Jan 2024 09:21:54 +1300 Subject: [PATCH 3/4] remove stray semicolon --- src/balanced_bagging.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/balanced_bagging.jl b/src/balanced_bagging.jl index f0b4600..eaa5af6 100644 --- a/src/balanced_bagging.jl +++ b/src/balanced_bagging.jl @@ -133,7 +133,7 @@ function MLJBase.prefit(composite_model::BalancedBaggingClassifier, verbosity, X machines = (machine(:model, Xsub, ysub) for (Xsub, ysub) in X_y_list_s) # Average the predictions from nodes all_preds = [MLJBase.predict(mach, Xs) for (mach, (X, _)) in zip(machines, X_y_list_s)] -; yhat = mean(all_preds) + yhat = mean(all_preds) return (; predict=yhat, report=(;chosen_T=node(()->T))) end From a0d7ac21562767761519abcc698168dfdb2c2e6c Mon Sep 17 00:00:00 2001 From: Essam Date: Mon, 8 Jan 2024 15:00:51 -0600 Subject: [PATCH 4/4] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Bump=20Project.toml?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 66fc0e3..a2a12ae 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBalancing" uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" authors = ["Essam Wisam ", "Anthony Blaom and contributors"] -version = "0.1.3" +version = "0.1.4" [deps] MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"