diff --git a/parser/src/parser/mod.rs b/parser/src/parser/mod.rs index a9fe084..d8e3fc0 100644 --- a/parser/src/parser/mod.rs +++ b/parser/src/parser/mod.rs @@ -200,7 +200,14 @@ impl Parser<'_> { TokenKind::For => self.parse_for()?, TokenKind::Break => (ast::NodeValue::Break, range.end), TokenKind::Continue => (ast::NodeValue::Continue, range.end), - TokenKind::Return => todo!("parse return statement"), + TokenKind::Return => { + let token = self.next_token()?; + let node = self.parse_node(token, Precedence::Lowest)?; + validate_node_kind(&node, NodeKind::Expression)?; + + let end = node.range.end; + (ast::NodeValue::Return(Box::new(node)), end) + } TokenKind::Fn => self.parse_fn_literal()?, TokenKind::Use => { let token = self.next_token()?; diff --git a/parser/src/parser/test.rs b/parser/src/parser/test.rs index 55de514..94a1170 100644 --- a/parser/src/parser/test.rs +++ b/parser/src/parser/test.rs @@ -1132,6 +1132,47 @@ fn fn_literal_named() -> Result<()> { Ok(()) } +#[test] +fn return_statement() -> Result<()> { + let program = parse("return 1 + 2")?; + + assert_eq!(program.statements.len(), 1); + assert_eq!( + program.statements[0], + ast::Node { + value: ast::NodeValue::Return(Box::new(ast::Node { + value: ast::NodeValue::InfixOperator { + operator: ast::InfixOperatorKind::Add, + left: Box::new(ast::Node { + value: ast::NodeValue::IntegerLiteral(1), + range: Range { + start: Position::new(0, 7), + end: Position::new(0, 8) + } + }), + right: Box::new(ast::Node { + value: ast::NodeValue::IntegerLiteral(2), + range: Range { + start: Position::new(0, 11), + end: Position::new(0, 12) + } + }) + }, + range: Range { + start: Position::new(0, 7), + end: Position::new(0, 12) + } + })), + range: Range { + start: Position::new(0, 0), + end: Position::new(0, 12) + } + } + ); + + Ok(()) +} + #[test] fn errors() { let tests = [ @@ -1175,6 +1216,19 @@ fn errors() { }, }, ), + ( + "return continue", + Error { + kind: ErrorKind::InvalidNodeKind { + expected: ast::NodeKind::Expression, + got: ast::NodeKind::Statement, + }, + range: Range { + start: Position::new(0, 7), + end: Position::new(0, 15), + }, + }, + ), ]; for (input, expected) in tests { @@ -1217,6 +1271,8 @@ fn precedence() -> Result<()> { ), ("// comment", ""), ("//", ""), + ("return 1 + 1 * 2", "return (1 + (1 * 2))"), + ("return fn(){}", "return fn() {}"), ]; for (input, expected) in tests {