Skip to content

Wasm: specialization of number comparisons #1954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 100 additions & 3 deletions compiler/lib-wasm/generate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ module Generate (Target : Target_sig.S) = struct
{ live : int array
; in_cps : Effects.in_cps
; deadcode_sentinal : Var.t
; types : Typing.typ Var.Tbl.t
; blocks : block Addr.Map.t
; closures : Closure_conversion.closure Var.Map.t
; global_context : Code_generation.context
Expand Down Expand Up @@ -233,6 +234,39 @@ module Generate (Target : Target_sig.S) = struct
f context (transl_prim_arg x) (transl_prim_arg y) (transl_prim_arg z)
| _ -> invalid_arity name l ~expected:3)

let get_type ctx p =
match p with
| Pv x -> Var.Tbl.get ctx.types x
| Pc c -> Typing.constant_type c

let register_comparison name cmp_int cmp_boxed_int cmp_float =
register_prim name `Mutable (fun ctx _ transl_prim_arg l ->
match l with
| [ x; y ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
match get_type ctx x, get_type ctx y with
| Number Int, Number Int -> cmp_int x' y'
| Number Int32, Number Int32 ->
let* x' = Memory.unbox_int32 x' in
let* y' = Memory.unbox_int32 y' in
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
| Number Nativeint, Number Nativeint ->
let* x' = Memory.unbox_nativeint x' in
let* y' = Memory.unbox_nativeint y' in
Value.val_int (return (W.BinOp (I32 cmp_boxed_int, x', y')))
| Number Int64, Number Int64 ->
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
Value.val_int (return (W.BinOp (I64 cmp_boxed_int, x', y')))
| Number Float, Number Float -> float_comparison cmp_float x' y'
| _ ->
let* f = register_import ~name (Fun (func_type 2)) in
let* x' = x' in
let* y' = y' in
return (W.Call (f, [ x'; y' ])))
| _ -> invalid_arity name l ~expected:2)

let () =
register_bin_prim "caml_array_unsafe_get" `Mutable Memory.gen_array_get;
register_bin_prim "caml_floatarray_unsafe_get" `Mutable Memory.float_array_get;
Expand Down Expand Up @@ -605,7 +639,66 @@ module Generate (Target : Target_sig.S) = struct
l
~init:(return [])
in
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l)
Memory.allocate ~tag:0 ~deadcode_sentinal:ctx.deadcode_sentinal l);
register_comparison "caml_greaterthan" (fun y x -> Value.lt x y) (Gt S) Gt;
register_comparison "caml_greaterequal" (fun y x -> Value.le x y) (Ge S) Ge;
register_comparison "caml_lessthan" Value.lt (Lt S) Lt;
register_comparison "caml_lessequal" Value.le (Le S) Le;
register_comparison
"caml_equal"
(fun x y ->
let* x = x in
let* y = y in
Value.val_int (return (W.RefEq (x, y))))
Eq
Eq;
register_comparison
"caml_notequal"
(fun x y ->
let* x = x in
let* y = y in
Value.val_int (return (W.UnOp (I32 Eqz, RefEq (x, y)))))
Ne
Ne;
register_prim "caml_compare" `Mutable (fun ctx _ transl_prim_arg l ->
match l with
| [ x; y ] -> (
let x' = transl_prim_arg x in
let y' = transl_prim_arg y in
match get_type ctx x, get_type ctx y with
| Number Int, Number Int ->
Value.val_int
Arith.(
(Value.int_val y' < Value.int_val x')
- (Value.int_val x' < Value.int_val y'))
| Number Int32, Number Int32 ->
let* f = register_import ~name:"caml_int32_compare" (Fun (func_type 2)) in
let* x' = Memory.unbox_int32 x' in
let* y' = Memory.unbox_int32 y' in
return (W.Call (f, [ x'; y' ]))
| Number Nativeint, Number Nativeint ->
let* f =
register_import ~name:"caml_nativeint_compare" (Fun (func_type 2))
in
let* x' = Memory.unbox_nativeint x' in
let* y' = Memory.unbox_nativeint y' in
return (W.Call (f, [ x'; y' ]))
| Number Int64, Number Int64 ->
let* f = register_import ~name:"caml_int64_compare" (Fun (func_type 2)) in
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
return (W.Call (f, [ x'; y' ]))
| Number Float, Number Float ->
let* f = register_import ~name:"caml_float_compare" (Fun (func_type 2)) in
let* x' = Memory.unbox_int64 x' in
let* y' = Memory.unbox_int64 y' in
return (W.Call (f, [ x'; y' ]))
| _ ->
let* f = register_import ~name:"caml_compare" (Fun (func_type 2)) in
let* x' = x' in
let* y' = y' in
return (W.Call (f, [ x'; y' ])))
| _ -> invalid_arity "caml_compare" l ~expected:2)

let rec translate_expr ctx context x e =
match e with
Expand Down Expand Up @@ -1175,7 +1268,8 @@ module Generate (Target : Target_sig.S) = struct
~should_export
~warn_on_unhandled_effect
*)
~deadcode_sentinal =
~deadcode_sentinal
~types =
global_context.unit_name <- unit_name;
let p, closures = Closure_conversion.f p in
(*
Expand All @@ -1185,6 +1279,7 @@ module Generate (Target : Target_sig.S) = struct
{ live = live_vars
; in_cps
; deadcode_sentinal
; types
; blocks = p.blocks
; closures
; global_context
Expand Down Expand Up @@ -1292,8 +1387,10 @@ let start () = make_context ~value_type:Gc_target.Value.value

let f ~context ~unit_name p ~live_vars ~in_cps ~deadcode_sentinal =
let t = Timer.make () in
let state, info = Global_flow.f' ~fast:false p in
let types = Typing.f ~state ~info p in
let p = fix_switch_branches p in
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal p in
let res = G.f ~context ~unit_name ~live_vars ~in_cps ~deadcode_sentinal ~types p in
if times () then Format.eprintf " code gen.: %a@." Timer.print t;
res

Expand Down
Loading
Loading