Skip to content

Commit

Permalink
Inferring closures types within higher order functions
Browse files Browse the repository at this point in the history
fixed diagnostic path bug

inferring closures types within higher order functions
  • Loading branch information
orizi authored and dean-starkware committed Jan 19, 2025
1 parent 9351827 commit 441054e
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 23 deletions.
2 changes: 1 addition & 1 deletion corelib/src/test/option_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ fn test_default_for_option() {
#[test]
fn test_option_some_map() {
let maybe_some_string: Option<ByteArray> = Option::Some("Hello, World!");
let maybe_some_len = maybe_some_string.map(|s: ByteArray| s.len());
let maybe_some_len = maybe_some_string.map(|s| s.len());
assert!(maybe_some_len == Option::Some(13));
}

Expand Down
18 changes: 17 additions & 1 deletion crates/cairo-lang-semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use crate::diagnostic::SemanticDiagnosticKind;
use crate::expr::inference::{self, ImplVar, ImplVarId};
use crate::items::constant::{ConstCalcInfo, ConstValueId, Constant, ImplConstantId};
use crate::items::function_with_body::FunctionBody;
use crate::items::functions::{ImplicitPrecedence, InlineConfiguration};
use crate::items::functions::{GenericFunctionId, ImplicitPrecedence, InlineConfiguration};
use crate::items::generics::{GenericParam, GenericParamData, GenericParamsData};
use crate::items::imp::{
ImplId, ImplImplId, ImplLookupContext, ImplicitImplImplData, UninferredImpl,
Expand Down Expand Up @@ -1431,6 +1431,22 @@ pub trait SemanticGroup:
#[salsa::invoke(items::functions::concrete_function_signature)]
fn concrete_function_signature(&self, function_id: FunctionId) -> Maybe<semantic::Signature>;

/// Returns a mapping of closure types to their associated parameter types for a concrete
/// function.
#[salsa::invoke(items::functions::concrete_function_closure_params)]
fn concrete_function_closure_params(
&self,
function_id: FunctionId,
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>>;

/// Returns a mapping of closure types to their associated parameter types for a generic
/// function.
#[salsa::invoke(items::functions::get_closure_params)]
fn get_closure_params(
&self,
generic_function_id: GenericFunctionId,
) -> Maybe<OrderedHashMap<TypeId, TypeId>>;

// Generic type.
// =============
/// Returns the generic params of a generic type.
Expand Down
74 changes: 59 additions & 15 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ use crate::expr::inference::{ImplVarTraitItemMappings, InferenceId};
use crate::items::constant::{ConstValue, resolve_const_expr_and_evaluate, validate_const_expr};
use crate::items::enm::SemanticEnumEx;
use crate::items::feature_kind::extract_item_feature_config;
use crate::items::functions::function_signature_params;
use crate::items::functions::{concrete_function_closure_params, function_signature_params};
use crate::items::imp::{ImplLookupContext, filter_candidate_traits, infer_impl_by_self};
use crate::items::modifiers::compute_mutability;
use crate::items::us::get_use_path_segments;
Expand Down Expand Up @@ -424,7 +424,7 @@ pub fn maybe_compute_expr_semantic(
ast::Expr::Indexed(expr) => compute_expr_indexed_semantic(ctx, expr),
ast::Expr::FixedSizeArray(expr) => compute_expr_fixed_size_array_semantic(ctx, expr),
ast::Expr::For(expr) => compute_expr_for_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr),
ast::Expr::Closure(expr) => compute_expr_closure_semantic(ctx, expr, None),
}
}

Expand Down Expand Up @@ -882,7 +882,7 @@ fn compute_expr_function_call_semantic(
let mut arg_types = vec![];
for arg_syntax in args_iter {
let stable_ptr = arg_syntax.stable_ptr();
let arg = compute_named_argument_clause(ctx, arg_syntax);
let arg = compute_named_argument_clause(ctx, arg_syntax, None);
if arg.2 != Mutability::Immutable {
return Err(ctx.diagnostics.report(stable_ptr, RefClosureArgument));
}
Expand Down Expand Up @@ -930,7 +930,7 @@ fn compute_expr_function_call_semantic(
let named_args: Vec<_> = args_syntax
.elements(syntax_db)
.into_iter()
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax))
.map(|arg_syntax| compute_named_argument_clause(ctx, arg_syntax, None))
.collect();
if named_args.len() != 1 {
return Err(ctx.diagnostics.report(syntax, WrongNumberOfArguments {
Expand Down Expand Up @@ -979,16 +979,21 @@ fn compute_expr_function_call_semantic(
let mut args_iter = args_syntax.elements(syntax_db).into_iter();
// Normal parameters
let mut named_args = vec![];
for _ in function_parameter_types(ctx, function)? {
let closure_params = concrete_function_closure_params(db, function)?;
for ty in function_parameter_types(ctx, function)? {
let Some(arg_syntax) = args_iter.next() else {
continue;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).copied(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function, named_args, syntax, syntax.stable_ptr().into())
Expand All @@ -1006,6 +1011,7 @@ fn compute_expr_function_call_semantic(
pub fn compute_named_argument_clause(
ctx: &mut ComputationContext<'_>,
arg_syntax: ast::Arg,
closure_params_tuple_ty: Option<TypeId>,
) -> NamedArg {
let syntax_db = ctx.db.upcast();

Expand All @@ -1017,11 +1023,16 @@ pub fn compute_named_argument_clause(

let arg_clause = arg_syntax.arg_clause(syntax_db);
let (expr, arg_name_identifier) = match arg_clause {
ast::ArgClause::Unnamed(arg_unnamed) => {
(compute_expr_semantic(ctx, &arg_unnamed.value(syntax_db)), None)
}
ast::ArgClause::Unnamed(arg_unnamed) => (
handle_possible_closure_expr(
ctx,
&arg_unnamed.value(syntax_db),
closure_params_tuple_ty,
),
None,
),
ast::ArgClause::Named(arg_named) => (
compute_expr_semantic(ctx, &arg_named.value(syntax_db)),
handle_possible_closure_expr(ctx, &arg_named.value(syntax_db), closure_params_tuple_ty),
Some(arg_named.name(syntax_db)),
),
ast::ArgClause::FieldInitShorthand(arg_field_init_shorthand) => {
Expand All @@ -1034,10 +1045,28 @@ pub fn compute_named_argument_clause(
(expr, Some(arg_name_identifier))
}
};

NamedArg(expr, arg_name_identifier, mutability)
}

/// Handles the semantic computation of a closure expression.
/// It processes a closure expression, computes its semantic model,
/// allocates it in the expression arena, and ensures that the closure's
/// parameter types are conformed if provided.
fn handle_possible_closure_expr(
ctx: &mut ComputationContext<'_>,
expr: &ast::Expr,
closure_param_types: Option<TypeId>,
) -> ExprAndId {
if let ast::Expr::Closure(expr_closure) = expr {
let expr = compute_expr_closure_semantic(ctx, expr_closure, closure_param_types);
let expr = wrap_maybe_with_missing(ctx, expr, expr_closure.stable_ptr().into());
let id = ctx.arenas.exprs.alloc(expr.clone());
ExprAndId { expr, id }
} else {
compute_expr_semantic(ctx, expr)
}
}

pub fn compute_root_expr(
ctx: &mut ComputationContext<'_>,
syntax: &ast::ExprBlock,
Expand Down Expand Up @@ -1645,6 +1674,7 @@ fn compute_loop_body_semantic(
fn compute_expr_closure_semantic(
ctx: &mut ComputationContext<'_>,
syntax: &ast::ExprClosure,
params_tuple_ty: Option<TypeId>,
) -> Maybe<Expr> {
ctx.are_closures_in_context = true;
let syntax_db = ctx.db.upcast();
Expand All @@ -1663,6 +1693,14 @@ fn compute_expr_closure_semantic(
} else {
vec![]
};
let closure_type =
TypeLongId::Tuple(params.iter().map(|param| param.ty).collect()).intern(new_ctx.db);
if let Some(param_types) = params_tuple_ty {
if let Err(err_set) = new_ctx.resolver.inference().conform_ty(closure_type, param_types)
{
new_ctx.resolver.inference().consume_error_without_reporting(err_set);
}
}

params.iter().filter(|param| param.mutability == Mutability::Reference).for_each(|param| {
new_ctx.diagnostics.report(param.stable_ptr(ctx.db.upcast()), RefClosureParam);
Expand Down Expand Up @@ -2834,16 +2872,22 @@ fn method_call_expr(
// Self argument.
let mut named_args = vec![NamedArg(fixed_lexpr, None, mutability)];
// Other arguments.
for _ in function_parameter_types(ctx, function_id)?.skip(1) {
let closure_params: OrderedHashMap<TypeId, TypeId> =
concrete_function_closure_params(ctx.db, function_id)?;
for ty in function_parameter_types(ctx, function_id)?.skip(1) {
let Some(arg_syntax) = args_iter.next() else {
break;
};
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(
ctx,
arg_syntax,
closure_params.get(&ty).copied(),
));
}

// Maybe coupon
if let Some(arg_syntax) = args_iter.next() {
named_args.push(compute_named_argument_clause(ctx, arg_syntax));
named_args.push(compute_named_argument_clause(ctx, arg_syntax, None));
}

expr_function_call(ctx, function_id, named_args, &expr, stable_ptr)
Expand Down
20 changes: 20 additions & 0 deletions crates/cairo-lang-semantic/src/expr/test_data/closure
Original file line number Diff line number Diff line change
Expand Up @@ -768,3 +768,23 @@ error: Closure parameters cannot be references
--> lib.cairo:2:14
let _ = |ref a| {
^^^^^

//! > ==========================================================================

//! > Passing closures as args with less explicit typing.

//! > test_runner_name
test_function_diagnostics(expect_diagnostics: false)

//! > function
fn foo() -> Option<u32> {
let x: Option<Array<i32>> = Option::Some(array![1, 2, 3]);
x.map(|x| x.len())
}

//! > function_name
foo

//! > module_code

//! > expected_diagnostics
61 changes: 55 additions & 6 deletions crates/cairo-lang-semantic/src/items/functions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::fmt::Debug;
use std::sync::Arc;

use cairo_lang_debug::DebugWithDb;
Expand All @@ -14,6 +13,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax as syntax;
use cairo_lang_syntax::attribute::structured::Attribute;
use cairo_lang_syntax::node::{Terminal, TypedSyntaxNode, ast};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{
Intern, LookupIntern, OptionFrom, define_short_id, require, try_extract_matches,
};
Expand All @@ -27,16 +27,16 @@ use super::generics::{fmt_generic_args, generic_params_to_args};
use super::imp::{ImplId, ImplLongId};
use super::modifiers;
use super::trt::ConcreteTraitGenericFunctionId;
use crate::corelib::{panic_destruct_trait_fn, unit_ty};
use crate::corelib::{fn_traits, panic_destruct_trait_fn, unit_ty};
use crate::db::SemanticGroup;
use crate::diagnostic::{SemanticDiagnosticKind, SemanticDiagnostics, SemanticDiagnosticsBuilder};
use crate::expr::compute::Environment;
use crate::resolve::{Resolver, ResolverData};
use crate::substitution::{GenericSubstitution, SemanticRewriter, SubstitutionRewriter};
use crate::types::resolve_type;
use crate::{
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericParam, SemanticDiagnostic,
TypeId, semantic, semantic_object_for_id,
ConcreteImplId, ConcreteImplLongId, ConcreteTraitLongId, GenericArgumentId, GenericParam,
SemanticDiagnostic, TypeId, semantic, semantic_object_for_id,
};

/// A generic function of an impl.
Expand Down Expand Up @@ -146,8 +146,11 @@ impl GenericFunctionId {
GenericFunctionId::Extern(id) => db.extern_function_declaration_generic_params(id),
GenericFunctionId::Impl(id) => {
let concrete_trait_id = db.impl_concrete_trait(id.impl_id)?;
let id = ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
db.concrete_trait_function_generic_params(id)
let concrete_id =
ConcreteTraitGenericFunctionId::new(db, concrete_trait_id, id.function);
let substitution = GenericSubstitution::from_impl(id.impl_id);
let mut rewriter = SubstitutionRewriter { db, substitution: &substitution };
rewriter.rewrite(db.concrete_trait_function_generic_params(concrete_id)?)
}
GenericFunctionId::Trait(id) => db.concrete_trait_function_generic_params(id),
}
Expand Down Expand Up @@ -860,6 +863,19 @@ pub fn concrete_function_signature(
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_signature)
}

/// Query implementation of [crate::db::SemanticGroup::concrete_function_closure_params].
pub fn concrete_function_closure_params(
db: &dyn SemanticGroup,
function_id: FunctionId,
) -> Maybe<OrderedHashMap<semantic::TypeId, semantic::TypeId>> {
let ConcreteFunction { generic_function, generic_args, .. } =
function_id.lookup_intern(db).function;
let generic_params = generic_function.generic_params(db)?;
let generic_closure_params = db.get_closure_params(generic_function)?;
let substitution = GenericSubstitution::new(&generic_params, &generic_args);
SubstitutionRewriter { db, substitution: &substitution }.rewrite(generic_closure_params)
}

/// For a given list of AST parameters, returns the list of semantic parameters along with the
/// corresponding environment.
fn update_env_with_ast_params(
Expand Down Expand Up @@ -1010,3 +1026,36 @@ impl FromIterator<TypeId> for ImplicitPrecedence {
Self(Vec::from_iter(iter))
}
}

/// Query implementation of [crate::db::SemanticGroup::get_closure_params].
pub fn get_closure_params(
db: &dyn SemanticGroup,
generic_function_id: GenericFunctionId,
) -> Maybe<OrderedHashMap<TypeId, TypeId>> {
let mut closure_params_map = OrderedHashMap::default();
let generic_params = generic_function_id.generic_params(db)?;

for param in generic_params {
if let GenericParam::Impl(generic_param_impl) = param {
let trait_id = generic_param_impl.concrete_trait?.trait_id(db);

if fn_traits(db).contains(&trait_id) {
if let Ok(concrete_trait) = generic_param_impl.concrete_trait {
let [
GenericArgumentId::Type(closure_type),
GenericArgumentId::Type(params_type),
] = *concrete_trait.generic_args(db)
else {
unreachable!(
"Fn trait must have exactly two generic arguments: closure type and \
parameter type."
)
};

closure_params_map.insert(closure_type, params_type);
}
}
}
}
Ok(closure_params_map)
}

0 comments on commit 441054e

Please sign in to comment.