diff --git a/.gitignore b/.gitignore index e0f63a0998..ea323cb541 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ /target **/*.rs.bk **/*.un~ -**/*.toml~ \ No newline at end of file +**/*.toml~ +*.ir +*.o diff --git a/examples/function_with_return.st b/examples/function_with_return.st new file mode 100644 index 0000000000..e1304b1500 --- /dev/null +++ b/examples/function_with_return.st @@ -0,0 +1,11 @@ +FUNCTION smaller_than_ten: BOOL + VAR_INPUT + n : INT; + END_VAR + + IF n < 10 THEN + smaller_than_ten := TRUE; + RETURN; + END_IF; + smaller_than_ten := FALSE; +END_FUNCTION diff --git a/src/ast.rs b/src/ast.rs index 134fc1e116..738520340f 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -523,6 +523,9 @@ pub enum Statement { CaseCondition { condition: Box, }, + ReturnStatement { + location: SourceRange, + }, } impl Debug for Statement { @@ -729,6 +732,7 @@ impl Debug for Statement { .debug_struct("CaseCondition") .field("condition", condition) .finish(), + Statement::ReturnStatement { .. } => f.debug_struct("ReturnStatement").finish(), } } } @@ -807,6 +811,7 @@ impl Statement { } Statement::MultipliedStatement { location, .. } => location.clone(), Statement::CaseCondition { condition } => condition.get_location(), + Statement::ReturnStatement { location } => location.clone(), } } } diff --git a/src/codegen/generators/pou_generator.rs b/src/codegen/generators/pou_generator.rs index b7ff754302..43a6a97e6f 100644 --- a/src/codegen/generators/pou_generator.rs +++ b/src/codegen/generators/pou_generator.rs @@ -138,6 +138,8 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { let statement_gen = StatementCodeGenerator::new( &self.llvm, self.index, + self, + implementation.pou_type, &local_index, &function_context, ); @@ -254,7 +256,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { /// generates the function's return statement only if the given pou_type is a `PouType::Function` /// /// a function returns the value of the local variable that has the function's name - fn generate_return_statement( + pub fn generate_return_statement( &self, function_context: &FunctionContext<'ink>, local_index: &LlvmTypedIndex<'ink>, diff --git a/src/codegen/generators/statement_generator.rs b/src/codegen/generators/statement_generator.rs index 0aee562fcb..3906c8e4d9 100644 --- a/src/codegen/generators/statement_generator.rs +++ b/src/codegen/generators/statement_generator.rs @@ -1,7 +1,9 @@ // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder use std::ops::Range; +use super::pou_generator::PouGenerator; use super::{expression_generator::ExpressionCodeGenerator, llvm::Llvm}; +use crate::ast::PouType; use crate::codegen::LlvmTypedIndex; use crate::typesystem::{RANGE_CHECK_LS_FN, RANGE_CHECK_LU_FN, RANGE_CHECK_S_FN, RANGE_CHECK_U_FN}; use crate::{ast::SourceRange, codegen::llvm_typesystem::cast_if_needed}; @@ -31,6 +33,8 @@ pub struct FunctionContext<'a> { pub struct StatementCodeGenerator<'a, 'b> { llvm: &'b Llvm<'a>, index: &'b Index, + pou_generator: &'b PouGenerator<'a, 'b>, + pou_type: PouType, llvm_index: &'b LlvmTypedIndex<'a>, function_context: &'b FunctionContext<'a>, @@ -43,12 +47,16 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { pub fn new( llvm: &'b Llvm<'a>, index: &'b Index, + pou_generator: &'b PouGenerator<'a, 'b>, + pou_type: PouType, llvm_index: &'b LlvmTypedIndex<'a>, linking_context: &'b FunctionContext<'a>, ) -> StatementCodeGenerator<'a, 'b> { StatementCodeGenerator { llvm, index, + pou_generator, + pou_type, llvm_index, function_context: linking_context, load_prefix: "load_".to_string(), @@ -119,6 +127,14 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { } => { self.generate_case_statement(selector, case_blocks, else_block)?; } + Statement::ReturnStatement { location } => { + self.pou_generator.generate_return_statement( + self.function_context, + self.llvm_index, + self.pou_type, + Some(location.clone()), + )?; + } _ => { self.create_expr_generator() .generate_expression(statement)?; diff --git a/src/codegen/tests/code_gen_tests.rs b/src/codegen/tests/code_gen_tests.rs index d07c93d26b..8a3127b9d8 100644 --- a/src/codegen/tests/code_gen_tests.rs +++ b/src/codegen/tests/code_gen_tests.rs @@ -3797,6 +3797,85 @@ fn nested_array_access() { assert_eq!(result, expected); } +#[test] +fn returning_early_in_function() { + let result = codegen!( + " + FUNCTION smaller_than_ten: INT + VAR_INPUT n : SINT; END_VAR + IF n < 10 THEN + RETURN; + END_IF; + END_FUNCTION + " + ); + + let expected = r#"; ModuleID = 'main' +source_filename = "main" + +%smaller_than_ten_interface = type { i8 } + +define i16 @smaller_than_ten(%smaller_than_ten_interface* %0) { +entry: + %n = getelementptr inbounds %smaller_than_ten_interface, %smaller_than_ten_interface* %0, i32 0, i32 0 + %smaller_than_ten = alloca i16, align 2 + %load_n = load i8, i8* %n, align 1 + %1 = sext i8 %load_n to i32 + %tmpVar = icmp slt i32 %1, 10 + br i1 %tmpVar, label %condition_body, label %continue + +condition_body: ; preds = %entry + %smaller_than_ten_ret = load i16, i16* %smaller_than_ten, align 2 + ret i16 %smaller_than_ten_ret + br label %continue + +continue: ; preds = %condition_body, %entry + %smaller_than_ten_ret1 = load i16, i16* %smaller_than_ten, align 2 + ret i16 %smaller_than_ten_ret1 +} +"#; + + assert_eq!(result, expected); +} + +#[test] +fn returning_early_in_function_block() { + let result = codegen!( + " + FUNCTION_BLOCK abcdef + VAR_INPUT n : SINT; END_VAR + IF n < 10 THEN + RETURN; + END_IF; + END_FUNCTION_BLOCK + " + ); + + let expected = r#"; ModuleID = 'main' +source_filename = "main" + +%abcdef_interface = type { i8 } + +define void @abcdef(%abcdef_interface* %0) { +entry: + %n = getelementptr inbounds %abcdef_interface, %abcdef_interface* %0, i32 0, i32 0 + %load_n = load i8, i8* %n, align 1 + %1 = sext i8 %load_n to i32 + %tmpVar = icmp slt i32 %1, 10 + br i1 %tmpVar, label %condition_body, label %continue + +condition_body: ; preds = %entry + ret void + br label %continue + +continue: ; preds = %condition_body, %entry + ret void +} +"#; + + assert_eq!(result, expected); +} + #[test] fn accessing_nested_array_in_struct() { let result = codegen!( diff --git a/src/lexer/tokens.rs b/src/lexer/tokens.rs index 4b115e0a4d..21d5479f2b 100644 --- a/src/lexer/tokens.rs +++ b/src/lexer/tokens.rs @@ -221,6 +221,9 @@ pub enum Token { #[token("CASE")] KeywordCase, + #[token("RETURN")] + KeywordReturn, + #[token("ARRAY")] KeywordArray, diff --git a/src/parser/control_parser.rs b/src/parser/control_parser.rs index 02488581e5..cb89df3228 100644 --- a/src/parser/control_parser.rs +++ b/src/parser/control_parser.rs @@ -17,10 +17,17 @@ pub fn parse_control_statement(lexer: &mut ParseSession) -> Statement { KeywordWhile => parse_while_statement(lexer), KeywordRepeat => parse_repeat_statement(lexer), KeywordCase => parse_case_statement(lexer), + KeywordReturn => parse_return_statement(lexer), _ => parse_statement(lexer), } } +fn parse_return_statement(lexer: &mut ParseSession) -> Statement { + let location = lexer.location(); + lexer.advance(); + Statement::ReturnStatement { location } +} + fn parse_if_statement(lexer: &mut ParseSession) -> Statement { let start = lexer.range().start; lexer.advance(); //If diff --git a/src/parser/tests/control_parser_tests.rs b/src/parser/tests/control_parser_tests.rs index f3e33418d3..4a315de41e 100644 --- a/src/parser/tests/control_parser_tests.rs +++ b/src/parser/tests/control_parser_tests.rs @@ -33,6 +33,16 @@ fn if_statement() { assert_eq!(ast_string, expected_ast); } +#[test] +fn test_return_statement() { + let lexer = super::lex("PROGRAM ret RETURN END_PROGRAM"); + let result = parse(lexer).0; + let prg = &result.implementations[0]; + let stmt = &prg.statements[0]; + + assert_eq!(format!("{:?}", stmt), "ReturnStatement"); +} + #[test] fn if_else_statement_with_expressions() { let lexer = super::lex( diff --git a/tests/correctness/control_flow.rs b/tests/correctness/control_flow.rs index 5e24e40268..d162cd7a13 100644 --- a/tests/correctness/control_flow.rs +++ b/tests/correctness/control_flow.rs @@ -100,6 +100,30 @@ fn adding_through_conditions_to_function_return() { assert_eq!(res, 100); } +#[test] +fn early_return_test() { + #[allow(dead_code)] + #[repr(C)] + struct MainType { + ret: i32, + } + + let function = r#" + FUNCTION main : DINT + main := 100; + // Windows does not like multiple returns in a + // row. That's why we wrap it inside a dummy IF. + IF TRUE THEN + RETURN + END_IF; + main := 200; + END_FUNCTION + "#; + + let (res, _) = compile_and_run(function.to_string(), &mut MainType { ret: 0 }); + assert_eq!(res, 100); +} + #[test] fn for_loop_and_increment_10_times() { #[allow(dead_code)]