diff --git a/parser/src/ast.rs b/parser/src/ast.rs index 9ac15ae..9d198dd 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -39,11 +39,7 @@ pub enum NodeValue { left: Box, index: Box, }, - If { - condition: Box, - consequence: Vec, - alternative: Vec, - }, + If(IfNode), While { condition: Box, body: Vec, @@ -76,6 +72,13 @@ pub struct HashLiteralPair { pub value: Node, } +#[derive(Debug, PartialEq, Clone)] +pub struct IfNode { + pub condition: Box, + pub consequence: Vec, + pub alternative: Vec, +} + #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum PrefixOperatorKind { Not, @@ -217,24 +220,26 @@ impl Display for NodeValue { } => write!(f, "({left} {operator} {right})"), NodeValue::Assign { ident, value } => write!(f, "({ident} = {value})"), NodeValue::Index { left, index } => write!(f, "({left}[{index}])"), - NodeValue::If { - condition, - consequence, - alternative, - } => { - let cons = consequence + NodeValue::If(if_node) => { + let cons = if_node + .consequence .iter() .map(|node| node.to_string()) .collect::>() .join("\n"); - let alt = alternative + let alt = if_node + .alternative .iter() .map(|node| node.to_string()) .collect::>() .join("\n"); - write!(f, "if ({condition}) {{{cons}}} else {{{alt}}}") + write!( + f, + "if ({}) {{{}}} else {{{}}}", + if_node.condition, cons, alt + ) } NodeValue::While { condition, body } => { let body = body diff --git a/parser/src/parser/mod.rs b/parser/src/parser/mod.rs index 0ce4d35..89c3b52 100644 --- a/parser/src/parser/mod.rs +++ b/parser/src/parser/mod.rs @@ -159,7 +159,10 @@ impl Parser<'_> { TokenKind::LBracket => self.parse_grouped(range)?, TokenKind::LSquare => self.parse_array_literal(range)?, TokenKind::LCurly => self.parse_hash_map_literal(range)?, - TokenKind::If => todo!("parse if statement"), + TokenKind::If => { + let (if_node, end) = self.parse_if(range)?; + (ast::NodeValue::If(if_node), end) + } TokenKind::While => todo!("parse while loop"), TokenKind::For => todo!("parse for loop"), TokenKind::Break => (ast::NodeValue::Break, range.end), @@ -473,6 +476,153 @@ impl Parser<'_> { )) } + fn parse_if(&mut self, start_range: Range) -> Result<(ast::IfNode, Position)> { + // Read `(` + let token = self.next_token(start_range)?; + if token.kind != TokenKind::LBracket { + return Err(Error { + kind: ErrorKind::InvalidTokenKind { + expected: TokenKind::LBracket, + got: token.kind, + }, + range: token.range, + }); + } + + // Parse condition + let cond_token = self.next_token(Range { + start: start_range.start, + end: token.range.end, + })?; + let condition = self.parse_node(cond_token, Precedence::Lowest)?; + validate_node_kind(&condition, NodeKind::Expression)?; + + // Read `)` + let token = self.next_token(Range { + start: start_range.start, + end: condition.range.end, + })?; + if token.kind != TokenKind::RBracket { + return Err(Error { + kind: ErrorKind::InvalidTokenKind { + expected: TokenKind::RBracket, + got: token.kind, + }, + range: token.range, + }); + } + + // Parse consequence + let cons_token = self.next_token(Range { + start: start_range.start, + end: token.range.end, + })?; + let (consequence, cons_end) = self.parse_block(cons_token)?; + + // Construct the if node + let mut if_node = ast::IfNode { + condition: Box::new(condition), + consequence, + alternative: vec![], + }; + + // After consequence we can have eof, eol or else. + peek_token!(self, else_token, return Ok((if_node, cons_end))); + if else_token.kind == TokenKind::Eol { + return Ok((if_node, cons_end)); + } + + if else_token.kind != TokenKind::Else { + return Err(Error { + kind: ErrorKind::InvalidTokenKind { + expected: TokenKind::Else, + got: else_token.kind.clone(), + }, + range: else_token.range, + }); + } + + // Read else token and discard it. + let else_token_range = else_token.range; + self.lexer.next(); + + // Handle else and else if + let token = self.next_token(else_token_range)?; + if token.kind == TokenKind::If { + let (alternative, alternative_end) = self.parse_if(token.range)?; + + if_node.alternative = vec![ast::Node { + value: ast::NodeValue::If(alternative), + range: Range { + start: token.range.start, + end: alternative_end, + }, + }]; + Ok((if_node, alternative_end)) + } else { + let (alternative, alternative_end) = self.parse_block(token)?; + + if_node.alternative = alternative; + Ok((if_node, alternative_end)) + } + } + + // Helper function that reads block { ... }. + // It returns vector of nodes and end position, which is the end + // position of `}` + // + // This function checks if the start token is `{`, so the caller doesn't have to do this. + fn parse_block(&mut self, start_token: Token) -> Result<(Vec, Position)> { + // Start token should be `{` + validate::validate_token_kind(&start_token, TokenKind::LCurly)?; + + let mut nodes = Vec::new(); + let mut end = start_token.range.end; + loop { + // Skip \n's + self.skip_eol()?; + + // Check if next token is `}`. In this case we are done with the block + let token = self.next_token(Range { + start: start_token.range.start, + end, + })?; + + if token.kind == TokenKind::RCurly { + return Ok((nodes, token.range.end)); + } + + // Parse next node + let node = self.parse_node(token, Precedence::Lowest)?; + end = node.range.end; + nodes.push(node); + + // Token after ndoe should be one of: + // - `}` => We are done with the block + // - `\n` => We repeat the loop + // - `// ...` => We repeat the loop + // Otherwise, we return an error + let token = self.next_token(Range { + start: start_token.range.start, + end, + })?; + + if token.kind == TokenKind::RCurly { + return Ok((nodes, token.range.end)); + } + + if !matches!(token.kind, TokenKind::Eol | TokenKind::Comment(_)) { + return Err(Error { + kind: ErrorKind::InvalidTokenKind { + expected: TokenKind::RCurly, + got: token.kind, + }, + range: token.range, + }); + } + } + } + // Helper function used for parsing arrays, hash maps, function arguments, function calls. fn parse_multiple( &mut self, diff --git a/parser/src/parser/test.rs b/parser/src/parser/test.rs index 413bcb1..5457f16 100644 --- a/parser/src/parser/test.rs +++ b/parser/src/parser/test.rs @@ -800,6 +800,97 @@ fn index() -> Result<()> { Ok(()) } +#[test] +fn if_node() -> Result<()> { + let tests = [ + ( + "if (true) {}", + ast::Node { + value: ast::NodeValue::If(ast::IfNode { + condition: Box::new(ast::Node { + value: ast::NodeValue::BoolLiteral(true), + range: Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + }), + consequence: vec![], + alternative: vec![], + }), + range: Range { + start: Position::new(0, 0), + end: Position::new(0, 12), + }, + }, + ), + ( + "if (true) {\n} else {\n}", + ast::Node { + value: ast::NodeValue::If(ast::IfNode { + condition: Box::new(ast::Node { + value: ast::NodeValue::BoolLiteral(true), + range: Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + }), + consequence: vec![], + alternative: vec![], + }), + range: Range { + start: Position::new(0, 0), + end: Position::new(2, 1), + }, + }, + ), + ( + "if (true) {\n} else if (false) {\n}", + ast::Node { + value: ast::NodeValue::If(ast::IfNode { + condition: Box::new(ast::Node { + value: ast::NodeValue::BoolLiteral(true), + range: Range { + start: Position::new(0, 4), + end: Position::new(0, 8), + }, + }), + consequence: vec![], + alternative: vec![ast::Node { + value: ast::NodeValue::If(ast::IfNode { + condition: Box::new(ast::Node { + value: ast::NodeValue::BoolLiteral(false), + range: Range { + start: Position::new(1, 11), + end: Position::new(1, 16), + }, + }), + consequence: vec![], + alternative: vec![], + }), + range: Range { + start: Position::new(1, 7), + end: Position::new(2, 1), + }, + }], + }), + range: Range { + start: Position::new(0, 0), + end: Position::new(2, 1), + }, + }, + ), + ]; + + for (input, expected) in tests { + let program = parse(input)?; + + assert_eq!(program.statements.len(), 1); + assert_eq!(program.statements[0], expected); + } + + Ok(()) +} + #[test] fn precedence() -> Result<()> { let tests = [ diff --git a/parser/src/parser/validate.rs b/parser/src/parser/validate.rs index e1d1761..90a1c5f 100644 --- a/parser/src/parser/validate.rs +++ b/parser/src/parser/validate.rs @@ -1,6 +1,7 @@ use crate::{ ast::{HashLiteralPair, Node, NodeKind, NodeValue}, error::{Error, ErrorKind, Result}, + token::{Token, TokenKind}, }; pub fn validate_hash_literal(items: &[HashLiteralPair]) -> Result<()> { @@ -78,3 +79,17 @@ pub fn validate_node_kind(node: &Node, expected: NodeKind) -> Result<()> { Ok(()) } + +pub fn validate_token_kind(token: &Token, expected: TokenKind) -> Result<()> { + if token.kind != expected { + return Err(Error { + kind: ErrorKind::InvalidTokenKind { + expected, + got: token.kind.clone(), + }, + range: token.range, + }); + } + + Ok(()) +}