Skip to content

Commit

Permalink
FUCK YEAH, TEMPLATE INFERENCE! :tad
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Nov 19, 2024
1 parent 66f82e4 commit d200f24
Show file tree
Hide file tree
Showing 8 changed files with 234 additions and 97 deletions.
4 changes: 2 additions & 2 deletions src/flattening/typechecking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,8 +580,8 @@ pub fn apply_types(
);
}
for FailedUnification{mut found, mut expected, span, context} in type_checker.domain_substitutor.extract_errors() {
found.fully_substitute(&type_checker.domain_substitutor).unwrap();
expected.fully_substitute(&type_checker.domain_substitutor).unwrap();
assert!(found.fully_substitute(&type_checker.domain_substitutor));
assert!(expected.fully_substitute(&type_checker.domain_substitutor));

let expected_name = format!("{expected:?}");
let found_name = format!("{found:?}");
Expand Down
84 changes: 82 additions & 2 deletions src/instantiation/concrete_typecheck.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@

use crate::typing::template::ConcreteTemplateArg;
use std::ops::Deref;

use crate::flattening::{DeclarationPortInfo, WireReferenceRoot, WireSource, WrittenType};
use crate::linker::LinkInfo;
use crate::typing::template::{ConcreteTemplateArg, HowDoWeKnowTheTemplateArg};
use crate::typing::{
concrete_type::{ConcreteType, BOOL_CONCRETE_TYPE, INT_CONCRETE_TYPE},
delayed_constraint::{DelayedConstraint, DelayedConstraintStatus, DelayedConstraintsList},
Expand Down Expand Up @@ -103,7 +107,7 @@ impl<'fl, 'l> InstantiationContext<'fl, 'l> {

fn finalize(&mut self) {
for (_id, w) in &mut self.wires {
if let Err(()) = w.typ.fully_substitute(&self.type_substitutor) {
if w.typ.fully_substitute(&self.type_substitutor) == false {
let typ_as_str = w.typ.to_string(&self.linker.types);

let span = self.md.get_instruction_span(w.original_instruction);
Expand Down Expand Up @@ -146,8 +150,84 @@ struct SubmoduleTypecheckConstraint {
sm_id: SubModuleID
}

/// Part of Template Value Inference.
///
/// Specifically, for code like this:
///
/// ```sus
/// module add_all #(int Size) {
/// input int[Size] arr // We're targeting the 'Size' within the array size
/// output int total
/// }
/// ```
fn can_wire_can_be_value_inferred(link_info: &LinkInfo, flat_wire: FlatID) -> Option<TemplateID> {
let wire = link_info.instructions[flat_wire].unwrap_wire();
let WireSource::WireRef(wr) = &wire.source else {return None};
if !wr.path.is_empty() {return None} // Must be a plain, no fuss reference to a de
let WireReferenceRoot::LocalDecl(wire_declaration, _span) = &wr.root else {return None};
let template_arg_decl = link_info.instructions[*wire_declaration].unwrap_wire_declaration();
let DeclarationPortInfo::GenerativeInput(template_id) = &template_arg_decl.is_port else {return None};
Some(*template_id)
}

fn try_to_attach_value_to_template_arg(template_wire_referernce: FlatID, found_value: &ConcreteType, template_args: &mut ConcreteTemplateArgs, submodule_link_info: &LinkInfo) {
let ConcreteType::Value(v) = found_value else {return}; // We don't have a value to assign
if let Some(template_id) = can_wire_can_be_value_inferred(submodule_link_info, template_wire_referernce) {
if let ConcreteTemplateArg::NotProvided = &template_args[template_id] {
template_args[template_id] = ConcreteTemplateArg::Value(TypedValue::from_value(v.clone()), HowDoWeKnowTheTemplateArg::Inferred)
}
}
}

fn infer_parameters_by_walking_type(port_wr_typ: &WrittenType, connected_typ: &ConcreteType, template_args: &mut ConcreteTemplateArgs, submodule_link_info: &LinkInfo) {
match port_wr_typ {
WrittenType::Error(_) => {} // Can't continue, bad written type
WrittenType::Named(_) => {} // Seems we've run out of type to check
WrittenType::Array(_span, written_arr_box) => {
let ConcreteType::Array(concrete_arr_box) = connected_typ else {return}; // Can't continue, type not worked out. TODO should we seed concrete types with derivates from AbstractTypes?
let (written_arr, written_size_var, _) = written_arr_box.deref();
let (concrete_arr, concrete_size) = concrete_arr_box.deref();

infer_parameters_by_walking_type(written_arr, concrete_arr, template_args, submodule_link_info); // Recurse down

try_to_attach_value_to_template_arg(*written_size_var, concrete_size, template_args, submodule_link_info); // Potential place for template inference!
}
WrittenType::TemplateVariable(_span, template_id) => {
if !connected_typ.contains_unknown() {
if let ConcreteTemplateArg::NotProvided = &template_args[*template_id] {
template_args[*template_id] = ConcreteTemplateArg::Type(connected_typ.clone(), HowDoWeKnowTheTemplateArg::Inferred)
}
}
}
}
}

impl SubmoduleTypecheckConstraint {
fn try_infer_parameters(&mut self, context: &mut InstantiationContext) {
let sm = &mut context.submodules[self.sm_id];

let sub_module = &context.linker.modules[sm.module_uuid];

for (id, p) in sm.port_map.iter_valids() {
let wire = &context.wires[p.maps_to_wire];

let mut wire_typ_clone = wire.typ.clone();
wire_typ_clone.fully_substitute(&context.type_substitutor);

let port_decl_instr = sub_module.ports[id].declaration_instruction;
let port_decl = sub_module.link_info.instructions[port_decl_instr].unwrap_wire_declaration();

infer_parameters_by_walking_type(&port_decl.typ_expr, &wire_typ_clone, &mut sm.template_args, &sub_module.link_info);
}
}

}

impl DelayedConstraint<InstantiationContext<'_, '_>> for SubmoduleTypecheckConstraint {
fn try_apply(&mut self, context : &mut InstantiationContext) -> DelayedConstraintStatus {
// Try to infer template arguments based on the connections to the ports of the module
self.try_infer_parameters(context);

let sm = &context.submodules[self.sm_id];

let submod_instr = context.md.link_info.instructions[sm.original_instruction].unwrap_submodule();
Expand Down
4 changes: 2 additions & 2 deletions src/typing/abstract_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,12 @@ impl TypeUnifier {

pub fn finalize_domain_type(&mut self, typ_domain: &mut DomainType) {
use super::type_inference::HindleyMilner;
typ_domain.fully_substitute(&self.domain_substitutor).unwrap();
assert!(typ_domain.fully_substitute(&self.domain_substitutor) == true);
}

pub fn finalize_abstract_type(&mut self, types: &ArenaAllocator<StructType, TypeUUIDMarker>, typ: &mut AbstractType, span: Span, errors: &ErrorCollector) {
use super::type_inference::HindleyMilner;
if typ.fully_substitute(&self.type_substitutor).is_err() {
if typ.fully_substitute(&self.type_substitutor) == false {
let typ_as_string = typ.to_string(types, &self.template_type_names);
errors.error(span, format!("Could not fully figure out the type of this object. {typ_as_string}"));
}
Expand Down
11 changes: 11 additions & 0 deletions src/typing/concrete_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,15 @@ impl ConcreteType {
let (sub, _sz) = arr_box.deref();
sub
}
pub fn contains_unknown(&self) -> bool {
match self {
ConcreteType::Named(_) => false,
ConcreteType::Value(_) => false,
ConcreteType::Array(arr_box) => {
let (arr_arr, arr_size) = arr_box.deref();
arr_arr.contains_unknown() || arr_size.contains_unknown()
}
ConcreteType::Unknown(_) => true,
}
}
}
29 changes: 15 additions & 14 deletions src/typing/type_inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ pub trait HindleyMilner<VariableIDMarker: UUIDMarker> : Sized {
/// This is never called by the user, only by [TypeSubstitutor::unify]
fn unify_all_args<F : FnMut(&Self, &Self) -> bool>(left : &Self, right : &Self, unify : &mut F) -> bool;

/// Has to be implemented per
/// Has to be implemented separately per type
///
/// Returns Ok(()) when no variables remain
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, VariableIDMarker>) -> Result<(), ()>;
/// Returns true when no Unknowns remain
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, VariableIDMarker>) -> bool;
}


Expand Down Expand Up @@ -205,15 +205,15 @@ impl HindleyMilner<TypeVariableIDMarker> for AbstractType {
}
}

fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, TypeVariableIDMarker>) -> Result<(), ()> {
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, TypeVariableIDMarker>) -> bool {
match self {
AbstractType::Template(_) => Ok(()), // Template Name is included in get_hm_info
AbstractType::Named(_) => Ok(()), // Name is included in get_hm_info
AbstractType::Named(_) | AbstractType::Template(_) => true, // Template Name & Name is included in get_hm_info
AbstractType::Array(arr_typ) => {
arr_typ.fully_substitute(substitutor)
},
AbstractType::Unknown(var) => {
*self = substitutor.substitution_map[var.get_hidden_value()].get().ok_or(())?.clone();
let Some(replacement) = substitutor.substitution_map[var.get_hidden_value()].get() else {return false};
*self = replacement.clone();
self.fully_substitute(substitutor)
}
}
Expand All @@ -236,10 +236,10 @@ impl HindleyMilner<DomainVariableIDMarker> for DomainType {
true
}

/// For domains, always returns Ok(()). Or rather it should, since any leftover unconnected domains should be assigned an ID of their own by the type checker
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, DomainVariableIDMarker>) -> Result<(), ()> {
/// For domains, always returns true. Or rather it should, since any leftover unconnected domains should be assigned an ID of their own by the type checker
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, DomainVariableIDMarker>) -> bool {
match self {
DomainType::Generative | DomainType::Physical(_) => Ok(()), // Do nothing, These are done already
DomainType::Generative | DomainType::Physical(_) => true, // Do nothing, These are done already
DomainType::DomainVariable(var) => {
*self = substitutor.substitution_map[var.get_hidden_value()].get().expect("It's impossible for domain variables to remain, as any unset domain variable would have been replaced with a new physical domain").clone();
self.fully_substitute(substitutor)
Expand Down Expand Up @@ -282,16 +282,17 @@ impl HindleyMilner<ConcreteTypeVariableIDMarker> for ConcreteType {
}
}

fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, ConcreteTypeVariableIDMarker>) -> Result<(), ()> {
fn fully_substitute(&mut self, substitutor: &TypeSubstitutor<Self, ConcreteTypeVariableIDMarker>) -> bool {
match self {
ConcreteType::Named(_) | ConcreteType::Value(_) => Ok(()), // Already included in get_hm_info
ConcreteType::Named(_) | ConcreteType::Value(_) => true, // Don't need to do anything, this is already final
ConcreteType::Array(arr_typ) => {
let (arr_typ, arr_sz) = arr_typ.deref_mut();
arr_typ.fully_substitute(substitutor)?;
arr_typ.fully_substitute(substitutor) &&
arr_sz.fully_substitute(substitutor)
},
ConcreteType::Unknown(var) => {
*self = substitutor.substitution_map[var.get_hidden_value()].get().ok_or(())?.clone();
let Some(replacement) = substitutor.substitution_map[var.get_hidden_value()].get() else {return false};
*self = replacement.clone();
self.fully_substitute(substitutor)
}
}
Expand Down
8 changes: 7 additions & 1 deletion test.sus
Original file line number Diff line number Diff line change
Expand Up @@ -847,8 +847,14 @@ module replicate #(T, int NUM_REPLS) {
}

module use_replicate {
replicate #(NUM_REPLS: 50, NUM_REPLS: 30, T: type bool) a
//replicate #(NUM_REPLS: 50, NUM_REPLS: 30, T: type bool) a
replicate #(NUM_REPLS: 20, T: type int[30]) b
replicate c

int val = 3

c.data = val
int[30] out = c.result
}

module permute_t #(T, int SIZE, int[SIZE] SOURCES) {
Expand Down
Loading

0 comments on commit d200f24

Please sign in to comment.