Skip to content

Commit

Permalink
Add HM for Concrete Typing
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Nov 13, 2024
1 parent cd4f5bc commit a3716d0
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 203 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ In this example, we create a memory block with a read port and a write port. Thi
### Typing & Inference
- [x] Hindley-Milner typing for Abstract Types
- [x] Hindley-Milner typing for Domain Types
- [ ] Hindley-Milner typing for Concrete Types
- [x] Hindley-Milner typing for Concrete Types
- [ ] Template Type Inference
- [ ] Generative Parameter Inference
- [ ] Latency Count Inference
Expand Down
125 changes: 62 additions & 63 deletions src/instantiation/concrete_typecheck.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::typing::{concrete_type::ConcreteType, type_inference::FailedUnification};
use crate::typing::{concrete_type::{ConcreteType, BOOL_CONCRETE_TYPE, INT_CONCRETE_TYPE}, type_inference::FailedUnification};

use super::*;

Expand All @@ -7,69 +7,83 @@ use crate::typing::type_inference::HindleyMilner;
impl<'fl, 'l> InstantiationContext<'fl, 'l> {
fn walk_type_along_path(
&self,
mut cur_typ: ConcreteType,
path: &[RealWirePathElem],
mut current_type_in_progress: ConcreteType,
path: &[RealWirePathElem]
) -> ConcreteType {
for p in path {
let typ_after_applying_array = ConcreteType::Unknown(self.type_substitutor.alloc());
match p {
RealWirePathElem::ArrayAccess {
span: _,
idx_wire: _,
} => {
cur_typ = cur_typ.down_array().clone();
RealWirePathElem::ArrayAccess {span: _, idx_wire: _} => { // TODO #28 integer size <-> array bound check
let arr_size = ConcreteType::Unknown(self.type_substitutor.alloc());
let arr_box = Box::new((typ_after_applying_array.clone(), arr_size));
self.type_substitutor.unify_must_succeed(&current_type_in_progress, &ConcreteType::Array(arr_box));
current_type_in_progress = typ_after_applying_array;
}
}
}

cur_typ
current_type_in_progress
}

pub fn typecheck(&mut self) {
fn make_array_of(&self, concrete_typ: ConcreteType) -> ConcreteType {
ConcreteType::Array(Box::new((concrete_typ, ConcreteType::Unknown(self.type_substitutor.alloc()))))
}

fn typecheck_all_wires(&self) {
for this_wire_id in self.wires.id_range() {
let this_wire = &self.wires[this_wire_id];
let span = self.md.get_instruction_span(this_wire.original_instruction);
span.debug();

match &this_wire.source {
RealWireDataSource::ReadOnly => {}
RealWireDataSource::Multiplexer {
is_state: _,
sources: _,
} => {} // Do muxes later.
RealWireDataSource::Multiplexer { is_state, sources } => {
if let Some(is_state) = is_state {
assert!(is_state.is_of_type(&this_wire.typ));
}
for s in sources {
let source_typ = &self.wires[s.from.from].typ;
let destination_typ = self.walk_type_along_path(self.wires[this_wire_id].typ.clone(), &s.to_path);
self.type_substitutor.unify_report_error(&destination_typ, &source_typ, span, "write wire access");
}
}
&RealWireDataSource::UnaryOp { op, right } => {
let right_typ = self.wires[right].typ.clone();
self.wires[this_wire_id]
.typ
.typecheck_concrete_unary_operator(
op,
&right_typ,
span,
&self.linker.types,
&self.errors,
);
// TODO overloading
let (input_typ, output_typ) = match op {
UnaryOperator::Not => (BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE),
UnaryOperator::Negate => (INT_CONCRETE_TYPE, INT_CONCRETE_TYPE),
UnaryOperator::And | UnaryOperator::Or | UnaryOperator::Xor => (self.make_array_of(BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
UnaryOperator::Sum | UnaryOperator::Product => (self.make_array_of(INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
};

self.type_substitutor.unify_report_error(&self.wires[right].typ, &input_typ, span, "unary input");
self.type_substitutor.unify_report_error(&self.wires[this_wire_id].typ, &output_typ, span, "unary output");
}
&RealWireDataSource::BinaryOp { op, left, right } => {
let left_typ = self.wires[left].typ.clone();
let right_typ = self.wires[right].typ.clone();
self.wires[this_wire_id]
.typ
.typecheck_concrete_binary_operator(
op,
&left_typ,
&right_typ,
span,
&self.linker.types,
&self.errors,
);
// TODO overloading
let ((in_left, in_right), out) = match op {
BinaryOperator::And => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Or => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Xor => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Add => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Subtract => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Multiply => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Divide => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Modulo => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Equals => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::NotEquals => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::GreaterEq => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Greater => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::LesserEq => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Lesser => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
};
self.type_substitutor.unify_report_error(&self.wires[this_wire_id].typ, &out, span, "binary output");
self.type_substitutor.unify_report_error(&self.wires[left].typ, &in_left, span, "binary left");
self.type_substitutor.unify_report_error(&self.wires[right].typ, &in_right, span, "binary right");
}
RealWireDataSource::Select { root, path } => {
let found_typ = self.walk_type_along_path(self.wires[*root].typ.clone(), path);
self.wires[this_wire_id].typ.check_or_update_type(
&found_typ,
span,
&self.linker.types,
&self.errors,
);
self.type_substitutor.unify_report_error(&found_typ, &self.wires[this_wire_id].typ, span, "wire access");
}
RealWireDataSource::Constant { value } => {
assert!(
Expand All @@ -79,27 +93,6 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
}
};
}

// Do typechecking of Multiplexers afterwards, because typechecker isn't so smart right now.
for this_wire_id in self.wires.id_range() {
let this_wire = &self.wires[this_wire_id];
let span = self.md.get_instruction_span(this_wire.original_instruction);
span.debug();

if let RealWireDataSource::Multiplexer { is_state, sources } = &this_wire.source {
if let Some(is_state) = is_state {
assert!(is_state.is_of_type(&this_wire.typ));
}
for s in sources {
let source_typ = &self.wires[s.from.from].typ;
let destination_typ =
self.walk_type_along_path(self.wires[this_wire_id].typ.clone(), &s.to_path);
destination_typ.check_type(&source_typ, span, &self.linker.types, &self.errors);
}
};
}

self.finalize();
}

fn finalize(&mut self) {
Expand Down Expand Up @@ -128,4 +121,10 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {
);
}
}

pub fn typecheck(&mut self) {
self.typecheck_all_wires();

self.finalize();
}
}
142 changes: 4 additions & 138 deletions src/typing/concrete_type.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
use crate::flattening::StructType;
use crate::prelude::*;

use std::ops::{Deref, Index};
use std::ops::Deref;

use crate::linker::get_builtin_type;
use crate::{
flattening::{BinaryOperator, UnaryOperator},
value::Value,
};
use crate::
value::Value
;

use super::type_inference::ConcreteTypeVariableID;

Expand All @@ -22,63 +20,6 @@ pub enum ConcreteType {
Unknown(ConcreteTypeVariableID)
}

/// Panics on Type Errors that should have been caught by [AbstractType]
///
/// TODO Add checks for array sizes being equal etc.
pub fn get_unary_operator_expected_output(
op: UnaryOperator,
input_typ: &ConcreteType,
) -> ConcreteType {
let gather_type = match op {
UnaryOperator::Not => {
assert_eq!(*input_typ, BOOL_CONCRETE_TYPE);
return BOOL_CONCRETE_TYPE;
}
UnaryOperator::Negate => {
assert_eq!(*input_typ, INT_CONCRETE_TYPE);
return INT_CONCRETE_TYPE;
}
UnaryOperator::And => BOOL_CONCRETE_TYPE,
UnaryOperator::Or => BOOL_CONCRETE_TYPE,
UnaryOperator::Xor => BOOL_CONCRETE_TYPE,
UnaryOperator::Sum => INT_CONCRETE_TYPE,
UnaryOperator::Product => INT_CONCRETE_TYPE,
};
assert_eq!(input_typ.down_array(), &gather_type);
gather_type
}

/// Panics on Type Errors that should have been caught by [AbstractType]
///
/// TODO Add checks for array sizes being equal etc.
pub fn get_binary_operator_expected_output(
op: BinaryOperator,
left_typ: &ConcreteType,
right_typ: &ConcreteType,
) -> ConcreteType {
let ((in_left, in_right), out) = match op {
BinaryOperator::And => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Or => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Xor => ((BOOL_CONCRETE_TYPE, BOOL_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Add => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Subtract => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Multiply => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Divide => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Modulo => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), INT_CONCRETE_TYPE),
BinaryOperator::Equals => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::NotEquals => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::GreaterEq => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Greater => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::LesserEq => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
BinaryOperator::Lesser => ((INT_CONCRETE_TYPE, INT_CONCRETE_TYPE), BOOL_CONCRETE_TYPE),
};

assert_eq!(*left_typ, in_left);
assert_eq!(*right_typ, in_right);

out
}

impl ConcreteType {
#[track_caller]
pub fn unwrap_value(&self) -> &Value {
Expand All @@ -94,79 +35,4 @@ impl ConcreteType {
let (sub, _sz) = arr_box.deref();
sub
}

pub fn type_compare(&self, found: &ConcreteType) -> bool {
match (self, found) {
(ConcreteType::Named(exp), ConcreteType::Named(fnd)) => exp == fnd,
(ConcreteType::Array(exp), ConcreteType::Array(fnd)) => {
let (target_arr_typ, target_arr_size) = exp.deref();
let (found_arr_typ, found_arr_size) = fnd.deref();
target_arr_typ.type_compare(found_arr_typ)
&& target_arr_size.type_compare(found_arr_size)
}
(ConcreteType::Value(lv), ConcreteType::Value(rv)) => lv == rv,
(ConcreteType::Unknown(_), _) | (_, ConcreteType::Unknown(_)) => {
todo!("Type Unification {self:?} {found:?}")
}
_ => false,
}
}
pub fn check_type<TypVec: Index<TypeUUID, Output = StructType>>(
&self,
source_type: &ConcreteType,
span: Span,
linker_types: &TypVec,
errors: &ErrorCollector,
) {
if !self.type_compare(source_type) {
errors.error(
span,
format!(
"Concrete Type Error! Expected {} but found {}",
self.to_string(linker_types),
source_type.to_string(linker_types)
),
);
}
}

pub fn check_or_update_type<TypVec: Index<TypeUUID, Output = StructType>>(
&mut self,
source_type: &ConcreteType,
span: Span,
linker_types: &TypVec,
errors: &ErrorCollector,
) {
if let ConcreteType::Unknown(_) = self {
*self = source_type.clone();
} else {
self.check_type(source_type, span, linker_types, errors);
}
}

pub fn typecheck_concrete_unary_operator<TypVec: Index<TypeUUID, Output = StructType>>(
&mut self,
op: UnaryOperator,
input_typ: &ConcreteType,
span: Span,
linker_types: &TypVec,
errors: &ErrorCollector,
) {
let expected = get_unary_operator_expected_output(op, input_typ);

self.check_or_update_type(&expected, span, linker_types, errors);
}
pub fn typecheck_concrete_binary_operator<TypVec: Index<TypeUUID, Output = StructType>>(
&mut self,
op: BinaryOperator,
left_typ: &ConcreteType,
right_typ: &ConcreteType,
span: Span,
linker_types: &TypVec,
errors: &ErrorCollector,
) {
let expected = get_binary_operator_expected_output(op, left_typ, right_typ);

self.check_or_update_type(&expected, span, linker_types, errors);
}
}
3 changes: 2 additions & 1 deletion src/typing/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ impl HindleyMilner<ConcreteTypeVariableIDMarker> for ConcreteType {

fn unify_all_args<F : FnMut(&Self, &Self) -> bool>(left : &Self, right : &Self, unify : &mut F) -> bool {
match (left, right) {
(ConcreteType::Named(na), ConcreteType::Named(nb)) => {assert!(*na == *nb); true}, // Already covered by get_hm_info
(ConcreteType::Named(na), ConcreteType::Named(nb)) => {assert!(*na == *nb); true} // Already covered by get_hm_info
(ConcreteType::Value(v_1), ConcreteType::Value(v_2)) => {assert!(*v_1 == *v_2); true} // Already covered by get_hm_info
(ConcreteType::Array(arr_typ_1), ConcreteType::Array(arr_typ_2)) => {
let (arr_typ_1_arr, arr_typ_1_sz) = arr_typ_1.deref();
let (arr_typ_2_arr, arr_typ_2_sz) = arr_typ_2.deref();
Expand Down

0 comments on commit a3716d0

Please sign in to comment.