diff --git a/src/parser/ast.rs b/src/parser/ast.rs index e171ffa..a07f8a3 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -9,7 +9,13 @@ pub enum Symbol<'de> { } #[derive(Debug, Clone)] -pub enum Statement<'de> { +pub struct Statement<'de> { + pub value: StatementValue<'de>, + pub span: miette::SourceSpan, +} + +#[derive(Debug, Clone)] +pub enum StatementValue<'de> { Block(Vec>), Expression(Expression<'de>), Assignment { @@ -27,7 +33,6 @@ pub enum Statement<'de> { Function { header: FunctionHeader<'de>, body: Expression<'de>, - explicit_return_type: Option>, }, Trait { name: Cow<'de, str>, @@ -47,40 +52,6 @@ pub struct Expression<'de> { pub span: miette::SourceSpan, } -impl<'de> Expression<'de> { - pub fn at(span: miette::SourceSpan, value: ExpressionValue<'de>) -> Self { - Self { value, span } - } - - pub fn at_multiple( - spans: Vec>, - value: ExpressionValue<'de>, - ) -> Expression<'de> { - let spans = spans.into_iter().map(|s| s.into()).collect::>(); - - let start = spans - .iter() - .min_by_key(|s| s.offset()) - .map(|s| s.offset()) - .unwrap_or(0); - - // Go through all the spans and find the one with the highest end offset - let end = spans - .iter() - .max_by_key(|s| s.offset() + s.len()) - .map(|s| s.offset() + s.len()) - .unwrap_or(0); - - let span = miette::SourceSpan::new(start.into(), end - start); - - Expression::at(span, value) - } - - pub fn label(&self, text: impl Into) -> miette::LabeledSpan { - miette::LabeledSpan::at(self.span, text.into()) - } -} - #[derive(Debug, Clone)] pub enum ExpressionValue<'de> { Primitive(Primitive<'de>), @@ -113,6 +84,7 @@ pub enum ExpressionValue<'de> { pub struct FunctionHeader<'de> { pub name: Cow<'de, str>, pub parameters: Vec>, + pub explicit_return_type: Option>, } #[derive(Debug, Clone)] @@ -168,7 +140,13 @@ pub enum UnaryOperator { } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Type<'de> { +pub struct Type<'de> { + pub value: TypeValue<'de>, + pub span: SourceSpan, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TypeValue<'de> { Unit, Boolean, Integer, @@ -179,3 +157,53 @@ pub enum Type<'de> { Collection(Box>), Set(Box>), } + +pub trait Spannable<'de>: Sized { + type Value; + + fn at(span: miette::SourceSpan, value: Self::Value) -> Self; + + fn at_multiple(spans: Vec>, value: Self::Value) -> Self { + let spans = spans.into_iter().map(|s| s.into()).collect::>(); + + let start = spans + .iter() + .min_by_key(|s| s.offset()) + .map(|s| s.offset()) + .unwrap_or(0); + + let end = spans + .iter() + .max_by_key(|s| s.offset() + s.len()) + .map(|s| s.offset() + s.len()) + .unwrap_or(0); + + let span = miette::SourceSpan::new(start.into(), end - start); + + Self::at(span, value) + } +} + +impl<'de> Spannable<'de> for Expression<'de> { + type Value = ExpressionValue<'de>; + + fn at(span: miette::SourceSpan, value: Self::Value) -> Self { + Self { value, span } + } +} + +impl<'de> Spannable<'de> for Statement<'de> { + type Value = StatementValue<'de>; + + fn at(span: miette::SourceSpan, value: Self::Value) -> Self { + Self { value, span } + } +} + +impl<'de> Spannable<'de> for Type<'de> { + type Value = TypeValue<'de>; + + fn at(span: miette::SourceSpan, value: Self::Value) -> Self { + Self { value, span } + } +} diff --git a/src/parser/expression/binary.rs b/src/parser/expression/binary.rs index 35ef41f..edb5078 100644 --- a/src/parser/expression/binary.rs +++ b/src/parser/expression/binary.rs @@ -1,5 +1,5 @@ use crate::parser::{ - ast::{BinaryOperator, Expression, ExpressionValue}, + ast::{BinaryOperator, Expression, ExpressionValue, Spannable}, lookup::BindingPower, Parser, }; diff --git a/src/parser/expression/mod.rs b/src/parser/expression/mod.rs index b14a22f..45f7aa7 100644 --- a/src/parser/expression/mod.rs +++ b/src/parser/expression/mod.rs @@ -1,7 +1,7 @@ use miette::Result; use super::{ - ast::{Expression, ExpressionValue}, + ast::{Expression, ExpressionValue, Spannable}, Parser, }; use crate::{lexer::TokenKind, parser::lookup::BindingPower}; diff --git a/src/parser/expression/primitive.rs b/src/parser/expression/primitive.rs index 5fd33be..cf004c5 100644 --- a/src/parser/expression/primitive.rs +++ b/src/parser/expression/primitive.rs @@ -1,7 +1,7 @@ use crate::{ lexer::{TokenKind, TokenValue}, parser::{ - ast::{Expression, ExpressionValue, Primitive}, + ast::{Expression, ExpressionValue, Primitive, Spannable}, Parser, }, }; diff --git a/src/parser/expression/unary.rs b/src/parser/expression/unary.rs index bf64043..dbc414d 100644 --- a/src/parser/expression/unary.rs +++ b/src/parser/expression/unary.rs @@ -1,7 +1,7 @@ use crate::{ lexer::TokenKind, parser::{ - ast::{Expression, ExpressionValue, UnaryOperator}, + ast::{Expression, ExpressionValue, Spannable, UnaryOperator}, lookup::BindingPower, Parser, }, diff --git a/src/parser/lookup.rs b/src/parser/lookup.rs index 5e75a8b..a9d7cea 100644 --- a/src/parser/lookup.rs +++ b/src/parser/lookup.rs @@ -1,5 +1,5 @@ use super::{ - ast::{Expression, ExpressionValue, Primitive, Statement, Type}, + ast::{Expression, ExpressionValue, Primitive, Spannable, Statement, StatementValue, Type}, expression, statement, typing, Parser, }; use crate::lexer::{TokenKind, TokenValue}; @@ -287,9 +287,9 @@ fn block<'de>(parser: &mut Parser<'de>) -> Result> { } let return_value = if last_is_return { - match statements.last() { - Some(Statement::Expression(_)) => match statements.pop() { - Some(Statement::Expression(expression)) => expression, + match statements.last().map(|s| &s.value) { + Some(StatementValue::Expression(_)) => match statements.pop().map(|s| s.value) { + Some(StatementValue::Expression(expression)) => expression, _ => unreachable!(), }, _ => Expression::at( diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 04e84ed..f952b08 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -1,5 +1,5 @@ use crate::lexer::Lexer; -use ast::{Statement, Symbol}; +use ast::{Spannable, Statement, StatementValue, Symbol}; use lookup::Lookup; use miette::Result; @@ -29,6 +29,9 @@ impl<'de> Parser<'de> { statements.push(statement::parse(self, false)?); } - Ok(Symbol::Statement(Statement::Block(statements))) + Ok(Symbol::Statement(Statement::at_multiple( + statements.iter().map(|s| s.span).collect(), + StatementValue::Block(statements), + ))) } } diff --git a/src/parser/statement.rs b/src/parser/statement.rs index 9f7eb4d..4ac44d6 100644 --- a/src/parser/statement.rs +++ b/src/parser/statement.rs @@ -1,13 +1,13 @@ use super::{ ast::{ - EnumMemberDeclaration, FunctionHeader, ParameterDeclaration, Statement, - StructMemberDeclaration, + EnumMemberDeclaration, FunctionHeader, ParameterDeclaration, Spannable, Statement, + StatementValue, StructMemberDeclaration, }, expression, lookup::BindingPower, statement, typing, Parser, }; -use crate::lexer::{TokenKind, TokenValue}; +use crate::lexer::{Token, TokenKind, TokenValue}; use miette::{Context, Result}; pub fn parse<'de>(parser: &mut Parser<'de>, optional_semicolon: bool) -> Result> { @@ -33,16 +33,21 @@ pub fn parse<'de>(parser: &mut Parser<'de>, optional_semicolon: bool) -> Result< .wrap_err("while parsing a statement")?; if !optional_semicolon { - parser + let token = parser .lexer .expect( TokenKind::Semicolon, "expected a semicolon at the end of an expression", ) .wrap_err(format!("while parsing for {}", token_kind))?; - } - Statement::Expression(expression) + Statement::at_multiple( + vec![expression.span, token.span], + StatementValue::Expression(expression), + ) + } else { + Statement::at(expression.span, StatementValue::Expression(expression)) + } } }; @@ -50,13 +55,13 @@ pub fn parse<'de>(parser: &mut Parser<'de>, optional_semicolon: bool) -> Result< } pub fn let_<'de>(parser: &mut Parser<'de>) -> Result> { - parser + let token = parser .lexer .expect(TokenKind::Let, "expected a let keyword")?; let identifier = parser .lexer .expect(TokenKind::Identifier, "expected a variable name")?; - let identifier = match identifier.value { + let name = match identifier.value { TokenValue::Identifier(identifier) => identifier, _ => unreachable!(), }; @@ -65,14 +70,17 @@ pub fn let_<'de>(parser: &mut Parser<'de>) -> Result> { .expect(TokenKind::Equal, "expected an equal sign")?; let expression = expression::parse(parser, BindingPower::None)?; - Ok(Statement::Assignment { - name: identifier, - value: expression, - }) + Ok(Statement::at_multiple( + vec![token.span, identifier.span], + StatementValue::Assignment { + name, + value: expression, + }, + )) } pub fn struct_<'de>(parser: &mut Parser<'de>) -> Result> { - parser + let token = parser .lexer .expect(TokenKind::Struct, "expected a struct keyword")?; @@ -80,7 +88,7 @@ pub fn struct_<'de>(parser: &mut Parser<'de>) -> Result> { .lexer .expect(TokenKind::Identifier, "expected a struct name")?; - let identifier = match identifier.value { + let name = match identifier.value { TokenValue::Identifier(identifier) => identifier, _ => unreachable!(), }; @@ -123,14 +131,14 @@ pub fn struct_<'de>(parser: &mut Parser<'de>) -> Result> { .lexer .expect(TokenKind::Semicolon, "expected a semicolon")?; - Ok(Statement::Struct { - name: identifier, - fields, - }) + Ok(Statement::at_multiple( + vec![token.span, identifier.span], + StatementValue::Struct { name, fields }, + )) } pub fn enum_<'de>(parser: &mut Parser<'de>) -> Result> { - parser + let token = parser .lexer .expect(TokenKind::Enum, "expected an enum keyword")?; @@ -138,7 +146,7 @@ pub fn enum_<'de>(parser: &mut Parser<'de>) -> Result> { .lexer .expect(TokenKind::Identifier, "expected an enum name")?; - let identifier = match identifier.value { + let name = match identifier.value { TokenValue::Identifier(identifier) => identifier, _ => unreachable!(), }; @@ -177,14 +185,14 @@ pub fn enum_<'de>(parser: &mut Parser<'de>) -> Result> { .lexer .expect(TokenKind::Semicolon, "expected a semicolon")?; - Ok(Statement::Enum { - name: identifier, - variants, - }) + Ok(Statement::at_multiple( + vec![token.span, identifier.span], + StatementValue::Enum { name, variants }, + )) } pub fn function_<'de>(parser: &mut Parser<'de>) -> Result> { - parser + let token = parser .lexer .expect(TokenKind::Function, "expected a function keyword")?; @@ -192,7 +200,7 @@ pub fn function_<'de>(parser: &mut Parser<'de>) -> Result> { .lexer .expect(TokenKind::Identifier, "expected function name")?; - let identifier = match identifier.value { + let name = match identifier.value { TokenValue::Identifier(identifier) => identifier, _ => unreachable!(), }; @@ -247,13 +255,13 @@ pub fn function_<'de>(parser: &mut Parser<'de>) -> Result> { let body = expression::parse(parser, BindingPower::None)?; - Ok(Statement::Function { + Ok(Statement::at_multiple(vec![], value)::Function { header: FunctionHeader { - name: identifier, + name, parameters, + explicit_return_type, }, body, - explicit_return_type, }) } diff --git a/src/passer/typing/mod.rs b/src/passer/typing/mod.rs index 113a81d..b59d56d 100644 --- a/src/passer/typing/mod.rs +++ b/src/passer/typing/mod.rs @@ -1,6 +1,6 @@ use super::{Passer, PasserResult}; use crate::parser::{ - ast::{Expression, ExpressionValue, Statement, Symbol, Type}, + ast::{Expression, ExpressionValue, Statement, StatementValue, Symbol, Type}, expression, }; use miette::{Error, LabeledSpan, Report, Result}; @@ -14,20 +14,20 @@ impl Passer for TypingPasser { } -pub fn walk_statement<'de>(statement: &Statement<'de>, statement_fn: fn(&Statement<'de>, expression_fn: fn(&Expression<'de>)) { +pub fn walk_statement<'de>(statement: &Statement<'de>, statement_fn: fn(&Statement<'de>), expression_fn: fn(&Expression<'de>)) { match statement { Statement::Block(statements) => statements.iter().for_each(|statement| { - walk_statement(statement, statement_fn); + walk_statement(statement, statement_fn, expression_fn); }), - Statement::Expression(expression) => walk_expression(expression), + Statement::Expression(expression) => walk_expression(statement_fn, expression_fn), Statement::Assignment { name, value } => walk_expression(expression, expression_fn), - Statement::Struct { name, fields } => todo!(), - Statement::Enum { name, variants } => todo!(), + Statement::Struct { name, fields } => {}, + Statement::Enum { name, variants } => {}, Statement::Function { header, body, explicit_return_type, - } => todo!(), + } => , Statement::Trait { name, functions } => todo!(), Statement::Return(expression) => todo!(), Statement::Conditional { @@ -38,8 +38,8 @@ pub fn walk_statement<'de>(statement: &Statement<'de>, statement_fn: fn(&Stateme } } -pub fn walk_expression<'de, T>(expression: Expression<'de>, statement_fn: fn(&Statement<'de>, expression_fn: fn(&Expression<'de>)) { - +pub fn walk_expression<'de, T>(expression: Expression<'de>, statement_fn: fn(&Statement<'de>, expression_fn: fn(&Expression<'de>))) { + todo!() } pub trait Typing { @@ -94,3 +94,40 @@ impl Typing for Expression<'_> { } } } + +impl Typing for Statement<'_> { + fn possible_types(&self) -> Vec<(Type, miette::SourceSpan)> { + match &self.value { + StatementValue::Block(statements) => vec![], + StatementValue::Expression(expression) => expression.possible_types(), + StatementValue::Assignment { name: _, value } => value.possible_types(), + StatementValue::Struct { name, fields } => vec![(Type::Symbol(name.clone()), self.span)], + StatementValue::Enum { name, variants } => vec![(Type::Symbol(name.clone()), self.span)], + StatementValue::Function { + header: _, + body, + + } => { + let mut types = body.possible_types(); + if let Some(explicit_return_type) = explicit_return_type { + types.push((explicit_return_type,)); + } + types + } + StatementValue::Trait { name, functions } => vec![(Type::Symbol(name.clone()), self.span)], + StatementValue::Return(expression) => expression.possible_types(), + StatementValue::Conditional { + condition, + truthy, + falsy, + } => { + let mut types = condition.possible_types(); + types.extend(truthy.possible_types()); + if let Some(falsy) = falsy { + types.extend(falsy.possible_types()); + } + types + } + } + } +}