Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure pattern-matching #152

Merged
merged 6 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/Lang/Core.mli
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,19 @@ type _ typ =
(** Effect of the delimiter *)
} -> ktype typ

| TData : ttype * ctor_type list -> ktype typ
| TData : ttype * effect * ctor_type list -> ktype typ
(** Proof of the shape of ADT.

Algebraic data type (ADTs) are just abstract types, but each operation
on them like constructors or pattern-matching requires additional
computationally irrelevant parameter of type that describes the shape
of ADTs. This approach simplifies many things, e.g., mutually recursive
types are not recursive at all! *)
types are not recursive at all!

The element of type [TData(tp, eff, ctors)] is a witness that type [tp]
has constructors [ctors]. The effect [eff] is an effect of
pattern-matching: its pure for strictly positively recursive types,
and impure for other types (because it may lead to non-termination). *)

| TApp : ('k1 -> 'k2) typ * 'k1 typ -> 'k2 typ
(** Type application *)
Expand Down Expand Up @@ -130,8 +135,13 @@ type data_def =
args : TVar.ex list;
(** List of type parameters of this ADT. *)

ctors : ctor_type list
ctors : ctor_type list;
(** List of constructors. *)

strictly_positive : bool
(** A flag indicating if the type is strictly positively recursive (in
particular, not recursive at all) and therefore can be deconstructed
without performing NTerm effect. *)
}

| DD_Label of (** Label *)
Expand Down
5 changes: 3 additions & 2 deletions src/Lang/CorePriv/SExprPrinter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ let rec tr_type : type k. k typ -> SExpr.t =
List (List.map tr_type lbl.val_types);
tr_type lbl.delim_tp;
tr_type lbl.delim_eff ]
| TData(tp, ctors) ->
List (Sym "data" :: tr_type tp :: List.map tr_ctor_type ctors)
| TData(tp, eff, ctors) ->
List (Sym "data" :: tr_type tp :: List (tr_effect eff) ::
List.map tr_ctor_type ctors)
| TApp _ -> tr_type_app tp []

and tr_effect : effect -> SExpr.t list =
Expand Down
5 changes: 3 additions & 2 deletions src/Lang/CorePriv/Subst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ let rec in_type_rec : type k. t -> k typ -> k typ =
delim_tp = in_type_rec sub lbl.delim_tp;
delim_eff = in_type_rec sub lbl.delim_eff
}
| TData(tp, ctors) ->
TData(in_type_rec sub tp, List.map (in_ctor_type_rec sub) ctors)
| TData(tp, eff, ctors) ->
TData(in_type_rec sub tp, in_type_rec sub eff,
List.map (in_ctor_type_rec sub) ctors)
| TApp(tp1, tp2) ->
TApp(in_type_rec sub tp1, in_type_rec sub tp2)

Expand Down
9 changes: 5 additions & 4 deletions src/Lang/CorePriv/Syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ type var = Var.t

