-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-Introduce ShardedForm
for large expressions
#937
Comments
ShardedForm
for large expressions
Someone just needs to fix it. We cannot default to it if it's not correct. If you're willing to fix its dependency analysis then we'd be happy to re-enable it. |
I am happy to take a look. I have little exposure in this area; so no idea if I can be of actual use here. Are there any more pointers/evidence of what exactly seems to be the issue? |
The observed equations are not appended to the front of the sharded equations so it errors with any observables. At least that's what someone mentioned to me 2 weeks ago (@shashi?) |
I would like to see a reproducer from you guys for the multithreading deadlock. The problem might have gone away now. @wsphillips was saying it gave wrong answers. But just from the code, I don't see a chance of dead lock or race conditions, unless you are passing in indices into build_function with repeated indices, in which case, even the serial version would be wrong. |
The fix from @shashi solves the observable scoping issue. But when I tested it with some Conductor.jl models it returns solutions that are wrong if in You can use this MWE I adapted from the MTK docs to reproduce: Edit: using Shashi's opaque closure PR using ModelingToolkit, Plots, OrdinaryDiffEq, LinearAlgebra
using Symbolics: scalarize
@variables t
D = Differential(t)
function Mass(; name, m = 1.0, xy = [0.0, 0.0], u = [0.0, 0.0])
ps = @parameters m = m
sts = @variables pos(t)[1:2]=xy v(t)[1:2]=u
eqs = scalarize(D.(pos) .~ v)
ODESystem(eqs, t, [pos..., v...], ps; name)
end
function Spring(; name, k = 1e4, l = 1.0)
ps = @parameters k=k l=l
@variables x(t), dir(t)[1:2]
ODESystem(Equation[], t, [x, dir...], ps; name)
end
function connect_spring(spring, a, b)
[spring.x ~ norm(scalarize(a .- b))
scalarize(spring.dir .~ scalarize(a .- b))]
end
function spring_force(spring)
-spring.k .* scalarize(spring.dir) .* (spring.x - spring.l) ./ spring.x
end
m = 1.0
xy = [1.0, -1.0]
k = 1e4
l = 1.0
center = [0.0, 0.0]
g = [0.0, -9.81]
@named mass = Mass(m = m, xy = xy)
@named spring = Spring(k = k, l = l)
eqs = [connect_spring(spring, mass.pos, center)
scalarize(D.(mass.v) .~ spring_force(spring) / mass.m .+ g)]
@named _model = ODESystem(eqs, t, [spring.x; spring.dir; mass.pos], [])
@named model = compose(_model, mass, spring)
sys = structural_simplify(model)
# if parallel = `Symbolics.ShardedForm(2,2)` or default serial form this works fine
prob = ODEProblem(sys, [], (0.0, 3.0); parallel = Symbolics.MultithreadedForm(2,2))
sol = solve(prob, Rosenbrock23())
plot(sol) # no oscillations/wrong output when multithreaded |
The removal of
ShardedForm
for large arrays in 616ef52 has caused a major performance regression in one of my projects. Compilation time is now about 10x longer. Beyond that, previously, calling the resulting function from 23 distributed workers in parallel worked just fine. Now, it quickly runs out of memory on a machine with 64GB of ram; I guess because all workers have to compile the function upon first call and compilation is more memory intensive for theSerialForm
.Is there a way to default again to
ShardedForm
s for large functions or is this a fundamental limitation of RuntimeGeneratedFunctions.jl?The text was updated successfully, but these errors were encountered: