Skip to content

Commit

Permalink
Merge pull request #26 from JuliaAI/dev
Browse files Browse the repository at this point in the history
For a 0.1.4 Release
  • Loading branch information
EssamWisam authored Jan 8, 2024
2 parents 7b6f43f + a0d7ac2 commit 96a09a5
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJBalancing"
uuid = "45f359ea-796d-4f51-95a5-deb1a414c586"
authors = ["Essam Wisam <[email protected]>", "Anthony Blaom <[email protected]> and contributors"]
version = "0.1.3"
version = "0.1.4"

[deps]
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
Expand Down
10 changes: 5 additions & 5 deletions src/balanced_bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,17 +144,16 @@ MMI.metadata_pkg(
package_uuid = "45f359ea-796d-4f51-95a5-deb1a414c586",
package_url = "https://github.com/JuliaAI/MLJBalancing.jl",
is_pure_julia = true,
is_wrapper = true,
)

MMI.metadata_model(
BalancedBaggingClassifier,
input_scitype = Union{Union{Infinite,Finite}},
output_scitype = Union{Union{Infinite,Finite}},
target_scitype = AbstractVector,
target_scitype = AbstractVector{<:Finite},
load_path = "MLJBalancing." * string(BalancedBaggingClassifier),
)

MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{P}}) where {P} =
MMI.iteration_parameter(::Type{<:BalancedBaggingClassifier{<:Any,<:Any,P}}) where P =
MLJBase.prepend(:model, iteration_parameter(P))
for trait in [
:input_scitype,
Expand All @@ -173,7 +172,8 @@ for trait in [
:prediction_type,
]
quote
MMI.$trait(::Type{<:BalancedBaggingClassifier{P}}) where {P} = MMI.$trait(P)
MMI.$trait(::Type{<:BalancedBaggingClassifier{<:Any,<:Any, P}}) where {P} =
MMI.$trait(P)
end |> eval
end

Expand Down
6 changes: 5 additions & 1 deletion test/balanced_bagging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ end
mach = machine(modelo, X, y)
fit!(mach)
@test report(mach) == (chosen_T = 9,)

## traits
@test fit_data_scitype(modelo) == fit_data_scitype(model)
@test is_wrapper(modelo)
end


Expand All @@ -137,4 +141,4 @@ end
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
BalancedBaggingClassifier(model; model=model, T=T, rng=R)
end
end
end
9 changes: 7 additions & 2 deletions test/balanced_model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@
fit!(mach)
y_pred2 = MLJBase.predict(mach, X_test)

@test y_pred y_pred2
@test y_pred y_pred2

# traits:
@test fit_data_scitype(balanced_model) == fit_data_scitype(model_prob)
@test is_wrapper(balanced_model)


### 2. Make a pipeline of the three balancers and a deterministic model
## ordinary way
Expand Down Expand Up @@ -99,4 +104,4 @@ end
@test_logs (:warn, MLJBalancing.WRN_MODEL_GIVEN) begin
BalancedModel(model; model=model, balancer1=balancer1)
end
end
end

0 comments on commit 96a09a5

Please sign in to comment.