type data_def =
| DD_Data of
{ tvar : TVar.ex;
proof : var;
args : TVar.ex list;
ctors : ctor_type list
{ tvar : TVar.ex;
proof : var;
args : TVar.ex list;
ctors : ctor_type list;
strictly_positive : bool
}
| DD_Label of
{ tvar : keffect tvar;
Expand Down
74 changes: 69 additions & 5 deletions src/Lang/CorePriv/Type.ml
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ let rec equal : type k. k typ -> k typ -> bool =
end
| TLabel _, _ -> false

| TData(tp1, ctors1), TData(tp2, ctors2) ->
| TData(tp1, eff1, ctors1), TData(tp2, eff2, ctors2) ->
equal tp1 tp2 &&
equal eff1 eff2 &&
List.length ctors1 = List.length ctors2 &&
List.for_all2 ctor_type_equal ctors1 ctors2
| TData _, _ -> false
Expand Down Expand Up @@ -176,7 +177,11 @@ let rec subtype tp1 tp2 =
| TLabel _, (TUVar _ | TVar _ | TArrow _ | TForall _ | TData _ | TApp _) ->
false

| TData _, TData _ -> equal tp1 tp2
| TData(tp1, eff1, ctors1), TData(tp2, eff2, ctors2) ->
equal tp1 tp2 &&
subeffect eff1 eff2 &&
List.length ctors1 = List.length ctors2 &&
List.for_all2 ctor_type_equal ctors1 ctors2
| TData _, (TUVar _ | TVar _ | TArrow _ | TForall _ | TLabel _ | TApp _) ->
false

Expand Down Expand Up @@ -241,12 +246,13 @@ let rec type_in_scope : type k. _ -> k typ -> k typ option =
{ effect; tvars = lbl.tvars; val_types; delim_tp; delim_eff })
| _ -> None
end
| TData(tp, ctors) ->
| TData(tp, eff, ctors) ->
begin match
type_in_scope scope tp,
type_in_scope scope eff,
forall_map (ctor_type_in_scope scope) ctors
with
| Some tp, Some ctors -> Some (TData(tp, ctors))
| Some tp, Some eff, Some ctors -> Some (TData(tp, eff, ctors))
| _ -> None
end
| TApp(tp1, tp2) ->
Expand Down Expand Up @@ -318,7 +324,7 @@ let rec supertype_in_scope scope (tp : ttype) =
are members of given set ([scope]) *)
and subtype_in_scope scope (tp : ttype) =
match tp with
| TUVar _ | TVar _ | TLabel _ | TData _ | TApp _ -> type_in_scope scope tp
| TUVar _ | TVar _ | TLabel _ | TApp _ -> type_in_scope scope tp
| TArrow(tp1, tp2, eff) ->
begin match
supertype_in_scope scope tp1,
Expand All @@ -333,6 +339,64 @@ and subtype_in_scope scope (tp : ttype) =
| Some body -> Some (TForall(a, body))
| None -> None
end
| TData(tp, eff, ctors) ->
begin match
type_in_scope scope tp,
subeffect_in_scope scope eff,
forall_map (ctor_type_in_scope scope) ctors
with
| Some tp, eff, Some ctors -> Some (TData(tp, eff, ctors))
| _ -> None
end

(** Check if all types on non-strictly positive positions fits in given
scope. *)
let rec strictly_positive : type k. nonrec_scope:_ -> k typ -> bool =
fun ~nonrec_scope tp ->
match tp with
| TUVar _ | TVar _ | TEffPure -> true
| TLabel _ | TData _ ->
begin match type_in_scope nonrec_scope tp with
| Some _ -> true
| None -> false
end

| TEffJoin(eff1, eff2) ->
strictly_positive ~nonrec_scope eff1 &&
strictly_positive ~nonrec_scope eff2

| TArrow(tp1, tp2, eff) ->
begin match
type_in_scope nonrec_scope tp1,
strictly_positive ~nonrec_scope tp2,
strictly_positive ~nonrec_scope eff
with
| Some _, true, true -> true
| _ -> false
end

| TForall(a, tp) ->
strictly_positive ~nonrec_scope:(TVar.Set.add a nonrec_scope) tp

| TApp(tp1, tp2) ->
begin match
strictly_positive ~nonrec_scope tp1,
type_in_scope nonrec_scope tp2
with
| true, Some _ -> true
| _ -> false
end

(** Check if all types on non-strictly positive positions fits in given
scope (for ADT constructors) *)
let strictly_positive_ctor ~nonrec_scope ctor =
let nonrec_scope = add_tvars_to_scope ctor.ctor_tvars nonrec_scope in
List.for_all (strictly_positive ~nonrec_scope) ctor.ctor_arg_types

(** Check if all types on non-strictly positive positions fits in given
scope (for list of ADT constructors) *)
let strictly_positive_ctors ~nonrec_scope ctors =
List.for_all (strictly_positive_ctor ~nonrec_scope) ctors

type ex = Ex : 'k typ -> ex

