Skip to content

Commit

Permalink
Improve swap predicate with memoization and add condition for swaps o…
Browse files Browse the repository at this point in the history
…nly with closer nodes
  • Loading branch information
ba2tro committed Nov 22, 2024
1 parent a65fae9 commit bed4af2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 20 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
NetworkLayout = "46757867-2c16-5918-afeb-47bfcb05e46a"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down
2 changes: 1 addition & 1 deletion examples/controlplane/3a_cl_interactive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end
sim, net, obs, entlog, entlogaxis, fid_axis, histaxis, num_epr_axis, fig = prepare_vis(consumer)

step_ts = range(0.0, 1000.0, step=0.1)
record(fig, "sim.mp4", step_ts; framerate=10, visible=true) do t
record(fig, "sim2.mp4", step_ts; framerate=10, visible=true) do t
run(sim, t)
notify.((obs,entlog))
ylims!(entlogaxis, (-1.04,1.04))
Expand Down
23 changes: 14 additions & 9 deletions src/ProtocolZoo/ProtocolZoo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import ResumableFunctions
using ResumableFunctions: @resumable
import SumTypes
using Graphs
using Memoize

export
# protocols
Expand All @@ -25,7 +26,7 @@ export
# controllers
NetController, Controller, CLController,
#utils
PathMetadata, path_selection
PathMetadata, path_selection, swap_predicate, swap_choose

abstract type AbstractProtocol end

Expand Down Expand Up @@ -189,6 +190,8 @@ See also: [`SwapperProt`](@ref), [`EntanglementTracker`](@ref), [`EntanglementRe
src::Int
"""destination node for the `DistributionRequest`"""
dst::Int
"""whether the request is from a connection-oriented(0) or connection-less controller(1)"""
conn::Int
end
Base.show(io::IO, tag::SwapRequest) = print(io, "Node $(tag.swapping_node) perform a swap")
Tag(tag::SwapRequest) = Tag(SwapRequest, tag.swapping_node, tag.rounds, tag.src, tag.dst)
Expand Down Expand Up @@ -536,19 +539,21 @@ end
entangler = EntanglerProt(prot.sim, prot.net, prot.node, neighbor; rounds=rounds, randomize=true)
@process entangler()
else
msg = querydelete!(mb, requesttagsymbol, ❓, ❓, ❓, ❓)
msg = querydelete!(mb, requesttagsymbol, ❓, ❓, ❓, ❓, ❓)
@debug "RequestTracker @$(prot.node): Received $msg"
isnothing(msg) && continue
workwasdone = true
(msg_src, (_, _, rounds, req_src, req_dst)) = msg
(msg_src, (_, _, rounds, req_src, req_dst, conn)) = msg
@debug "RequestTracker @$(prot.node): Performing a swap"
if req_src == 0
swapper = SwapperProt(prot.sim, prot.net, prot.node; nodeL = <(prot.node), nodeH = >(prot.node), chooseL=argmin, chooseH=argmax, rounds=rounds)
if conn == 0 # connection-oriented
swapper = SwapperProt(prot.sim, prot.net, prot.node; nodeL = req_src, nodeH = req_dst, rounds=rounds)
@process swapper()
else
###instantiate predicate
### instantiate choosing function
swapper = SwapperProt(prot.sim, prot.net, prot.node; nodeL = <(prot.node), nodeH = >(prot.node), chooseL=argmin, chooseH=argmax, rounds=rounds)
else # connection-less
pred_low = swap_predicate(prot.net.graph, req_src, req_dst, prot.node)
pred_high = swap_predicate(prot.net.graph, req_src, req_dst, prot.node; low=false)
choose_low = swap_choose(prot.net.graph, req_src)
choose_high = swap_choose(prot.net.graph, req_dst)
swapper = SwapperProt(prot.sim, prot.net, prot.node; nodeL = pred_low, nodeH = pred_high, chooseL=choose_low, chooseH=choose_high, rounds=rounds)
@process swapper()
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/ProtocolZoo/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ end
end

for i in 2:length(path)-1
msg = Tag(SwapRequest, path[i], 1, 0, 0)
msg = Tag(SwapRequest, path[i], 1, path[i-1], path[i+1], 0)
if prot.node == path[i]
put!(mb, msg)
else
Expand Down Expand Up @@ -150,7 +150,7 @@ end
(msg_src, (_, req_src, req_dst)) = msg
for v in vertices(prot.net)
if v != req_src && v != req_dst
msg = Tag(SwapRequest, v, 1, req_src, req_dst)
msg = Tag(SwapRequest, v, 1, req_src, req_dst, 1)
if prot.node == v
put!(mb, msg)
else
Expand Down
26 changes: 18 additions & 8 deletions src/ProtocolZoo/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,27 @@ function findswapablequbits(net, node, pred_low, pred_high, choose_low, choose_h
end

"""
A generic predicate function for any arbitrary topology and entanglement flow
A generic predicate function for any arbitrary topology and entanglement flow used by [`SwapperProt`](@ref) and `findswapablequbits`
Returns a predicate function indicating whether a node is closer to Alice(source) or Bob(destination)
"""
function predicate(graph, src, dst, node; low=true)
d_src = length(a_star(graph, src, node))
d_dst = length(a_star(graph, dst, node))
return low ? d_src <= d_dst : d_src > d_dst
function swap_predicate(graph, src, dst, curr_node; low=true)
return node -> begin
d_src = get_distance(graph, src, node)
d_dst = get_distance(graph, dst, node)
is_closer = low ? d_src < get_distance(graph, src, curr_node) : d_dst < get_distance(graph, dst, curr_node)
res = d_src <= d_dst
low ? res & is_closer : !res & is_closer
end
end

"""
A generic choosing function for any arbitrary topology and entanglement flow
A generic choosing function for any arbitrary topology and entanglement flow. Returns the index of the node closest to `target_node`
for performing a swap.
"""
function choose(graph, target_node, arr)
return argmin(length.([a_star(graph, node, target_node) for node in arr]))
function swap_choose(graph, target_node)
return arr -> argmin([get_distance(graph, target_node, node) for node in arr])
end

@memoize function get_distance(graph, nodeA, nodeB)
return length(a_star(graph, nodeA, nodeB))
end

0 comments on commit bed4af2

Please sign in to comment.