Skip to content

Commit

Permalink
ADT recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
Grant Wuerker committed Feb 14, 2024
1 parent fd274c9 commit e3b2f62
Show file tree
Hide file tree
Showing 12 changed files with 422 additions and 61 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/common2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ camino = "1.1.4"
smol_str = "0.1.24"
salsa = { git = "https://github.com/salsa-rs/salsa", package = "salsa-2022" }
parser = { path = "../parser2", package = "fe-parser2" }
rustc-hash = "1.1.0"
1 change: 1 addition & 0 deletions crates/common2/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod diagnostics;
pub mod input;
pub mod recursion;

pub use input::{InputFile, InputIngot};

Expand Down
195 changes: 195 additions & 0 deletions crates/common2/src/recursion/dsf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
//! code copied from https://gist.github.com/jlikhuva/70bef29102d62054e8379a7c8f51ad87#
//!
//! Sometimes, we need to keep track of k disjoint groups of
//! items — meaning that each each item uniquely belongs to one group.
//! The most common operations that we'd like to run on our collection of groups are:
//! find_set(x) which tells us which of our k groups u belongs to and union(s1, s2)
//! which allows us to merge two groups. Remember in the section on connected
//! components we needed to quickly find out if a node was already part of some component.

use std::marker::PhantomData;

/// A single node in the disjoit set forest
#[derive(Debug, Eq, PartialEq)]
pub struct DSFNode<T: PartialEq> {
/// The payload stored at this node. This is a unique
/// identifier of a node
key: T,

/// Each node must have a parent. The root node
/// is its own parent
parent: usize,

/// The location of this node in forest
index: usize,

/// With each node, we maintain an integer value x.rank
/// This is an upper bound on the height of x -- that is
/// an upper bound on the number of edges in the longest
/// x -> descendant_leaf simple path. make_set, initializes
/// this value to 0
rank: usize,
}

#[derive(Debug, PartialEq, Eq)]
pub struct DSFNodeHandle<T>(usize, PhantomData<T>);

impl<T: PartialEq> DSFNode<T> {
/// Create a new node with the given value and parent.
/// On creation, the parent field should point to
/// the location of this node in the forest vector --
/// that is, a singleton node is its own parent
pub fn new(key: T, parent: usize) -> Self {
let rank = 0;
let index = parent;
DSFNode {
key,
parent,
rank,
index,
}
}
}
/// A disjoint set forest is a collection of trees.
/// Only the nodes in a single tree are linked together.
/// A user interacts with the forest using the nodes
#[derive(Debug)]
pub struct DSF<T: PartialEq> {
/// A collection of all the nodes in the tree
forest: Vec<DSFNode<T>>,
}

impl<T: PartialEq> DSF<T> {
/// Creates a new disjoint set forest structure with no trees in it
pub fn new() -> Self {
DSF { forest: Vec::new() }
}

/// Adds a new node into the disjoint set forest.
/// It returns a handle to a node that can be passed into
/// the other two methods
pub fn make_set(&mut self, x: T) -> DSFNodeHandle<T> {
let idx = self.forest.len();
self.forest.push(DSFNode::new(x, idx));
DSFNodeHandle(idx, PhantomData)
}

/// Th union operation has two bcases. if the roots have unequal rank, we make
/// the root with lower rank point to the root with the higher rank. The
/// ranks, however, do not change. If the roots have equal ranke, we choose one
/// of the roots as the root of the combined set. We also increase
/// the rank of the new root by 1.
pub fn union(&mut self, a: &DSFNodeHandle<T>, b: &DSFNodeHandle<T>) {
let a_root = Self::find_set_helper(&mut self.forest, a.0);
let b_root = Self::find_set_helper(&mut self.forest, b.0);

// We make the root with the higher rank the parent of the one
// with the lower rank. This effectively makes it the representative
// of the combined set
if self.forest[a_root].rank > self.forest[b_root].rank {
self.forest[b_root].parent = self.forest[a_root].index
} else {
self.forest[a_root].parent = self.forest[b_root].index;

// Note that ranks only change when we merge two trees with the same
// ranks. The choice of whose rank to increase is made arbitrarily
if self.forest[a_root].rank == self.forest[b_root].rank {
self.forest[b_root].rank += 1;
}
}
}

/// Finds the representative if x. Also does path compression. It does not change
/// the value of rank.
pub fn find_set(&mut self, x: &DSFNodeHandle<T>) -> DSFNodeHandle<T> {
let idx = Self::find_set_helper(&mut self.forest, x.0);
DSFNodeHandle(idx, PhantomData)
}

fn find_set_helper(forest: &mut Vec<DSFNode<T>>, x_index: usize) -> usize {
// When I first saw this method, I simply thought it was the coolest
// thing in the world. Recursion on recursive structures yields
// simple, elegant, yet powerful code. According to CLRS, this
// is an instance of a general method called `the two-pass` method.

// First make an upward pass to find the representative, i.e the root
// then make a downward pass, as the stack is being unwound, to set
// the parent of each node in the x -> root path
let cur_x_parent = forest[x_index].parent;
if cur_x_parent != x_index {
forest[x_index].parent = Self::find_set_helper(forest, cur_x_parent);
}
forest[x_index].parent
}
}

