From e9bf33f8e8be693b8b6b6eb95b4bafa94f6333dd Mon Sep 17 00:00:00 2001 From: jake-87 Date: Sun, 7 Jul 2024 21:56:22 +1000 Subject: [PATCH] trait resolution and bugfixing --- lib/common/error.ml | 10 -- lib/common/fresh.ml | 5 + lib/common/info.ml | 13 ++ lib/common/log.ml | 62 ++++++++ lib/front/ast.ml | 35 +++-- lib/front/convert_idents_to_strings.ml | 119 ++++++++------- lib/front/driver.ml | 1 + lib/front/error.ml | 0 lib/front/instance_resolve.ml | 137 +++++++++++++++++ lib/front/match_compl.ml | 3 + lib/front/modules.ml | 56 ++++--- lib/front/tycheck.ml | 203 ++++++++++++++++--------- lib/front/unify.ml | 55 ++++--- lib/main.ml | 77 +++++----- 14 files changed, 538 insertions(+), 238 deletions(-) create mode 100644 lib/common/fresh.ml create mode 100644 lib/common/log.ml delete mode 100644 lib/front/error.ml create mode 100644 lib/front/instance_resolve.ml create mode 100644 lib/front/match_compl.ml diff --git a/lib/common/error.ml b/lib/common/error.ml index 39c36dc..24e1d47 100644 --- a/lib/common/error.ml +++ b/lib/common/error.ml @@ -40,13 +40,3 @@ let option_app (f : 'a -> ('b, 'c) result) (x : 'a option) : | None -> ok None | Some n -> ( match f n with Ok s -> ok @@ Some s | Error e -> Error e) - -type error_location = Frontend' -[@@deriving show { with_path = false }] - -let compiler_error loc str = - print_string - ("Compiler error in " ^ show_error_location loc ^ " :\n"); - print_string str; - print_string "\n\n Aborting compilation.\n"; - exit 1 diff --git a/lib/common/fresh.ml b/lib/common/fresh.ml new file mode 100644 index 0000000..b960ffd --- /dev/null +++ b/lib/common/fresh.ml @@ -0,0 +1,5 @@ +let fresh = + let i = ref (-1) in + fun () -> + incr i; + !i diff --git a/lib/common/info.ml b/lib/common/info.ml index 3bc5ff4..8c744af 100644 --- a/lib/common/info.ml +++ b/lib/common/info.ml @@ -1,10 +1,13 @@ type info = .. type data = .. +type info += Dummy +type data += Dummy' (* unique ids *) type id = Id of int [@@deriving show { with_path = false }] let noid = Id (-1) +let id' () = Id (Fresh.fresh ()) let p_INFO_TABLE : (id, (info * data) list) Hashtbl.t = Hashtbl.create 100 @@ -22,3 +25,13 @@ let set_property (id : id) (prop : info) (data : data) : unit = | Some l -> (* update *) Hashtbl.replace p_INFO_TABLE id ((prop, data) :: l) + +let print_related_entries i printer = + Hashtbl.iter + (fun id v -> + List.iter + (fun (info, data) -> + if info = i then + printer id data) + v) + p_INFO_TABLE diff --git a/lib/common/log.ml b/lib/common/log.ml new file mode 100644 index 0000000..99792d4 --- /dev/null +++ b/lib/common/log.ml @@ -0,0 +1,62 @@ +type level = + | TRACE + | DEBUG + | INFO + | WARN + | ERROR + | FATAL + +let lvl_to_ident (lvl : level) = + match lvl with + | TRACE -> "\x1B[0;32mTrace:" + | DEBUG -> "\x1B[0;36mDebug:" + | INFO -> "\x1B[0;34mInfo:" + | WARN -> "\x1B[0;33mWarn:" + | ERROR -> "\x1B[0;31mError:" + | FATAL -> "\x1B[1;31m!!! FATAL !!!" + +let reset = "\x1B[0m" + +type message = Msg of level * string + +let log_ : message list ref = ref [] +let init_log () = () +let append_ msg = log_ := msg :: !log_ +let get_ () = List.rev !log_ + +let msg_to_str (Msg (lvl, msg)) = + let intro = lvl_to_ident lvl in + intro ^ "\n" ^ reset ^ msg ^ "\n\n" + +let print_log () = + let l = get_ () in + match l with + | [] -> print_endline "\n\x1B[1;35mlog was empty :)\x1B[0m\n" + | _ -> + print_string "\n\x1B[1;35m~~~ LOG ~~~ \x1B[0m\n\n"; + List.iter (fun x -> print_string (msg_to_str x)) l; + print_string "\x1B[1;35m~~~ END LOG ~~~ \x1B[0m\n" + +let log level str rest = + let m = Msg (level, Printf.sprintf str rest) in + append_ m; + match level with + | FATAL -> raise (Invalid_argument "fatal") + | _ -> () + +let log_fatal level str rest = + let m = Msg (level, Printf.sprintf str rest) in + append_ m; + match level with + | FATAL -> + print_log (); + print_endline "Khasmc terminated from fatal error!"; + exit 1 + | _ -> raise (Invalid_argument "non-fatal") + +let trace x = log TRACE "%s" x +let debug x = log DEBUG "%s" x +let info x = log INFO "%s" x +let warn x = log WARN "%s" x +let error x = log ERROR "%s" x +let fatal x = log_fatal FATAL "%s" x diff --git a/lib/front/ast.ml b/lib/front/ast.ml index f1531d8..d0d7a7a 100644 --- a/lib/front/ast.ml +++ b/lib/front/ast.ml @@ -87,7 +87,7 @@ and pp_ty (fmt : Format.formatter) (ty : ty) : unit = | Free s -> Format.fprintf fmt "'%s" s | Custom t -> Format.fprintf fmt "%a" pp_path t | Tuple t -> Format.fprintf fmt "(%a)" (pp_list fmt) t - | Arrow (a, b) -> Format.fprintf fmt "%a -> %a" pp_ty a pp_ty b + | Arrow (a, b) -> Format.fprintf fmt "(%a) -> %a" pp_ty a pp_ty b | TApp (p, l) -> Format.fprintf fmt "%a (%a)" pp_path p (pp_list fmt) l | TForall (s, t) -> @@ -102,6 +102,11 @@ and pp_list fmt fmt x = Format.fprintf fmt ", ") x +let print_ty ty = + pp_ty Format.std_formatter ty; + Format.print_newline (); + Format.print_flush () + (* also carries free vars *) type ty' = freevar list * ty [@@deriving show { with_path = false }] @@ -150,8 +155,6 @@ type tm = | Record of id * path * (string * tm) list (* foo.bar *) | Project of id * tm * string - (* error *) - | Poison of id * exn [@@deriving show { with_path = false }] type definition_no_body = { @@ -160,6 +163,7 @@ type definition_no_body = { constraints : constraint' list; args : (string * ty) list; ret : ty; + id : Common.Info.id; } [@@deriving show { with_path = false }] @@ -170,18 +174,19 @@ type definition = { args : (string * ty) list; ret : ty; body : tm; + id : Common.Info.id; } [@@deriving show { with_path = false }] let to_definition_no_body (d : definition) : definition_no_body = match d with - | { name; free_vars; constraints; args; ret; body = _body } -> - { name; free_vars; constraints; args; ret } + | { name; free_vars; constraints; id; args; ret; body = _body } -> + { name; free_vars; constraints; id; args; ret } let to_definition (b : tm) (d : definition_no_body) : definition = match d with - | { name; free_vars; constraints; args; ret } -> - { body = b; name; free_vars; constraints; args; ret } + | { name; free_vars; constraints; id; args; ret } -> + { body = b; name; free_vars; id; constraints; args; ret } type trait = { name : string; @@ -192,14 +197,16 @@ type trait = { constraints : constraint' list; (* member functions *) functions : definition_no_body list; + id : Common.Info.id; } [@@deriving show { with_path = false }] type impl = { name : string; - args : ty list; + args : (string * ty) list; assoc_types : (string * ty) list; impls : definition list; + id : Common.Info.id; } [@@deriving show { with_path = false }] @@ -207,16 +214,17 @@ type typ = { name : string; args : freevar list; expr : tyexpr; + id : Common.Info.id; } [@@deriving show { with_path = false }] type statement = (* name, freevars, constraints, args, return type, term *) - | Definition of id * definition + | Definition of definition (* name, args, body *) - | Type of id * typ - | Trait of id * trait - | Impl of id * impl + | Type of typ + | Trait of trait + | Impl of impl [@@deriving show { with_path = false }] type file = { @@ -259,6 +267,5 @@ let get_tm_id (t : tm) : id = | ITE (i, _, _, _) | Annot (i, _, _) | Record (i, _, _) - | Project (i, _, _) - | Poison (i, _) -> + | Project (i, _, _) -> i diff --git a/lib/front/convert_idents_to_strings.ml b/lib/front/convert_idents_to_strings.ml index 6b6cbc2..4fd9b42 100644 --- a/lib/front/convert_idents_to_strings.ml +++ b/lib/front/convert_idents_to_strings.ml @@ -56,7 +56,6 @@ let rec collapse_tm (tm : tm) : tm = collapse pth, List.map (fun (a, b) -> (a, collapse_tm b)) flds ) | Project (i, p, s) -> Project (i, collapse_tm p, s) - | Poison (_, _) -> tm let rec collapse_tyexpr (tyexpr : tyexpr) : tyexpr = match (tyexpr : tyexpr) with @@ -69,64 +68,74 @@ let rec collapse_tyexpr (tyexpr : tyexpr) : tyexpr = let convert' (s : statement) : statement = match (s : statement) with - | Definition (id, { name; free_vars; constraints; args; ret; body }) + | Definition { id; name; free_vars; constraints; args; ret; body } -> Definition - ( id, - { - name; - free_vars; - constraints = collapse_cons constraints; - args = List.map (fun (a, b) -> (a, collapse_ty b)) args; - ret = collapse_ty ret; - body = collapse_tm body; - } ) - | Type (id, { name; args; expr }) -> - Type (id, { name; args; expr = collapse_tyexpr expr }) - | Trait (i, { name; args; assoc_types; constraints; functions }) -> + { + id; + name; + free_vars; + constraints = collapse_cons constraints; + args = List.map (fun (a, b) -> (a, collapse_ty b)) args; + ret = collapse_ty ret; + body = collapse_tm body; + } + | Type { id; name; args; expr } -> + Type { id; name; args; expr = collapse_tyexpr expr } + | Trait { id; name; args; assoc_types; constraints; functions } -> Trait - ( i, - { - name; - args; - assoc_types; - constraints = collapse_cons constraints; - functions = - List.map - (fun ({ name; free_vars; constraints; args; ret } : - definition_no_body) -> - { - name; - free_vars; - constraints = collapse_cons constraints; - args = - List.map (fun (a, b) -> (a, collapse_ty b)) args; - ret = collapse_ty ret; - }) - functions; - } ) - | Impl (i, { name; args; assoc_types; impls }) -> + { + id; + name; + args; + assoc_types; + constraints = collapse_cons constraints; + functions = + List.map + (fun ({ name; id; free_vars; constraints; args; ret } : + definition_no_body) -> + { + id; + name; + free_vars; + constraints = collapse_cons constraints; + args = + List.map (fun (a, b) -> (a, collapse_ty b)) args; + ret = collapse_ty ret; + }) + functions; + } + | Impl { id; name; args; assoc_types; impls } -> Impl - ( i, - { - name; - args = List.map collapse_ty args; - assoc_types = - List.map (fun (a, b) -> (a, collapse_ty b)) assoc_types; - impls = - List.map - (fun { name; free_vars; constraints; args; ret; body } -> - { - name; - free_vars; - constraints = collapse_cons constraints; - args = - List.map (fun (a, b) -> (a, collapse_ty b)) args; - ret = collapse_ty ret; - body = collapse_tm body; - }) - impls; - } ) + { + id; + name; + args = List.map (fun (a, b) -> (a, collapse_ty b)) args; + assoc_types = + List.map (fun (a, b) -> (a, collapse_ty b)) assoc_types; + impls = + List.map + (fun { + id; + name; + free_vars; + constraints; + args; + ret; + body; + } -> + { + id; + name; + free_vars; + constraints = collapse_cons constraints; + args = + List.map (fun (a, b) -> (a, collapse_ty b)) args; + ret = collapse_ty ret; + body = collapse_tm body; + }) + impls; + } let convert (s : statement list) : statement list = List.map convert' s diff --git a/lib/front/driver.ml b/lib/front/driver.ml index 27480e9..dc39f33 100644 --- a/lib/front/driver.ml +++ b/lib/front/driver.ml @@ -10,4 +10,5 @@ let do_frontend (files : file list) : (statement list, 'a) result = Convert_idents_to_strings.convert statements in let+ tycheckd = Tycheck.typecheck compressed_paths in + Instance_resolve.test (); [] diff --git a/lib/front/error.ml b/lib/front/error.ml deleted file mode 100644 index e69de29..0000000 diff --git a/lib/front/instance_resolve.ml b/lib/front/instance_resolve.ml new file mode 100644 index 0000000..7d73156 --- /dev/null +++ b/lib/front/instance_resolve.ml @@ -0,0 +1,137 @@ +open Ast +open Common.Info +open Common.Error +open Common +open Tycheck + +type info += Traitfn_inst +type data += Traitfn_inst' of id option + +type ctx = { + trait_fns : (string * trait) list; + impls : impl list; +} + +let empty_ctx () = { trait_fns = []; impls = [] } + +let add_traitfn ctx fnname trait = + { ctx with trait_fns = (fnname, trait) :: ctx.trait_fns } + +let add_impl ctx impl = { ctx with impls = impl :: ctx.impls } + +let match_solve_frees args fnty traitty = + let inst = Unify.inst_frees args fnty in + let pairs = ref [] in + let cont (a : ty) (b : ty) : ty = + match (a, b) with + | x, Free s -> + pairs := (s, x) :: !pairs; + x + | _ -> raise (Unify.BadUnify (a, b)) + in + Log.trace (show_ty inst); + Log.trace (show_ty traitty); + ignore @@ Unify.unify' cont inst traitty; + List.iter (fun (a, b) -> Log.trace (a ^ " & " ^ show_ty b)) !pairs; + !pairs + +let solve ctx id fnname ty trait = + let fn = + List.find + (fun (d : definition_no_body) -> d.name = fnname) + trait.functions + in + let t = mk_ty (List.map snd fn.args) fn.ret in + let pairs = match_solve_frees trait.args ty t in + if + not + (List.for_all (( = ) true) + @@ List.map (fun x -> List.mem (fst x) trait.args) pairs) + then + Log.fatal "we don't support non-full inference at this time" + else + let sorted = List.sort (fun (a, _) (b, _) -> compare a b) pairs in + match + List.find_opt + (fun (i : impl) -> + let args = + List.sort (fun (a, _) (b, _) -> compare a b) i.args + in + Log.debug (string_of_int @@ List.length sorted); + Log.debug (string_of_int @@ List.length args); + List.map2 + (fun (_, pair) (_, inst) -> Unify.unify_b [] pair inst) + sorted args + |> List.mem false + |> not) + ctx.impls + with + | None -> Log.fatal "no instance found" + | Some s -> + Log.trace "found:"; + Log.trace (Ast.show_impl s); + set_property id Traitfn_inst (Traitfn_inst' (Some s.id)) + +let test () = + let open Ast in + let show = + { + name = "Show"; + args = [ "a" ]; + assoc_types = []; + constraints = []; + functions = + [ + { + name = "show"; + free_vars = []; + constraints = []; + args = [ ("dummy", Free "a") ]; + ret = TyString; + id = id' (); + }; + ]; + id = id' (); + } + in + let ctx = + { + trait_fns = [ ("show", show) ]; + impls = + [ + { + name = "Show"; + args = [ ("a", TyInt) ]; + assoc_types = []; + impls = + [ + { + name = "show"; + free_vars = []; + constraints = []; + args = [ ("dummy", TyInt) ]; + ret = TyString; + body = Var (noid, "DUMMY"); + id = id' (); + }; + ]; + id = id' (); + }; + ]; + } + in + solve ctx (id' ()) "show" (Arrow (TyInt, TyString)) show + +(* trait fn names & impls *) +let gather_needed_info (stmts : statement list) : ctx = + List.fold_left + (fun acc x -> + match x with + | Trait t -> + List.fold_left + (fun acc (fn : definition_no_body) -> + add_traitfn acc fn.name t) + acc t.functions + | Impl t -> add_impl acc t + | _ -> acc) + (empty_ctx ()) stmts diff --git a/lib/front/match_compl.ml b/lib/front/match_compl.ml new file mode 100644 index 0000000..86189ef --- /dev/null +++ b/lib/front/match_compl.ml @@ -0,0 +1,3 @@ +open Ast +open Common.Info +open Common.Error diff --git a/lib/front/modules.ml b/lib/front/modules.ml index 5c10a34..9b82617 100644 --- a/lib/front/modules.ml +++ b/lib/front/modules.ml @@ -216,7 +216,6 @@ let rec handle_tm (bound : string list) (ctx : ctx) (tm : tm) : | Project (id, path, field) -> let+ tm = handle_tm bound ctx path in Project (id, tm, field) - | Poison (id, exn) -> ok @@ Poison (id, exn) let handle_constraints (ctx : ctx) (tm : constraint') : (constraint', 'a) result = @@ -226,7 +225,7 @@ let handle_constraints (ctx : ctx) (tm : constraint') : (nm, tys) let handle_definition (ctx : ctx) - { name; free_vars; constraints; args; ret; body } : + { id; name; free_vars; constraints; args; ret; body } : (definition, 'a) result = let* cons = collect @@ List.map (handle_constraints ctx) constraints @@ -237,6 +236,7 @@ let handle_definition (ctx : ctx) and+ body = handle_tm arg1 ctx body in let args = List.combine arg1 tys in { + id; name = add_file' ctx name; free_vars; constraints = cons; @@ -286,8 +286,7 @@ let rec handle_trait (ctx : ctx) (trait : trait) : (trait, 'a) result dfn with name = to_str (InMod (trait.name, Base dfn.name)); }) - |> List.map - (to_definition (Poison (noid, Failure "handle_trait"))) + |> List.map (to_definition (Var (noid, "%.%BAD BAD BAD%.%"))) |> List.map (handle_definition ctx) |> collect |$> List.map to_definition_no_body @@ -295,7 +294,13 @@ let rec handle_trait (ctx : ctx) (trait : trait) : (trait, 'a) result { trait with name = trait.name; constraints; functions } let rec handle_impl (ctx : ctx) (impl : impl) : (impl, 'a) result = - let+ args = collect @@ List.map (handle_ty ctx) impl.args + let+ args = + collect + @@ List.map + (fun (a, b) -> + let+ b = handle_ty ctx b in + (a, b)) + impl.args and+ assoc_types = collect @@ List.map @@ -306,7 +311,13 @@ let rec handle_impl (ctx : ctx) (impl : impl) : (impl, 'a) result = and+ impls = collect @@ List.map (handle_definition ctx) impl.impls in - { name = add_file' ctx impl.name; args; assoc_types; impls } + { + id = impl.id; + name = add_file' ctx impl.name; + args; + assoc_types; + impls; + } let rec base_name b = match b with @@ -332,15 +343,14 @@ let handle_file (ctx : ctx) (file : file) : (file * ctx, 'a) result = let collect_names = List.map (function - | Definition (_, dfn) -> - InMod (file.name, Base dfn.name) :: [] - | Type (_, def) -> + | Definition dfn -> InMod (file.name, Base dfn.name) :: [] + | Type def -> InMod (file.name, Base def.name) :: get_constr_paths file def.name def.expr - | Trait (_, def) -> + | Trait def -> InMod (file.name, Base def.name) :: get_impl_paths file def.name def.functions - | Impl (_, impl) -> InMod (file.name, Base impl.name) :: []) + | Impl impl -> InMod (file.name, Base impl.name) :: []) file.toplevel |> List.flatten |> List.map (fun x -> (x, x)) @@ -352,24 +362,19 @@ let handle_file (ctx : ctx) (file : file) : (file * ctx, 'a) result = collect @@ List.map (function - | Definition (id, dfn) -> + | Definition dfn -> let+ def = handle_definition ctx dfn in - Definition (id, def) - | Type (id, def) -> + Definition def + | Type def -> let+ t = handle_tyexpr ctx def.expr in Type - ( id, - { - def with - expr = t; - name = add_file' ctx def.name; - } ) - | Trait (id, trait) -> + { def with expr = t; name = add_file' ctx def.name } + | Trait trait -> let+ t = handle_trait ctx trait in - Trait (id, t) - | Impl (id, impl) -> + Trait t + | Impl impl -> let+ t = handle_impl ctx impl in - Impl (id, t)) + Impl t) file.toplevel in ({ file with toplevel = defs }, ctx) @@ -412,4 +417,5 @@ let handle_files files = in let tmp = List.flatten @@ List.map errfmt e in let formatted = List.map (fun (a, b) -> format_error a b) tmp in - err' formatted + List.iter Log.error formatted; + err' "Module removal failed" diff --git a/lib/front/tycheck.ml b/lib/front/tycheck.ml index 0a7ccc8..62ba0e7 100644 --- a/lib/front/tycheck.ml +++ b/lib/front/tycheck.ml @@ -3,6 +3,7 @@ open Common.Info open Common.Error module BUR = BatUref open Unify +open Common type ('a, 'b) either = ('a, 'b) Either.t type info += Type @@ -14,6 +15,12 @@ let is_trait id = set_property id IsTrait IsTrait' let set_ty id ty = set_property id Type (Type' ty) let get_ty id = get_property id Type +let print_types () = + print_related_entries Type (fun id (Type' t) -> + print_string (show_id id); + print_string " : "; + print_ty (force t)) + let rec is_rank_one' (t : ty) : bool = match force t with | Tuple t -> List.for_all is_rank_one' t @@ -98,17 +105,40 @@ let add_args ctx id args = List.map (fun (a, b) -> (a, id, ([], b))) args @ ctx.locals; } +let rec find_frees (ty : ty) = + match force ty with + | Free s -> + [ (s, "fresh_free + " ^ string_of_int @@ Fresh.fresh ()) ] + | Tuple t -> List.flatten @@ List.map find_frees t + | Arrow (a, b) -> find_frees a @ find_frees b + | TApp (_, b) -> List.flatten @@ List.map find_frees b + | TForall (_, s) -> find_frees s + | _ -> [] + +let rec rename_frees' map ty = + match force ty with + | Free s -> Free (List.assoc s map) + | Tuple t -> Tuple (List.map (rename_frees' map) t) + | Arrow (a, b) -> Arrow (rename_frees' map a, rename_frees' map b) + | TApp (t, b) -> TApp (t, List.map (rename_frees' map) b) + | TForall (s, b) -> TForall (s, rename_frees' map b) + | _ -> ty + +let rename_frees ty = + let map = find_frees ty in + rename_frees' map ty + let find_ty (ctx : ctx) (p : string) : (ty, 'a) result = match List.filter (fun (nm, _, _) -> nm = p) ctx.locals with | [ (_, _, ty) ] -> ok (snd ty) | _ -> ( match List.filter (fun (nm, _, _) -> nm = p) ctx.bound with - | [ (_, _, ty) ] -> ok (snd ty) + | [ (_, _, ty) ] -> ok (rename_frees (snd ty)) | _ -> ( match List.filter (fun (nm, _, _, _) -> nm = p) ctx.constrs with - | [ (_, _, _, ty) ] -> ok (snd ty) + | [ (_, _, _, ty) ] -> ok (rename_frees @@ snd ty) | _ -> err @@ `Impossible_ctx (ctx, "var not found: " ^ p))) (* TODO: check *) @@ -155,12 +185,12 @@ let process_trait (ctx : ctx) (id : id) let collect_statement (ctx : ctx) (state : statement) : ctx = match state with - | Definition (id, def) -> + | Definition def -> let ty = mk_ty (List.map snd def.args) def.ret in - add_def ctx def.name id (def.free_vars, ty) - | Type (id, typ) -> add_typ ctx typ - | Trait (id, trait) -> process_trait ctx id trait - | Impl (_, _) -> ctx (* we ignore impls for the moment *) + add_def ctx def.name def.id (def.free_vars, ty) + | Type typ -> add_typ ctx typ + | Trait trait -> process_trait ctx trait.id trait + | Impl _ -> ctx (* we ignore impls for the moment *) (* goal: to take a pattern, a context, and a type, and figure out @@ -193,6 +223,7 @@ let rec check_tm (ctx : ctx) (tm : tm) (ty : ty) : (unit, 'a) result = begin match (tm, ty) with | Let (id, pat, head, body), t -> + (* TODO: let generalize ? *) let* head't = infer_tm ctx head in let* vars = deduce_pat_type ctx pat head't in let ctx' = @@ -220,7 +251,7 @@ let rec check_tm (ctx : ctx) (tm : tm) (ty : ty) : (unit, 'a) result = begin match tymb with | None -> ok () - | Some t -> unify t a |$> fun _ -> () + | Some t -> unify ctx.frees t a |$> fun _ -> () end in let* tys = deduce_pat_type ctx pat a in @@ -233,7 +264,7 @@ let rec check_tm (ctx : ctx) (tm : tm) (ty : ty) : (unit, 'a) result = | Lam _, _ -> err @@ `Lam_Not_Arrow (tm, ty) | Annot (_, tm, ty), t -> let+ _ = check_tm ctx tm ty - and+ _ = unify ty t in + and+ _ = unify ctx.frees ty t in () | Record (_, nm, fields), Custom t -> let* typ, field't = find_typ_by_record_nm ctx (to_str nm) in @@ -266,29 +297,14 @@ let rec check_tm (ctx : ctx) (tm : tm) (ty : ty) : (unit, 'a) result = |> collect |$> fun _ -> () end - | Project (_, tm, field), t -> - let* tm't = infer_tm ctx tm in - begin - match tm't with - | Custom r -> - let* typ, fields = - find_typ_by_record_nm ctx (to_str r) - in - begin - match List.assoc_opt field fields with - | Some n -> unify t n - | None -> err @@ `Field_Not_Found (ctx, field) - end - | _ -> err @@ `Project_Not_Record (ctx, tm't) - end - | Poison (_, _), t -> failwith "poison" | tm, ty -> let* tm't = infer_tm ctx tm in - unify tm't ty + unify ctx.frees tm't ty end |$> fun () -> set_ty (get_tm_id tm) ty and infer_tm (ctx : ctx) (tm : tm) : (ty, 'a) result = + let open Common in begin match tm with | App (id, f, x) -> @@ -296,20 +312,21 @@ and infer_tm (ctx : ctx) (tm : tm) : (ty, 'a) result = let rec go ty xs = match (ty, xs) with | Arrow (a, b), x :: xs -> - let* _ = unify a x in + let* _ = unify ctx.frees a x in go b xs | t, [] -> ok t | _ -> err @@ `App_Mismatch (id, f, x) in let* f't = infer_tm ctx f in go f't x't - | Var (_, t) -> find_ty ctx t - | Bound (_, t) -> find_ty ctx (to_str t) + | Var (_, t) -> find_ty ctx t |$> inst_frees ctx.frees + | Bound (_, t) -> find_ty ctx (to_str t) |$> inst_frees ctx.frees | Bool _ -> ok TyBool | String _ -> ok TyString | Int _ -> ok TyInt | Char _ -> ok TyChar | Let (id, pat, head, body) -> + (* TODO: let generalization ? *) let* h't = infer_tm ctx head in let* adds = deduce_pat_type ctx pat h't in let ctx' = @@ -329,10 +346,47 @@ and infer_tm (ctx : ctx) (tm : tm) : (ty, 'a) result = let+ _ = check_tm ctx t e't in e't end + | Project (_, tm, field) -> + let* tm't = infer_tm ctx tm in + begin + match tm't with + | Custom r -> + let* typ, fields = + find_typ_by_record_nm ctx (to_str r) + in + begin + match List.assoc_opt field fields with + | Some n -> ok n + | None -> err @@ `Field_Not_Found (ctx, field) + end + | _ -> err @@ `Project_Not_Record (ctx, tm't) + end | Annot (id, tm, ty) -> let+ _ = check_tm ctx tm ty in ty - (* TODO: add a lot more cases here *) + (* TODO: add a lot more cases here *) + | Match (id, tm, exprs) -> + let* tm't = infer_tm ctx tm in + let rec go (pat, body) = + let* vars = deduce_pat_type ctx pat tm't in + let ctx' = + List.map (fun (nm, ty) -> (nm, id, ty)) vars + |> add_locals ctx + in + infer_tm ctx' body + in + List.map go exprs |> collect |=> fun tms -> + let rec go acc = function + | [] -> ok acc + | x :: xs -> + let* rest = go acc xs in + unify ctx.frees x rest |$> fun () -> x + in + begin + match tms with + | [] -> err @@ `Can't_Infer_Empty_Match (ctx, tm) + | x :: xs -> go x xs + end | _ -> print_endline (show_ctx ctx); print_endline (show_tm tm); @@ -345,74 +399,77 @@ and infer_tm (ctx : ctx) (tm : tm) : (ty, 'a) result = let typecheck_statement (ctx : ctx) (s : statement) : (statement, 'a) result = match s with - | Definition (id, def) -> + | Definition def -> let ctx' = add_frees ctx def.free_vars in - let ctx' = add_args ctx' id def.args in + let ctx' = add_args ctx' def.id def.args in (* TODO: check all toplevel types are valid *) let+ _ = check_tm ctx' def.body def.ret in - Definition (id, def) - | Impl (id, impl) -> failwith "todo: typecheck impl" + Definition def + | Impl impl -> failwith "todo: typecheck impl ?" | _ -> ok s let typecheck (files : statement list) : (statement list, 'a) result = + let open Common in let ctx = List.fold_left collect_statement (empty ()) files in match collect @@ List.map (typecheck_statement ctx) files with | Ok s -> ok s | Error e -> - List.map + List.iter (fun e -> match e with - | `App_Mismatch (id, tm, tms) -> failwith "appmismatch" + | `App_Mismatch (id, tm, tms) -> Log.error "appmismatch" | `Bad_Pattern (pat, ty) -> - print_endline (show_pat pat); - print_endline (show_ty ty); - failwith "bad pat" + Log.trace (show_pat pat); + Log.trace (show_ty ty); + Log.error "bad pat" | `Bad_Unify ((t1, t1'), (t2, t2')) -> - print_endline (show_ty t1'); - print_endline (show_ty t2'); - print_endline "as part of:"; - print_endline (show_ty t1); - print_endline (show_ty t2); - failwith "bad unify" + Log.trace "bad unify"; + Log.trace (show_ty t1'); + Log.trace (show_ty t2'); + Log.trace "as part of:"; + Log.trace (show_ty t1); + Log.trace (show_ty t2); + Log.error "bad unify oop" | `Constr_Not_In_Type (pat, ty) -> - failwith "constr not in typ" + Log.error "constr not in typ" | `Impossible i -> - print_endline i; - failwith "impossible" - | `Lam_Not_Arrow (tm, ty) -> failwith "lam not arrow" + Log.trace i; + Log.error "impossible" + | `Lam_Not_Arrow (tm, ty) -> Log.error "lam not arrow" | `Mismatch_Pattern_Args (pat, ty) -> - failwith "mismatch pat arg" + Log.error "mismatch pat arg" | `Mismatched_Tuple_Length (pat, ty) -> - failwith "tuple len mismatch" + Log.error "tuple len mismatch" | `Impossible_ctx (ctx, i) -> - print_endline (show_ctx ctx); - print_endline i; - failwith "impossible" + Log.trace (show_ctx ctx); + Log.trace i; + Log.error "impossible" | `Mismatched_Field_Ty (ctx, s, e) -> - print_endline (show_ctx ctx); - print_endline s; - print_endline e; - failwith "mismatched field type" + Log.trace (show_ctx ctx); + Log.trace s; + Log.trace e; + Log.error "mismatched field type" | `Mismatched_Record (ctx, fields, tys) -> - print_endline (show_ctx ctx); + Log.trace (show_ctx ctx); List.iter - (fun (nm, _) -> print_endline ("field: " ^ nm)) + (fun (nm, _) -> Log.trace ("field: " ^ nm)) fields; List.iter (fun (nm, ty) -> - print_endline ("field: " ^ nm ^ " : " ^ show_ty ty)) + Log.trace ("field: " ^ nm ^ " : " ^ show_ty ty)) tys; - failwith "mismatched record" + Log.error "mismatched record" | `Record_Type_Mismatch (ctx, exp, got) -> - print_endline (show_ctx ctx); - print_endline (to_str exp); - print_endline (to_str got); - failwith "record type mismatch" + Log.trace (show_ctx ctx); + Log.trace (to_str exp); + Log.trace (to_str got); + Log.error "record type mismatch" | `Type_Not_Record (ctx, s) -> - print_endline (show_ctx ctx); - print_endline s; - failwith "type not record" - | `Project_Not_Record e -> failwith "e" - | `Field_Not_Found e -> failwith "fnf") - e - |> collect + Log.trace (show_ctx ctx); + Log.trace s; + Log.error "type not record" + | `Project_Not_Record e -> Log.error "proj not record" + | `Field_Not_Found e -> Log.error "field not found" + | `Can't_Infer_Empty_Match e -> Log.error "empty match") + e; + err' "Typechecking failed" diff --git a/lib/front/unify.ml b/lib/front/unify.ml index c0cb8bc..8f5c62b 100644 --- a/lib/front/unify.ml +++ b/lib/front/unify.ml @@ -18,7 +18,7 @@ let join_metas (a : meta BUR.t) (b : meta BUR.t) : unit = exception BadUnify of ty * ty -let rec unify' (a : ty) (b : ty) : ty = +let rec unify' cont (a : ty) (b : ty) : ty = match (force a, force b) with | TyInt, TyInt | TyBool, TyBool @@ -27,65 +27,76 @@ let rec unify' (a : ty) (b : ty) : ty = a | TyMeta a, TyMeta b -> ( match (get a, get b) with - | Solved a, Solved b -> unify' a b + | Solved a, Solved b -> unify' cont a b | _ -> join_metas a b; TyMeta a) | TyMeta a, t | t, TyMeta a -> begin match get a with - | Solved b -> unify' t b + | Solved b -> unify' cont t b | Unsolved -> join_metas a (BUR.uref (Solved t)); t end | Free a, Free b when a = b -> Free a | Custom a, Custom b when a = b -> Custom a - | Tuple w, Tuple e -> Tuple (List.map2 unify' w e) - | Arrow (a, b), Arrow (q, w) -> Arrow (unify' a q, unify' b w) + | Tuple w, Tuple e -> Tuple (List.map2 (unify' cont) w e) + | Arrow (a, b), Arrow (q, w) -> + Arrow ((unify' cont) a q, (unify' cont) b w) | TApp (a, b), TApp (q, w) when a = q -> - TApp (q, List.map2 unify' b w) - | TForall _, t -> failwith "not sure yet" - | _ -> raise (BadUnify (a, b)) + TApp (q, List.map2 (unify' cont) b w) + | TForall _, t -> Common.Log.fatal "not sure yet" + | a, b -> cont a b -let rec inst_frees (mapping : (string * meta BUR.t) list) (ty : ty) : +let rec inst_frees' (exclude : string list) + (mapping : (string * meta BUR.t) list) (ty : ty) : (string * meta BUR.t) list * ty = match force ty with | Custom _ | TyInt | TyChar | TyBool | TyString | TyMeta _ -> (mapping, ty) | Free s -> begin - match List.assoc_opt s mapping with - | Some s -> (mapping, TyMeta s) - | None -> - let n = new_meta () in - ((s, n) :: mapping, TyMeta n) + if List.mem s exclude then + (mapping, Free s) + else + match List.assoc_opt s mapping with + | Some s -> (mapping, TyMeta s) + | None -> + let n = new_meta () in + ((s, n) :: mapping, TyMeta n) end | Tuple w -> let map, t = List.fold_left (fun (acc, l) x -> - let map, me = inst_frees acc x in + let map, me = inst_frees' exclude acc x in (map, me :: l)) (mapping, []) w in (map, Tuple t) | Arrow (a, b) -> - let map, a = inst_frees mapping a in - let map, b = inst_frees map b in + let map, a = inst_frees' exclude mapping a in + let map, b = inst_frees' exclude map b in (map, Arrow (a, b)) | TApp (a, b) -> let map, t = List.fold_left (fun (acc, l) x -> - let map, me = inst_frees acc x in + let map, me = inst_frees' exclude acc x in (map, me :: l)) (mapping, []) b in (map, TApp (a, t)) | TForall (f, l) -> failwith "handle scoping issues" -let unify (a : ty) (b : ty) : (unit, 'a) result = - let map, a = inst_frees [] a in - let map, b = inst_frees map b in - match unify' a b with +let inst_frees exclude x = snd @@ inst_frees' exclude [] x + +let unify exe (a : ty) (b : ty) : (unit, 'a) result = + let map, a = inst_frees' exe [] a in + let map, b = inst_frees' exe map b in + let cont a b = raise (BadUnify (a, b)) in + match unify' cont a b with | s -> ok () | exception BadUnify (q, w) -> err @@ `Bad_Unify ((a, q), (b, w)) + +let unify_b exe a b = + match unify exe a b with Ok _ -> true | Error _ -> false diff --git a/lib/main.ml b/lib/main.ml index 2cde559..3415fd1 100644 --- a/lib/main.ml +++ b/lib/main.ml @@ -11,40 +11,28 @@ let example_files = opens = []; toplevel = [ - Type - ( noid, - { - name = "MyRecord"; - args = []; - expr = TRecord [ ("foo", TyInt); ("bar", TyBool) ]; - } ); Definition - ( noid, - { - name = "RecordTest"; - free_vars = []; - constraints = []; - args = []; - ret = Custom (Base "MyRecord"); - body = - Record - ( noid, - Base "MyRecord", - [ - ("foo", Int (noid, "5")); - ("bar", Bool (noid, true)); - ] ); - } ); + { + id = id' (); + name = "id"; + constraints = []; + free_vars = [ "a" ]; + args = [ ("x", Free "a") ]; + ret = Free "a"; + body = Var (id' (), "x"); + }; Definition - ( noid, - { - name = "ProjTest"; - free_vars = []; - constraints = []; - args = [ ("x", Custom (Base "MyRecord")) ]; - ret = TyInt; - body = Project (noid, Var (noid, "x"), "foo"); - } ); + { + id = id' (); + name = "integer"; + constraints = []; + free_vars = [ "a"; "b" ]; + args = []; + ret = Arrow (Free "b", Free "b"); + body = + App + (id' (), Var (id' (), "id"), [ Var (id' (), "id") ]); + }; ]; }; ] @@ -52,11 +40,22 @@ let example_files = let driver () = example_files |> Front.Driver.do_frontend let main () = + Common.Log.init_log (); Printexc.record_backtrace true; - driver () |> function - | Ok e -> - print_endline "ok!"; - ok e - | Error e -> - print_endline "err :("; - err' e + ignore + @@ begin + match driver () with + | Ok e -> + print_endline "ok!"; + ok e + | Error e -> + Common.Log.error "error"; + print_endline "err :("; + err' e + | exception e -> + Common.Log.error "exception!"; + print_endline "exception!"; + err' "exn" + end; + Common.Log.print_log (); + Front.Tycheck.print_types ()