Skip to content

Commit

Permalink
Bug fix BinaryHeap:remove() + add MSTSearch.lua, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
hansonchar committed Feb 4, 2025
1 parent 42d0671 commit 143e5ca
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 18 deletions.
25 changes: 18 additions & 7 deletions learning-lua/algo/BinaryHeap.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
-- Courtesy of https://en.wikipedia.org/wiki/J._W._J._Williams
local BinaryHeap = {}

local TOMBSTONE <const> = {}

local function swap(a, i, j)
a[i], a[j] = a[j], a[i]
a[i].pos, a[j].pos = i, j
Expand All @@ -15,7 +17,7 @@ local function bubble_up(self, i)
local a, comp = self, self.comp
while i > 1 do
local p = i >> 1 -- parent
if comp(a[p].val, a[i].val) then -- value stored at index i is not smaller than that of its parent
if a[i].val ~= TOMBSTONE and comp(a[p].val, a[i].val) then -- value stored at index i is not smaller than that of its parent
return
end
swap(a, i, p) -- swap with parent
Expand Down Expand Up @@ -43,6 +45,13 @@ local function trickle_down(self, i)
end
end

local function remove_root(self)
local a = self
a[1], a[#a] = a[#a], nil
a[1].pos = 1
trickle_down(a, 1)
end

--- Move the last element to the root, and then maintain the heap invariant as necessary by repeatedly
--- swapping with the smallest/largest of the two children.
function BinaryHeap:remove(i)
Expand All @@ -55,15 +64,17 @@ function BinaryHeap:remove(i)
end
local a = self
assert(0 < i and i <= #a, "index out of bound")
local root = a[i]
local old_val = a[i].val
if #a == i then
a[i] = nil
else
a[i], a[#a] = a[#a], nil
a[i].pos = i
trickle_down(a, i)
elseif i == 1 then
remove_root(self)
else -- somewhere in the middle
a[i].val = TOMBSTONE
bubble_up(self, i) -- bubble up to the top, and
remove_root(self) -- then remove it
end
return root.val
return old_val
end

function BinaryHeap:top()
Expand Down
2 changes: 1 addition & 1 deletion learning-lua/algo/Graph.lua
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ function Graph:add(u, v, weight)
self[u].egress = self[u].egress or {}
local egress = self[u].egress
if v then
assert(not egress[v], string.format("Duplicate addition of %s-%s %d", u, v, weight))
assert(not egress[v], string.format("Duplicate addition of %s-%s %s", u, v, weight))
egress[v] = tonumber(weight)
add_vertex(self, v)
if self:is_ingress_built() then
Expand Down
40 changes: 40 additions & 0 deletions learning-lua/algo/MSTSearch-tests.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
local MSTSearch = require "algo.MSTSearch"
local Graph = require "algo.Graph"
local DEBUG = require "algo.Debug":new(false)
local debugf, debug = DEBUG.debugf, DEBUG.debug

local function load_input(input)
local G = Graph:new()
for _, edge in ipairs(input) do
local from, to, cost = edge:match('(%w-)%-(%w+)=(%d+)')
G:add_undirected(from, to, cost)
end
return G
end

local function test_mst(G, expected_total_weight)
for src in G:vertices() do
local total = 0
for u, v, weight in MSTSearch:new(G):iterate(src) do
debugf("%s-%s: %d", u, v, weight)
total = total + weight
end
assert(total == expected_total_weight)
debug()
end
end

-- Algorithms Illuminated Part 3 Example 15.2.1 by Time Roughgarden.
print("Testing Tim MST ...")
test_mst(load_input {'a-b=1', 'b-d=2', 'a-d=3', 'a-c=4', 'c-d=5'}, 7)

print("Testing geeksforgeeks MST ...")
-- https://www.geeksforgeeks.org/kruskals-minimum-spanning-tree-algorithm-greedy-algo-2/
test_mst(load_input {
'0-1=4', '0-7=8',
'1-2=8', '1-7=11',
'2-3=7', '2-5=4', '2-8=2',
'3-4=9', '3-5=14',
'4-5=10', '5-6=2',
'6-7=1', '6-8=6', '7-8=7'
}, 37)
79 changes: 79 additions & 0 deletions learning-lua/algo/MSTSearch.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
local BinaryHeap = require "algo.BinaryHeap"
local GraphSearch = require "algo.GraphSearch"
local MSTSearch = GraphSearch:class()

local DEBUG = require "algo.Debug":new(false)
local debugf, debug = DEBUG.debugf, DEBUG.debug

-- Format of a heap entry: {weight, u, v}. Note a heap reference, in contrast, has the structure:
-- { pos=(number), val=(heap entry) }
local WEIGHT <const>, U <const>, V <const> = 1, 2, 3

-- Iterates through a miminum spanning tree of a given undirected connected graph.
-- Greedily expand the MST frontier by selecting the least weight among all the nodes
-- outside but directly connected to the frontier.
-- This is basically Prim's Algorithm, optimized with a binary heap.
---@param self (table)
---@param src (string?) a source node, or default to an arbitrary node if not specified.
local function _iterate(self, src)
-- In general, nodes that belong to the MST is put as key into 'mst' with it's from edge as the value.
-- One exception is the source node which has itself as the value in the 'mst'.
src = src or next(self.graph)
-- Used to keep track of the MST with each node as the key,
-- and each "from" node as the value.
local mst = {[src] = src}
local heap_ref = {} -- keep track of the heap references, with the node as the key.
local heap = BinaryHeap:new({}, function(a_entry, b_entry)
local a, b = a_entry[WEIGHT], b_entry[WEIGHT]
return a <= b
end)
local u = src
local u_vertex = self.graph:vertex(u)
assert(u_vertex, "Vertex not found in graph.")
for v, weight in u_vertex:outgoings() do
assert(not heap_ref[v])
heap_ref[v] = heap:add {weight, u, v}
end
local entry = heap:remove()
while entry do
local weight, u, v = table.unpack(entry)
assert(heap_ref[v])
heap_ref[v] = nil -- entry no longer on the heap
if mst[v] then
-- debugf("Discarding %s-%s: %s", u, v, weight)
else
mst[v] = u
self._yield {u, v, weight}
for w, weight in self.graph:vertex(v):outgoings() do
if mst[w] then
else
local w_ref = heap_ref[w] -- Check if 'w' is currently on the heap.
if w_ref then -- If so, we update it only if we find an edge with a lesser weight to lead to it.
if weight < w_ref.val[WEIGHT] then
assert(w == w_ref.val[V])
debugf("%s - %s-%s: %s < %s-%s: %s on heap",
w, v, w, weight, w_ref.val[U], w, w_ref.val[WEIGHT])
heap:remove(w_ref.pos)
-- heap:verify()
heap_ref[w] = heap:add {weight, v, w}
-- heap:verify()
end
else -- w not on heap
heap_ref[w] = heap:add {weight, v, w}
end
end
end -- end for
end -- end if
entry = heap:remove()
end -- end while
assert(not heap:remove())
assert(not next(heap_ref))
end

---@param G (table) graph
function MSTSearch:new(G)
local o = getmetatable(self):new(G, _iterate)
return o
end

return MSTSearch
19 changes: 9 additions & 10 deletions learning-lua/algo/ShortestPathSearch-tests.lua
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
local ShortestPathSearch = require "algo.ShortestPathSearch"
local Graph = require "algo.Graph"

local function debug(...)
print(...)
end
local DEBUG = require "algo.Debug":new(false)
local debugf, debug = DEBUG.debugf, DEBUG.debug

local function basic_tests()
print("ShortestPathSearch basic tests...")
Expand All @@ -16,11 +15,11 @@ local function basic_tests()

local search = ShortestPathSearch:new(G)
for from, to, weight, level, min_cost in search:iterate('s') do
debug(string.format("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost))
debugf("%d: %s-%s=%d, min:%d", level, from, to, weight, min_cost)
end

local sssp = search:shortest_paths()
print(sssp)
debug(sssp)
assert(sssp:min_cost_of('s') == 0)
assert(sssp:min_cost_of('t') == 6)
assert(sssp:min_cost_of('v') == 1)
Expand Down Expand Up @@ -53,7 +52,7 @@ local function tim_test()
assert(count == 1)
end
local shortest_paths = search:shortest_paths()
print(shortest_paths)
debug(shortest_paths)
assert(shortest_paths:min_cost_of('s') == 0)
assert(shortest_paths:shortest_path_of('s') == 's')
assert(shortest_paths:min_cost_of('v') == 1)
Expand Down Expand Up @@ -85,7 +84,7 @@ local function geek_test()
assert(level_counts[4] == 1)

local sssp = search:shortest_paths()
print(sssp)
debug(sssp)
assert(sssp:shortest_path_of('0') == '0')
assert(sssp:min_cost_of('0') == 0)

Expand Down Expand Up @@ -133,7 +132,7 @@ local function redblobgames_test()
end

local sssp = search:shortest_paths()
print(sssp)
debug(sssp)
assert(sssp:shortest_path_of('A') == 'A')
assert(sssp:min_cost_of(src) == 0)

Expand Down Expand Up @@ -170,7 +169,7 @@ local function algodaily_test()
assert(level_counts[3] == 1)

local sssp = search:shortest_paths()
print(sssp)
debug(sssp)
assert(sssp:shortest_path_of('A') == 'A')
assert(sssp:min_cost_of(src) == 0)

Expand Down Expand Up @@ -211,7 +210,7 @@ local function scott_moura_test()
assert(level_counts[4] == 1)

local sssp = search:shortest_paths()
print(sssp)
debug(sssp)
assert(sssp:shortest_path_of('A') == 'A')
assert(sssp:min_cost_of(src) == 0)

Expand Down

0 comments on commit 143e5ca

Please sign in to comment.