Expand Down
2 changes: 1 addition & 1 deletion src/Lang/CorePriv/TypeBase.ml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ type _ typ =
delim_tp : ttype;
delim_eff : effect
} -> ktype typ
| TData : ttype * ctor_type list -> ktype typ
| TData : ttype * effect * ctor_type list -> ktype typ
| TApp : ('k1 -> 'k2) typ * 'k1 typ -> 'k2 typ

and ttype = ktype typ
Expand Down
30 changes: 22 additions & 8 deletions src/Lang/CorePriv/WellTypedInvariant.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ let rec tr_type : type k. Env.t -> k typ -> k typ =
delim_tp = tr_type env lbl.delim_tp;
delim_eff = tr_type env lbl.delim_eff
}
| TData(tp, ctors) ->
TData(tr_type env tp, List.map (tr_ctor_type env) ctors)
| TData(tp, eff, ctors) ->
TData(tr_type env tp, tr_type env eff, List.map (tr_ctor_type env) ctors)
| TApp(tp1, tp2) ->
TApp(tr_type env tp1, tr_type env tp2)

Expand Down Expand Up @@ -202,14 +202,27 @@ let prepare_data_def env (dd : data_def) =
let (env, a) = Env.add_tvar env lbl.tvar in
(env, DD_Label { lbl with tvar = a })

let finalize_data_def (env, dd_eff) dd =
let adt_effect ~nonrec_scope strictly_positive args ctors =
if strictly_positive then
let nonrec_scope = Type.add_tvars_to_scope args nonrec_scope in
if Type.strictly_positive_ctors ~nonrec_scope ctors then
TEffPure
else
InterpLib.InternalError.report
~reason:"Type is not strictly positvely recursive"
()
else
Effect.nterm

let finalize_data_def ~nonrec_scope (env, dd_eff) dd =
match dd with
| DD_Data adt ->
let (TVar.Ex a) = adt.tvar in
let (xs, data_tp, ctors) = check_data env (TVar a) adt.args adt.ctors in
let eff = adt_effect ~nonrec_scope adt.strictly_positive xs ctors in
let env =
Env.add_irr_var env adt.proof
(Type.t_foralls xs (TData(data_tp, ctors))) in
(Type.t_foralls xs (TData(data_tp, eff, ctors))) in
(env, dd_eff)

| DD_Label lbl ->
Expand All @@ -224,8 +237,9 @@ let finalize_data_def (env, dd_eff) dd =
(env, Effect.join Effect.nterm dd_eff)

let check_data_defs env dds =
let nonrec_scope = Env.scope env in
let (env, dds) = List.fold_left_map prepare_data_def env dds in
List.fold_left finalize_data_def (env, TEffPure) dds
List.fold_left (finalize_data_def ~nonrec_scope) (env, TEffPure) dds

let rec infer_type_eff env e =
match e with
Expand Down Expand Up @@ -285,7 +299,7 @@ let rec infer_type_eff env e =
let tp = tr_type env tp in
let eff = tr_type env eff in
begin match infer_type_check_eff (Env.irrelevant env) proof TEffPure with
| TData(data_tp, ctors) when List.length cls = List.length ctors ->
| TData(data_tp, meff, ctors) when List.length cls = List.length ctors ->
check_vtype env v data_tp;
List.iter2 (fun cl ctor ->
let xs = cl.cl_vars in
Expand All @@ -296,7 +310,7 @@ let rec infer_type_eff env e =
let env = List.fold_left2 Env.add_var env xs tps in
check_type_eff env cl.cl_body tp eff
) cls ctors;
(tp, Effect.join Effect.nterm eff)
(tp, Effect.join meff eff)
| _ ->
failwith "Internal type error"
end
Expand Down Expand Up @@ -375,7 +389,7 @@ and infer_vtype env v =
and infer_ctor_type env proof n tps args check_arg =
assert (n >= 0);
begin match infer_type_check_eff (Env.irrelevant env) proof TEffPure with
| TData(tp, ctors) ->
| TData(tp, _, ctors) ->
begin match List.nth_opt ctors n with
| Some ctor ->
let sub = tr_types_sub env Subst.empty ctor.ctor_tvars tps in
Expand Down
15 changes: 9 additions & 6 deletions src/Lang/Unif.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ module CtorDecl = struct
let subst = UnifPriv.Subst.in_ctor_decl

