diff --git a/src/mlj_model_interface.jl b/src/mlj_model_interface.jl index f25c137e..aa9850c4 100644 --- a/src/mlj_model_interface.jl +++ b/src/mlj_model_interface.jl @@ -7,6 +7,11 @@ MLJModelInterface.deep_properties(::Type{<:MLJFluxModel}) = # # CLEAN METHOD +const ERR_BAD_OPTIMISER = ArgumentError( + "Flux.jl optimiser detected. Only optimisers from Optimisers.jl are supported. "* + "For example, use `optimiser=Optimisers.Momentum()` after `import Optimisers`. " +) + function MLJModelInterface.clean!(model::MLJFluxModel) warning = "" if model.lambda < 0 @@ -40,6 +45,9 @@ function MLJModelInterface.clean!(model::MLJFluxModel) "on an RNG during training, such as `Dropout`. Consider using "* " `Random.default_rng()` instead. `" end + # TODO: This could be removed in next breaking release (0.6.0): + model.optimiser isa Flux.Optimise.AbstractOptimiser && throw(ERR_BAD_OPTIMISER) + return warning end @@ -79,7 +87,7 @@ function regularized_optimiser(model, nbatches) Optimisers.WeightDecay(λ_weight), model.optimiser, ) - end + end end function MLJModelInterface.fit(model::MLJFluxModel, diff --git a/test/mlj_model_interface.jl b/test/mlj_model_interface.jl index 7b929050..522b059e 100644 --- a/test/mlj_model_interface.jl +++ b/test/mlj_model_interface.jl @@ -35,6 +35,10 @@ end end @test model.acceleration == CUDALibs() end + + @test_throws MLJFlux.ERR_BAD_OPTIMISER NeuralNetworkClassifier( + optimiser=Flux.Optimise.Adam(), + ) end @testset "regularization: logic" begin