Skip to content

Commit

Permalink
Refactor of how interface ports are handled
Browse files Browse the repository at this point in the history
  • Loading branch information
VonTum committed Jan 3, 2024
1 parent 2d8f728 commit 68c915b
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 107 deletions.
4 changes: 2 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ pub struct Module {
impl Module {
pub fn print_flattened_module(&self, linker : &Linker) {
println!("Interface:");
for port in &self.interface.interface_wires {
let port_direction = if port.is_input {"input"} else {"output"};
for (port_idx, port) in self.interface.interface_wires.iter().enumerate() {
let port_direction = if port_idx < self.interface.outputs_start {"input"} else {"output"};
let port_type = port.typ.to_string(linker);
let port_name = &port.port_name;
println!(" {port_direction} {port_type} {port_name} -> {:?}", port.wire_id);
Expand Down
12 changes: 6 additions & 6 deletions src/codegen_fallback.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{iter::zip, ops::Deref};

use crate::{ast::{Module, IdentifierType}, instantiation::{InstantiatedModule, RealWireDataSource, StateInitialValue, ConnectToPathElem}, linker::{NamedUUID, get_builtin_uuid}, typing::ConcreteType, tokenizer::get_token_type_name, flattening::{Instantiation, WireDeclaration}, value::Value};
use crate::{ast::{Module, IdentifierType}, instantiation::{InstantiatedModule, RealWireDataSource, StateInitialValue, ConnectToPathElem}, linker::{NamedUUID, get_builtin_uuid}, typing::ConcreteType, tokenizer::get_token_type_name, flattening::Instantiation, value::Value};

fn get_type_name_size(id : NamedUUID) -> u64 {
if id == get_builtin_uuid("int") {
Expand Down Expand Up @@ -56,9 +56,9 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String
assert!(!instance.errors.did_error(), "Module cannot have experienced an error");
let mut program_text : String = format!("module {}(\n\tinput clk, \n", md.link_info.name);
let submodule_interface = instance.interface.as_ref().unwrap();
for (port, real_port) in zip(&md.interface.interface_wires, submodule_interface) {
let wire = &instance.wires[real_port.id];
program_text.push_str(if port.is_input {"\tinput"} else {"\toutput /*mux_wire*/ reg"});
for (port_idx, (port, real_port)) in zip(md.interface.interface_wires.iter(), submodule_interface).enumerate() {
let wire = &instance.wires[*real_port];
program_text.push_str(if port_idx < md.interface.outputs_start {"\tinput"} else {"\toutput /*mux_wire*/ reg"});
program_text.push_str(&typ_to_verilog_array(&wire.typ));
program_text.push(' ');
program_text.push_str(&wire.name);
Expand Down Expand Up @@ -117,7 +117,7 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String
let Some(sm_interface) = &sm.instance.interface else {unreachable!()}; // Having an invalid interface in a submodule is an error! This should have been caught before!
for (port, wire) in zip(sm_interface, &sm.wires) {
program_text.push_str(",\n.");
program_text.push_str(&sm.instance.wires[port.id].name);
program_text.push_str(&sm.instance.wires[*port].name);
program_text.push('(');
program_text.push_str(&instance.wires[*wire].name);
program_text.push_str(")");
Expand All @@ -132,7 +132,7 @@ pub fn gen_verilog_code(md : &Module, instance : &InstantiatedModule) -> String
let output_name = w.name.deref();
match is_state {
StateInitialValue::Combinatorial => {
program_text.push_str(&format!("/*always_comb*/ always @(*) begin\n\t{output_name} <= 1'bX; // Not defined when not valid\n"));
program_text.push_str(&format!("/*always_comb*/ always @(*) begin\n\t{output_name} <= 1'bX; // Combinatorial wires are not defined when not valid\n"));
}
StateInitialValue::State{initial_value : _} => {
program_text.push_str(&format!("/*always_ff*/ always @(posedge clk) begin\n"));
Expand Down
144 changes: 72 additions & 72 deletions src/flattening.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,22 +45,27 @@ pub struct Connection {
pub condition : Option<FlatID>
}

#[derive(Debug,Clone,Copy)]
pub struct InterfacePort {
pub is_input : bool,
pub id : FlatID
}

#[derive(Debug)]
pub enum WireSource {
WireRead{from_wire : FlatID}, // Used to add a span to the reference of a wire.
//SubModuleOutput{submodule : FlatID, port_idx : usize},
UnaryOp{op : Operator, right : FlatID},
BinaryOp{op : Operator, left : FlatID, right : FlatID},
ArrayAccess{arr : FlatID, arr_idx : FlatID},
Constant{value : Value},
}

impl WireSource {
pub fn for_each_input_wire<F : FnMut(FlatID)>(&self, func : &mut F) {
match self {
&WireSource::WireRead { from_wire } => {func(from_wire)}
&WireSource::UnaryOp { op:_, right } => {func(right)}
&WireSource::BinaryOp { op:_, left, right } => {func(left); func(right)}
&WireSource::ArrayAccess { arr, arr_idx } => {func(arr); func(arr_idx)}
WireSource::Constant { value:_ } => {}
}
}
}

#[derive(Debug)]
pub struct WireInstance {
pub typ : Type,
Expand All @@ -85,9 +90,18 @@ impl WireDeclaration {
}
}

#[derive(Debug)]
pub struct SubModuleInstance {
pub module_uuid : NamedUUID,
pub name : Box<str>,
pub typ_span : Span,
pub outputs_start : usize,
pub local_wires : Box<[FlatID]>
}

#[derive(Debug)]
pub enum Instantiation {
SubModule{module_uuid : NamedUUID, name : Box<str>, typ_span : Span, interface_wires : Vec<InterfacePort>},
SubModule(SubModuleInstance),
WireDeclaration(WireDeclaration),
Wire(WireInstance),
Connection(Connection),
Expand Down Expand Up @@ -174,26 +188,20 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
Some(())
}
fn alloc_module_interface(&self, name : Box<str>, module : &Module, module_uuid : NamedUUID, typ_span : Span) -> Instantiation {
let interface_wires = module.interface.interface_wires.iter().enumerate().map(|(port_idx, port)| {
let identifier_type = if port.is_input {
IdentifierType::Input
} else {
IdentifierType::Output
};
let id = self.instantiations.alloc(Instantiation::WireDeclaration(WireDeclaration{
let local_wires : Vec<FlatID> = module.interface.interface_wires.iter().enumerate().map(|(port_idx, port)| {
self.instantiations.alloc(Instantiation::WireDeclaration(WireDeclaration{
typ: port.typ.clone(),
typ_span,
read_only : !port.is_input,
identifier_type,
name : format!("{}_{}", &module.link_info.name, &port.port_name).into_boxed_str(),
read_only : port_idx >= module.interface.outputs_start,
identifier_type : IdentifierType::Local,
name : format!("{}_{}", &name, &port.port_name).into_boxed_str(),
name_token : None
}));
InterfacePort{is_input : port.is_input, id}
}))
}).collect();

Instantiation::SubModule{name, module_uuid, typ_span, interface_wires}
Instantiation::SubModule(SubModuleInstance{name, module_uuid, typ_span, outputs_start : module.interface.outputs_start, local_wires : local_wires.into_boxed_slice()})
}
fn desugar_func_call(&self, func_and_args : &[SpanExpression], closing_bracket_pos : usize, condition : Option<FlatID>) -> Option<(&Module, &[InterfacePort])> {
fn desugar_func_call(&self, func_and_args : &[SpanExpression], closing_bracket_pos : usize, condition : Option<FlatID>) -> Option<(&Module, &[FlatID])> {
let (name_expr, name_expr_span) = &func_and_args[0]; // Function name is always there
let func_instantiation_id = match name_expr {
Expression::Named(LocalOrGlobal::Local(l)) => {
Expand All @@ -212,9 +220,9 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
}
};
let func_instantiation = &self.instantiations[func_instantiation_id];
let Instantiation::SubModule{module_uuid, name : _, typ_span : _, interface_wires} = func_instantiation else {unreachable!("It should be proven {func_instantiation:?} was a Module!");};
let Instantiation::SubModule(SubModuleInstance{module_uuid, name : _, typ_span : _, outputs_start:_, local_wires}) = func_instantiation else {unreachable!("It should be proven {func_instantiation:?} was a Module!");};
let Named::Module(md) = &self.linker.links.globals[*module_uuid] else {unreachable!("UUID Should be a module!");};
let (inputs, output_range) = md.interface.get_function_sugar_inputs_outputs();
let (inputs, output_range) = md.interface.func_call_syntax_interface();

let mut args = &func_and_args[1..];

Expand All @@ -239,12 +247,12 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
if self.typecheck(arg_read_side, &md.interface.interface_wires[field].typ, "submodule output") == None {
continue;
}
let func_input_port = &interface_wires[field];
self.create_connection(Connection { num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(func_input_port.id, *name_expr_span), condition });
let func_input_port = &local_wires[field];
self.create_connection(Connection { num_regs: 0, from: arg_read_side, to: ConnectionWrite::simple(*func_input_port, *name_expr_span), condition });
}
}

Some((md, &interface_wires[output_range]))
Some((md, &local_wires[output_range]))
}
fn flatten_single_expr(&self, (expr, expr_span) : &SpanExpression, condition : Option<FlatID>) -> Option<FlatID> {
let span = *expr_span; // for more compact constructors
Expand Down Expand Up @@ -302,7 +310,7 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
return None;
}

outputs[0].id
outputs[0]
}
};
Some(single_connection_side)
Expand Down Expand Up @@ -430,8 +438,8 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
let Some(write_side) = self.flatten_assignable_expr(&to_i.expr, condition) else {return;};

// temporary
let module_port_wire_decl = self.instantiations[field.id].extract_wire_declaration();
let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, inst : WireSource::WireRead { from_wire: field.id }}));
let module_port_wire_decl = self.instantiations[*field].extract_wire_declaration();
let module_port_proxy = self.instantiations.alloc(Instantiation::Wire(WireInstance{typ : module_port_wire_decl.typ.clone(), is_compiletime : module_port_wire_decl.identifier_type == IdentifierType::Generative, span : *func_span, inst : WireSource::WireRead { from_wire: *field }}));
self.create_connection(Connection{num_regs : to_i.num_regs, from: module_port_proxy, to: write_side, condition});
}
},
Expand All @@ -457,42 +465,30 @@ impl<'l, 'm, 'fl> FlatteningContext<'l, 'm, 'fl> {
#[derive(Debug)]
pub struct FlattenedInterfacePort {
pub wire_id : FlatID,
pub is_input : bool,
pub typ : Type,
pub port_name : Box<str>,
pub span : Span
}

#[derive(Debug, Default)]
pub struct FlattenedInterface {
pub interface_wires : Vec<FlattenedInterfacePort>, // Indexed by FieldID
pub interface_wires : Box<[FlattenedInterfacePort]>, // Ordered such that all inputs come first, then all outputs
pub outputs_start : usize
}

impl FlattenedInterface {
pub fn new() -> Self {
FlattenedInterface { interface_wires: Vec::new() }
FlattenedInterface { interface_wires: Box::new([]), outputs_start : 0 }
}
pub fn get_function_sugar_inputs_outputs(&self) -> (Range<FieldID>, Range<FieldID>) {
let mut last_output = self.interface_wires.len() - 1;

while last_output > 0 {
last_output -= 1;
if self.interface_wires[last_output].is_input {
last_output += 1;
break;
}
}

let mut last_input = last_output - 1;
while last_input > 0 {
last_input -= 1;
if !self.interface_wires[last_input].is_input {
last_input += 1;
break;
}
}

(last_input..last_output, last_output..self.interface_wires.len())
// Todo, just treat all inputs and outputs as function call interface
pub fn func_call_syntax_interface(&self) -> (Range<FieldID>, Range<FieldID>) {
(0..self.outputs_start, self.outputs_start..self.interface_wires.len())
}
pub fn inputs(&self) -> &[FlattenedInterfacePort] {
&self.interface_wires[..self.outputs_start]
}
pub fn outputs(&self) -> &[FlattenedInterfacePort] {
&self.interface_wires[self.outputs_start..]
}
}

Expand All @@ -515,8 +511,6 @@ impl FlattenedModule {
Must be further processed by flatten, but this requires all modules to have been Initial Flattened for dependency resolution
*/
pub fn initialize_interfaces(linker : &Linker, module : &Module) -> (FlattenedInterface, FlattenedModule, FlatAlloc<Option<FlatID>, DeclIDMarker>) {
let mut interface = FlattenedInterface::new();

let flat_mod = FlattenedModule {
instantiations: ListAllocator::new(),
errors: ErrorCollector::new(module.link_info.file)
Expand All @@ -530,6 +524,8 @@ impl FlattenedModule {
module,
};

let mut inputs = Vec::new();
let mut outputs = Vec::new();
for (decl_id, decl) in &module.declarations {
let is_input = match decl.identifier_type {
IdentifierType::Input => true,
Expand All @@ -546,11 +542,22 @@ impl FlattenedModule {
name : decl.name.clone(),
name_token : Some(decl.name_token)
}));

let port = FlattenedInterfacePort { wire_id, typ, port_name: decl.name.clone(), span: decl.span };
if is_input {
inputs.push(port);
} else {
outputs.push(port);
}

interface.interface_wires.push(FlattenedInterfacePort { wire_id, is_input, typ, port_name: decl.name.clone(), span: decl.span });
context.decl_to_flat_map[decl_id] = Some(wire_id);
}

let outputs_start = inputs.len();
inputs.reserve(outputs.len());
inputs.append(&mut outputs);
let interface = FlattenedInterface{interface_wires: inputs.into_boxed_slice(), outputs_start};

let decl_to_flat_map = context.decl_to_flat_map;
(interface, flat_mod, decl_to_flat_map)
}
Expand Down Expand Up @@ -589,11 +596,9 @@ impl FlattenedModule {

let mut wire_to_explore_queue : Vec<FlatID> = Vec::new();

for port in &md.interface.interface_wires {
if !port.is_input {
is_instance_used_map[port.wire_id] = true;
wire_to_explore_queue.push(port.wire_id);
}
for port in md.interface.outputs() {
is_instance_used_map[port.wire_id] = true;
wire_to_explore_queue.push(port.wire_id);
}

println!("Pre Explore");
Expand All @@ -611,17 +616,12 @@ impl FlattenedModule {
match &self.instantiations[item] {
Instantiation::WireDeclaration(_) => {}
Instantiation::Wire(wire) => {
match &wire.inst {
WireSource::WireRead{from_wire} => {
func(*from_wire);
}
_other => {}
}
wire.inst.for_each_input_wire(&mut func);
}
Instantiation::SubModule{module_uuid : _, name : _, typ_span : _, interface_wires} => {
for port in interface_wires {
if port.is_input {
func(port.id);
Instantiation::SubModule(submodule) => {
for (port_id, port) in submodule.local_wires.iter().enumerate() {
if port_id < submodule.outputs_start {
func(*port);
}
}
}
Expand Down
Loading

0 comments on commit 68c915b

Please sign in to comment.