let find_index cs name = List.find_index (fun c -> c.ctor_name = name) cs

let strictly_positive = UnifPriv.Type.ctor_strictly_positive
end

module Subst = UnifPriv.Subst
Expand All @@ -58,10 +60,11 @@ type var = Var.t

type data_def =
| DD_Data of
{ tvar : tvar;
proof : var;
args : named_tvar list;
ctors : ctor_decl list
{ tvar : tvar;
proof : var;
args : named_tvar list;
ctors : ctor_decl list;
strictly_positive : bool
}

| DD_Label of
Expand Down Expand Up @@ -94,8 +97,8 @@ and expr_data =
| ELetRec of rec_def list * expr
| ECtor of expr * int * typ list * expr list
| EData of data_def list * expr
| EMatchEmpty of expr * expr * typ * effrow
| EMatch of expr * match_clause list * typ * effrow
| EMatchEmpty of expr * expr * typ * effrow option
| EMatch of expr * match_clause list * typ * effrow option
| EHandle of tvar * var * typ * expr * expr
| EHandler of
{ label : var;
Expand Down
22 changes: 17 additions & 5 deletions src/Lang/Unif.mli
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,13 @@ type data_def =
args : named_tvar list;
(** List of type parameters of this ADT. *)

ctors : ctor_decl list
ctors : ctor_decl list;
(** List of constructors. *)

strictly_positive : bool
(** A flag indicating if the type is strictly positively recursive (in
particular, not recursive at all) and therefore can be deconstructed
in pure way. *)
}

| DD_Label of (** Label *)
Expand Down Expand Up @@ -203,13 +208,15 @@ and expr_data =
| EData of data_def list * expr
(** Definition of mutually recursive ADTs. *)

| EMatchEmpty of expr * expr * typ * effrow
| EMatchEmpty of expr * expr * typ * effrow option
(** Pattern-matching of an empty type. The first parameter is an
irrelevant expression, that is a witness that the type of the second
parameter is an empty ADT *)
parameter is an empty ADT. The last parameter is an optional effect of
the whole expression: [None] means that pattern-matching is pure. *)

| EMatch of expr * match_clause list * typ * effrow
(** Pattern-matching. It stores type and effect of the whole expression. *)
| EMatch of expr * match_clause list * typ * effrow option
(** Pattern-matching. It stores type and effect of the whole expression.
If the effect is [None], the pattern-matching is pure. *)

| EHandle of tvar * var * typ * expr * expr
(** Handling construct. In [EHandle(a, x, tp, e1, e2)] the meaning of
Expand Down Expand Up @@ -745,6 +752,11 @@ module CtorDecl : sig

(** Get the index of a constructor with a given name *)
val find_index : ctor_decl list -> string -> int option

(** Check if given constructor is strictly positive, i.e., if all type
variables on non-strictly positive positions and all scopes of unification
variables fit in [nonrec_scope]. *)
val strictly_positive : nonrec_scope:scope -> ctor_decl -> bool
end

(* ========================================================================= *)
Expand Down
2 changes: 2 additions & 0 deletions src/Lang/UnifPriv/Scope.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ let filter lvl f scope =

let mem scope x = TVar.Set.mem x scope.tvar_set

let for_all f scope = TVar.Set.for_all f scope.tvar_set

let perm p scope =
{ tvar_set = TVar.Perm.map_set p scope.tvar_set;
level = scope.level
Expand Down
3 changes: 3 additions & 0 deletions src/Lang/UnifPriv/Scope.mli
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ val filter : int -> (TVar.t -> bool) -> t -> t
(** Check if given type variable is defined in given scope *)
val mem : t -> TVar.t -> bool

(** Check if given predicate holds for each element of given scope *)
val for_all : (TVar.t -> bool) -> t -> bool

(** Permute variables in given scope *)
val perm : TVar.Perm.t -> t -> t

Expand Down
Loading