#[cfg(test)]
mod test {
#[test]
fn make_set() {
use super::DSF;
let mut forest = DSF::<&str>::new();
let _t1 = forest.make_set("good");
let _t2 = forest.make_set("splendid");
let _t3 = forest.make_set("remarkable");
let _t4 = forest.make_set("nice");
let _t5 = forest.make_set("amazing");
}

#[test]
fn union_and_find_set() {
use super::DSF;
let mut forest = DSF::<&str>::new();
// Synonyms for good
let t1 = forest.make_set("good");
let t2 = forest.make_set("splendid");
let t3 = forest.make_set("remarkable");
let t4 = forest.make_set("nice");
let t5 = forest.make_set("amazing");

// Assert Singleton Trees
assert_ne!(forest.find_set(&t1), forest.find_set(&t2));
assert_ne!(forest.find_set(&t2), forest.find_set(&t3));
assert_ne!(forest.find_set(&t3), forest.find_set(&t4));
assert_ne!(forest.find_set(&t4), forest.find_set(&t5));

// Synonyms for bad
let t6 = forest.make_set("bad");
let t7 = forest.make_set("schlecht");
let t8 = forest.make_set("unpleasany");
let t9 = forest.make_set("poor");

// Assert Singleton Trees
assert_ne!(forest.find_set(&t6), forest.find_set(&t7));
assert_ne!(forest.find_set(&t7), forest.find_set(&t8));
assert_ne!(forest.find_set(&t8), forest.find_set(&t9));

// Union Galore
forest.union(&t1, &t2);
forest.union(&t1, &t3);
forest.union(&t2, &t4);
forest.union(&t5, &t3);

forest.union(&t6, &t7);
forest.union(&t8, &t9);
forest.union(&t9, &t7);

// Assert Only 2 disjoint sets
//
// First Set
assert_eq!(forest.find_set(&t1), forest.find_set(&t2));
assert_eq!(forest.find_set(&t2), forest.find_set(&t3));
assert_eq!(forest.find_set(&t3), forest.find_set(&t4));
assert_eq!(forest.find_set(&t4), forest.find_set(&t5));

// Second Set
assert_eq!(forest.find_set(&t6), forest.find_set(&t7));
assert_eq!(forest.find_set(&t7), forest.find_set(&t8));
assert_eq!(forest.find_set(&t8), forest.find_set(&t9));

// Assert Disjointness
assert_ne!(forest.find_set(&t6), forest.find_set(&t1));
assert_ne!(forest.find_set(&t7), forest.find_set(&t3));
assert_ne!(forest.find_set(&t8), forest.find_set(&t4));
}
}
125 changes: 125 additions & 0 deletions crates/common2/src/recursion/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::hash::Hash;

use self::dsf::{DSFNodeHandle, DSF};
use rustc_hash::FxHashMap;

mod dsf;

