Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add back_AD and test #133

Merged
merged 3 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/computational_graph/operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,66 @@ function backAD(diag::Graph{F,W}, debug::Bool=false) where {F,W}
# return dual[rootid]
end


function forwardAD_root(diags::AbstractVector{G}) where {G<:Graph}
dual = Dict{Int,G}()
# println("rootID: ", diag.id)
for diag in diags
for node in PreOrderDFS(diag)
visited = false
if haskey(dual, node.id)
dual[node.id].name != "None" && continue
visited = true
end
# println("Node: ", node.id)

if node.operator == Sum
nodes_deriv = G[]
for sub_node in node.subgraphs
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
push!(nodes_deriv, dual[sub_node.id])
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(G[]; name="None")
push!(nodes_deriv, g_dual)
dual[sub_node.id] = g_dual
end
end
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = node.subgraph_factors
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; subgraph_factors=node.subgraph_factors, factor=node.factor)
end
elseif node.operator == Prod
nodes_deriv = G[]
for (i, sub_node) in enumerate(node.subgraphs)
if haskey(dual, sub_node.id)
# println("subNode haskey: ", sub_node.id)
subgraphs = [j == i ? dual[subg.id] : subg for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))
else
# println("subNode nokey: ", sub_node.id)
g_dual = Graph(G[]; name="None")
dual[sub_node.id] = g_dual
subgraphs = [j == i ? g_dual : subg for (j, subg) in enumerate(node.subgraphs)]
push!(nodes_deriv, Graph(subgraphs; operator=Prod(), subgraph_factors=node.subgraph_factors))

end
end
if visited
dual[node.id].subgraphs = nodes_deriv
dual[node.id].subgraph_factors = one.(eachindex(nodes_deriv))
dual[node.id].name = node.name
else
dual[node.id] = Graph(nodes_deriv; factor=node.factor)
end
end
end
end
return dual
end

forwardAD_root(diag::Graph) = forwardAD_root([diag])
55 changes: 54 additions & 1 deletion test/computational_graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ end

@testset verbose = true "Auto Differentiation" begin
using FeynmanDiagram.ComputationalGraphs:
eval!, frontAD, backAD, node_derivative
eval!, frontAD, backAD, node_derivative, forwardAD_root
g1 = propagator(𝑓⁻(1)𝑓⁺(2))
g2 = propagator(𝑓⁻(3)𝑓⁺(4))
g3 = propagator(𝑓⁻(5)𝑓⁺(6), factor=2.0)
Expand Down Expand Up @@ -286,6 +286,59 @@ end
# end
# end
end
@testset "forwardAD_root" begin
F3 = g1 + g2
F2 = linear_combination([g1, g3, F3], [2, 1, 3])
F1 = Graph([g1, F2, F3], operator=Graphs.Prod(), subgraph_factors=[3.0, 1.0, 1.0])

dual = forwardAD_root(F1) # auto-differentation!
@test dual[F3.id].subgraphs == [dual[g1.id], dual[g2.id]]
@test dual[F2.id].subgraphs == [dual[g1.id], dual[g3.id], dual[F3.id]]

leafmap = Dict{Int,Int}()
leafmap[g1.id], leafmap[g2.id], leafmap[g3.id] = 1, 2, 3
leafmap[dual[g1.id].id] = 4
leafmap[dual[g2.id].id] = 5
leafmap[dual[g3.id].id] = 6
leaf = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0] # d F1 / d g1
@test eval!(dual[F1.id], leafmap, leaf) == 120.0
@test eval!(dual[F2.id], leafmap, leaf) == 5.0
@test eval!(dual[F3.id], leafmap, leaf) == 1.0

leaf = [5.0, -1.0, 2.0, 0.0, 1.0, 0.0] # d F1 / d g2
@test eval!(dual[F1.id], leafmap, leaf) == 570.0
@test eval!(dual[F2.id], leafmap, leaf) == 3.0
@test eval!(dual[F3.id], leafmap, leaf) == 1.0

leaf = [5.0, -1.0, 2.0, 0.0, 0.0, 1.0] # d F1 / d g3
@test eval!(dual[F1.id], leafmap, leaf) == 60.0
@test eval!(dual[F2.id], leafmap, leaf) == 1.0
@test eval!(dual[F3.id], leafmap, leaf) == 0.0

F0 = F1 * F3
dual1 = forwardAD_root(F0)
leafmap[dual1[g1.id].id] = 4
leafmap[dual1[g2.id].id] = 5
leafmap[dual1[g3.id].id] = 6

leaf = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 300.0
leaf = [5.0, -1.0, 2.0, 0.0, 1.0, 0.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 3840.0
leaf = [5.0, -1.0, 2.0, 0.0, 0.0, 1.0]
@test eval!(dual1[F0.id], leafmap, leaf) == 240.0
@test isequiv(dual[F1.id], dual1[F1.id], :id, :weight, :vertices)

F0_r1 = F1 + F3
dual = forwardAD_root([F0, F0_r1])
leafmap[dual[g1.id].id] = 4
leafmap[dual[g2.id].id] = 5
leafmap[dual[g3.id].id] = 6
@test eval!(dual[F0.id], leafmap, leaf) == 240.0
@test eval!(dual[F0_r1.id], leafmap, leaf) == 60.0
@test isequiv(dual[F0.id], dual1[F0.id], :id, :weight)
@test isequiv(dual[F1.id], dual1[F1.id], :id, :weight)
end
end

@testset verbose = true "Tree properties" begin
Expand Down