Skip to content

Commit

Permalink
support regions in canonical format
Browse files Browse the repository at this point in the history
  • Loading branch information
vaivaswatha committed Jan 2, 2025
1 parent 4770b78 commit 7fd6eea
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 46 deletions.
96 changes: 60 additions & 36 deletions pliron-derive/src/derive_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,17 @@ fn derive_from_parsed(input: FmtInput, irobj: DeriveIRObject) -> Result<TokenStr
Ok(format_tokens)
}

#[derive(Default)]
struct OpPrinterState {
is_canonical: bool,
}

/// Generate token stream for derived [Printable](::pliron::printable::Printable) trait.
trait PrintableBuilder {
trait PrintableBuilder<State: Default> {
// Entry function. Builds the outer function outline.
fn build(input: &FmtInput) -> Result<TokenStream> {
let name = input.ident.clone();
let body = Self::build_body(input)?;
let body = Self::build_body(input, &mut State::default())?;

let derived = quote! {
impl ::pliron::printable::Printable for #name {
Expand All @@ -113,45 +118,47 @@ trait PrintableBuilder {
}

// Build the body of the outer function Printable::fmt.
fn build_body(input: &FmtInput) -> Result<TokenStream> {
Self::build_format(input)
fn build_body(input: &FmtInput, state: &mut State) -> Result<TokenStream> {
Self::build_format(input, state)
}

fn build_lit(_input: &FmtInput, lit: &str) -> TokenStream {
fn build_lit(_input: &FmtInput, _state: &mut State, lit: &str) -> TokenStream {
quote! { ::pliron::printable::Printable::fmt(&#lit, ctx, state, fmt)?; }
}

fn build_format(input: &FmtInput) -> Result<TokenStream> {
fn build_format(input: &FmtInput, state: &mut State) -> Result<TokenStream> {
let derived_format = input
.format
.elems
.iter()
.map(|elem| Self::build_elem(input, elem))
.map(|elem| Self::build_elem(input, state, elem))
.try_fold(TokenStream::new(), |mut acc, e| {
acc.extend(e?);
Ok(acc)
});
derived_format
}

fn build_elem(input: &FmtInput, elem: &Elem) -> Result<TokenStream> {
fn build_elem(input: &FmtInput, state: &mut State, elem: &Elem) -> Result<TokenStream> {
match elem {
Elem::Lit(Lit { lit, .. }) => Ok(Self::build_lit(input, lit)),
Elem::Var(Var { name, .. }) => Self::build_var(input, name),
Elem::UnnamedVar(UnnamedVar { index, .. }) => Self::build_unnamed_var(input, *index),
Elem::Directive(ref d) => Self::build_directive(input, d),
Elem::Lit(Lit { lit, .. }) => Ok(Self::build_lit(input, state, lit)),
Elem::Var(Var { name, .. }) => Self::build_var(input, state, name),
Elem::UnnamedVar(UnnamedVar { index, .. }) => {
Self::build_unnamed_var(input, state, *index)
}
Elem::Directive(ref d) => Self::build_directive(input, state, d),
}
}

fn build_var(input: &FmtInput, name: &str) -> Result<TokenStream>;
fn build_unnamed_var(input: &FmtInput, index: usize) -> Result<TokenStream>;
fn build_directive(input: &FmtInput, d: &Directive) -> Result<TokenStream>;
fn build_var(input: &FmtInput, state: &mut State, name: &str) -> Result<TokenStream>;
fn build_unnamed_var(input: &FmtInput, state: &mut State, index: usize) -> Result<TokenStream>;
fn build_directive(input: &FmtInput, state: &mut State, d: &Directive) -> Result<TokenStream>;
}

struct DeriveBasePrintable;

impl PrintableBuilder for DeriveBasePrintable {
fn build_var(input: &FmtInput, name: &str) -> Result<TokenStream> {
impl PrintableBuilder<()> for DeriveBasePrintable {
fn build_var(input: &FmtInput, _state: &mut (), name: &str) -> Result<TokenStream> {
let FmtData::Struct(ref struct_fields) = input.data;
if !struct_fields
.fields
Expand All @@ -168,7 +175,7 @@ impl PrintableBuilder for DeriveBasePrintable {
Ok(quote! { ::pliron::printable::Printable::fmt(&self.#field, ctx, state, fmt)?; })
}

fn build_unnamed_var(input: &FmtInput, index: usize) -> Result<TokenStream> {
fn build_unnamed_var(input: &FmtInput, _state: &mut (), index: usize) -> Result<TokenStream> {
// This is a struct unnamed field access.
let FmtData::Struct(ref struct_fields) = input.data;
if !struct_fields
Expand All @@ -186,15 +193,19 @@ impl PrintableBuilder for DeriveBasePrintable {
Ok(quote! { ::pliron::printable::Printable::fmt(&self.#index, ctx, state, fmt)?; })
}

fn build_directive(_input: &FmtInput, _d: &Directive) -> Result<TokenStream> {
fn build_directive(_input: &FmtInput, _state: &mut (), _d: &Directive) -> Result<TokenStream> {
todo!()
}
}

struct DeriveOpPrintable;

impl PrintableBuilder for DeriveOpPrintable {
fn build_var(input: &FmtInput, attr_name: &str) -> Result<TokenStream> {
impl PrintableBuilder<OpPrinterState> for DeriveOpPrintable {
fn build_var(
input: &FmtInput,
_state: &mut OpPrinterState,
attr_name: &str,
) -> Result<TokenStream> {
let attr_name = attr_name.to_string();
let op_name = input.ident.clone();
let missing_attr_err = format!("Missing attribute {} on Op {}", &attr_name, &op_name);
Expand All @@ -208,15 +219,24 @@ impl PrintableBuilder for DeriveOpPrintable {
})
}

fn build_unnamed_var(_input: &FmtInput, index: usize) -> Result<TokenStream> {
fn build_unnamed_var(
_input: &FmtInput,
_state: &mut OpPrinterState,
index: usize,
) -> Result<TokenStream> {
Ok(quote! {
let opd = self.get_operation().deref(ctx).get_operand(#index);
::pliron::printable::Printable::fmt(&opd, ctx, state, fmt)?;
})
}

fn build_directive(input: &FmtInput, d: &Directive) -> Result<TokenStream> {
fn build_directive(
input: &FmtInput,
state: &mut OpPrinterState,
d: &Directive,
) -> Result<TokenStream> {
if d.name == "canonical" {
state.is_canonical = true;
Ok(quote! { ::pliron::op::canonical_syntax_print(Box::new(*self), ctx, state, fmt)?; })
} else if d.name == "type" {
let err = Err(syn::Error::new_spanned(
Expand All @@ -239,19 +259,23 @@ impl PrintableBuilder for DeriveOpPrintable {
}
}

fn build_body(input: &FmtInput) -> Result<TokenStream> {
let mut output = quote! {
use ::pliron::op::Op;
use ::pliron::irfmt::printers::iter_with_sep;
let op = self.get_operation().deref(ctx);
if op.get_num_results() > 0 {
let sep = ::pliron::printable::ListSeparator::CharSpace(',');
let results = iter_with_sep(op.results(), sep);
write!(fmt, "{} = ", results.disp(ctx))?;
}
write!(fmt, "{} ", self.get_opid())?;
};
output.extend(Self::build_format(input)?);
fn build_body(input: &FmtInput, state: &mut OpPrinterState) -> Result<TokenStream> {
let formatted_tokens = Self::build_format(input, state)?;
let mut output = quote! {};
if !state.is_canonical {
output.extend(quote! {
use ::pliron::op::Op;
use ::pliron::irfmt::printers::iter_with_sep;
let op = self.get_operation().deref(ctx);
if op.get_num_results() > 0 {
let sep = ::pliron::printable::ListSeparator::CharSpace(',');
let results = iter_with_sep(op.results(), sep);
write!(fmt, "{} = ", results.disp(ctx))?;
}
write!(fmt, "{} ", self.get_opid())?;
});
}
output.extend(formatted_tokens);
Ok(output)
}
}
Expand Down
9 changes: 8 additions & 1 deletion src/irfmt/parsers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
value::Value,
};
use combine::{
between, many1,
between, many, many1,
parser::char::{digit, spaces},
sep_by, token, Parser, Stream,
};
Expand Down Expand Up @@ -104,6 +104,13 @@ pub fn delimited_list_parser<Input: Stream<Token = char>, Output>(
)
}

/// Parse zero-or-more occurrences (ignoring spaces) of `parser`.
pub fn zero_or_more_parser<Input: Stream<Token = char>, Output>(
parser: impl Parser<Input, Output = Output>,
) -> impl Parser<Input, Output = Vec<Output>> {
many::<Vec<_>, _, _>(spaces().with(parser.skip(spaces())))
}

/// Parse an identifier into an SSA [Value]. Typically called to parse
/// the SSA operands of an [Operation]. If the SSA value hasn't been defined yet,
/// a [forward reference](crate::builtin::ops::ForwardRefOp) is returned.
Expand Down
33 changes: 25 additions & 8 deletions src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use crate::{
irfmt::{
parsers::{
block_opd_parser, delimited_list_parser, location, process_parsed_ssa_defs, spaced,
ssa_opd_parser,
ssa_opd_parser, zero_or_more_parser,
},
printers::{functional_type, iter_with_sep},
},
Expand All @@ -62,6 +62,7 @@ use crate::{
parsable::{IntoParseResult, Parsable, ParseResult, ParserFn, StateStream},
printable::{self, Printable},
r#type::Typed,
region::Region,
result::Result,
};

Expand Down Expand Up @@ -291,12 +292,11 @@ pub static OP_INTERFACE_VERIFIERS_MAP: LazyLock<

/// Printer for an [Op] in canonical syntax.
/// `res_1, res_2, ... res_n =
/// op_id (opd_1, opd_2, ... opd_n) [succ_1, succ_2, ... succ_n] [attr-dic]: function-type`
/// TODO: Handle operations with regions.
/// op_id (opd_1, opd_2, ... opd_n) [succ_1, succ_2, ... succ_n] [attr-dict]: function-type (regions)*`
pub fn canonical_syntax_print(
op: OpObj,
ctx: &Context,
_state: &printable::State,
state: &printable::State,
f: &mut fmt::Formatter<'_>,
) -> fmt::Result {
let sep = printable::ListSeparator::CharSpace(',');
Expand All @@ -307,21 +307,27 @@ pub fn canonical_syntax_print(
iter_with_sep(op.operands().map(|opd| opd.get_type(ctx)), sep),
iter_with_sep(op.results().map(|res| res.get_type(ctx)), sep),
);
let regions = iter_with_sep(op.regions.iter(), printable::ListSeparator::Newline);

if op.get_num_results() != 0 {
let results = iter_with_sep(op.results(), sep);
write!(f, "{} = ", results.disp(ctx))?;
}
let ret = write!(

write!(
f,
"{} ({}) [{}] {}: {}",
op.get_opid().disp(ctx),
operands.disp(ctx),
successors.disp(ctx),
op.attributes.disp(ctx),
op_type.disp(ctx),
);
ret
)?;

if op.regions.len() > 0 {
regions.fmt(ctx, state, f)?;
}
Ok(())
}

#[derive(Error, Debug)]
Expand All @@ -339,14 +345,22 @@ pub fn canonical_syntax_parse<'a>(
state_stream: &mut StateStream<'a>,
results: Vec<(Identifier, Location)>,
) -> ParseResult<'a, OpObj> {
let parent_for_regions = state_stream.state.parent_for_regions;
// Results and opid have already been parsed. Continue after that.
delimited_list_parser('(', ')', ',', ssa_opd_parser())
.and(spaces().with(delimited_list_parser('[', ']', ',', block_opd_parser())))
.and(spaces().with(AttributeDict::parser(())))
.skip(spaced(token(':')))
.and((location(), FunctionType::parser(())))
.and((
location(),
zero_or_more_parser(Region::parser(parent_for_regions)),
))
.then(
move |(((operands, successors), attr_dict), (fty_loc, fty))| {
move |(
(((operands, successors), attr_dict), (fty_loc, fty)),
(_regions_loc, regions),
)| {
let opid = opid.clone();
let results = results.clone();
let fty_loc = fty_loc.clone();
Expand Down Expand Up @@ -383,6 +397,9 @@ pub fn canonical_syntax_parse<'a>(
);
opr.deref_mut(ctx).attributes = attr_dict.clone();
let op = from_operation(ctx, opr);
for region in regions.iter() {
Region::move_to_op(*region, opr, ctx);
}
process_parsed_ssa_defs(parsable_state, &results, opr)?;
Ok(op).into_parse_result()
})
Expand Down
24 changes: 23 additions & 1 deletion src/parsable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::{
input_err,
irfmt::parsers::int_parser,
location::{self, Located, Location},
op::op_impls,
op::{op_impls, Op},
operation::Operation,
result::{self, Result},
value::Value,
Expand All @@ -39,19 +39,41 @@ pub struct State<'a> {
pub ctx: &'a mut Context,
pub(crate) name_tracker: NameTracker,
pub src: location::Source,
/// Constructing a [Region](crate::region::Region) requires a parent [Operation],
/// but during parsing, the parent itself may not be constructed yet.
/// So we use a dummy [ForwardRefOp] to represent the parent until it is actually constructed,
/// and then move the region to the actual parent.
pub parent_for_regions: Ptr<Operation>,
}

impl<'a> State<'a> {
/// Create a new empty [State].
pub fn new(ctx: &'a mut Context, src: location::Source) -> State<'a> {
let parent_for_regions = ForwardRefOp::new(ctx).get_operation();
State {
ctx,
name_tracker: NameTracker::default(),
src,
parent_for_regions,
}
}
}

impl<'a> Drop for State<'a> {
// Ensure that `parent_for_regions` doesn't have any regions left and erase it.
fn drop(&mut self) {
let parent_for_regions = self.parent_for_regions;
// This assert is disabled because, if the parser fails, then we could have
// regions that were constructed but not moved to their parents.
assert!(
true || parent_for_regions.deref(self.ctx).num_regions() == 0,
"Regions constructed during parsing must be moved to \
their respective parents before the end of parsing"
);
Operation::erase(parent_for_regions, self.ctx);
}
}

/// A wrapper around any [char] [Iterator] object.
/// Buffering and positioning are automatically handled hereafter.
pub struct CharIterator<'a>(Box<dyn Iterator<Item = char> + 'a>);
Expand Down
20 changes: 20 additions & 0 deletions src/region.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ impl Region {
Self::alloc(ctx, f)
}

/// Move this [Region] to (the end of) a different [Operation].
/// If `new_parent_op` is same as the current parent, no action.
/// Indices of other regions in the current parent will be invalidated.
pub fn move_to_op(ptr: Ptr<Self>, new_parent_op: Ptr<Operation>, ctx: &mut Context) {
let src = ptr.deref(ctx).get_parent_op();
if src == new_parent_op {
return;
}
let regions = &mut src.deref_mut(ctx).regions;
regions.swap_remove(
regions
.iter()
.position(|x| *x == ptr)
.expect("Region missing in it's current parent Operations"),
);
new_parent_op.deref_mut(ctx).regions.push(ptr);
ptr.deref_mut(ctx).parent_op = new_parent_op;
}

/// Get the operation that contains this region.
pub fn get_parent_op(&self) -> Ptr<Operation> {
self.parent_op
Expand Down Expand Up @@ -109,6 +128,7 @@ impl Printable for Region {
state: &printable::State,
f: &mut core::fmt::Formatter<'_>,
) -> core::fmt::Result {
fmt_indented_newline(state, f)?;
write!(f, "{{")?;

indented_block!(state, {
Expand Down

0 comments on commit 7fd6eea

Please sign in to comment.