Skip to content

Commit

Permalink
Inferring closures types within higher order functions
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi authored and dean-starkware committed Jan 18, 2025
1 parent 739ecc2 commit 583014c
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 71 deletions.
2 changes: 1 addition & 1 deletion corelib/src/test/option_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ fn test_default_for_option() {
#[test]
fn test_option_some_map() {
let maybe_some_string: Option<ByteArray> = Option::Some("Hello, World!");
let maybe_some_len = maybe_some_string.map(|s: ByteArray| s.len());
let maybe_some_len = maybe_some_string.map(|s| s.len());
assert!(maybe_some_len == Option::Some(13));
}

Expand Down
8 changes: 4 additions & 4 deletions crates/cairo-lang-lowering/src/lower/test_data/for
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Statements:
(v16: core::array::Array::<core::felt252>, v17: @core::array::Array::<core::felt252>) <- snapshot(v15)
(v18: core::array::Span::<core::felt252>) <- core::array::ArrayToSpan::<core::felt252>::span(v17)
(v19: core::array::SpanIter::<core::felt252>) <- core::array::SpanIntoIterator::<core::felt252>::into_iter(v18)
(v21: core::array::SpanIter::<core::felt252>, v22: core::felt252, v20: ()) <- test::foo[expr30](v19, v0, v2)
(v21: core::array::SpanIter::<core::felt252>, v22: core::felt252, v20: ()) <- test::foo[expr38](v19, v0, v2)
End:
Return(v22)

Expand All @@ -69,7 +69,7 @@ Statements:
(v14: core::array::Array::<core::felt252>, v15: @core::array::Array::<core::felt252>) <- snapshot(v10)
(v16: core::array::Span::<core::felt252>) <- struct_construct(v15)
(v17: core::array::SpanIter::<core::felt252>) <- struct_construct(v16)
(v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr30](v0, v1, v17, v11, v13)
(v18: core::RangeCheck, v19: core::gas::GasBuiltin, v20: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr38](v0, v1, v17, v11, v13)
End:
Match(match_enum(v20) {
PanicResult::Ok(v21) => blk1,
Expand Down Expand Up @@ -111,7 +111,7 @@ End:
blk1:
Statements:
(v6: core::felt252) <- core::Felt252Add::add(v1, v2)
(v8: core::array::SpanIter::<core::felt252>, v9: core::felt252, v7: ()) <- test::foo[expr30](v4, v6, v2)
(v8: core::array::SpanIter::<core::felt252>, v9: core::felt252, v7: ()) <- test::foo[expr38](v4, v6, v2)
End:
Goto(blk3, {v9 -> v12, v8 -> v13, v7 -> v11})

Expand Down Expand Up @@ -195,7 +195,7 @@ End:
blk8:
Statements:
(v30: core::felt252) <- core::felt252_add(v3, v4)
(v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr30](v5, v6, v27, v30, v4)
(v31: core::RangeCheck, v32: core::gas::GasBuiltin, v33: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[expr38](v5, v6, v27, v30, v4)
End:
Return(v31, v32, v33)

Expand Down
20 changes: 10 additions & 10 deletions crates/cairo-lang-lowering/src/lower/test_data/loop
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ Statements:
(v1: test::T) <- struct_construct(v0)
(v2: test::S) <- struct_destructure(v1)
(v3: test::S, v4: @test::S) <- snapshot(v2)
(v5: ()) <- test::foo[expr7](v4)
(v5: ()) <- test::foo[expr8](v4)
(v6: ()) <- test::STT::f1oo(v4)
(v7: ()) <- struct_construct()
End:
Expand Down Expand Up @@ -1158,7 +1158,7 @@ Statements:
(v4: test::B, v5: @test::B) <- snapshot(v3)
(v6: core::integer::u32) <- struct_destructure(v4)
(v7: test::B) <- struct_construct(v6)
(v8: ()) <- test::foo[expr16](v6, v5)
(v8: ()) <- test::foo[expr18](v6, v5)
(v9: ()) <- struct_construct()
End:
Return(v9)
Expand Down Expand Up @@ -1214,7 +1214,7 @@ Parameters: v0: core::integer::u32, v1: @test::B
blk0 (root):
Statements:
() <- test::ex1(v1)
(v2: ()) <- test::foo[expr14](v0)
(v2: ()) <- test::foo[expr16](v0)
(v3: ()) <- struct_construct()
End:
Return(v3)
Expand Down Expand Up @@ -1282,7 +1282,7 @@ Statements:
(v1: test::B) <- struct_construct(v0)
(v2: test::A) <- struct_construct(v1)
(v3: test::B) <- struct_destructure(v2)
(v4: ()) <- test::foo[expr16](v3)
(v4: ()) <- test::foo[expr18](v3)
(v5: ()) <- struct_construct()
End:
Return(v5)
Expand Down Expand Up @@ -1339,7 +1339,7 @@ Statements:
() <- test::ex1(v0)
(v1: core::integer::u32) <- struct_destructure(v0)
(v2: core::integer::u32, v3: @core::integer::u32) <- snapshot(v1)
(v4: ()) <- test::foo[expr14](v3)
(v4: ()) <- test::foo[expr16](v3)
(v5: ()) <- struct_construct()
End:
Return(v5)
Expand Down Expand Up @@ -1399,7 +1399,7 @@ Statements:
(v0: core::felt252) <- 0
(v1: test::A) <- struct_construct(v0)
(v2: core::felt252) <- 0
(v4: test::A, v5: core::felt252, v3: ()) <- test::foo[expr19](v2, v1)
(v4: test::A, v5: core::felt252, v3: ()) <- test::foo[expr20](v2, v1)
(v6: ()) <- struct_construct()
End:
Return(v6)
Expand All @@ -1412,7 +1412,7 @@ Statements:
(v2: core::felt252) <- 0
(v3: core::felt252) <- 0
(v4: test::A) <- struct_construct(v2)
(v5: core::RangeCheck, v6: core::gas::GasBuiltin, v7: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v0, v1, v3, v4)
(v5: core::RangeCheck, v6: core::gas::GasBuiltin, v7: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr20](v0, v1, v3, v4)
End:
Match(match_enum(v7) {
PanicResult::Ok(v8) => blk1,
Expand Down Expand Up @@ -1462,7 +1462,7 @@ Statements:
() <- test::use_a(v12)
(v13: core::felt252) <- 1
(v15: core::felt252, v14: ()) <- core::ops::arith::DeprecatedAddAssign::<core::felt252, core::Felt252AddEq>::add_assign(v2, v13)
(v17: test::A, v18: core::felt252, v16: ()) <- test::foo[expr19](v15, v11)
(v17: test::A, v18: core::felt252, v16: ()) <- test::foo[expr20](v15, v11)
End:
Goto(blk3, {v17 -> v21, v18 -> v22, v16 -> v20})

Expand Down Expand Up @@ -1513,7 +1513,7 @@ Statements:
() <- test::use_a(v16)
(v17: core::felt252) <- 1
(v18: core::felt252) <- core::felt252_add(v2, v17)
(v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr19](v4, v5, v18, v15)
(v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(test::A, core::felt252, ())>) <- test::foo[expr20](v4, v5, v18, v15)
End:
Return(v19, v20, v21)

Expand Down Expand Up @@ -1584,7 +1584,7 @@ Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
blk0 (root):
Statements:
(v2: core::integer::u8) <- 0
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[expr9](v0, v1, v2)
(v3: core::RangeCheck, v4: core::gas::GasBuiltin, v5: core::panics::PanicResult::<(core::integer::u8, ())>) <- test::MyImpl::impl_in_trait[expr10](v0, v1, v2)
End:
Match(match_enum(v5) {
PanicResult::Ok(v6) => blk1,
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-lowering/src/test_data/for
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Statements:
(v11: core::array::Array::<core::felt252>, v12: @core::array::Array::<core::felt252>) <- snapshot(v10)
(v13: core::array::Span::<core::felt252>) <- struct_construct(v12)
(v14: core::array::SpanIter::<core::felt252>) <- struct_construct(v13)
(v15: core::RangeCheck, v16: core::gas::GasBuiltin, v17: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, ())>) <- test::foo[expr26](v0, v1, v14)
(v15: core::RangeCheck, v16: core::gas::GasBuiltin, v17: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, ())>) <- test::foo[expr34](v0, v1, v14)
End:
Match(match_enum(v17) {
PanicResult::Ok(v18) => blk1,
Expand Down
2 changes: 1 addition & 1 deletion crates/cairo-lang-lowering/src/test_data/strings
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ Statements:
(v8: core::byte_array::ByteArray, v9: @core::byte_array::ByteArray) <- snapshot(v7)
(v10: @core::array::Array::<core::bytes_31::bytes31>, v11: @core::felt252, v12: @core::integer::u32) <- struct_destructure(v9)
(v13: core::array::Span::<core::bytes_31::bytes31>) <- struct_construct(v10)
(v14: core::RangeCheck, v15: core::gas::GasBuiltin, v16: core::panics::PanicResult::<(core::array::Span::<core::bytes_31::bytes31>, core::array::Array::<core::bytes_31::bytes31>, ())>) <- core::array::ArrayTCloneImpl::clone[expr14](v0, v1, v13, v4)
(v14: core::RangeCheck, v15: core::gas::GasBuiltin, v16: core::panics::PanicResult::<(core::array::Span::<core::bytes_31::bytes31>, core::array::Array::<core::bytes_31::bytes31>, ())>) <- core::array::ArrayTCloneImpl::clone[expr15](v0, v1, v13, v4)
End:
Match(match_enum(v16) {
PanicResult::Ok(v17) => blk1,
Expand Down
18 changes: 17 additions & 1 deletion crates/cairo-lang-semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::diagnostic::SemanticDiagnosticKind;
use crate::expr::inference::{self, ImplVar, ImplVarId};
use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId};
use crate::items::function_with_body::FunctionBody;
use crate::items::functions::{ImplicitPrecedence, InlineConfiguration};
use crate::items::functions::{GenericFunctionId, ImplicitPrecedence, InlineConfiguration};
use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData};
use crate::items::imp::{
ImplId, ImplImplId, ImplLookupContext, ImplicitImplImplData, UninferredImpl,
Expand Down Expand Up @@ -1431,6 +1431,22 @@ pub trait SemanticGroup:
#[salsa::invoke(items::functions::concrete_function_signature)]
fn concrete_function_signature(&self, function_id: FunctionId) -> Maybe<semantic::Signature>;

/// Returns a mapping of closure types to their associated parameter types for a concrete
/// function.
#[salsa::invoke(items::functions::concrete_function_closure_params)]
fn concrete_function_closure_params(
&self,
function_id: FunctionId,
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>>;

/// Returns a mapping of closure types to their associated parameter types for a generic
/// function.
#[salsa::invoke(items::functions::get_closure_params)]
fn get_closure_params(
&self,
generic_function_id: GenericFunctionId,
) -> Maybe<OrderedHashMap<TypeId, TypeId>>;

// Generic type.
// =============
/// Returns the generic params of a generic type.
Expand Down
82 changes: 65 additions & 17 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use cairo_lang_filesystem::ids::{FileKind, FileLongId, VirtualFile};
use cairo_lang_proc_macros::DebugWithDb;
use cairo_lang_syntax::node::ast::{
BinaryOperator, BlockOrIf, ClosureParamWrapper, ExprPtr, OptionReturnTypeClause, PatternListOr,
PatternStructParam, UnaryOperator,
PatternStructParam, TerminalIdentifier, UnaryOperator,
};
use cairo_lang_syntax::node::db::SyntaxGroup;
use cairo_lang_syntax::node::helpers::{GetIdentifier, PathSegmentEx};
Expand Down Expand Up @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId};
use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr};
use crate::items::enm::SemanticEnumEx;
use crate::items::feature_kind::extract_item_feature_config;
use crate::items::functions::function_signature_params;
use crate::items::functions::{concrete_function_closure_params, function_signature_params};
use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self};
use crate::items::modifiers::compute_mutability;
use crate::items::us::get_use_path_segments;
Expand Down Expand Up @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic(
ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr),
ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr),
ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None),
}
}

Expand Down Expand Up @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic(
let mut arg_types = vec![];
for arg_syntax in args_iter {
let stable_ptr = arg_syntax.stable_ptr();
let arg = compute_named_argument_clause(ctx, arg_syntax);
let arg = compute_named_argument_clause(ctx, arg_syntax, None);
if arg.2 != Mutability::Immutable {
return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument));
}
Expand Down Expand Up @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic(
let named_args: Vec<_> = args_syntax
.elements(syntax_db)
.into_iter()
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax))
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None))
.collect();
if named_args.len() != 1 {
return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments {
Expand Down Expand Up @@ -979,16 +979,21 @@ fn compute_expr_function_call_semantic(
let mut args_iter = args_syntax.elements(syntax_db).into_iter();
// Normal parameters
let mut named_args = vec![];
for _ in function_parameter_types(ctx, function)? {
let closure_params = concrete_function_closure_params(db, function)?;
for ty in function_parameter_types(ctx, function)? {
let Some(arg_syntax) = args_iter.next() else {
continue;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).cloned(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into())
Expand All @@ -1006,6 +1011,7 @@ fn compute_expr_function_call_semantic(
pub fn compute_named_argument_clause(
ctx: &mut ComputationContext<'_>,
arg_syntax: ast::Arg,
closure_param_types: Option<TypeId>,
) -> NamedArg {
let syntax_db = ctx.db.upcast();

Expand All @@ -1017,11 +1023,16 @@ pub fn compute_named_argument_clause(

let arg_clause = arg_syntax.arg_clause(syntax_db);
let (expr, arg_name_identifier) = match arg_clause {
ast::ArgClause::Unnamed(arg_unnamed) => {
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
}
ast::ArgClause::Named(arg_named) => (
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
ast::ArgClause::Unnamed(arg_unnamed) => handle_possible_closure_expr(
ctx,
&arg_unnamed.value(syntax_db),
closure_param_types,
None,
),
ast::ArgClause::Named(arg_named) => handle_possible_closure_expr(
ctx,
&arg_named.value(syntax_db),
closure_param_types,
Some(arg_named.name(syntax_db)),
),
ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => {
Expand All @@ -1034,10 +1045,32 @@ pub fn compute_named_argument_clause(
(expr, Some(arg_name_identifier))
}
};

NamedArg(expr, arg_name_identifier, mutability)
}

/// Handles the semantic computation of a closure expression.
/// It processes a closure expression, computes its semantic model,
/// allocates it in the expression arena, and ensures that the closure's
/// parameter types are conformed if provided.
fn handle_possible_closure_expr(
ctx: &mut ComputationContext<'_>,
expr: &ast::Expr,
closure_param_types: Option<TypeId>,
arg_name: Option<TerminalIdentifier>,
) -> (ExprAndId, Option<TerminalIdentifier>) {
if let ast::Expr::Closure(expr_closure) = expr {
let expr = compute_expr_closure_semantic(ctx, expr_closure, closure_param_types);
let expr = wrap_maybe_with_missing(ctx, expr, expr_closure.stable_ptr().into());
let id = ctx.arenas.exprs.alloc(expr.clone());
(ExprAndId { expr, id }, arg_name)
} else {
let expr = compute_expr_semantic(ctx, expr);
let expr = wrap_maybe_with_missing(ctx, Ok(expr.expr.clone()), expr.stable_ptr());
let id = ctx.arenas.exprs.alloc(expr.clone());
(ExprAndId { expr, id }, arg_name)
}
}

pub fn compute_root_expr(
ctx: &mut ComputationContext<'_>,
syntax: &ast::ExprBlock,
Expand Down Expand Up @@ -1645,6 +1678,7 @@ fn compute_loop_body_semantic(
fn compute_expr_closure_semantic(
ctx: &mut ComputationContext<'_>,
syntax: &ast::ExprClosure,
param_types: Option<TypeId>,
) -> Maybe<Expr> {
ctx.are_closures_in_context = true;
let syntax_db = ctx.db.upcast();
Expand All @@ -1663,6 +1697,14 @@ fn compute_expr_closure_semantic(
} else {
vec![]
};
let closure_type =
TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db);
if let Some(param_types) = param_types {
if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types)
{
new_ctx.resolver.inference().consume_error_without_reporting(err_set);
}
}

params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| {
new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam);
Expand Down Expand Up @@ -2834,16 +2876,22 @@ fn method_call_expr(
// Self argument.
let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)];
// Other arguments.
for _ in function_parameter_types(ctx, function_id)?.skip(1) {
let closure_params: OrderedHashMap<TypeId, TypeId> =
concrete_function_closure_params(ctx.db, function_id)?;
for ty in function_parameter_types(ctx, function_id)?.skip(1) {
let Some(arg_syntax) = args_iter.next() else {
break;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).cloned(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function_id, named_args, &expr, stable_ptr)
Expand Down
20 changes: 20 additions & 0 deletions crates/cairo-lang-semantic/src/expr/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,23 @@ error: Closure parameters cannot be references
--> lib.cairo:2:14
let _ = |ref a| {
^^^^^

//! > ==========================================================================

//! > Passing closures as args with less explicit typing.

//! > test_runner_name
test_function_diagnostics(expect_diagnostics: false)

//! > function
fn foo() -> Option<u32> {
let x: Option<Array<i32>> = Option::Some(array![1, 2, 3]);
x.map(|x| x.len())
}

//! > function_name
foo

//! > module_code

//! > expected_diagnostics
Loading

0 comments on commit 583014c

Please sign in to comment.