From 1e88fe5aa4d139c3357d41a4d720493e43b8b3b5 Mon Sep 17 00:00:00 2001 From: Volkan Date: Thu, 7 Dec 2023 10:17:01 +0100 Subject: [PATCH] fix(codegen): Convert pointer to actual type when passing by-value (#1039) We were missing an edge-case where an argument is passed into a function A by-ref (INOUT) which in turn is passed into a function B by-val (INPUT) which generated the following IR for a DINT array ``` %call = call i32 @function1(i32 %load_x)` // we want `@function1([5 x i32] %load_x)` however ``` This commit fixes it, in that the commented IR is generated - specifically before passing the argument into function B it is bit-cast into its actual type. --- .../generators/expression_generator.rs | 47 ++++++++++++++-- src/codegen/tests/function_tests.rs | 33 ++++++++++++ ...ests__argument_fed_by_ref_then_by_val.snap | 49 +++++++++++++++++ src/typesystem.rs | 4 ++ tests/correctness/functions.rs | 53 +++++++++++++++++++ 5 files changed, 181 insertions(+), 5 deletions(-) create mode 100644 src/codegen/tests/snapshots/rusty__codegen__tests__function_tests__argument_fed_by_ref_then_by_val.snap diff --git a/src/codegen/generators/expression_generator.rs b/src/codegen/generators/expression_generator.rs index b0e279a18e..cfaf1f9be9 100644 --- a/src/codegen/generators/expression_generator.rs +++ b/src/codegen/generators/expression_generator.rs @@ -168,6 +168,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { // we trust that the validator only passed us valid parameters (so left & right should be same type) return self.generate_expression(statement); } + let v = self .generate_expression_value(expression)? .as_r_value(self.llvm, self.get_load_name(expression)) @@ -763,12 +764,48 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { type_name: &str, param_statement: &AstNode, ) -> Result, Diagnostic> { - Ok(match self.index.find_effective_type_by_name(type_name) { - Some(type_info) if type_info.information.is_string() => { - self.generate_string_argument(type_info, param_statement)? + let Some(type_info) = self.index.find_effective_type_by_name(type_name) else { + return self.generate_expression(param_statement); + }; + + if type_info.is_string() { + return self.generate_string_argument(type_info, param_statement); + } + + // https://github.com/PLC-lang/rusty/issues/1037: + // This if-statement covers the case where we want to convert a pointer into its actual + // type, e.g. if an argument is passed into a function A by-ref (INOUT) which in turn is + // passed into another function B by-val (INPUT) then the pointer argument in function A has + // to be bit-cast into its actual type before passing it into function B. + if type_info.is_aggregate_type() && !type_info.is_vla() { + let deref = self.generate_expression_value(param_statement)?; + + if deref.get_basic_value_enum().is_pointer_value() { + let ty = self.llvm_index.get_associated_type(type_name)?; + let cast = self.llvm.builder.build_bitcast( + deref.get_basic_value_enum(), + ty.ptr_type(AddressSpace::from(ADDRESS_SPACE_GENERIC)), + "", + ); + + let load = self.llvm.builder.build_load( + cast.into_pointer_value(), + &self.get_load_name(param_statement).unwrap_or_default(), + ); + + if let Some(target_ty) = self.annotations.get_type_hint(param_statement, self.index) { + let actual_ty = self.annotations.get_type_or_void(param_statement, self.index); + let annotation = self.annotations.get(param_statement); + + return Ok(cast_if_needed!(self, target_ty, actual_ty, load, annotation)); + } + + return Ok(load); } - _ => self.generate_expression(param_statement)?, - }) + } + + // Fallback + self.generate_expression(param_statement) } /// Before passing a string to a function, it is copied to a new string with the diff --git a/src/codegen/tests/function_tests.rs b/src/codegen/tests/function_tests.rs index a01d117879..d12f525b5e 100644 --- a/src/codegen/tests/function_tests.rs +++ b/src/codegen/tests/function_tests.rs @@ -370,3 +370,36 @@ fn return_variable_in_nested_call() { // we want a call passing the return-variable as apointer (actually the adress as a LWORD) insta::assert_snapshot!(codegen(src)); } + +#[test] +fn argument_fed_by_ref_then_by_val() { + let result = codegen( + " + TYPE MyType : ARRAY[1..5] OF DWORD; END_TYPE + + FUNCTION main : DINT + VAR + arr : MyType; + END_VAR + + fn_by_ref(arr); + END_FUNCTION + + FUNCTION fn_by_ref : DINT + VAR_IN_OUT + arg_by_ref : MyType; + END_VAR + + fn_by_val(arg_by_ref); + END_FUNCTION + + FUNCTION fn_by_val : DINT + VAR_INPUT + arg_by_val : MyType; + END_VAR + END_FUNCTION + ", + ); + + insta::assert_snapshot!(result) +} diff --git a/src/codegen/tests/snapshots/rusty__codegen__tests__function_tests__argument_fed_by_ref_then_by_val.snap b/src/codegen/tests/snapshots/rusty__codegen__tests__function_tests__argument_fed_by_ref_then_by_val.snap new file mode 100644 index 0000000000..d382b75e49 --- /dev/null +++ b/src/codegen/tests/snapshots/rusty__codegen__tests__function_tests__argument_fed_by_ref_then_by_val.snap @@ -0,0 +1,49 @@ +--- +source: src/codegen/tests/function_tests.rs +expression: result +--- +; ModuleID = 'main' +source_filename = "main" + +define i32 @main() { +entry: + %main = alloca i32, align 4 + %arr = alloca [5 x i32], align 4 + %0 = bitcast [5 x i32]* %arr to i8* + call void @llvm.memset.p0i8.i64(i8* align 1 %0, i8 0, i64 ptrtoint ([5 x i32]* getelementptr ([5 x i32], [5 x i32]* null, i32 1) to i64), i1 false) + store i32 0, i32* %main, align 4 + %1 = bitcast [5 x i32]* %arr to i32* + %call = call i32 @fn_by_ref(i32* %1) + %main_ret = load i32, i32* %main, align 4 + ret i32 %main_ret +} + +define i32 @fn_by_ref(i32* %0) { +entry: + %fn_by_ref = alloca i32, align 4 + %arg_by_ref = alloca i32*, align 8 + store i32* %0, i32** %arg_by_ref, align 8 + store i32 0, i32* %fn_by_ref, align 4 + %deref = load i32*, i32** %arg_by_ref, align 8 + %1 = bitcast i32* %deref to [5 x i32]* + %load_arg_by_ref = load [5 x i32], [5 x i32]* %1, align 4 + %call = call i32 @fn_by_val([5 x i32] %load_arg_by_ref) + %fn_by_ref_ret = load i32, i32* %fn_by_ref, align 4 + ret i32 %fn_by_ref_ret +} + +define i32 @fn_by_val([5 x i32] %0) { +entry: + %fn_by_val = alloca i32, align 4 + %arg_by_val = alloca [5 x i32], align 4 + store [5 x i32] %0, [5 x i32]* %arg_by_val, align 4 + store i32 0, i32* %fn_by_val, align 4 + %fn_by_val_ret = load i32, i32* %fn_by_val, align 4 + ret i32 %fn_by_val_ret +} + +; Function Attrs: argmemonly nofree nounwind willreturn writeonly +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) #0 + +attributes #0 = { argmemonly nofree nounwind willreturn writeonly } + diff --git a/src/typesystem.rs b/src/typesystem.rs index 2a9bd3f4e4..e97cdd34d7 100644 --- a/src/typesystem.rs +++ b/src/typesystem.rs @@ -176,6 +176,10 @@ impl DataType { self.get_type_information().is_aggregate() } + pub fn is_string(&self) -> bool { + self.get_type_information().is_string() + } + pub fn get_nature(&self) -> TypeNature { self.nature } diff --git a/tests/correctness/functions.rs b/tests/correctness/functions.rs index 9fb9cc106d..3ab64d6154 100644 --- a/tests/correctness/functions.rs +++ b/tests/correctness/functions.rs @@ -1243,3 +1243,56 @@ fn sizeof_len() { assert_eq!(13, res); } + +#[test] +fn argument_passed_by_ref_then_by_val() { + #[repr(C)] + struct MainType { + arr: [i32; 5], + } + + let source = r" + TYPE MyType : ARRAY[1..5] OF DINT; END_TYPE + + PROGRAM main + VAR + arr : MyType; + END_VAR + + fn_by_ref(arr); + END_PROGRAM + + FUNCTION fn_by_ref : DINT + VAR_IN_OUT + arg_by_ref : MyType; + END_VAR + + // These SHOULD modify the underlying array passed from main + arg_by_ref[1] := 1; + arg_by_ref[2] := 2; + arg_by_ref[3] := 3; + arg_by_ref[4] := 4; + arg_by_ref[5] := 5; + + fn_by_val(arg_by_ref); + END_FUNCTION + + FUNCTION fn_by_val : DINT + VAR_INPUT + arg_by_val : MyType; + END_VAR + + // These should NOT modify the underlying array passed from main + arg_by_val[1] := 10; + arg_by_val[2] := 20; + arg_by_val[3] := 30; + arg_by_val[4] := 40; + arg_by_val[5] := 50; + END_FUNCTION + "; + + let mut maintype = MainType { arr: [0; 5] }; + let _: i32 = compile_and_run(source, &mut maintype); + + assert_eq!(maintype.arr, [1, 2, 3, 4, 5]); +}