Skip to content

Commit

Permalink
feat: enhance runtime type cast and check for lambda arguments and re…
Browse files Browse the repository at this point in the history
…turn values

Signed-off-by: peefy <[email protected]>
  • Loading branch information
Peefy committed Jul 31, 2024
1 parent bbac702 commit 3e81acc
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 14 deletions.
54 changes: 47 additions & 7 deletions kclvm/compiler/src/codegen/llvm/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2179,9 +2179,21 @@ impl<'ctx> TypedResultWalker<'ctx> for LLVMCodeGenContext<'ctx> {
}
}
self.walk_arguments(&lambda_expr.args, args, kwargs);
let val = self
let mut val = self
.walk_stmts(&lambda_expr.body)
.expect(kcl_error::COMPILE_ERROR_MSG);
if let Some(ty) = &lambda_expr.return_ty {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
val = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
val,
type_annotation,
self.bool_value(false),
],
);
}
self.builder.build_return(Some(&val));
// Exist the function
self.builder.position_at_end(func_before_block);
Expand Down Expand Up @@ -2731,23 +2743,39 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
kwargs: BasicValueEnum<'ctx>,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::COMPILE_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
// Arguments are immutable, so we place them in different scopes.
self.store_argument_in_current_scope(&arg_name.get_name());
self.walk_identifier_with_ctx(arg_name, &ast::ExprContext::Store, Some(arg_value))
Expand All @@ -2756,7 +2784,7 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
// for loop in 0..argument_len in LLVM begin
let argument_len = self.build_call(&ApiFunc::kclvm_list_len.name(), &[args]);
let end_block = self.append_block("");
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = self.builder.build_int_compare(
IntPredicate::ULT,
Expand All @@ -2768,14 +2796,26 @@ impl<'ctx> LLVMCodeGenContext<'ctx> {
self.builder
.build_conditional_branch(is_in_range, next_block, end_block);
self.builder.position_at_end(next_block);
let arg_value = self.build_call(
let mut arg_value = self.build_call(
&ApiFunc::kclvm_list_get_option.name(),
&[
self.current_runtime_ctx_ptr(),
args,
self.native_int_value(i as i32),
],
);
if let Some(ty) = arg_type {
let type_annotation = self.native_global_string_value(&ty.node.to_string());
arg_value = self.build_call(
&ApiFunc::kclvm_convert_collection_value.name(),
&[
self.current_runtime_ctx_ptr(),
arg_value,
type_annotation,
self.bool_value(false),
],
);
}
self.store_variable(&arg_name.names[0].node, arg_value);
}
// for loop in 0..argument_len in LLVM end
Expand Down
6 changes: 5 additions & 1 deletion kclvm/evaluator/src/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use kclvm_runtime::ValueRef;
use scopeguard::defer;

use crate::proxy::Proxy;
use crate::ty::type_pack_and_check;
use crate::Evaluator;
use crate::{error as kcl_error, EvalContext};

Expand Down Expand Up @@ -125,8 +126,11 @@ pub fn func_body(
}
// Evaluate arguments and keyword arguments and store values to local variables.
s.walk_arguments(&ctx.node.args, args, kwargs);
let result = s
let mut result = s
.walk_stmts(&ctx.node.body)
.expect(kcl_error::RUNTIME_ERROR_MSG);
if let Some(ty) = &ctx.node.return_ty {
result = type_pack_and_check(s, &result, vec![&ty.node.to_string()], false);
}
result
}
24 changes: 18 additions & 6 deletions kclvm/evaluator/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1449,23 +1449,31 @@ impl<'ctx> Evaluator<'ctx> {
kwargs: &ValueRef,
) {
// Arguments names and defaults
let (arg_names, arg_defaults) = if let Some(args) = &arguments {
let (arg_names, arg_types, arg_defaults) = if let Some(args) = &arguments {
let names = &args.node.args;
let types = &args.node.ty_list;
let defaults = &args.node.defaults;
(
names.iter().map(|identifier| &identifier.node).collect(),
types.iter().collect(),
defaults.iter().collect(),
)
} else {
(vec![], vec![])
(vec![], vec![], vec![])
};
// Default parameter values
for (arg_name, value) in arg_names.iter().zip(arg_defaults.iter()) {
let arg_value = if let Some(value) = value {
for ((arg_name, arg_type), value) in
arg_names.iter().zip(&arg_types).zip(arg_defaults.iter())
{
let mut arg_value = if let Some(value) = value {
self.walk_expr(value).expect(kcl_error::RUNTIME_ERROR_MSG)
} else {
self.none_value()
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
// Arguments are immutable, so we place them in different scopes.
let name = arg_name.get_name();
self.store_argument_in_current_scope(&name);
Expand All @@ -1477,14 +1485,18 @@ impl<'ctx> Evaluator<'ctx> {
}
// Positional arguments
let argument_len = args.len();
for (i, arg_name) in arg_names.iter().enumerate() {
for (i, (arg_name, arg_type)) in arg_names.iter().zip(arg_types).enumerate() {
// Positional arguments
let is_in_range = i < argument_len;
if is_in_range {
let arg_value = match args.list_get_option(i as isize) {
let mut arg_value = match args.list_get_option(i as isize) {
Some(v) => v,
None => self.undefined_value(),
};
if let Some(ty) = arg_type {
arg_value =
type_pack_and_check(self, &arg_value, vec![&ty.node.to_string()], false);
}
self.store_variable(&arg_name.names[0].node, arg_value);
} else {
break;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda family: ProviderFamily -> ProviderFamily {
family
}

v = providerFamily({
version: "1.6.0"
})
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
schema ProviderFamily:
version: str
marketplace: bool = True

providerFamily = lambda -> ProviderFamily {
{
version: "1.6.0"
}
}

v = providerFamily()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
v:
version: '1.6.0'
marketplace: true

0 comments on commit 3e81acc

Please sign in to comment.