From 3e81accf042ce8c0b4a95ba542ab26e9bb78359c Mon Sep 17 00:00:00 2001 From: peefy Date: Wed, 31 Jul 2024 16:04:26 +0800 Subject: [PATCH] feat: enhance runtime type cast and check for lambda arguments and return values Signed-off-by: peefy --- kclvm/compiler/src/codegen/llvm/node.rs | 54 ++++++++++++++++--- kclvm/evaluator/src/func.rs | 6 ++- kclvm/evaluator/src/node.rs | 24 ++++++--- .../type_annotation_schema_2/main.k | 11 ++++ .../type_annotation_schema_2/stdout.golden | 3 ++ .../type_annotation_schema_3/main.k | 11 ++++ .../type_annotation_schema_3/stdout.golden | 3 ++ 7 files changed, 98 insertions(+), 14 deletions(-) create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_2/main.k create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_3/main.k create mode 100644 test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden diff --git a/kclvm/compiler/src/codegen/llvm/node.rs b/kclvm/compiler/src/codegen/llvm/node.rs index 24f392f43..e3353a184 100644 --- a/kclvm/compiler/src/codegen/llvm/node.rs +++ b/kclvm/compiler/src/codegen/llvm/node.rs @@ -2179,9 +2179,21 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> { } } self.walk_arguments(&lambda_expr.args, args, kwargs); - let val = self + let mut val = self .walk_stmts(&lambda_expr.body) .expect(kcl_error::COMPILE_ERROR_MSG); + if let Some(ty) = &lambda_expr.return_ty { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + val = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + val, + type_annotation, + self.bool_value(false), + ], + ); + } self.builder.build_return(Some(&val)); // Exist the function self.builder.position_at_end(func_before_block); @@ -2731,23 +2743,39 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { kwargs: BasicValueEnum<'ctx>, ) { // Arguments names and defaults - let (arg_names, arg_defaults) = if let Some(args) = &arguments { + let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments { let names = &args.node.args; + let types = &args.node.ty_list; let defaults = &args.node.defaults; ( names.iter().map(|identifier| &identifier.node).collect(), + types.iter().collect(), defaults.iter().collect(), ) } else { - (vec![], vec![]) + (vec![], vec![], vec![]) }; // Default parameter values - for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) { - let arg_value = if let Some(value) = value { + for ((arg_name, arg_type), value) in + arg_names.iter().zip(&arg_types).zip(arg_defaults.iter()) + { + let mut arg_value = if let Some(value) = value { self.walk_expr(value).expect(kcl_error::COMPILE_ERROR_MSG) } else { self.none_value() }; + if let Some(ty) = arg_type { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + arg_value = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + arg_value, + type_annotation, + self.bool_value(false), + ], + ); + } // Arguments are immutable, so we place them in different scopes. self.store_argument_in_current_scope(&arg_name.get_name()); self.walk_identifier_with_ctx(arg_name, &ast::ExprContext::Store, Some(arg_value)) @@ -2756,7 +2784,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { // for loop in 0..argument_len in LLVM begin let argument_len = self.build_call(&ApiFunc::kclvm_list_len.name(), &[args]); let end_block = self.append_block(""); - for (i, arg_name) in arg_names.iter().enumerate() { + for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() { // Positional arguments let is_in_range = self.builder.build_int_compare( IntPredicate::ULT, @@ -2768,7 +2796,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { self.builder .build_conditional_branch(is_in_range, next_block, end_block); self.builder.position_at_end(next_block); - let arg_value = self.build_call( + let mut arg_value = self.build_call( &ApiFunc::kclvm_list_get_option.name(), &[ self.current_runtime_ctx_ptr(), @@ -2776,6 +2804,18 @@ impl<'ctx> LLVMCodeGenContext<'ctx> { self.native_int_value(i as i32), ], ); + if let Some(ty) = arg_type { + let type_annotation = self.native_global_string_value(&ty.node.to_string()); + arg_value = self.build_call( + &ApiFunc::kclvm_convert_collection_value.name(), + &[ + self.current_runtime_ctx_ptr(), + arg_value, + type_annotation, + self.bool_value(false), + ], + ); + } self.store_variable(&arg_name.names[0].node, arg_value); } // for loop in 0..argument_len in LLVM end diff --git a/kclvm/evaluator/src/func.rs b/kclvm/evaluator/src/func.rs index 978130765..f6e9ef65e 100644 --- a/kclvm/evaluator/src/func.rs +++ b/kclvm/evaluator/src/func.rs @@ -8,6 +8,7 @@ use kclvm_runtime::ValueRef; use scopeguard::defer; use crate::proxy::Proxy; +use crate::ty::type_pack_and_check; use crate::Evaluator; use crate::{error as kcl_error, EvalContext}; @@ -125,8 +126,11 @@ pub fn func_body( } // Evaluate arguments and keyword arguments and store values to local variables. s.walk_arguments(&ctx.node.args, args, kwargs); - let result = s + let mut result = s .walk_stmts(&ctx.node.body) .expect(kcl_error::RUNTIME_ERROR_MSG); + if let Some(ty) = &ctx.node.return_ty { + result = type_pack_and_check(s, &result, vec![&ty.node.to_string()], false); + } result } diff --git a/kclvm/evaluator/src/node.rs b/kclvm/evaluator/src/node.rs index 70ea3b52b..1e02094ec 100644 --- a/kclvm/evaluator/src/node.rs +++ b/kclvm/evaluator/src/node.rs @@ -1449,23 +1449,31 @@ impl<'ctx> Evaluator<'ctx> { kwargs: &ValueRef, ) { // Arguments names and defaults - let (arg_names, arg_defaults) = if let Some(args) = &arguments { + let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments { let names = &args.node.args; + let types = &args.node.ty_list; let defaults = &args.node.defaults; ( names.iter().map(|identifier| &identifier.node).collect(), + types.iter().collect(), defaults.iter().collect(), ) } else { - (vec![], vec![]) + (vec![], vec![], vec![]) }; // Default parameter values - for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) { - let arg_value = if let Some(value) = value { + for ((arg_name, arg_type), value) in + arg_names.iter().zip(&arg_types).zip(arg_defaults.iter()) + { + let mut arg_value = if let Some(value) = value { self.walk_expr(value).expect(kcl_error::RUNTIME_ERROR_MSG) } else { self.none_value() }; + if let Some(ty) = arg_type { + arg_value = + type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false); + } // Arguments are immutable, so we place them in different scopes. let name = arg_name.get_name(); self.store_argument_in_current_scope(&name); @@ -1477,14 +1485,18 @@ impl<'ctx> Evaluator<'ctx> { } // Positional arguments let argument_len = args.len(); - for (i, arg_name) in arg_names.iter().enumerate() { + for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() { // Positional arguments let is_in_range = i < argument_len; if is_in_range { - let arg_value = match args.list_get_option(i as isize) { + let mut arg_value = match args.list_get_option(i as isize) { Some(v) => v, None => self.undefined_value(), }; + if let Some(ty) = arg_type { + arg_value = + type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false); + } self.store_variable(&arg_name.names[0].node, arg_value); } else { break; diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k b/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k new file mode 100644 index 000000000..8a577f18a --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_2/main.k @@ -0,0 +1,11 @@ +schema ProviderFamily: + version: str + marketplace: bool = True + +providerFamily = lambda family: ProviderFamily -> ProviderFamily { + family +} + +v = providerFamily({ + version: "1.6.0" +}) diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden b/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden new file mode 100644 index 000000000..67d7078bb --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_2/stdout.golden @@ -0,0 +1,3 @@ +v: + version: '1.6.0' + marketplace: true diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k b/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k new file mode 100644 index 000000000..544f1cb23 --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_3/main.k @@ -0,0 +1,11 @@ +schema ProviderFamily: + version: str + marketplace: bool = True + +providerFamily = lambda -> ProviderFamily { + { + version: "1.6.0" + } +} + +v = providerFamily() diff --git a/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden b/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden new file mode 100644 index 000000000..67d7078bb --- /dev/null +++ b/test/grammar/schema/type_annotation/type_annotation_schema_3/stdout.golden @@ -0,0 +1,3 @@ +v: + version: '1.6.0' + marketplace: true