Skip to content

Commit

Permalink
explicitly perform multiplications
Browse files Browse the repository at this point in the history
  • Loading branch information
Shreyas-Ekanathan committed Sep 8, 2024
1 parent 4252639 commit 3318e37
Showing 1 changed file with 47 additions and 11 deletions.
58 changes: 47 additions & 11 deletions lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1369,19 +1369,19 @@ end

J = calc_J(integrator, cache)
if u isa Number
LU1 = -γdt * mass_matrix + J
tmp = -(αdt[1] + βdt[1] * im) * mass_matrix + J
else
LU1 = lu(-γdt * mass_matrix + J)
tmp = lu(-(αdt[1] + βdt[1] * im) * mass_matrix + J)
end
LU2 = Vector{typeof(tmp)}(undef, (num_stages - 1) ÷ 2)
LU2[1] = tmp
if u isa Number
LU1 = -γdt * mass_matrix + J
for i in 2 : (num_stages - 1) ÷ 2
LU2[i] = -(αdt[i] + βdt[i] * im) * mass_matrix + J
end
else
LU1 = lu(-γdt * mass_matrix + J)
for i in 2 : (num_stages - 1) ÷ 2
LU2[i] = lu(-(αdt[i] + βdt[i] * im) * mass_matrix + J)
end
Expand Down Expand Up @@ -1412,7 +1412,13 @@ end
end
z[i] = @.. z[i] * c_prime[i]
end
w = TI*z
#w = TI*z
for i in 1:num_stages
w[i] = zero(u)
for j in 1:num_stages
w[i] += TI[i,j] * z[j]
end
end
end

# Newton iteration
Expand All @@ -1431,7 +1437,15 @@ end
end
integrator.stats.nf += num_stages

fw = TI * ff
#fw = TI * ff
fw = Vector{typeof(u)}(undef, num_stages)
for i in 1:num_stages
fw[i] = zero(u)
for j in 1:num_stages
fw[i] += TI[i,j] * ff[j]
end
end

Mw = Vector{typeof(u)}(undef, num_stages)
if mass_matrix isa UniformScaling # `UniformScaling` doesn't play nicely with broadcast
for i in 1 : num_stages
Expand Down Expand Up @@ -1481,8 +1495,13 @@ end
w = @.. w - dw

# transform `w` to `z`
z = T * w

#z = T * w
for i in 1:num_stages
z[i] = zero(u)
for j in 1:num_stages
z[i] += T[i,j] * w[j]
end
end
# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand Down Expand Up @@ -1565,7 +1584,6 @@ end
mass_matrix = integrator.f.mass_matrix

# precalculations

γdt, αdt, βdt = γ / dt, α ./ dt, β ./ dt
(new_jac = do_newJ(integrator, alg, cache, repeat_step)) &&
(calc_J!(J, integrator, cache); cache.W_γdt = dt)
Expand Down Expand Up @@ -1600,7 +1618,13 @@ end
end
@.. z[i] *= c_prime[i]
end
mul!(w, TI, z)
#mul!(w, TI, z)
for i in 1:num_stages
w[i] = zero(u)
for j in 1:num_stages
w[i] += TI[i,j] * z[j]
end
end
end

# Newton iteration
Expand All @@ -1620,13 +1644,19 @@ end
end
integrator.stats.nf += num_stages

mul!(fw, TI, ks)
#mul!(fw, TI, ks)
for i in 1:num_stages
fw[i] = zero(u)
for j in 1:num_stages
fw[i] += TI[i,j] * ks[j]
end
end

if mass_matrix === I
Mw = w
elseif mass_matrix isa UniformScaling
for i in 1 : num_stages
mul!(z[i], mass_matrix.λ, w[i])
mul!(z[i], mass_matrix.λ, w[i])
end
Mw = z
else
Expand Down Expand Up @@ -1698,7 +1728,13 @@ end
end

# transform `w` to `z`
mul!(z, T, w)
#mul!(z, T, w)
for i in 1:num_stages
z[i] = zero(u)
for j in 1:num_stages
z[i] += T[i,j] * w[j]
end
end
# check stopping criterion
iter > 1 &&= θ / (1 - θ))
if η * ndw < κ && (iter > 1 || iszero(ndw) || !iszero(integrator.success_iter))
Expand Down

0 comments on commit 3318e37

Please sign in to comment.