Skip to content

Commit

Permalink
Merge pull request #133 from numericalEFT/hpc_AD
Browse files Browse the repository at this point in the history
add back_AD and test
  • Loading branch information
fsxbhyy authored Sep 25, 2023
2 parents 7047e66 + fef35c2 commit 6449ea7
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
63 changes: 63 additions & 0 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,66 @@ function backAD(diag::Graph{F,W}, debug::Bool=false) where {F,W}
# return dual[rootid]
end


function forwardAD_root(diags::AbstractVector{G}) where {G<:Graph}
dual = Dict{Int,G}()
# println("rootID: ", diag.id)
for diag in diags
for node in PreOrderDFS(diag)
visited = false
if haskey(dual, node.id)
dual[node.id].name != "None" && continue
visited = true
end
# println("Node: ", node.id)

if node.operator == Sum
nodes_deriv = G[]
for sub_node in node.subgraphs
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
push!(nodes_deriv, dual[sub_node.id])
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(G[]; name="None")
push!(nodes_deriv, g_dual)
dual[sub_node.id] = g_dual
end
end
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = node.subgraph_factors
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors, factor=node.factor)
end
elseif node.operator == Prod
nodes_deriv = G[]
for (i, sub_node) in enumerate(node.subgraphs)
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
subgraphs = [j == i ? dual[subg.id] : subg for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(G[]; name="None")
dual[sub_node.id] = g_dual
subgraphs = [j == i ? g_dual : subg for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))

end
end
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = one.(eachindex(nodes_deriv))
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; factor=node.factor)
end
end
end
end
return dual
end

forwardAD_root(diag::Graph) = forwardAD_root([diag])
55 changes: 54 additions & 1 deletion test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ end

@testset verbose = true "Auto Differentiation" begin
using FeynmanDiagram.ComputationalGraphs:
eval!, frontAD, backAD, node_derivative
eval!, frontAD, backAD, node_derivative, forwardAD_root
g1 = propagator(𝑓⁻(1)𝑓⁺(2))
g2 = propagator(𝑓⁻(3)𝑓⁺(4))
g3 = propagator(𝑓⁻(5)𝑓⁺(6), factor=2.0)
Expand Down Expand Up @@ -286,6 +286,59 @@ end
# end
# end
end
@testset "forwardAD_root" begin
F3 = g1 + g2
F2 = linear_combination([g1, g3, F3], [2, 1, 3])
F1 = Graph([g1, F2, F3], operator=Graphs.Prod(), subgraph_factors=[3.0, 1.0, 1.0])

dual = forwardAD_root(F1) # auto-differentation!
@test dual[F3.id].subgraphs == [dual[g1.id], dual[g2.id]]
@test dual[F2.id].subgraphs == [dual[g1.id], dual[g3.id], dual[F3.id]]

leafmap = Dict{Int,Int}()
leafmap[g1.id], leafmap[g2.id], leafmap[g3.id] = 1, 2, 3
leafmap[dual[g1.id].id] = 4
leafmap[dual[g2.id].id] = 5
leafmap[dual[g3.id].id] = 6
leaf = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0] # d F1 / d g1
@test eval!(dual[F1.id], leafmap, leaf) == 120.0
@test eval!(dual[F2.id], leafmap, leaf) == 5.0
@test eval!(dual[F3.id], leafmap, leaf) == 1.0

leaf = [5.0, -1.0, 2.0, 0.0, 1.0, 0.0] # d F1 / d g2
@test eval!(dual[F1.id], leafmap, leaf) == 570.0
@test eval!(dual[F2.id], leafmap, leaf) == 3.0
@test eval!(dual[F3.id], leafmap, leaf) == 1.0

leaf = [5.0, -1.0, 2.0, 0.0, 0.0, 1.0] # d F1 / d g3
@test eval!(dual[F1.id], leafmap, leaf) == 60.0
@test eval!(dual[F2.id], leafmap, leaf) == 1.0
@test eval!(dual[F3.id], leafmap, leaf) == 0.0

F0 = F1 * F3
dual1 = forwardAD_root(F0)
leafmap[dual1[g1.id].id] = 4
leafmap[dual1[g2.id].id] = 5
leafmap[dual1[g3.id].id] = 6

leaf = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 300.0
leaf = [5.0, -1.0, 2.0, 0.0, 1.0, 0.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 3840.0
leaf = [5.0, -1.0, 2.0, 0.0, 0.0, 1.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 240.0
@test isequiv(dual[F1.id], dual1[F1.id], :id, :weight, :vertices)

F0_r1 = F1 + F3
dual = forwardAD_root([F0, F0_r1])
leafmap[dual[g1.id].id] = 4
leafmap[dual[g2.id].id] = 5
leafmap[dual[g3.id].id] = 6
@test eval!(dual[F0.id], leafmap, leaf) == 240.0
@test eval!(dual[F0_r1.id], leafmap, leaf) == 60.0
@test isequiv(dual[F0.id], dual1[F0.id], :id, :weight)
@test isequiv(dual[F1.id], dual1[F1.id], :id, :weight)
end
end

@testset verbose = true "Tree properties" begin
Expand Down

0 comments on commit 6449ea7

Please sign in to comment.