-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial review of MLJ interface. #120
Comments
@ablaom have just come back to this, because I'm facing very similar issues with another one of our packages: https://github.com/JuliaTrustworthyAI/JointEnergyModels.jl?tab=readme-ov-file The reason we targeted MLJFlux in both cases was that the underlying atomic models are Flux models and MLJFlux comes with useful functionality like the builders.
This was done here in order to make Laplace approximation (LA) part of the call to 1. Interface MLJ directly (not through MLJFlux)This is manageable (see. e.g. NeuroTreeModels. We lose the added functionality of MLJFlux, like builders and such, but I'm starting to think we'll have to live with that. 2. Give up on MLJ interface for nowAnother option would be to give up on the custom models MLJ models and instead just run LA post-hoc. In this scenario, users could just rely on MLJFlux as they normally do to train conventional neural networks and then run LA. But this does not seem ideal, since we need the interface for JuliaTrustworthyAI/ConformalPrediction.jl#125 I'm leaning towards interfacing MLJ directly, but curious to hear your thoughts. Thanks again for your input here @ablaom and sorry for submitting this to MLJ a little prematurely. |
@pat-alt i guess we can interface mlj directly and in the model definition we leave a field for the flux chain provided by the user. it would have saved us a lot of time. |
Thanks for the review and feedback, It's really helpful to have some direct input to improve this.
|
My sense is that the degree of code complexity you need to add to make MLJFlux.jl work is not worth any extra functionality you buy into, and so interfacing MLJ directly may be better. Of course you should feel free to mirror whatever is useful to you from MLJFlux.jl. Let me know what you decide. That said, I would support a redesign of MLJFlux's "internal API" that accommodates a wider range of models. I just don't have the resources to do this on my own. It sounds like you and your team would be in a good place to suggest such a design, if you likewise have the resources at some point. |
To clarify an earlier point regarding the form of predictions. It is not absolutely necessary that predictions be distributions. I noticed that prediction is VERY slow at present, so am guessing you are doing some sort of sampling to get the parametric forms (sorry I didn't research LA). An option in MLJ is to just directly return a vector of "sampleable" objects (objects implementing just I have also been making the usual assumtion that prediction components are not correlated. That is, the |
@pat-alt pls correct me if i am wrong, but no there is no sampling involved and the prediction are not correlated, the hessian is found during the training phase, so i guess it's a defect of our mljflux implementation. |
FYI This timing is fairly reproducible on my machine: X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification();
fitresult, _, _ = MLJBase.fit(model, 0, X, y);
MLJBase.predict(model, fitresult, X);
@time MLJBase.predict(model, fitresult, X)
# 30.179502 seconds (58.51 k allocations: 8.287 GiB, 2.13% gc time) For julia> @time MLJBase.predict(model, fitresult, X)
0.000275 seconds (1.33 k allocations: 83.375 KiB) |
@ablaom i went back to see the theory behind the LAplace approximation. For each new element x, it is necessary to compute again the jacobians, so maybe this is the reason why it takes longer in the inference phase respect to a standard neural network |
Thanks very much for flagging this @ablaom. The Jacobian computation is indeed a bottleneck in forward passes but it's not causing this issue. Adapting @ablaom's example from above: using LaplaceRedux
using MLJBase
X = MLJBase.table(rand(Float32, 100, 3));
y = coerce(rand("abc", 100), Multiclass);
model = LaplaceClassification();
fitresult, _, _ = MLJBase.fit(model, 0, X, y);
la = fitresult[1];
Xmat = matrix(X) |> permutedims;
# Single test sample:
Xtest = Xmat[:,1:10];
Xtest_tab = MLJBase.table(Xtest');
MLJBase.predict(model, fitresult, Xtest_tab); # warm up
LaplaceRedux.predict(la, Xmat); # warm up Generating predictions using our MLJ interface vs. our default predict method leads to wildly different computation times: julia> @time MLJBase.predict(model, fitresult, Xtest_tab);
1.806758 seconds (5.86 k allocations: 848.624 MiB, 0.78% gc time)
julia> @time LaplaceRedux.predict(la, Xtest);
0.189871 seconds (4.71 k allocations: 86.994 MiB)
julia> @time glm_predictive_distribution(la, Xtest);
0.189886 seconds (4.71 k allocations: 86.994 MiB) Curiously, it takes pretty much exactly 10x as long using the MLJ interface and here I've chosen 10 test samples. Trying it with 50 test samples seems to confirm that julia> Xtest = Xmat[:,1:50];
julia> Xtest_tab = MLJBase.table(Xtest');
julia> MLJBase.predict(model, fitresult, Xtest_tab); # warm up
julia> LaplaceRedux.predict(la, Xmat); # warm up
julia> @time MLJBase.predict(model, fitresult, Xtest_tab);
8.926642 seconds (29.26 k allocations: 4.144 GiB, 0.76% gc time)
julia> @time LaplaceRedux.predict(la, Xtest);
0.260523 seconds (22.60 k allocations: 101.032 MiB)
julia> @time glm_predictive_distribution(la, Xtest);
0.252946 seconds (22.60 k allocations: 101.032 MiB) The problem is the |
@pat-alt eh i was modifying exactly this part yesterday. i was trying
because i remembered that LaplaceRedux.predict accepted vectors as input |
I'm posting this in response to the request at JuliaAI/MLJModels.jl#571.
I can see a some work has gone into understanding MLJ's API requirements (and in understanding internals of MLJFlux).
I have not made an exhaustive review of the interface but list below some issues identified so far. Read point 4 first, as it is the more serious.
1. Form of predictions
Whenever possible, probabilistic predictions must take the form of a vector of distributions, where a "distribution" is something implementing
Distributions.pdf
andRandom.rand
(docs). So, instead of returning raw probabilities, the classifier should return a vector with element typeUnivariateFinite
(owned by CategoricalDistributions.jl). For example, here's whatMLJFlux.NeuralNetworkClassifier
predictions look like:Perhaps you mirror the code for that model, here.
Similarly, the regressor should return a vector of whatever
Distributions
distribution you are returning, e.g. a vector ofDistributions.Normal
, and not simply return parameters.2. Table handling
I suspect there is something not generic about tables handling. If I train the classifier using data
X, y = MLJBase.@load_iris
I get an error, although training withX, y = make_moons()
works fine. Getting the number of rows of a generic table (if that's the issue) has always been a bit of a problem, because the Tables.jl API was designed to include tables without length. I think the idea is that you should useDataAPI.nrow
(row
singular) for this, but I thinkMLJModelInterface.nrows
orMLJBase.nrows
(rows
plural) are probably okay.3. Metadata/traits
The
load_path
s are wrong (see correction below).Your input and target types need some tweaking. For example, I'm getting warnings with the above data sets about the type of data when a I do
machine(model, X, y)
. One problem is you haveFinite
in some places you probably want<:Finite
, becauseFinite
is aUnionAll
type (parameterised,Finite{N}
). See my suggestion below.a Do you really support input's
X
with categorical features? (If you are you may be interested in the pending MLJFlux PR which adds entity embedding for categorical features for the non-image models. This might be more useful than static one-hot encoding, if that is what you do to handle categoricals.)b Do you really support classification for non categorical targets
y
(you currently allowy
to beContinuous
)?c Do you really intend to support regression with categorical targets
y
. What would that mean?d So you really intend to exclude mixed data types in input
X
(some categorical, some continuous)?e Do you handle
OrderedFactor
andMulticlass
differently (as you probably should)? If not, perhaps you mean to restrict toMulticlass
and have the user coerceOrderedFactor
toContinuous
(assuming you do not already do this under the hood).Assuming the answers to a- d are: yes, no, no, no, here's my stab at a revised metadata declaration:
4. Use of private API (more serious)
The overloaded methods
MLJFlux.shape
,MLJFlux.build(::FluxModel, ...)
,MLJFlux.fitresult
, andMLJFlux.train
are not public API. They are simply abstractions that arose to try to remove some code duplication with the different models provided by MLJFlux. I am consequently reluctant to make this public. Indeed, the entity embedding PR referred to above breaksMLJ.fitresult
, and future patch releases may break the API further. There may be a good argument for making this API public, but I feel this requires a substantial rethink. Indeed your own attempt to "hack" aspects of this API reveal the inadequacies: The fact that you feel the need to overloadMLJFlux.train
at all; the fact that thechain
get's modified intrain
and not is some earlier stage, etc.Unfortunately, I personally don't have the bandwidth for this kind of refactoring of MLJFlux any time soon. Your best option may simply be to cut and paste the MLJFlux code you need and have LaplaceRedux own independent versions of the private MLJFlux API methods referenced above. Alternatively, you could leave things as they are and live with breakages, as they occur. Not sure how keen I am on registering such a model, however. Perhaps we wait and see how stable the internal API winds up being.
5. Minor nomenclature point
For consistency with other MLJ models, I suggest
LaplaceRegressor
overLaplaceRegression
andLaplaceClassifier
overLaplaceClassification
. Of course I understand you may have other reasons for the name choices.The text was updated successfully, but these errors were encountered: