Skip to content

Commit

Permalink
Implement backwards dataflow analysis to fix tests (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
julian-hartl authored Mar 24, 2024
1 parent e9e8cf1 commit be77b30
Show file tree
Hide file tree
Showing 26 changed files with 622 additions and 447 deletions.
50 changes: 10 additions & 40 deletions ir/crates/back/src/codegen/machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use smallvec::{SmallVec, smallvec};
use tracing::debug;

pub use abi::Abi;
use natrix_middle;
use natrix_middle::instruction::CmpOp;
use natrix_middle::ty::Type;
pub use module::Module;
Expand Down Expand Up @@ -1055,28 +1054,28 @@ impl<A: Abi> BasicBlock<A> {
}

#[derive(Debug)]
pub struct FunctionBuilder<B: Backend> {
function: Function<B::ABI>,
backend: B,
pub struct FunctionBuilder<TM: TargetMachine> {
function: Function<TM::Abi>,
backend: TM::Backend,
bb_mapping: FxHashMap<natrix_middle::cfg::BasicBlockId, BasicBlockId>,
}

impl<B: Backend> FunctionBuilder<B> {
impl<TM: TargetMachine> FunctionBuilder<TM> {
pub fn new() -> Self {
Self {
function: Function::new(Default::default()),
backend: B::new(),
backend: TM::backend(),
bb_mapping: FxHashMap::default(),
}
}

pub fn build(mut self, function: &mut natrix_middle::Function) -> Function<B::ABI> {
pub fn build(mut self, function: &mut natrix_middle::Function) -> Function<TM::Abi> {
self.function.name = function.name.clone();
self.function.return_ty_size = Size::from_ty(
&function.ret_ty
);
debug!("Building machine function for function {}", function.name);
let mut sel_dag_builder = selection_dag::Builder::<B::ABI>::new(&mut self.function);
let mut sel_dag_builder = selection_dag::Builder::new(&mut self.function);
let mut sel_dag = sel_dag_builder.build(function);
for bb in function.cfg.basic_block_ids_ordered() {
self.create_bb(bb);
Expand Down Expand Up @@ -1191,7 +1190,7 @@ impl<B: Backend> FunctionBuilder<B> {
let mut matching_pattern = None;
debug!("Matching patterns for node {:?}", op);

for pattern in B::patterns() {
for pattern in TM::Backend::patterns() {
let pattern_in = pattern.in_();
debug!("Checking {:?}", pattern_in);
debug!("Matching with {:?}", dag_node_pattern);
Expand Down Expand Up @@ -1276,49 +1275,20 @@ impl<B: Backend> FunctionBuilder<B> {
mbb
}

fn operand_to_matched_pattern_operand(&self, src: &Operand<<B as Backend>::ABI>) -> MatchedPatternOperand<<B as Backend>::ABI> {
fn operand_to_matched_pattern_operand(&self, src: &Operand<TM::Abi>) -> MatchedPatternOperand<TM::Abi> {
match src {
Operand::Reg(reg) => MatchedPatternOperand::Reg(reg.clone()),
Operand::Imm(imm) => MatchedPatternOperand::Imm(imm.clone()),
}
}

fn operand_to_pattern(&self, src: &Operand<<B as Backend>::ABI>) -> PatternInOperand {
fn operand_to_pattern(&self, src: &Operand<TM::Abi>) -> PatternInOperand {
match src {
Operand::Reg(reg) => PatternInOperand::Reg(reg.size(&self.function)),
Operand::Imm(imm) => PatternInOperand::Imm(imm.size),
}
}
}

#[cfg(test)]
mod tests {
use tracing_test::traced_test;

use crate::codegen::isa;
use crate::natrix_middle::cfg;
use crate::natrix_middle::cfg::{RetTerm, TerminatorKind};
use crate::natrix_middle::instruction::{Const, Op};
use crate::test_utils::create_test_function;

#[test]
#[traced_test]
fn test() {
let mut function = create_test_function();
let mut builder = cfg::Builder::new(&mut function);
let bb = builder.start_bb();
let (val1, _) = builder.op(None, Op::Const(Const::i32(323))).unwrap();
let (val2, _) = builder.op(None, Op::Const(Const::i32(90))).unwrap();
let (return_value, _) = builder.sub(None, Op::Value(val1), Op::Value(val2)).unwrap();
builder.end_bb(TerminatorKind::Ret(RetTerm::new(Op::Value(return_value))));
drop(builder);
let function_builder = super::FunctionBuilder::<isa::x86_64::Backend>::new();
let function = function_builder.build(&function);
println!("{:?}", function.basic_blocks);
println!("{}", function);
function.assemble();
}
}



2 changes: 1 addition & 1 deletion ir/crates/back/src/codegen/machine/module/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ impl<'module, TM: TargetMachine> Builder<'module, TM> {

pub fn build(mut self) -> Module<TM> {
for (_, function) in &mut self.module.functions {
let builder = FunctionBuilder::<TM::Backend>::new();
let builder = FunctionBuilder::<TM>::new();
self.mtbb.functions.push(
builder.build(function)
);
Expand Down
152 changes: 132 additions & 20 deletions ir/crates/back/src/codegen/register_allocator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,16 @@ impl Display for InstrUid {
}
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProgPoint {
Read(InstrNr),
Write(InstrNr),
}

impl ProgPoint {
pub fn instr_nr(&self) -> InstrNr {
pub const fn instr_nr(&self) -> InstrNr {
match self {
ProgPoint::Read(nr) => *nr,
ProgPoint::Write(nr) => *nr,
Self::Write(nr) | Self::Read(nr) => *nr,
}
}
}
Expand All @@ -65,19 +64,25 @@ impl Default for ProgPoint {
}
}

impl PartialOrd for ProgPoint {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl Ord for ProgPoint {
fn cmp(&self, other: &Self) -> Ordering {
match (self, other) {
(ProgPoint::Read(a), ProgPoint::Read(b)) => a.cmp(b),
(ProgPoint::Write(a), ProgPoint::Write(b)) => a.cmp(b),
(ProgPoint::Read(a), ProgPoint::Write(b)) => {
(Self::Read(a), Self::Read(b)) |
(Self::Write(a), Self::Write(b)) => a.cmp(b),
(Self::Read(a), Self::Write(b)) => {
if a <= b {
Ordering::Less
} else {
Ordering::Greater
}
}
(ProgPoint::Write(a), ProgPoint::Read(b)) => {
(Self::Write(a), Self::Read(b)) => {
if a < b {
Ordering::Less
} else {
Expand Down Expand Up @@ -114,8 +119,21 @@ impl Lifetime {
self.start <= pp && pp <= self.end
}

pub fn overlaps(&self, other: &Self) -> bool {
self.start <= other.end && other.start <= self.end
/// Returns true if the two lifetimes overlap.
///
/// Two lifetimes l, j overlap if the intersection of their ranges is not empty.
/// This is the case iff.
/// - l.start <= j.end
/// and
/// - j.start <= l.end
///
/// Overlaps are **symmetric**.
pub fn are_overlapping(l: &Self, j: &Self) -> bool {
l.start <= j.end && j.start <= l.end
}

pub fn overlaps_with(&self, other: &Self) -> bool {
Self::are_overlapping(self, other)
}
}

Expand All @@ -128,11 +146,105 @@ impl Display for Lifetime {
}
}

#[cfg(test)]
mod prog_point_tests {
use super::*;

#[test]
fn ord_should_be_correct() {
let inputs = [
(ProgPoint::Read(0), ProgPoint::Read(1), Ordering::Less),
(ProgPoint::Read(0), ProgPoint::Write(1), Ordering::Less),
(ProgPoint::Write(0), ProgPoint::Read(1), Ordering::Less),
(ProgPoint::Write(0), ProgPoint::Write(1), Ordering::Less),
(ProgPoint::Read(0), ProgPoint::Read(0), Ordering::Equal),
(ProgPoint::Write(0), ProgPoint::Write(0), Ordering::Equal),
(ProgPoint::Read(0), ProgPoint::Write(0), Ordering::Less),
(ProgPoint::Write(0), ProgPoint::Read(0), Ordering::Greater),
];

for (a, b, expected) in inputs {
assert_eq!(a.cmp(&b), expected);
}
}
}

#[cfg(test)]
mod lifetime_tests {
use crate::codegen::register_allocator::{Lifetime, ProgPoint};

#[test]
fn test1() {
todo!()
fn lifetimes_overlap() {
let inputs = [
// Both lifetimes are the same
(
Lifetime::new(0, ProgPoint::Read(2)),
Lifetime::new(0, ProgPoint::Read(2)),
true,
),
// The second lifetime is within the first one
(
Lifetime::new(0, ProgPoint::Read(2)),
Lifetime::new(0, ProgPoint::Read(1)),
true,
),
// The lifetimes do not overlap
(
Lifetime::new(0, ProgPoint::Read(1)),
Lifetime::new(2, ProgPoint::Read(3)),
false,
),
// The lifetimes overlap at one point
(
Lifetime::new(0, ProgPoint::Write(2)),
Lifetime::new(2, ProgPoint::Read(3)),
true,
),
// The lifetimes are the same but with different ProgPoints
(
Lifetime::new(0, ProgPoint::Write(2)),
Lifetime::new(0, ProgPoint::Read(2)),
true,
),
];
for (l1, l2, should_overlap) in inputs {
// Overlaps are symmetric
assert_eq!(l1.overlaps_with(&l2), should_overlap, "{:?} and {:?} should overlap: {}", l1, l2, should_overlap);
assert_eq!(l2.overlaps_with(&l1), should_overlap, "{:?} and {:?} should overlap: {}", l2, l1, should_overlap);
}
}

#[test]
fn lifetimes_contain() {
let inputs = [
// The lifetime contains the ProgPoint
(
Lifetime::new(0, ProgPoint::Read(2)),
ProgPoint::Read(1),
true,
),
// The lifetime does not contain the ProgPoint
(
Lifetime::new(0, ProgPoint::Read(1)),
ProgPoint::Write(2),
false,
),
// The lifetime contains the ProgPoint at the interval end
(
Lifetime::new(0, ProgPoint::Write(2)),
ProgPoint::Write(2),
true,
),
// The lifetime does not contain the ProgPoint at the interval start
(
Lifetime::new(0, ProgPoint::Read(1)),
ProgPoint::Write(0),
true,
),
];
for (lifetime, pp, should_contain) in inputs {
assert_eq!(lifetime.contains(pp), should_contain, "{:?} should contain {:?}: {}", lifetime, pp, should_contain);
}
}
}

Expand Down Expand Up @@ -427,7 +539,7 @@ impl<'liveness, 'func, A: Abi, RegAlloc: RegAllocAlgorithm<'liveness, A>> Regist
}
Some(tied_to) => {
debug!("{vreg} is tied to {tied_to}. Trying to put it in the same register");
assert!(!lifetime.overlaps(&self.liveness_repr.lifetime(tied_to, &self.func)), "Tied register {tied_to} overlaps with {vreg}");
assert!(!lifetime.overlaps_with(&self.liveness_repr.lifetime(tied_to, &self.func)), "Tied register {tied_to} overlaps with {vreg}");
let allocated_reg = self.allocations.get_allocated_reg(tied_to);
match allocated_reg {
None => {
Expand Down Expand Up @@ -487,12 +599,12 @@ impl<'liveness, 'func, A: Abi, RegAlloc: RegAllocAlgorithm<'liveness, A>> Regist
self.func.params.iter().map(|param| self.func.vregs[*param].size)
).collect_vec();
for (arg, slot) in self.func.params.iter().copied().zip(slots) {
match slot {
Slot::Register(reg) => {
self.func.vregs[arg].fixed = Some(reg);
}
Slot::Stack => unimplemented!()
}
match slot {
Slot::Register(reg) => {
self.func.vregs[arg].fixed = Some(reg);
}
Slot::Stack => unimplemented!()
}
}
}
}
Expand Down Expand Up @@ -597,7 +709,7 @@ impl<A: Abi> Function<A> {
// liveins.insert(bb_id, undeclared_reg);
// current_intervals.get_mut(&undeclared_reg).unwrap().start = entry_pp;
// }
//
//
// for (reg, interval) in current_intervals {
// repr.reg_lifetimes[reg].insert(interval);
// }
Expand Down
Loading

0 comments on commit be77b30

Please sign in to comment.