Skip to content

Commit

Permalink
Added array optimisations
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiManyer committed Apr 3, 2024
1 parent f080df6 commit 1f6ab00
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 69 deletions.
22 changes: 15 additions & 7 deletions src/ChainRules/AlgebraicIntegrandOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Algebraic view of an `IntegrandOperator` that allows working with plain
arrays instead of `FEFunction`s and `DomainContribution`s.
"""
struct AlgebraicIntegrandOperator{A,B,C,D}
struct AlgebraicIntegrandOperator{A,B,C,D} <: IntegrandOperator
F :: A
spaces :: B
assems :: C
Expand Down Expand Up @@ -34,21 +34,29 @@ function AlgebraicIntegrandOperator(
return AlgebraicIntegrandOperator(op,spaces;assems=assems)
end

function gradient_cache(AF::AlgebraicIntegrandOperator,uh,k)
function evaluate_cache(AF::AlgebraicIntegrandOperator,uh)
return evaluate_cache(AF.F,uh)
end

function Arrays.evaluate!(cache,AF::AlgebraicIntegrandOperator,uh;updated=false)
return evaluate!(cache,AF.F,uh;updated)
end

function gradient_cache(AF::AlgebraicIntegrandOperator,uh,K)
@check 0 < K <= length(AF.spaces)
Vk, ak = AF.spaces[k], AF.assems[k]
Vk, ak = AF.spaces[K], AF.assems[K]

dFduk_cache = gradient_cache(AF.F,uh,k)
dFduk = gradient!(dFduk_cache,AF.F,uh,k)
dFduk_cache = gradient_cache(AF.F,uh,K)
dFduk = gradient!(dFduk_cache,AF.F,uh,K)
xk = allocate_vector(ak,collect_cell_vector(Vk,dFduk))
return xk, dFduk_cache
end

function gradient!(cache,AF::AlgebraicIntegrandOperator,uh,K)
function gradient!(cache,AF::AlgebraicIntegrandOperator,uh,K;updated=false)
@check 0 < K <= length(AF.spaces)
xk, dFduk_cache = cache

dFduk = gradient!(dFduk_cache,AF.F,uh,K)
dFduk = gradient!(dFduk_cache,AF.F,uh,K;updated)
assemble_vector!(xk,ak,collect_cell_vector(Vk,dFduk))
return xk
end
46 changes: 46 additions & 0 deletions src/ChainRules/IntegrandOperatorArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@


function evaluate_cache(a::AbstractVector{<:IntegrandOperator},uh)
return evaluate_cache(first(a),uh)
end

function Arrays.evaluate!(cache,a::AbstractVector{<:IntegrandOperator},uh)
map(enumerate(a)) do (i,ai)
evaluate!(cache,ai,uh;updated=(i!=1))
end
end

function Arrays.evaluate(a::AbstractVector{<:IntegrandOperator},uh)
cache = evaluate_cache(a,uh)
return evaluate!(cache,a,uh)
end

function gradient_cache(a::AbstractVector{<:IntegrandOperator},uh,K)
return gradient_cache(first(a),uh,K)
end

function gradient!(cache,a::AbstractVector{<:IntegrandOperator},uh,K)
map(enumerate(a)) do (i,ai)
gradient!(cache,ai,uh,K;updated=(i!=1))
end
end

function Gridap.gradient(a::AbstractVector{<:IntegrandOperator},uh,K)
cache = gradient_cache(a,uh,K)
return gradient!(cache,a,uh,K)
end

function evaluate_and_gradient_cache(a::AbstractVector{<:IntegrandOperator},uh,K)
return evaluate_and_gradient_cache(first(a),uh,K)
end

function evaluate_and_gradient!(cache,a::AbstractVector{<:IntegrandOperator},uh,K)
map(enumerate(a)) do (i,ai)
evaluate_and_gradient!(cache,ai,uh,K;updated=(i!=1))
end
end

function evaluate_and_gradient(a::AbstractVector{<:IntegrandOperator},uh,K)
cache = evaluate_and_gradient_cache(a,uh,K)
return evaluate_and_gradient!(cache,a,uh,K)
end
83 changes: 69 additions & 14 deletions src/ChainRules/IntegrandOperators.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,49 @@
abstract type IntegrandOperator end

function gradient_cache(F::IntegrandOperator,uh,K)
(F::IntegrandOperator)(args...) = evaluate(F,args...)

function evaluate_cache(F::IntegrandOperator,uh)
return nothing
end

function gradient!(cache,F::IntegrandOperator,uh,K)
function Arrays.evaluate!(cache,F::IntegrandOperator,uh;updated=false)
@abstractmethod
end

Gridap.gradient(F::IntegrandOperator,uh) = Gridap.gradient(F,[uh],1)
function Arrays.evaluate(F::IntegrandOperator,uh)
cache = evaluate_cache(F,uh)
return evaluate!(cache,F,uh)
end

function gradient_cache(F::IntegrandOperator,uh,K)
return nothing
end

function gradient!(cache,F::IntegrandOperator,uh,K;updated=false)
@abstractmethod
end

function Gridap.gradient(F::IntegrandOperator,uh,K)
cache = gradient_cache(F,uh,K)
return gradient!(cache,F,uh,K)
end

function evaluate_and_gradient_cache(F::IntegrandOperator,uh,K)
eval_cache = evaluate_cache(F,uh)
grad_cache = gradient_cache(F,uh,K)
return eval_cache, grad_cache
end

function evaluate_and_gradient!(cache,F::IntegrandOperator,uh,K;updated=false)
eval_cache, grad_cache = cache
return evaluate!(eval_cache,F,uh;updated), gradient!(grad_cache,F,uh,K;updated)
end

function evaluate_and_gradient(F::IntegrandOperator,uh,K)
cache = evaluate_and_gradient_cache(F,uh,K)
return evaluate_and_gradient!(cache,F,uh,K)
end

"""
struct GenericIntegrandOperator{A,B<:Tuple} <: IntegrandOperator
Expand All @@ -36,15 +65,17 @@ struct GenericIntegrandOperator{A,B<:Tuple}
end
end

(F::GenericIntegrandOperator)(args...) = F.F(args...,F....)
function evaluate!(cache,F::GenericIntegrandOperator,uh)
return F.F(uh...,F....)
end

function Gridap.gradient!(cache,F::GenericIntegrandOperator,uh::Vector{<:FEFunction},K::Int)
function gradient!(cache,F::GenericIntegrandOperator,uh::Vector{<:FEFunction},K::Int)
@check 0 < K <= length(uh)
_f(uk) = F.F(uh[1:K-1]...,uk,uh[K+1:end]...,F....)
return Gridap.gradient(_f,uh[K])
end

function Gridap.gradient!(cache,F::GenericIntegrandOperator,uh::Vector,K::Int)
function gradient!(cache,F::GenericIntegrandOperator,uh::Vector,K::Int)
@check 0 < K <= length(uh)
local_fields = map(local_views,uh) |> to_parray_of_arrays
local_measures = map(local_views,F.dΩ) |> to_parray_of_arrays
Expand Down Expand Up @@ -76,31 +107,55 @@ struct ParametricIntegrandOperator{A,B} <: IntegrandOperator
end
end

function evaluate_cache(G::ParametricIntegrandOperator,φh)
uh = zero(get_trial_space(G.state_map))
F_cache = evaluate_cache(G.F,[uh,φh])
return uh, F_cache
end

function Arrays.evaluate!(cache,G::ParametricIntegrandOperator,φh;updated=false)
uh, F_cache = cache
if update
uh = forward_solve!(G.state_map,uh,φh)
end
return evaluate!(F_cache,G.F,[uh,φh])
end

function gradient_cache(G::ParametricIntegrandOperator,φh,K)
@check K == 1
return gradient_cache(G.F,φh[1])
end

function gradient!(cache,G::ParametricIntegrandOperator,φh,K)
@check K == 1
return gradient!(cache,G.F,φh[1])
end

function gradient_cache(G::ParametricIntegrandOperator,φh)
U = get_trial_space(G.state_map)
uh = zero(U)

dFdu_cache = gradient_cache(AF.F,[uh,φh],1)
dFdφ_cache = gradient_cache(AF.F,[uh,φh],2)

dFdu = gradient!(dFdu_cache,AF.F,[uh,φh],1)
x = allocate_vector(get_pde_assembler(G.state_map),collect_cell_vector(U,dFdu))
return x, dFdu_cache, dFdφ_cache
dFdu_vec = allocate_vector(get_pde_assembler(G.state_map),collect_cell_vector(U,dFdu))
return uh, dFdu_vec, dFdu_cache, dFdφ_cache
end

function gradient!(cache,G::ParametricIntegrandOperator,φh,K)
@check K == 1
dFdu_vec, dFdu_cache, dFdφ_cache = cache
function gradient!(cache,G::ParametricIntegrandOperator,φh;updated=false)
uh, dFdu_vec, dFdu_cache, dFdφ_cache = cache
U = get_trial_space(G.state_map)

u, u_pullback = rrule(G.state_map,φh)
uh = FEFunction(U,u)
if !updated
uh = forward_solve!(G.state_map,uh,φh)
update_adjoint_caches!(G.state_map,uh,φh)
end

dFdu = gradient!(dFdu_cache,AF.F,[uh,φh],1)
dFdφ = gradient!(dFdφ_cache,AF.F,[uh,φh],2)

assemble_vector!(dFdu_vec,collect_cell_vector(U,dFdu))
dGdφ = dFdφ + u_pullback(dFdu_vec)
dGdφ = dFdφ + pullback(G.state_map,uh,φh,dFdu_vec;updated=true)
return dGdφ
end
Loading

0 comments on commit 1f6ab00

Please sign in to comment.