/// `RecursionConstituent` stores information about part of a recursion. Constituents
/// of a single recursion can be joined using `RecursionHelper`.
///
/// `T` is the recursion's identifier type and `U` carries diagnostic information.
#[derive(Eq, PartialEq, Clone, Debug, Hash)]
pub struct RecursionConstituent<T, U>
where
T: PartialEq + Copy,
{
/// From where the constituent originates.
pub from: (T, U),
/// To where the constituent goes.
pub to: (T, U),
}

impl<T, U> RecursionConstituent<T, U>
where
T: PartialEq + Copy,
{
pub fn new(from: (T, U), to: (T, U)) -> Self {
Self { from, to }
}
}

pub struct RecursionHelper<T, U>
where
T: PartialEq + Copy,
{
constituents: Vec<RecursionConstituent<T, U>>,
forest: DSF<T>,
trees: FxHashMap<T, DSFNodeHandle<T>>,
}

/// `RecursionHelper` uses a disjoint set forest to unify constituents of recursions.
impl<T, U> RecursionHelper<T, U>
where
T: Eq + Hash + PartialEq + Copy,
{
pub fn new(constituents: Vec<RecursionConstituent<T, U>>) -> Self {
let mut forest = DSF::<_>::new();
let trees: FxHashMap<_, _> = constituents
.iter()
.map(|constituent| (constituent.from.0, forest.make_set(constituent.from.0)))
.collect();

for constituent in constituents.iter() {
forest.union(&trees[&constituent.from.0], &trees[&constituent.to.0])
}

Self {
constituents,
forest,
trees,
}
}

/// Removes a set of disjoint constituents from the helper and returns them.
///
/// This should be called until the disjoint set is empty.
pub fn remove_disjoint_set(&mut self) -> Option<Vec<RecursionConstituent<T, U>>> {
let mut disjoint_set = vec![];
let mut remaining_set = vec![];
let mut set_id = None;

while let Some(constituent) = self.constituents.pop() {
let cur_set_id = self.forest.find_set(&self.trees[&constituent.from.0]);

if set_id == None {
set_id = Some(cur_set_id);
disjoint_set.push(constituent);
} else if set_id == Some(cur_set_id) {
disjoint_set.push(constituent)
} else {
remaining_set.push(constituent)
}
}

self.constituents = remaining_set;

if set_id.is_some() {
Some(disjoint_set)
} else {
None
}
}
}

#[test]
fn one_recursion() {
let constituents = vec![
RecursionConstituent::new((0, ()), (1, ())),
RecursionConstituent::new((1, ()), (0, ())),
];

let mut helper = RecursionHelper::new(constituents);
let disjoint_constituents = helper.remove_disjoint_set();
// panic!("{:?}", disjoint_constituents)
// assert_eq!(disjoint_constituents[0].from.0, 0);
// assert_eq!(disjoint_constituents[1].from.0, 0);
}

#[test]
fn two_recursions() {
let constituents = vec![
RecursionConstituent::new((0, ()), (1, ())),
RecursionConstituent::new((1, ()), (0, ())),
RecursionConstituent::new((2, ()), (3, ())),
RecursionConstituent::new((3, ()), (4, ())),
RecursionConstituent::new((4, ()), (2, ())),
];

let mut helper = RecursionHelper::new(constituents);
let disjoint_constituents1 = helper.remove_disjoint_set();
let disjoint_constituents2 = helper.remove_disjoint_set();
// panic!("{:?}", disjoint_constituents1)
// assert_eq!(disjoint_constituents[0].from.0, 0);
// assert_eq!(disjoint_constituents[1].from.0, 0);
}
1 change: 1 addition & 0 deletions crates/hir-analysis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ pub struct Jar(
ty::diagnostics::ImplTraitDefDiagAccumulator,
ty::diagnostics::ImplDefDiagAccumulator,
ty::diagnostics::FuncDefDiagAccumulator,
ty::diagnostics::AdtRecursionConstituentAccumulator,
);

pub trait HirAnalysisDb: salsa::DbWithJar<Jar> + HirDb {
Expand Down
Loading

0 comments on commit e3b2f62

Please sign in to comment.