-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bug fix BinaryHeap:remove() + add MSTSearch.lua, etc.
- Loading branch information
1 parent
42d0671
commit 143e5ca
Showing
5 changed files
with
147 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
local MSTSearch = require "algo.MSTSearch" | ||
local Graph = require "algo.Graph" | ||
local DEBUG = require "algo.Debug":new(false) | ||
local debugf, debug = DEBUG.debugf, DEBUG.debug | ||
|
||
local function load_input(input) | ||
local G = Graph:new() | ||
for _, edge in ipairs(input) do | ||
local from, to, cost = edge:match('(%w-)%-(%w+)=(%d+)') | ||
G:add_undirected(from, to, cost) | ||
end | ||
return G | ||
end | ||
|
||
local function test_mst(G, expected_total_weight) | ||
for src in G:vertices() do | ||
local total = 0 | ||
for u, v, weight in MSTSearch:new(G):iterate(src) do | ||
debugf("%s-%s: %d", u, v, weight) | ||
total = total + weight | ||
end | ||
assert(total == expected_total_weight) | ||
debug() | ||
end | ||
end | ||
|
||
-- Algorithms Illuminated Part 3 Example 15.2.1 by Time Roughgarden. | ||
print("Testing Tim MST ...") | ||
test_mst(load_input {'a-b=1', 'b-d=2', 'a-d=3', 'a-c=4', 'c-d=5'}, 7) | ||
|
||
print("Testing geeksforgeeks MST ...") | ||
-- https://www.geeksforgeeks.org/kruskals-minimum-spanning-tree-algorithm-greedy-algo-2/ | ||
test_mst(load_input { | ||
'0-1=4', '0-7=8', | ||
'1-2=8', '1-7=11', | ||
'2-3=7', '2-5=4', '2-8=2', | ||
'3-4=9', '3-5=14', | ||
'4-5=10', '5-6=2', | ||
'6-7=1', '6-8=6', '7-8=7' | ||
}, 37) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
local BinaryHeap = require "algo.BinaryHeap" | ||
local GraphSearch = require "algo.GraphSearch" | ||
local MSTSearch = GraphSearch:class() | ||
|
||
local DEBUG = require "algo.Debug":new(false) | ||
local debugf, debug = DEBUG.debugf, DEBUG.debug | ||
|
||
-- Format of a heap entry: {weight, u, v}. Note a heap reference, in contrast, has the structure: | ||
-- { pos=(number), val=(heap entry) } | ||
local WEIGHT <const>, U <const>, V <const> = 1, 2, 3 | ||
|
||
-- Iterates through a miminum spanning tree of a given undirected connected graph. | ||
-- Greedily expand the MST frontier by selecting the least weight among all the nodes | ||
-- outside but directly connected to the frontier. | ||
-- This is basically Prim's Algorithm, optimized with a binary heap. | ||
---@param self (table) | ||
---@param src (string?) a source node, or default to an arbitrary node if not specified. | ||
local function _iterate(self, src) | ||
-- In general, nodes that belong to the MST is put as key into 'mst' with it's from edge as the value. | ||
-- One exception is the source node which has itself as the value in the 'mst'. | ||
src = src or next(self.graph) | ||
-- Used to keep track of the MST with each node as the key, | ||
-- and each "from" node as the value. | ||
local mst = {[src] = src} | ||
local heap_ref = {} -- keep track of the heap references, with the node as the key. | ||
local heap = BinaryHeap:new({}, function(a_entry, b_entry) | ||
local a, b = a_entry[WEIGHT], b_entry[WEIGHT] | ||
return a <= b | ||
end) | ||
local u = src | ||
local u_vertex = self.graph:vertex(u) | ||
assert(u_vertex, "Vertex not found in graph.") | ||
for v, weight in u_vertex:outgoings() do | ||
assert(not heap_ref[v]) | ||
heap_ref[v] = heap:add {weight, u, v} | ||
end | ||
local entry = heap:remove() | ||
while entry do | ||
local weight, u, v = table.unpack(entry) | ||
assert(heap_ref[v]) | ||
heap_ref[v] = nil -- entry no longer on the heap | ||
if mst[v] then | ||
-- debugf("Discarding %s-%s: %s", u, v, weight) | ||
else | ||
mst[v] = u | ||
self._yield {u, v, weight} | ||
for w, weight in self.graph:vertex(v):outgoings() do | ||
if mst[w] then | ||
else | ||
local w_ref = heap_ref[w] -- Check if 'w' is currently on the heap. | ||
if w_ref then -- If so, we update it only if we find an edge with a lesser weight to lead to it. | ||
if weight < w_ref.val[WEIGHT] then | ||
assert(w == w_ref.val[V]) | ||
debugf("%s - %s-%s: %s < %s-%s: %s on heap", | ||
w, v, w, weight, w_ref.val[U], w, w_ref.val[WEIGHT]) | ||
heap:remove(w_ref.pos) | ||
-- heap:verify() | ||
heap_ref[w] = heap:add {weight, v, w} | ||
-- heap:verify() | ||
end | ||
else -- w not on heap | ||
heap_ref[w] = heap:add {weight, v, w} | ||
end | ||
end | ||
end -- end for | ||
end -- end if | ||
entry = heap:remove() | ||
end -- end while | ||
assert(not heap:remove()) | ||
assert(not next(heap_ref)) | ||
end | ||
|
||
---@param G (table) graph | ||
function MSTSearch:new(G) | ||
local o = getmetatable(self):new(G, _iterate) | ||
return o | ||
end | ||
|
||
return MSTSearch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters