Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type instabilities in predict #959

Open
MilesCranmer opened this issue Feb 18, 2024 · 10 comments
Open

Type instabilities in predict #959

MilesCranmer opened this issue Feb 18, 2024 · 10 comments

Comments

@MilesCranmer
Copy link
Contributor

Due to the definition of last_model being an Any:

last_model(mach) = isdefined(mach, :old_model) ? mach.old_model : nothing

all of the operations in OPERATIONS cannot infer output type. This is okay for report, but for predict, the output type might be important for downstream operations, and so this can slow things down.

@ablaom
Copy link
Member

ablaom commented Mar 6, 2024

Nice observation. May be tricky to fix quickly. I'm curious how you came to identify the issue. Do you have a concrete use case with performance substantially compromised by this issue?

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Mar 6, 2024

Sure, here is a 50x slowdown with a really simple function. But keep in mind that this is just the most minimal example. Ultimately this type instability will infect any code it touches. The larger the codebase depending on a particular predict call, the worse the performance hit.

using SymbolicRegression, MLJBase, Random

rng = Random.MersenneTwister(0)
X = randn(rng, 30, 3)
y = @. cos(X[:, 1] * 2.1 - 0.9) * X[:, 2] - X[:, 3]

model = SRRegressor(deterministic=true, seed=0, parallelism=:serial, maxsize=10)
mach = machine(model, X, y)
fit!(mach)

# With predict(mach, X)
function f(mach, X)
    y = predict(mach, X)
    return mean(y), std(y)
end

# With low-level calling
function g(eqn, options, X)
    y, _ = eval_tree_array(eqn, X, options)
    return mean(y), std(y)
end

With the following performance comparison (Julia 1.10.2)

@btime f($mach, $X)
#   5.521 μs (259 allocations: 18.79 KiB)
@btime g(eqn, options, $(copy(X'))) setup=(r=report(mach); eqn = r.equations[r.best_idx]; options=mach.fitresult.options)
#  134.497 ns (3 allocations: 368 bytes)

(I highly recommend Cthulhu.jl for narrowing down type instabilities)

@MilesCranmer
Copy link
Contributor Author

Here is the Cthulhu.jl output when I descend into f(mach, X) – the predict(mach, X) call results in ::Any output (in red):
Screenshot 2024-03-06 at 21 02 08

Descending into the predict call, we can immediately see that the first appearance of an ::Any is from accessing the .fitresult:

Screenshot 2024-03-06 at 21 03 15

Then every call after that is infected by the red (type inference issue).

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Mar 6, 2024

Perhaps one way to fix this without deep structural changes would be to allow package developers to specify the fitresult type in a trait handled by MLJModelInterface. Then, when calling predict, or other operators, you can do a type assertion, allowing the function to specialize.

For example:

MMI.fitresult_type(::Type{<:MyRegressor{A,B,C}}) where {A,B,C} = MyFitResultType{A,B,C}

and simply have the default method be

MMI.fitresult_type(_) = Any

then, in predict, you could do the following:

ret = $(operator)(
    model,
    mach.fitresult::MMI.fitresult_type(model),
    ...
)

This would be completely backwards compatible, because the standard fitresult_type(_) = Any.

Using this type assertion would actually be enough for Julia to specialize.

What do you think?


Edit: looks like there is also no specialization to the array type passed to Machine. i.e., the Machine does not actually store info about the array in the type. Could that be added also?

@ablaom
Copy link
Member

ablaom commented Mar 20, 2024

Thanks @MilesCranmer for the further details and the very nice idea for a fix. Perhaps after I have a chance to look at a direct fix we can go this route.

@ablaom
Copy link
Member

ablaom commented Mar 22, 2024

Okay, the type instability in last_model is an easy fix. I just need to annotate the old_model field in the Machine struct. But the benchmarks are still bad. Tried Chulhu.jl as kindly suggested, but can't get it to work just now.

@MilesCranmer
Copy link
Contributor Author

I had that issue too. I put the following in my startup.jl: JuliaDebug/Cthulhu.jl#546 (comment)

@MilesCranmer
Copy link
Contributor Author

Actually it looks like it’s fixed on master version of Cthulhu, just not released yet. So just install master.

@MilesCranmer
Copy link
Contributor Author

Also, this tutorial from Tim Holy was extremely helpful for learning this type of analysis:
https://youtube.com/watch?v=pvduxLowpPY
Completely changed how I profile julia code

@ablaom
Copy link
Member

ablaom commented Apr 8, 2024

I have made some more progress removing some type instability with predict. Will cross-reference to a PR shorty.

However, while this reduces bloat in predict(::Machine, ...) across the board, it does not change the disparity reported for SymbolicRegression reported above. However, I wonder if the comparison above is really fair; if I replace the definition of g with

function g(model, fitresult, X)
    y = predict(model, fitresult, X)
    return mean(y), std(y)
end

So that I am just comparing predict(::Machine, ...) with predict(::Model, ...) (implemented in SymbolicRegression.jl not MLJBase) then both benchmarks are similar, with the large number of allocations.

Furthermore, there appears to be a type instability in predict(::Model, ...), as currently implemented in SymbolicRegression.jl:

rng = Random.MersenneTwister(0)
X = randn(rng, 30, 3)
y = @. cos(X[:, 1] * 2.1 - 0.9) * X[:, 2] - X[:, 3]

model = SRRegressor(deterministic=true, seed=0, parallelism=:serial, maxsize=10)
mach = machine(model, X, y)
fit!(mach, verbosity=0)

fitresult = mach.fitresult;

After @code_warntype predict(model, fitresult, X):

Screen Shot 2024-04-09 at 8 31 11 AM

Note the boldface Anys.

Am I missing something here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: priority low / involved
Development

No branches or pull requests

2 participants