Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TC: Iterate typechecking function arguments #753

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2039,6 +2039,10 @@ type ('a, 'b) pattern_functions = {
get_loc_typed : 'b -> l;
}

type ('a, 'b, 'c) function_arg_result = Arg_ok of 'a | Arg_error of 'b | Arg_defer of 'c

let is_arg_defer = function Arg_defer _ -> true | _ -> false

type ('a, 'b) vector_concat_elem = VC_elem_ok of 'a | VC_elem_error of 'b * exn | VC_elem_unknown of 'a

let unwrap_vector_concat_elem ~at:l = function
Expand Down Expand Up @@ -2157,6 +2161,8 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
(E_let (LB_aux (LB_val (tpat, inferred_bind), (let_loc, empty_tannot)), crule check_exp inner_env exp typ))
(check_shadow_leaks l inner_env env typ)
end
| E_vector_append (v1, E_aux (E_vector [], _)), _ -> check_exp env v1 typ
| E_vector_append (v1, v2), _ -> check_exp env (E_aux (E_app (mk_id "append", [v1; v2]), (l, uannot))) typ
| E_app_infix (x, op, y), _ -> check_exp env (E_aux (E_app (deinfix op, [x; y]), (l, uannot))) typ
| E_app (f, [E_aux (E_constraint nc, _)]), _ when string_of_id f = "_prove" ->
Env.wf_constraint ~at:l env nc;
Expand Down Expand Up @@ -2213,7 +2219,7 @@ let rec check_exp env (E_aux (exp_aux, (l, uannot)) as exp : uannot exp) (Typ_au
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
typ_print (lazy ("Check overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try crule check_exp env (E_aux (E_app (f, xs), (l, add_overload_attribute l orig_f uannot))) typ
with Type_error (err_l, err) ->
typ_debug (lazy "Error");
Expand Down Expand Up @@ -3532,7 +3538,7 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
let rec try_overload = function
| errs, [] -> typ_raise l (Err_no_overloading (orig_f, errs))
| errs, f :: fs -> begin
typ_print (lazy ("Overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
typ_print (lazy ("Infer overload: " ^ string_of_id f ^ "(" ^ string_of_list ", " string_of_exp xs ^ ")"));
try irule infer_exp env (E_aux (E_app (f, xs), (l, add_overload_attribute l orig_f uannot)))
with Type_error (err_l, err) ->
typ_debug (lazy "Error");
Expand Down Expand Up @@ -3839,7 +3845,15 @@ and infer_funapp' l env f (typq, f_typ) xs uannot expected_ret_typ =
else ([], List.map implicit_to_int typ_args, xs)
in

let typ_args =
typ_debug
( lazy
(Option.fold ~none:"No expected return"
~some:(fun typ -> Printf.sprintf "Expected return %s" (string_of_typ typ))
expected_ret_typ
)
);

let instantiate_return_type typ_args =
match expected_ret_typ with
| None -> typ_args
| Some expect when is_exist (Env.expand_synonyms env expect) -> typ_args
Expand All @@ -3863,14 +3877,16 @@ and infer_funapp' l env f (typq, f_typ) xs uannot expected_ret_typ =
)
in

let typ_args = instantiate_return_type typ_args in

(* We now iterate throught the function arguments, checking them and
instantiating quantifiers. *)
let instantiate env arg typ remaining_typs =
if KidSet.for_all (is_bound env) (tyvars_of_typ typ) then (
try
let checked_exp = crule check_exp env arg typ in
Ok (checked_exp, remaining_typs, env)
with Type_error (l, err) -> Error (l, 0, Err_function_arg (exp_loc arg, typ, err))
Arg_ok (checked_exp, remaining_typs, env)
with Type_error (l, err) -> Arg_error (l, 0, Err_function_arg (exp_loc arg, typ, err))
)
else (
let goals = quant_kopts (mk_typquant !quants) |> List.map kopt_kid |> KidSet.of_list in
Expand All @@ -3879,8 +3895,8 @@ and infer_funapp' l env f (typq, f_typ) xs uannot expected_ret_typ =
as it provides a heuristic for how likely any error is in a
function overloading *)
match can_unify_with env goals (irule infer_exp env arg) typ with
| exception Unification_error (l, m) -> Error (l, 1, Err_function_arg (exp_loc arg, typ, Err_other m))
| exception Type_error (l, err) -> Error (l, 0, Err_function_arg (exp_loc arg, typ, err))
| exception Unification_error (l, m) -> Arg_defer (l, 1, Err_function_arg (exp_loc arg, typ, Err_other m))
| exception Type_error (l, err) -> Arg_defer (l, 0, Err_function_arg (exp_loc arg, typ, err))
| inferred_arg, unifiers, env ->
record_unifiers unifiers;
let unifiers = KBindings.bindings unifiers in
Expand All @@ -3892,27 +3908,50 @@ and infer_funapp' l env f (typq, f_typ) xs uannot expected_ret_typ =
);
List.iter (fun unifier -> quants := instantiate_quants !quants unifier) unifiers;
List.iter (fun (v, arg) -> typ_ret := typ_subst v arg !typ_ret) unifiers;
let remaining_typs = instantiate_return_type remaining_typs in
let remaining_typs =
List.map (fun typ -> List.fold_left (fun typ (v, arg) -> typ_subst v arg typ) typ unifiers) remaining_typs
in
Ok (inferred_arg, remaining_typs, env)
Arg_ok (inferred_arg, remaining_typs, env)
)
in
let fold_instantiate (xs, args, env) x =
match args with
| arg :: remaining_args -> (
match instantiate env x arg remaining_args with
| Ok (x, remaining_args, env) -> (Ok x :: xs, remaining_args, env)
| Error (l, h, m) -> (Error (l, h, m) :: xs, remaining_args, env)
)
| [] -> raise (Reporting.err_unreachable l __POS__ "Empty arguments during instantiation")

(* We don't know the best order to check function arguments in order to instantiate the quantifiers, so we
iterate until we reach a fixpoint *)
let rec do_instantiation ~previously_deferred env xs typ_args =
let fold_instantiate (xs, typs, env, deferred) (n, x) =
match typs with
| typ :: remaining_typs -> (
match instantiate env x typ remaining_typs with
| Arg_ok (x, remaining_typs, env) -> ((n, Arg_ok x) :: xs, remaining_typs, env, deferred)
| Arg_defer (l, h, m) ->
typ_debug (lazy (Printf.sprintf "Deferring %s : %s" (string_of_exp x) (string_of_typ typ)));
((n, Arg_defer (l, h, m)) :: xs, remaining_typs @ [typ], env, deferred @ [(n, x)])
| Arg_error (l, h, m) -> ((n, Arg_error (l, h, m)) :: xs, remaining_typs, env, deferred)
)
| [] -> raise (Reporting.err_unreachable l __POS__ "Empty arguments during instantiation")
in
let xs, typ_args, env, deferred = List.fold_left fold_instantiate ([], typ_args, env, []) xs in
let num_deferred = List.length deferred in
typ_debug (lazy (Printf.sprintf "Have %d deferred arguments" num_deferred));
if num_deferred = previously_deferred then (xs, env)
else (
let ys, env = do_instantiation ~previously_deferred:num_deferred env deferred typ_args in
(List.filter (fun (_, result) -> not (is_arg_defer result)) xs @ ys, env)
)
in
let xs, _, env = List.fold_left fold_instantiate ([], typ_args, env) xs in
let xs, env = do_instantiation ~previously_deferred:0 env (List.mapi (fun n x -> (n, x)) xs) typ_args in
let xs = List.fast_sort (fun (n, _) (m, _) -> Int.compare m n) xs |> List.map snd in
let xs, instantiate_errors =
List.fold_left
(fun (acc, errs) x -> match x with Ok x -> (x :: acc, errs) | Error (l, h, m) -> (acc, (l, h, m) :: errs))
(fun (acc, errs) x ->
match x with
| Arg_ok x -> (x :: acc, errs)
| Arg_defer (l, h, m) | Arg_error (l, h, m) -> (acc, (l, h, m) :: errs)
)
([], []) xs
in
typ_debug (lazy (Printf.sprintf "Have %d instantiation errors" (List.length instantiate_errors)));
begin
match instantiate_errors with
| [] -> ()
Expand Down
4 changes: 2 additions & 2 deletions test/typecheck/fail/and_let_bool.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
fail/and_let_bool.sail:6.11-42:
6 | and_bool(let y : bool = x in not_bool(y), x)
 | ^-----------------------------^ checking function argument has type bool('p)
 | The type variable 'ex16# would leak into an outer scope.
 | The type variable 'ex18# would leak into an outer scope.
 |
 | Try adding a type annotation to this expression.
 |
 | Caused by fail/and_let_bool.sail:6.15-16:
 | 6 | and_bool(let y : bool = x in not_bool(y), x)
 |  | ^
 |  | Type variable 'ex16# was introduced here
 |  | Type variable 'ex18# was introduced here
2 changes: 1 addition & 1 deletion test/typecheck/fail/tuple_lexp1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
fail/tuple_lexp1.sail:10.11-20:
10 | (x, y) = (2, 3, 4)
 | ^-------^
 | Type mismatch between (int('ex180#), int('ex181#)) and (int(2), int(3), int(4))
 | Type mismatch between (int('ex182#), int('ex183#)) and (int(2), int(3), int(4))
 |
 | Caused by fail/tuple_lexp1.sail:10.2-8:
 | 10 | (x, y) = (2, 3, 4)
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/fail/tuple_lexp2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
fail/tuple_lexp2.sail:10.11-12:
10 | (x, y) = 2
 | ^
 | Type mismatch between (int('ex180#), int('ex181#)) and int(2)
 | Type mismatch between (int('ex182#), int('ex183#)) and int(2)
 |
 | Caused by fail/tuple_lexp2.sail:10.2-8:
 | 10 | (x, y) = 2
Expand Down
2 changes: 1 addition & 1 deletion test/typecheck/pass/Replicate/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ Explicit effect annotations are deprecated. They are no longer used and can be r
 | ^------------------------^
 | Could not resolve quantifiers for replicate_bits
 | * 'M >= 0
 | * 'ex176# >= 0
 | * 'ex178# >= 0
6 changes: 3 additions & 3 deletions test/typecheck/pass/bool_constraint/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ All external bindings should be marked as either pure or impure
12 | if b then n else 4
 | ^
 | int(4) is not a subtype of {('m : Int), (('b & 'm == 'n) | (not('b) & 'm == 3)). int('m)}
 | as (('b & 'ex171 == 'n) | (not('b) & 'ex171 == 3)) could not be proven
 | as (('b & 'ex173 == 'n) | (not('b) & 'ex173 == 3)) could not be proven
 |
 | type variable 'ex171:
 | type variable 'ex173:
 | pass/bool_constraint/v1.sail:9.25-73:
 | 9 | (bool('b), int('n)) -> {'m, 'b & 'm == 'n | not('b) & 'm == 3. int('m)}
 |  | ^----------------------------------------------^ derived from here
 | pass/bool_constraint/v1.sail:12.19-20:
 | 12 | if b then n else 4
 |  | ^ bound here
 |  | has constraint: 4 == 'ex171
 |  | has constraint: 4 == 'ex173
13 changes: 13 additions & 0 deletions test/typecheck/pass/bv_concat_implicit.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
default Order dec

$include <prelude.sail>

val zeros : forall 'n. implicit('n) -> bits('n)

val main : unit -> unit

function main() = {
let x : bits(5) = zeros() @ 0b1;
let y : bits(5) = 0b1 @ zeros();
let z : bits(5) = 0b1 @ zeros() @ 0b00;
}
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ Old set syntax, {|1, 2, 3|} can now be written as {1, 2, 3}.
26 | Some(Ctor1(a, x, c))
 | ^------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor1
 | * 'ex247# in {32, 64}
 | * 'ex249# in {32, 64}
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v1.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v1.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 | ^---------------^
 | (int(33), int('ex212)) is not a subtype of (int('ex207), int('ex208))
 | (int(33), int('ex214)) is not a subtype of (int('ex209), int('ex210))
 | as false could not be proven
 |
 | type variable 'ex207:
 | type variable 'ex209:
 | pass/existential_ast3/v1.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex208:
 | type variable 'ex210:
 | pass/existential_ast3/v1.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex212:
 | type variable 'ex214:
 | pass/existential_ast3/v1.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (33, unsigned(a));
 |  | ^---------------^ bound here
8 changes: 4 additions & 4 deletions test/typecheck/pass/existential_ast3/v2.expect
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,26 @@
pass/existential_ast3/v2.sail:17.48-65:
17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 | ^---------------^
 | (int(31), int('ex212)) is not a subtype of (int('ex207), int('ex208))
 | (int(31), int('ex214)) is not a subtype of (int('ex209), int('ex210))
 | as false could not be proven
 |
 | type variable 'ex207:
 | type variable 'ex209:
 | pass/existential_ast3/v2.sail:16.23-25:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex208:
 | type variable 'ex210:
 | pass/existential_ast3/v2.sail:16.26-28:
 | 16 | let (datasize, n) : {'d 'n, datasize('d) & 0 <= 'n < 'd. (int('d), int('n))} =
 |  | ^^ derived from here
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
 |
 | type variable 'ex212:
 | type variable 'ex214:
 | pass/existential_ast3/v2.sail:17.48-65:
 | 17 | if b == 0b0 then (64, unsigned(b @ a)) else (31, unsigned(a));
 |  | ^---------------^ bound here
2 changes: 1 addition & 1 deletion test/typecheck/pass/existential_ast3/v3.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
25 | Some(Ctor(64, unsigned(0b0 @ b @ a)))
 | ^-----------------------------^ checking function argument has type ast
 | Could not resolve quantifiers for Ctor
 | * (64 in {32, 64} & (0 <= 'ex244# & 'ex244# < 64))
 | * (64 in {32, 64} & (0 <= 'ex246# & 'ex246# < 64))
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v4.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
36 | if is_64 then 64 else 32;
 | ^^
 | int(64) is not a subtype of {('d : Int), (('is_64 & 'd == 63) | (not('is_64) & 'd == 32)). int('d)}
 | as (('is_64 & 'ex256 == 63) | (not('is_64) & 'ex256 == 32)) could not be proven
 | as (('is_64 & 'ex258 == 63) | (not('is_64) & 'ex258 == 32)) could not be proven
 |
 | type variable 'ex256:
 | type variable 'ex258:
 | pass/existential_ast3/v4.sail:35.18-79:
 | 35 | let 'datasize : {'d, ('is_64 & 'd == 63) | (not('is_64) & 'd == 32). int('d)} =
 |  | ^-----------------------------------------------------------^ derived from here
 | pass/existential_ast3/v4.sail:36.18-20:
 | 36 | if is_64 then 64 else 32;
 |  | ^^ bound here
 |  | has constraint: 64 == 'ex256
 |  | has constraint: 64 == 'ex258
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v5.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 | ^-------------^
 | range(0, 63) is not a subtype of range(0, ('datasize - 2))
 | as (0 <= 'ex264 & 'ex264 <= ('datasize - 2)) could not be proven
 | as (0 <= 'ex266 & 'ex266 <= ('datasize - 2)) could not be proven
 |
 | type variable 'ex264:
 | type variable 'ex266:
 | pass/existential_ast3/v5.sail:37.10-33:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^---------------------^ derived from here
 | pass/existential_ast3/v5.sail:37.50-65:
 | 37 | let n : range(0, 'datasize - 2) = if is_64 then unsigned(b @ a) else unsigned(a);
 |  | ^-------------^ bound here
 |  | has constraint: (0 <= 'ex264 & 'ex264 <= 63)
 |  | has constraint: (0 <= 'ex266 & 'ex266 <= 63)
6 changes: 3 additions & 3 deletions test/typecheck/pass/existential_ast3/v6.expect
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 | ^-------------^
 | range(0, 63) is not a subtype of range(0, ('datasize - 1))
 | as (0 <= 'ex270 & 'ex270 <= ('datasize - 1)) could not be proven
 | as (0 <= 'ex272 & 'ex272 <= ('datasize - 1)) could not be proven
 |
 | type variable 'ex270:
 | type variable 'ex272:
 | pass/existential_ast3/v6.sail:37.10-33:
 | 37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 |  | ^---------------------^ derived from here
 | pass/existential_ast3/v6.sail:37.71-86:
 | 37 | let n : range(0, 'datasize - 1) = if is_64 then unsigned(b @ a) else unsigned(b @ a);
 |  | ^-------------^ bound here
 |  | has constraint: (0 <= 'ex270 & 'ex270 <= 63)
 |  | has constraint: (0 <= 'ex272 & 'ex272 <= 63)
Loading
Loading