diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 43168ae..d446c5d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,11 +19,8 @@ jobs: fail-fast: false matrix: version: - - '1.0' - '1.8' - - 'nightly' - os: - - ubuntu-latest + os: [ubuntu-latest, windows-latest, macOS-latest] arch: - x64 steps: diff --git a/.gitignore b/.gitignore index b067edd..11d80c3 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /Manifest.toml +.CondaPkg diff --git a/Project.toml b/Project.toml index b8df21a..529b9cb 100644 --- a/Project.toml +++ b/Project.toml @@ -10,17 +10,16 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -julia = "1" MLJBase = "0.21" -MLJModelInterface = "1.9" OrderedCollections = "1.6" +julia = "1.6" [extras] DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68" -MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" +MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Imbalance", "DataFrames", "MLJLIBSVMInterface", "MLJLinearModels"] +test = ["Test", "Imbalance", "DataFrames", "MLJLinearModels", "MLJModels"] diff --git a/README.md b/README.md index 5d94990..1db6fd2 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,56 @@ # MLJBalancing -A package with exported learning networks that combine resampling methods from Imbalance.jl and classification models from MLJ +A package providing composite models wrapping class imbalance algorithms from [Imbalance.jl](https://github.com/JuliaAI/Imbalance.jl). + +## ⏬ Instalattion +```julia +import Pkg; +Pkg.add("MLJBalancing") +``` + +## 🚅 Sequential Resampling + +This package allows chaining of resampling methods from Imbalance.jl with classification models from MLJ. Simply construct a `BalancedModel` object while specifying the model (classifier) and an arbitrary number of resamplers (also called *balancers* - typically oversamplers and/or under samplers). + +### 📖 Example + +#### Construct the resamplers and the model +```julia +SMOTENC = @load SMOTENC pkg=Imbalance verbosity=0 +TomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0 + +oversampler = SMOTENC(k=5, ratios=1.0, rng=42) +undersampler = TomekUndersampler(min_ratios=0.5, rng=42) + +logistic_model = LogisticClassifier() +``` + +#### Wrap them all in BalancedModel +```julia +balanced_model = BalancedModel(model=logistic_model, balancer1=oversampler, balancer2=undersampler) +``` +Here training data will be passed to `balancer1` then `balancer2`, whose output is used to train the classifier `model`. In prediction, the resamplers `balancer1` and `blancer2` are bypassed. + +In general, there can be any number of balancers, and the user can give the balancers arbitrary names. + +#### At this point, they behave like one single model +You can fit, predict, cross-validate and finetune it like any other MLJ model. Here is an example for finetuning +```julia +r1 = range(balanced_model, :(balancer1.k), lower=3, upper=10) +r2 = range(balanced_model, :(balancer2.min_ratios), lower=0.1, upper=0.9) + +tuned_balanced_model = TunedModel( + model=balanced_model, + tuning=Grid(goal=4), + resampling=CV(nfolds=4), + range=[r1, r2], + measure=cross_entropy +); + +mach = machine(tuned_balanced_model, X, y); +fit!(mach, verbosity=0); +fitted_params(mach).best_model +``` + +## 🚆🚆 Parallel Resampling with EasyEnsemble + +Coming soon... \ No newline at end of file diff --git a/examples/BalancedModel.ipynb b/examples/BalancedModel.ipynb new file mode 100644 index 0000000..7713429 --- /dev/null +++ b/examples/BalancedModel.ipynb @@ -0,0 +1,442 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/Documents/GitHub/MLJBalancing`\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The project dependencies or compat requirements have changed since the manifest was last resolved.\n", + "│ It is recommended to `Pkg.resolve()` or consider `Pkg.update()` if necessary.\n", + "└ @ Pkg.API /Users/julia/.julia/scratchspaces/a66863c6-20e8-4ff4-8a62-49f30b1f605e/agent-cache/default-macmini-aarch64-4.0/build/default-macmini-aarch64-4-0/julialang/julia-release-1-dot-8/usr/share/julia/stdlib/v1.8/Pkg/src/API.jl:1535\n" + ] + } + ], + "source": [ + "ENV[\"JULIA_PKG_SERVER\"] = \"\"\n", + "using Pkg\n", + "Pkg.activate(@__DIR__)\n", + "Pkg.instantiate()\n", + "\n", + "using MLJ\n", + "using MLJBalancing: BalancedModel\n", + "using Imbalance\n", + "using Random\n", + "using DataFrames" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load Some Data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 200 (39.5%) \n", + "1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 294 (58.1%) \n", + "2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 506 (100.0%) \n" + ] + } + ], + "source": [ + "X, y = Imbalance.generate_imbalanced_data(1000, 5; probs=[0.2, 0.3, 0.5])\n", + "X = DataFrame(X)\n", + "(X_train, X_test), (y_train, y_test) = partition((X, y), 0.8, rng=123, multi=true)\n", + "Imbalance.checkbalance(y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load Some Balancers" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ROSE(\n", + " s = 1.0, \n", + " ratios = 1.3, \n", + " rng = 42, \n", + " try_perserve_type = true)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "balancer1 = Imbalance.MLJ.RandomOversampler(ratios=1.0, rng=42)\n", + "balancer2 = Imbalance.MLJ.SMOTENC(k=10, ratios=1.2, rng=42)\n", + "balancer3 = Imbalance.MLJ.ROSE(ratios=1.3, rng=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Load a Classification Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LogisticClassifier(\n", + " lambda = 2.220446049250313e-16, \n", + " gamma = 0.0, \n", + " penalty = :l2, \n", + " fit_intercept = true, \n", + " penalize_intercept = false, \n", + " scale_penalty_with_samples = true, \n", + " solver = nothing)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0\n", + "model_prob = LogisticClassifier()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Wrap the Balancers and the Classification Model Together" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BalancedModelProbabilistic(\n", + " model = LogisticClassifier(\n", + " lambda = 2.220446049250313e-16, \n", + " gamma = 0.0, \n", + " penalty = :l2, \n", + " fit_intercept = true, \n", + " penalize_intercept = false, \n", + " scale_penalty_with_samples = true, \n", + " solver = nothing), \n", + " balancer1 = RandomOversampler(\n", + " ratios = 1.0, \n", + " rng = 42, \n", + " try_perserve_type = true), \n", + " balancer2 = SMOTENC(\n", + " k = 10, \n", + " ratios = 1.2, \n", + " rng = 42, \n", + " try_perserve_type = true), \n", + " balancer3 = ROSE(\n", + " s = 1.0, \n", + " ratios = 1.3, \n", + " rng = 42, \n", + " try_perserve_type = true))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "balanced_model = BalancedModel(model=model_prob, \n", + " balancer1=balancer1, \n", + " balancer2=balancer2, \n", + " balancer3=balancer3)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Now they behave as a single model!" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(BalancedModelProbabilistic(model = LogisticClassifier(lambda = 2.220446049250313e-16, …), …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(ROSE(s = 1.0, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(SMOTENC(k = 10, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(RandomOversampler(ratios = 1.0, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(:model, …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:02\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 2\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:03\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 2\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 2\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 2\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 2\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n", + "│ optim_options: Optim.Options{Float64, Nothing}\n", + "│ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n", + "└ @ MLJLinearModels /Users/essam/.julia/packages/MLJLinearModels/zSQnL/src/mlj/interface.jl:72\n" + ] + }, + { + "data": { + "text/plain": [ + "200-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, Int64, UInt32, Float64}:\n", + " UnivariateFinite{Multiclass{3}}(0=>0.348, 1=>0.343, 2=>0.309)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.282, 1=>0.306, 2=>0.412)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.306, 1=>0.319, 2=>0.374)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.347, 1=>0.334, 2=>0.319)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.319, 1=>0.333, 2=>0.348)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.372, 1=>0.337, 2=>0.29)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.36, 1=>0.337, 2=>0.303)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.289, 1=>0.293, 2=>0.418)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.312, 1=>0.308, 2=>0.38)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.345, 1=>0.349, 2=>0.306)\n", + " ⋮\n", + " UnivariateFinite{Multiclass{3}}(0=>0.371, 1=>0.374, 2=>0.255)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.365, 1=>0.376, 2=>0.259)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.355, 1=>0.361, 2=>0.284)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.332, 1=>0.35, 2=>0.318)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.376, 1=>0.354, 2=>0.27)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.326, 1=>0.342, 2=>0.332)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.273, 1=>0.308, 2=>0.419)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.283, 1=>0.295, 2=>0.421)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.317, 1=>0.341, 2=>0.342)" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "mach = machine(balanced_model, X_train, y_train)\n", + "fit!(mach)\n", + "y_pred = predict(mach, X_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### You can even tune it if you wish" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "r1 = range(balanced_model, :(balancer1.ratios), lower=0.8, upper=1.0)\n", + "r2 = range(balanced_model, :(balancer2.k), lower=3, upper=10)\n", + "r3 = range(balanced_model, :(balancer3.s), lower=0.0, upper=0.5)\n", + "\n", + "tuned_balanced_model = TunedModel(model=balanced_model,\n", + "\t\t\t\t\t\t\t\t\t tuning=Grid(goal=4),\n", + "\t\t\t\t\t\t\t\t\t resampling=CV(nfolds=4),\n", + "\t\t\t\t\t\t\t\t\t range=[r1, r2, r3],\n", + "\t\t\t\t\t\t\t\t\t measure=cross_entropy);\n", + "\n", + "mach = machine(tuned_balanced_model, X, y);\n", + "fit!(mach, verbosity=0);" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BalancedModelProbabilistic(\n", + " model = LogisticClassifier(\n", + " lambda = 2.220446049250313e-16, \n", + " gamma = 0.0, \n", + " penalty = :l2, \n", + " fit_intercept = true, \n", + " penalize_intercept = false, \n", + " scale_penalty_with_samples = true, \n", + " solver = nothing), \n", + " balancer1 = RandomOversampler(\n", + " ratios = 0.8, \n", + " rng = 42, \n", + " try_perserve_type = true), \n", + " balancer2 = SMOTENC(\n", + " k = 10, \n", + " ratios = 1.2, \n", + " rng = 42, \n", + " try_perserve_type = true), \n", + " balancer3 = ROSE(\n", + " s = 0.5, \n", + " ratios = 1.3, \n", + " rng = 42, \n", + " try_perserve_type = true))" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fitted_params(mach).best_model" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.8.5", + "language": "julia", + "name": "julia-1.8" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.8.5" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/Manifest.toml b/examples/Manifest.toml new file mode 100644 index 0000000..b43e53e --- /dev/null +++ b/examples/Manifest.toml @@ -0,0 +1,994 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.8.5" +manifest_format = "2.0" +project_hash = "bb8a7bcab9f98b02a846587066cc275d77de9db7" + +[[deps.ARFFFiles]] +deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] +git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" +uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" +version = "1.4.1" + +[[deps.AbstractTrees]] +git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" +uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +version = "0.4.4" + +[[deps.Adapt]] +deps = ["LinearAlgebra", "Requires"] +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" +uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +version = "3.6.2" + +[[deps.ArgCheck]] +git-tree-sha1 = "a3a402a35a2f7e0b87828ccabbd5ebfbebe356b4" +uuid = "dce04be8-c92d-5529-be00-80e4d2c0e197" +version = "2.3.0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.ArrayInterface]] +deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" +uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +version = "7.4.11" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.BangBang]] +deps = ["Compat", "ConstructionBase", "InitialValues", "LinearAlgebra", "Requires", "Setfield", "Tables"] +git-tree-sha1 = "e28912ce94077686443433c2800104b061a827ed" +uuid = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" +version = "0.3.39" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.Baselet]] +git-tree-sha1 = "aebf55e6d7795e02ca500a689d326ac979aaf89e" +uuid = "9718e550-a3fa-408a-8086-8db961cd8217" +version = "0.1.1" + +[[deps.BitFlags]] +git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" +uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" +version = "0.1.7" + +[[deps.Calculus]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" +uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" +version = "0.5.1" + +[[deps.CategoricalArrays]] +deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] +git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" +uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" +version = "0.10.8" + +[[deps.CategoricalDistributions]] +deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes", "UnicodePlots"] +git-tree-sha1 = "ed760a4fde49997ff9360a780abe6e20175162aa" +uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" +version = "0.1.11" + +[[deps.ChainRulesCore]] +deps = ["Compat", "LinearAlgebra", "SparseArrays"] +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "1.16.0" + +[[deps.ChangesOfVariables]] +deps = ["InverseFunctions", "LinearAlgebra", "Test"] +git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" +uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" +version = "0.1.8" + +[[deps.CodecZlib]] +deps = ["TranscodingStreams", "Zlib_jll"] +git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" +uuid = "944b1d66-785c-5afd-91f1-9de20f533193" +version = "0.7.2" + +[[deps.ColorSchemes]] +deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] +git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" +uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +version = "3.23.0" + +[[deps.ColorTypes]] +deps = ["FixedPointNumbers", "Random"] +git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" +uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" +version = "0.11.4" + +[[deps.ColorVectorSpace]] +deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "Requires", "Statistics", "TensorCore"] +git-tree-sha1 = "a1f44953f2382ebb937d60dafbe2deea4bd23249" +uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" +version = "0.10.0" + +[[deps.Colors]] +deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] +git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" +uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" +version = "0.12.10" + +[[deps.Combinatorics]] +git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" +uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +version = "1.0.2" + +[[deps.CommonSubexpressions]] +deps = ["MacroTools", "Test"] +git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.3.0" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.9.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "1.0.1+0" + +[[deps.CompositionsBase]] +git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad" +uuid = "a33af91c-f02d-484b-be07-31d278c5ca2b" +version = "0.1.2" + +[[deps.ComputationalResources]] +git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" +uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" +version = "0.3.2" + +[[deps.ConcurrentUtilities]] +deps = ["Serialization", "Sockets"] +git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" +uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" +version = "2.2.1" + +[[deps.ConstructionBase]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "c53fc348ca4d40d7b371e71fd52251839080cbc9" +uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +version = "1.5.4" + +[[deps.Contour]] +git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" +uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" +version = "0.6.2" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataAPI]] +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" +uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" +version = "1.15.0" + +[[deps.DataFrames]] +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" +uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +version = "1.6.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.15" + +[[deps.DataValueInterfaces]] +git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" +uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" +version = "1.0.0" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.DefineSingletons]] +git-tree-sha1 = "0fba8b706d0178b4dc7fd44a96a92382c9065c2c" +uuid = "244e2a9f-e319-4986-a169-4d1fe445cd52" +version = "0.1.2" + +[[deps.DelimitedFiles]] +deps = ["Mmap"] +uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" + +[[deps.DensityInterface]] +deps = ["InverseFunctions", "Test"] +git-tree-sha1 = "80c3e8639e3353e5d2912fb3a1916b8455e2494b" +uuid = "b429d917-457f-4dbc-8f4c-0cc954292b1d" +version = "0.4.0" + +[[deps.DiffResults]] +deps = ["StaticArraysCore"] +git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.1.0" + +[[deps.DiffRules]] +deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.15.1" + +[[deps.Distances]] +deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" +uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +version = "0.10.9" + +[[deps.Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[deps.Distributions]] +deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" +uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" +version = "0.25.100" + +[[deps.DocStringExtensions]] +deps = ["LibGit2"] +git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" +uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +version = "0.9.3" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.DualNumbers]] +deps = ["Calculus", "NaNMath", "SpecialFunctions"] +git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" +uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" +version = "0.6.8" + +[[deps.EarlyStopping]] +deps = ["Dates", "Statistics"] +git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" +uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" +version = "0.3.0" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.9" + +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.20" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] +git-tree-sha1 = "a20eaa3ad64254c61eeb5f230d9306e937405434" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "1.6.1" + +[[deps.FiniteDiff]] +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] +git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" +uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" +version = "2.21.1" + +[[deps.FixedPointNumbers]] +deps = ["Statistics"] +git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" +uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" +version = "0.8.4" + +[[deps.ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.36" + +[[deps.Future]] +deps = ["Random"] +uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" + +[[deps.HTTP]] +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "5eab648309e2e060198b45820af1a37182de3cce" +uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" +version = "1.10.0" + +[[deps.HypergeometricFunctions]] +deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" +uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" +version = "0.3.23" + +[[deps.Imbalance]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MLJTestInterface", "NearestNeighbors", "OrderedCollections", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "TableOperations", "TableTransforms", "Tables", "TransformsBase"] +git-tree-sha1 = "53eeb73d88913134cab0b0e04dd58901769fc7db" +uuid = "c709b415-507b-45b7-9a3d-1767c89fde68" +version = "0.1.0" + +[[deps.InitialValues]] +git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" +uuid = "22cec73e-a1b8-11e9-2c92-598750a2cf9c" +version = "0.3.1" + +[[deps.InlineStrings]] +deps = ["Parsers"] +git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" +uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" +version = "1.4.0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.InverseFunctions]] +deps = ["Test"] +git-tree-sha1 = "68772f49f54b479fa88ace904f6127f0a3bb2e46" +uuid = "3587e190-3f89-42d0-90ee-14403ec27112" +version = "0.1.12" + +[[deps.InvertedIndices]] +git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" +uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" +version = "1.3.0" + +[[deps.IrrationalConstants]] +git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" +uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" +version = "0.2.2" + +[[deps.IterationControl]] +deps = ["EarlyStopping", "InteractiveUtils"] +git-tree-sha1 = "d7df9a6fdd82a8cfdfe93a94fcce35515be634da" +uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" +version = "0.5.3" + +[[deps.IterativeSolvers]] +deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] +git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" +uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" +version = "0.9.2" + +[[deps.IteratorInterfaceExtensions]] +git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" +uuid = "82899510-4779-5014-852e-03e436cf321d" +version = "1.0.0" + +[[deps.JLLWrappers]] +deps = ["Artifacts", "Preferences"] +git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" +uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" +version = "1.5.0" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.4" + +[[deps.LaTeXStrings]] +git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" +uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" +version = "1.3.0" + +[[deps.LatinHypercubeSampling]] +deps = ["Random", "StableRNGs", "StatsBase", "Test"] +git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" +uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" +version = "1.9.0" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LineSearches]] +deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] +git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" +uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +version = "7.2.0" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.LinearMaps]] +deps = ["ChainRulesCore", "LinearAlgebra", "SparseArrays", "Statistics"] +git-tree-sha1 = "6698ab5e662b47ffc63a82b2f43c1cee015cf80d" +uuid = "7a12625a-238d-50fd-b39a-03d52299707e" +version = "3.11.0" + +[[deps.LogExpFunctions]] +deps = ["ChainRulesCore", "ChangesOfVariables", "DocStringExtensions", "InverseFunctions", "IrrationalConstants", "LinearAlgebra"] +git-tree-sha1 = "7d6dd4e9212aebaeed356de34ccf262a3cd415aa" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.3.26" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.LoggingExtras]] +deps = ["Dates", "Logging"] +git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" +uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" +version = "1.0.2" + +[[deps.LossFunctions]] +deps = ["Markdown", "Requires", "Statistics"] +git-tree-sha1 = "df9da07efb9b05ca7ef701acec891ee8f73c99e2" +uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" +version = "0.11.1" + +[[deps.MLFlowClient]] +deps = ["Dates", "FilePathsBase", "HTTP", "JSON", "ShowCases", "URIs", "UUIDs"] +git-tree-sha1 = "32cee10a6527476bef0c6484ff4c60c2cead5d3e" +uuid = "64a0f543-368b-4a9a-827a-e71edb2a0b83" +version = "0.4.4" + +[[deps.MLJ]] +deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJFlow", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "Reexport", "ScientificTypes", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "193f1f1ac77d91eabe1ac81ff48646b378270eef" +uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +version = "0.19.5" + +[[deps.MLJBalancing]] +deps = ["MLJBase", "MLJModelInterface", "OrderedCollections", "Random"] +path = "/Users/essam/.julia/dev/MLJBalancing" +uuid = "45f359ea-796d-4f51-95a5-deb1a414c586" +version = "1.0.0-DEV" + +[[deps.MLJBase]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "0b7307d1a7214ec3c0ba305571e713f9492ea984" +uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +version = "0.21.14" + +[[deps.MLJEnsembles]] +deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] +git-tree-sha1 = "95b306ef8108067d26dfde9ff3457d59911cc0d6" +uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" +version = "0.3.3" + +[[deps.MLJFlow]] +deps = ["MLFlowClient", "MLJBase", "MLJModelInterface"] +git-tree-sha1 = "bceeeb648c9aa2fc6f65f957c688b164d30f2905" +uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" +version = "0.1.1" + +[[deps.MLJIteration]] +deps = ["IterationControl", "MLJBase", "Random", "Serialization"] +git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" +uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" +version = "0.5.1" + +[[deps.MLJLinearModels]] +deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] +git-tree-sha1 = "c92bf0ea37bf51e1ef0160069c572825819748b8" +uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" +version = "0.9.2" + +[[deps.MLJModelInterface]] +deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] +git-tree-sha1 = "03ae109be87f460fe3c96b8a0dbbf9c7bf840bd5" +uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" +version = "1.9.2" + +[[deps.MLJModels]] +deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" +uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +version = "0.16.10" + +[[deps.MLJTestInterface]] +deps = ["MLJBase", "Pkg", "Test"] +git-tree-sha1 = "9131806695e6a6d32c61ed5f7bccaadef9fef57e" +uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd" +version = "0.2.2" + +[[deps.MLJTuning]] +deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] +git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" +uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" +version = "0.7.4" + +[[deps.MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "9ee1618cbf5240e6d4e0371d6f24065083f60c48" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.11" + +[[deps.MarchingCubes]] +deps = ["PrecompileTools", "StaticArrays"] +git-tree-sha1 = "c8e29e2bacb98c9b6f10445227a8b0402f2f173a" +uuid = "299715c1-40a9-479a-aaf9-4a633d36f717" +version = "0.1.8" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS]] +deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] +git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" +uuid = "739be429-bea8-5141-9913-cc70e7f3736d" +version = "1.1.7" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" + +[[deps.MicroCollections]] +deps = ["BangBang", "InitialValues", "Setfield"] +git-tree-sha1 = "629afd7d10dbc6935ec59b32daeb33bc4460a42e" +uuid = "128add7d-3638-4c79-886c-908ea0c25c34" +version = "0.1.4" + +[[deps.Missings]] +deps = ["DataAPI"] +git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" +uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" +version = "1.1.0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" + +[[deps.NLSolversBase]] +deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] +git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" +uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" +version = "7.8.3" + +[[deps.NaNMath]] +deps = ["OpenLibm_jll"] +git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "1.0.2" + +[[deps.NearestNeighbors]] +deps = ["Distances", "StaticArrays"] +git-tree-sha1 = "2c3726ceb3388917602169bed973dbc97f1b51a8" +uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +version = "0.4.13" + +[[deps.NelderMead]] +git-tree-sha1 = "25abc2f9b1c752e69229f37909461befa7c1f85d" +uuid = "2f6b4ddb-b4ff-44c0-b59b-2ab99302f970" +version = "0.4.0" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" + +[[deps.OpenLibm_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "05823500-19ac-5b8b-9628-191a04bc5112" +version = "0.8.1+0" + +[[deps.OpenML]] +deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] +git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" +uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" +version = "0.3.1" + +[[deps.OpenSSL]] +deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] +git-tree-sha1 = "51901a49222b09e3743c65b8847687ae5fc78eb2" +uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" +version = "1.4.1" + +[[deps.OpenSSL_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e78db7bd5c26fc5a6911b50a47ee302219157ea8" +uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" +version = "3.0.10+0" + +[[deps.OpenSpecFun_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.5+0" + +[[deps.Optim]] +deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] +git-tree-sha1 = "963b004d15216f8129f6c0f7d187efa136570be0" +uuid = "429524aa-4258-5aef-a3af-852621145aeb" +version = "1.7.7" + +[[deps.OrderedCollections]] +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.6.2" + +[[deps.PDMats]] +deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" +uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" +version = "0.11.17" + +[[deps.Parameters]] +deps = ["OrderedCollections", "UnPack"] +git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" +uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" +version = "0.12.3" + +[[deps.Parsers]] +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.7.2" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" + +[[deps.PooledArrays]] +deps = ["DataAPI", "Future"] +git-tree-sha1 = "36d8b4b899628fb92c2749eb488d884a926614d3" +uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" +version = "1.4.3" + +[[deps.PositiveFactorizations]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" +uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" +version = "0.2.4" + +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "03b4c25b43cb84cee5c90aa9b5ea0a78fd848d2f" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.2.0" + +[[deps.Preferences]] +deps = ["TOML"] +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" +uuid = "21216c6a-2e73-6563-6e65-726566657250" +version = "1.4.0" + +[[deps.PrettyPrinting]] +git-tree-sha1 = "22a601b04a154ca38867b991d5017469dc75f2db" +uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" +version = "0.4.1" + +[[deps.PrettyTables]] +deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" +uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +version = "2.2.7" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.ProgressMeter]] +deps = ["Distributed", "Printf"] +git-tree-sha1 = "00099623ffee15972c16111bcf84c58a0051257c" +uuid = "92933f4c-e287-5a05-a399-4b506db050ca" +version = "1.9.0" + +[[deps.QuadGK]] +deps = ["DataStructures", "LinearAlgebra"] +git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" +uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" +version = "2.8.2" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.RecipesBase]] +deps = ["PrecompileTools"] +git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" +uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +version = "1.3.4" + +[[deps.Reexport]] +git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "1.2.2" + +[[deps.RelocatableFolders]] +deps = ["SHA", "Scratch"] +git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" +uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" +version = "1.0.0" + +[[deps.Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.3.0" + +[[deps.Rmath]] +deps = ["Random", "Rmath_jll"] +git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" +uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" +version = "0.7.1" + +[[deps.Rmath_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" +uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" +version = "0.4.0+0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.ScientificTypes]] +deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] +git-tree-sha1 = "75ccd10ca65b939dab03b812994e571bf1e3e1da" +uuid = "321657f4-b219-11e9-178b-2701a2544e81" +version = "3.0.2" + +[[deps.ScientificTypesBase]] +git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" +uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" +version = "3.0.0" + +[[deps.Scratch]] +deps = ["Dates"] +git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" +uuid = "6c6a2e73-6563-6170-7368-637461726353" +version = "1.2.0" + +[[deps.SentinelArrays]] +deps = ["Dates", "Random"] +git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" +uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" +version = "1.4.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.Setfield]] +deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] +git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" +uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" +version = "1.1.1" + +[[deps.ShowCases]] +git-tree-sha1 = "7f534ad62ab2bd48591bdeac81994ea8c445e4a5" +uuid = "605ecd9f-84a6-4c9e-81e2-4798472b76a3" +version = "0.1.0" + +[[deps.SimpleBufferStream]] +git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" +uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" +version = "1.1.0" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.SortingAlgorithms]] +deps = ["DataStructures"] +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" +uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" +version = "1.1.1" + +[[deps.SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[deps.SpecialFunctions]] +deps = ["ChainRulesCore", "IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] +git-tree-sha1 = "e2cfc4012a19088254b3950b85c3c1d8882d864d" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "2.3.1" + +[[deps.SplittablesBase]] +deps = ["Setfield", "Test"] +git-tree-sha1 = "e08a62abc517eb79667d0a29dc08a3b589516bb5" +uuid = "171d559e-b47b-412a-8079-5efa626c420e" +version = "0.1.15" + +[[deps.StableRNGs]] +deps = ["Random", "Test"] +git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" +uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" +version = "1.0.0" + +[[deps.StaticArrays]] +deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] +git-tree-sha1 = "51621cca8651d9e334a659443a74ce50a3b6dfab" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "1.6.3" + +[[deps.StaticArraysCore]] +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" +uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" +version = "1.4.2" + +[[deps.StatisticalTraits]] +deps = ["ScientificTypesBase"] +git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" +uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" +version = "3.2.0" + +[[deps.Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[deps.StatsAPI]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1ff449ad350c9c4cbc756624d6f8a8c3ef56d3ed" +uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" +version = "1.7.0" + +[[deps.StatsBase]] +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" +uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +version = "0.34.0" + +[[deps.StatsFuns]] +deps = ["ChainRulesCore", "HypergeometricFunctions", "InverseFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] +git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" +uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +version = "1.3.0" + +[[deps.StringManipulation]] +deps = ["PrecompileTools"] +git-tree-sha1 = "a04cabe79c5f01f4d723cc6704070ada0b9d46d5" +uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" +version = "0.3.4" + +[[deps.SuiteSparse]] +deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] +uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" + +[[deps.TableOperations]] +deps = ["SentinelArrays", "Tables", "Test"] +git-tree-sha1 = "e383c87cf2a1dc41fa30c093b2a19877c83e1bc1" +uuid = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" +version = "1.2.0" + +[[deps.TableTraits]] +deps = ["IteratorInterfaceExtensions"] +git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" +uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" +version = "1.0.1" + +[[deps.TableTransforms]] +deps = ["AbstractTrees", "CategoricalArrays", "Distributions", "LinearAlgebra", "NelderMead", "PrettyTables", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables", "Transducers", "TransformsBase"] +git-tree-sha1 = "d2fc117cc24ad1e459c9ff9d839e201431ec608a" +uuid = "0d432bfd-3ee1-4ac1-886a-39f05cc69a3e" +version = "1.10.0" + +[[deps.Tables]] +deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits"] +git-tree-sha1 = "a1f34829d5ac0ef499f6d84428bd6b4c71f02ead" +uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" +version = "1.11.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.1" + +[[deps.TensorCore]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" +uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" +version = "0.1.1" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[deps.TranscodingStreams]] +deps = ["Random", "Test"] +git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" +uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" +version = "0.9.13" + +[[deps.Transducers]] +deps = ["Adapt", "ArgCheck", "BangBang", "Baselet", "CompositionsBase", "ConstructionBase", "DefineSingletons", "Distributed", "InitialValues", "Logging", "Markdown", "MicroCollections", "Requires", "Setfield", "SplittablesBase", "Tables"] +git-tree-sha1 = "53bd5978b182fa7c57577bdb452c35e5b4fb73a5" +uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999" +version = "0.4.78" + +[[deps.TransformsBase]] +deps = ["AbstractTrees"] +git-tree-sha1 = "53e92e907bd67eef12e319ca932a7dd036428bfc" +uuid = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8" +version = "1.2.1" + +[[deps.URIs]] +git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.5.0" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.UnPack]] +git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" +uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" +version = "1.0.2" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.UnicodePlots]] +deps = ["ColorSchemes", "ColorTypes", "Contour", "Crayons", "Dates", "LinearAlgebra", "MarchingCubes", "NaNMath", "PrecompileTools", "Printf", "Requires", "SparseArrays", "StaticArrays", "StatsBase"] +git-tree-sha1 = "b96de03092fe4b18ac7e4786bee55578d4b75ae8" +uuid = "b8865327-cd53-5732-bb35-84acbb429228" +version = "3.6.0" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/examples/Project.toml b/examples/Project.toml new file mode 100644 index 0000000..d236ff2 --- /dev/null +++ b/examples/Project.toml @@ -0,0 +1,7 @@ +[deps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +Imbalance = "c709b415-507b-45b7-9a3d-1767c89fde68" +MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" +MLJBalancing = "45f359ea-796d-4f51-95a5-deb1a414c586" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" diff --git a/src/MLJBalancing.jl b/src/MLJBalancing.jl index c461030..d891030 100644 --- a/src/MLJBalancing.jl +++ b/src/MLJBalancing.jl @@ -5,6 +5,7 @@ using MLJModelInterface using OrderedCollections MMI = MLJModelInterface - +include("balanced_model.jl") +export BalancedModel end \ No newline at end of file diff --git a/src/balanced_model.jl b/src/balanced_model.jl new file mode 100644 index 0000000..88cd78e --- /dev/null +++ b/src/balanced_model.jl @@ -0,0 +1,201 @@ +#= +This is how the struct and the constructor for the model balancer +would look if it were to support only the probabilistic model type: + +struct BalancedModel <:ProbabilisticNetworkComposite + balancer # oversampler or undersampler + model::Probabilistic # get rid of abstract types +end + +BalancedModel(;model=nothing, balancer=nothing) = BalancedModel(model, balancer) +BalancedModel(model; kwargs...) = BalancedModel(; model, kwargs...) + +In the following, we use macros to automate code generation of these for all model +types +=# + +### 1. Define model structs + +# Supported Model Types +const SUPPORTED_MODEL_TYPES = (:Probabilistic, :Deterministic, :Interval) + +# A dictionary to convert e.g., from Probabilistic to BalancedModelProbabilistic +const MODELTYPE_TO_COMPOSITETYPE = Dict(atom => Symbol("BalancedModel$atom") for atom in SUPPORTED_MODEL_TYPES) +# A dictionary to convert e.g., form Probabilistic to ProbabilisticNetworkComposite +const MODELTYPE_TO_SUPERTYPE = Dict(atom => Symbol("$(atom)NetworkComposite") for atom in SUPPORTED_MODEL_TYPES) + +# Define a struct for each model type (corresponds to a composite type and supertype used in struct) +for model_type in SUPPORTED_MODEL_TYPES + struct_name = MODELTYPE_TO_COMPOSITETYPE[model_type] + super_type = MODELTYPE_TO_SUPERTYPE[model_type] + ex = quote + mutable struct $struct_name{balancernames, M <: $model_type} <: $super_type + balancers + model::M + function $struct_name(balancernames, balancers, model::M) where M <: $model_type + # generate an instance and use balancernames as type parameter + return new{balancernames, M}(balancers, model) + end + end + end + eval(ex) +end + +### 2. Define one keyword constructor for model structs + +# A version of MODELTYPE_TO_COMPOSITETYPE with evaluated keys and values (used in keyword constructor) +const MODELTYPE_TO_COMPOSITETYPE_EVAL = Dict() +for MODELTYPE in SUPPORTED_MODEL_TYPES + type = MODELTYPE_TO_COMPOSITETYPE[MODELTYPE] + @eval(MODELTYPE_TO_COMPOSITETYPE_EVAL[$MODELTYPE] = $type) +end + +# To represent any model type (to check input model type is one of them in keyword constructor) +const UNION_MODEL_TYPES = Union{keys(MODELTYPE_TO_COMPOSITETYPE_EVAL)...} + + +# Possible Errors (for the constructor as well) +const ERR_MODEL_UNSPECIFIED = ArgumentError("Expected an atomic model as argument. None specified. ") + +const WRN_BALANCER_UNSPECIFIED = "No balancer was provided. Data will be directly passed to the model. " + +const PRETTY_SUPPORTED_MODEL_TYPES = join([string("`", opt, "`") for opt in SUPPORTED_MODEL_TYPES], ", ",", and ") + +const ERR_UNSUPPORTED_MODEL(model) = ArgumentError( + "Only these model supertypes support wrapping: "* + "$PRETTY_SUPPORTED_MODEL_TYPES.\n"* + "Model provided has type `$(typeof(model))`. " +) + + +""" + BalancedModel(; balancers=[], model=nothing) + +Wraps a classification model with balancers that resample the data before passing it to the model. + +# Arguments +- `balancers::AbstractVector=[]`: A vector of balancers (i.e., resampling models). + Data passed to the model will be first passed to the balancers sequentially. +- `model=nothing`: The classification model which must be provided. + +""" +function BalancedModel(; model=nothing, named_balancers...) + # check model and balancer are given + model === nothing && throw(ERR_MODEL_UNSPECIFIED) + # check model is supported + model isa UNION_MODEL_TYPES || throw(ERR_UNSUPPORTED_MODEL(model)) + + nt = NamedTuple(named_balancers) + balancernames = keys(nt) + balancers = collect(nt) + # warn if balancer is not given + isempty(balancers) && @warn WRN_BALANCER_UNSPECIFIED + # call the appropriate constructor + return MODELTYPE_TO_COMPOSITETYPE_EVAL[MMI.abstract_type(model)](balancernames, balancers, model) +end + + +### 3. Make a property for each balancer given via keyword arguments + +# set the property names to include the keyword arguments +Base.propertynames(::BalancedModelProbabilistic{balancernames}) where balancernames = + tuple(:model, balancernames...) + +# overload getproperty to return the balancer form the vector in the struct +for model_type in SUPPORTED_MODEL_TYPES + struct_name = MODELTYPE_TO_COMPOSITETYPE[model_type] + ex = quote + function Base.getproperty(b::$struct_name{balancernames}, name::Symbol) where balancernames + balancers = getfield(b, :balancers) + for j in eachindex(balancernames) + name === balancernames[j] && return balancers[j] + end + return getfield(b, name) + end + end + eval(ex) +end + +const ERR_NO_PROP = "trying to access property $name which does not exist" +# overload set property to set the property from the vector in the struct +for model_type in SUPPORTED_MODEL_TYPES + struct_name = MODELTYPE_TO_COMPOSITETYPE[model_type] + ex = quote + function Base.setproperty!(b::$struct_name{balancernames}, name::Symbol, val) where balancernames + # find the balancer model with given balancer names + idx = findfirst(==(name), balancernames) + # get it from the vector in the struct and set it with the value + !isnothing(idx) && return getfield(b, :balancers)[idx] = val + # the other only option is model + name === :model && return setfield(b, :model, val) + error(ERR_NO_PROP) + end + end + eval(ex) +end + + + +### 4. Define the prefit method +# used below, represents any composite model type offered by our package (e.g., BalancedProbabilisitcMode) +const UNION_COMPOSITE_TYPES{balancernames} = Union{[type{balancernames} for type in values(MODELTYPE_TO_COMPOSITETYPE_EVAL)]...} + +""" +Overload the prefit method to export a learning network composed of a sequential pipeline of balancers + followed by a final model. +""" +function MLJBase.prefit(balanced_model::UNION_COMPOSITE_TYPES{balancernames}, verbosity, _X, _y) where balancernames + # the learning network: + X = source(_X) + y = source(_y) + X_over, y_over = X, y + # Let's transform the data through :balancer1, :balancer2,... + for symbolic_balancer in balancernames + balancer = getproperty(balanced_model, symbolic_balancer) + mach1 = machine(balancer) + data = MLJBase.transform(mach1, X_over, y_over) + X_over, y_over= first(data), last(data) + end + # we use the oversampled data for training: + mach2 = machine(:model, X_over, y_over) # wrap with the data to be trained + # but consume new prodution data from the source: + yhat = MLJBase.predict(mach2, X) + # return the learning network interface: + return (; predict=yhat) +end + + +### 5. Provide package information and pass up model traits +MMI.package_name(::Type{<:UNION_COMPOSITE_TYPES}) = "MLJBalancing" +MMI.package_license(::Type{<:UNION_COMPOSITE_TYPES}) = "MIT" +MMI.package_uuid(::Type{<:UNION_COMPOSITE_TYPES}) = "45f359ea-796d-4f51-95a5-deb1a414c586" +MMI.is_wrapper(::Type{<:UNION_COMPOSITE_TYPES}) = true +MMI.package_url(::Type{<:UNION_COMPOSITE_TYPES}) ="https://github.com/JuliaAI/MLJBalancing.jl" + +# All the composite types BalancedModelProbabilistic, BalancedModelDeterministic, etc. +const COMPOSITE_TYPES = values(MODELTYPE_TO_COMPOSITETYPE) +for composite_type in COMPOSITE_TYPES + quote + MMI.iteration_parameter(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} = + MLJBase.prepend(:model, iteration_parameter(M)) + end |> eval + for trait in [ + :input_scitype, + :output_scitype, + :target_scitype, + :fit_data_scitype, + :predict_scitype, + :transform_scitype, + :inverse_transform_scitype, + :is_pure_julia, + :supports_weights, + :supports_class_weights, + :supports_online, + :supports_training_losses, + :is_supervised, + :prediction_type] + quote + MMI.$trait(::Type{<:$composite_type{balancernames, M}}) where {balancernames, M} = MMI.$trait(M) + end |> eval + end +end \ No newline at end of file diff --git a/test/balanced_model.jl b/test/balanced_model.jl new file mode 100644 index 0000000..0552a5f --- /dev/null +++ b/test/balanced_model.jl @@ -0,0 +1,86 @@ +@testset "BalancedModel" begin + ### end-to-end test + # Create and split data + X, y = generate_imbalanced_data(100, 5; probs = [0.2, 0.3, 0.5]) + X = DataFrame(X) + train_inds, test_inds = + partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.Xoshiro(42)) + X_train, X_test = X[train_inds, :], X[test_inds, :] + y_train, y_test = y[train_inds], y[test_inds] + + # Load models and balancers + DeterministicConstantClassifier = @load DeterministicConstantClassifier pkg=MLJModels + LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels + + # Here are a probabilistic and a deterministic model + model_prob = LogisticClassifier() + model_det = DeterministicConstantClassifier() + # And here are three resamplers from Imbalance. + # The package should actually work with any `Static` transformer of the form `(X, y) -> (Xout, yout)` + # provided that it implements the MLJ interface. Here, the balancer is the transformer + balancer1 = Imbalance.MLJ.RandomOversampler(ratios = 1.0, rng = 42) + balancer2 = Imbalance.MLJ.SMOTENC(k = 10, ratios = 1.2, rng = 42) + balancer3 = Imbalance.MLJ.ROSE(ratios = 1.3, rng = 42) + + ### 1. Make a pipeline of the three balancers and a probablistic model + ## ordinary way + mach = machine(balancer1) + Xover, yover = MLJBase.transform(mach, X_train, y_train) + mach = machine(balancer2) + Xover, yover = MLJBase.transform(mach, Xover, yover) + mach = machine(balancer3) + Xover, yover = MLJBase.transform(mach, Xover, yover) + + mach = machine(model_prob, Xover, yover) + fit!(mach) + y_pred = MLJBase.predict(mach, X_test) + + # with MLJ balancing + @test_throws MLJBalancing.ERR_MODEL_UNSPECIFIED begin + BalancedModel(b1 = balancer1, b2 = balancer2, b3 = balancer3) + end + @test_throws "ArgumentError: Only these model supertypes support wrapping: `Probabilistic`, `Deterministic`, and `Interval`.\nModel provided has type `Int64`." begin + BalancedModel(model = 1, b1 = balancer1, b2 = balancer2, b3 = balancer3) + end + @test_logs (:warn, MLJBalancing.WRN_BALANCER_UNSPECIFIED) begin + BalancedModel(model = model_prob) + end + + balanced_model = + BalancedModel(model = model_prob, b1 = balancer1, b2 = balancer2, b3 = balancer3) + mach = machine(balanced_model, X_train, y_train) + fit!(mach) + y_pred2 = MLJBase.predict(mach, X_test) + + @test y_pred ≈ y_pred2 + + ### 2. Make a pipeline of the three balancers and a deterministic model + ## ordinary way + mach = machine(balancer1) + Xover, yover = MLJBase.transform(mach, X_train, y_train) + mach = machine(balancer2) + Xover, yover = MLJBase.transform(mach, Xover, yover) + mach = machine(balancer3) + Xover, yover = MLJBase.transform(mach, Xover, yover) + + mach = machine(model_det, Xover, yover) + fit!(mach) + y_pred = MLJBase.predict(mach, X_test) + + # with MLJ balancing + balanced_model = + BalancedModel(model = model_det, b1 = balancer1, b2 = balancer2, b3 = balancer3) + mach = machine(balanced_model, X_train, y_train) + fit!(mach) + y_pred2 = MLJBase.predict(mach, X_test) + + @test y_pred == y_pred2 + + ### check that setpropertyname and getpropertyname work + Base.getproperty(balanced_model, :b1) == balancer1 + Base.setproperty!(balanced_model, :b1, balancer2) + Base.getproperty(balanced_model, :b1) == balancer2 + @test_throws MLJBalancing.ERR_NO_PROP begin + Base.setproperty!(balanced_model, :name11, balancer2) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 7a9b9f8..1008fc8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,11 +1,10 @@ using Test using MLJBalancing -using MLJ +using MLJBase +using MLJModels using Imbalance -import MLJBase -using MLJModelInterface -MMI = MLJModelInterface using Random using DataFrames +include("balanced_model.jl")