Skip to content

Commit

Permalink
Support inferring the type of if-else based on the else branch (#762)
Browse files Browse the repository at this point in the history
If we can't infer the type of the `then` branch but we can infer the type of the `else` branch, then use that type and check the `then` branch matches it.
  • Loading branch information
Timmmm authored Nov 6, 2024
1 parent 66725a8 commit a93d125
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 40 deletions.
96 changes: 56 additions & 40 deletions src/lib/type_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
open Ast
open Ast_defs
open Ast_util
open Either
open Util
open Lazy
open Parse_ast.Attribute_data
Expand Down Expand Up @@ -3594,48 +3595,63 @@ and infer_exp env (E_aux (exp_aux, (l, uannot)) as exp) =
else annot_exp (E_for (v, inferred_t, inferred_f, checked_step, ord, checked_body)) unit_typ
| _, _ -> typ_error l "Ranges in foreach overlap"
end
| E_if (cond, then_branch, else_branch) ->
| E_if (cond, then_branch, else_branch) -> begin
(* Try to infer the type of the condition - in some cases it may be a constant `true`
or `false`, e.g. `xlen == 32`. If that fails check it is a bool without inference. *)
let cond' = try irule infer_exp env cond with Type_error _ -> crule check_exp env cond bool_typ in
let cond_constraint = destruct_atom_bool env (typ_of cond') in
let then_branch' =
irule infer_exp (add_opt_constraint l "then branch" (assert_constraint env true cond') env) then_branch

(* Constraints to apply when reasoning about the branch types. The condition must be
true when evaluating the type of the `then` branch, and false for `else`. *)
let then_env = add_opt_constraint l "then branch" (assert_constraint env true cond') env in
let else_env = add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env in

(* Infer the type of a branch and also see if it is a simple numeric type. Sail doesn't support
generic type unions (`int | string`) but it does support them for simple numeric types. For
example we will infer the type of `if foo then 2 else 4` as `{2, 4}`. *)
let branch_typ branch cond_env =
try
let inferred_exp = irule infer_exp cond_env branch in
let maybe_simple_numeric =
Option.map
(fun (kids, nc, nexp) -> to_simple_numeric kids nc nexp)
(destruct_numeric (Env.expand_synonyms env (typ_of inferred_exp)))
in
Left (inferred_exp, maybe_simple_numeric)
with Type_error (l, err) -> Right (Type_error (l, err))
in
(* We don't have generic type union in Sail, but we can union simple numeric types. *)
begin
match destruct_numeric (Env.expand_synonyms env (typ_of then_branch')) with
| Some (kids, nc, then_nexp) ->
let then_sn = to_simple_numeric kids nc then_nexp in
let else_branch' =
irule infer_exp
(add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env)
else_branch
in
begin
match destruct_numeric (Env.expand_synonyms env (typ_of else_branch')) with
| Some (kids, nc, else_nexp) ->
let else_sn = to_simple_numeric kids nc else_nexp in
let typ = typ_of_simple_numeric (union_simple_numeric cond_constraint then_sn else_sn) in
annot_exp (E_if (cond', then_branch', else_branch')) typ
| None -> typ_error l ("Could not infer type of " ^ string_of_exp else_branch)
end
| None -> begin
match typ_of then_branch' with
| Typ_aux (Typ_app (f, [_]), _) when string_of_id f = "atom_bool" ->
let else_branch' =
crule check_exp
(add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env)
else_branch bool_typ
in
annot_exp (E_if (cond', then_branch', else_branch')) bool_typ
| _ ->
let else_branch' =
crule check_exp
(add_opt_constraint l "else branch" (Option.map nc_not (assert_constraint env false cond')) env)
else_branch (typ_of then_branch')
in
annot_exp (E_if (cond', then_branch', else_branch')) (typ_of then_branch')
end
end

(* When one branch's type is inferred, check the other branch's type against it. *)
let one_branch_inferred inferred_branch other_branch other_env =
(* If the type of the inferred branch is a function call to `atom_bool` treat it as bool. *)
let inferred_typ = if is_atom_bool (typ_of inferred_branch) then bool_typ else typ_of inferred_branch in
(* Check the other branch matches the type. *)
let other_branch' = crule check_exp other_env other_branch inferred_typ in
(other_branch', inferred_typ)
in

match (branch_typ then_branch then_env, branch_typ else_branch else_env) with
(* Both branches are simple numeric types. *)
| Left (then_branch', Some then_sn), Left (else_branch', Some else_sn) ->
let cond_constraint = destruct_atom_bool env (typ_of cond') in
let typ = typ_of_simple_numeric (union_simple_numeric cond_constraint then_sn else_sn) in
annot_exp (E_if (cond', then_branch', else_branch')) typ
(* Both branches could be inferred but exactly one is a simple numeric type. *)
| Left (_, Some _), Left (_, None) | Left (_, None), Left (_, Some _) ->
typ_error l ("Incompatible types: " ^ string_of_exp then_branch ^ " vs " ^ string_of_exp else_branch)
(* One branch is a simple numeric type but the type of the other branch couldn't be inferred. *)
| Left (_, Some _), _ -> typ_error l ("Could not infer type of " ^ string_of_exp else_branch)
| _, Left (_, Some _) -> typ_error l ("Could not infer type of " ^ string_of_exp then_branch)
(* Neither branch is a simple numeric type, but we inferred the `then` branch. *)
| Left (then_branch', None), _ ->
let other_branch, inferred_typ = one_branch_inferred then_branch' else_branch else_env in
annot_exp (E_if (cond', then_branch', other_branch)) inferred_typ
(* Neither branch is a simple numeric type, but we inferred the `else` branch (but not the `then` branch). *)
| _, Left (else_branch', None) ->
let other_branch, inferred_typ = one_branch_inferred else_branch' then_branch then_env in
annot_exp (E_if (cond', other_branch, else_branch')) inferred_typ
(* We couldn't infer the type of either branch; raise the error for the `then` branch. *)
| Right err, _ -> raise err
end
| E_vector_access (v, n) -> begin
try infer_exp env (E_aux (E_app (mk_id "vector_access", [v; n]), (l, uannot))) with
| Type_error (err_l, err) -> (
Expand Down
17 changes: 17 additions & 0 deletions test/typecheck/pass/if_infer_else.sail
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
default Order dec

$include <prelude.sail>

val zeros : forall 'n, 'n >= 0 . implicit('n) -> bits('n)
function zeros (n) = sail_zeros (n)

function main() -> unit = {
let b : bool = true;
let _ = if b then 0x12 else zeros();
let _ = if b then zeros() else 0x12;
let _ = if b then 1 else 2;

let _ = if true then 0x12 else zeros();
let _ = if true then zeros() else 0x12;
let _ = if true then 1 else 2;
}

0 comments on commit a93d125

Please sign in to comment.