Skip to content

Commit

Permalink
Untested: primitive ops: change Cmpne to Cmpeq and add Not
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 29, 2025
1 parent a7c2053 commit f279856
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
20 changes: 12 additions & 8 deletions arrayjit/lib/ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ type binop =
| Min
| Mod
| Cmplt
| Cmpne
| Cmpeq
(* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
(* | Shl *)
(* | Shr *)
Expand All @@ -170,6 +170,7 @@ type unop =
| Recip_sqrt
| Neg
| Tanh_approx
| Not (** 0. -> 1. | _ -> 0. *)
[@@deriving sexp, compare, equal]

type ternop = Where (** Where(a,b,c): if a then b else c *) | FMA (** FMA(a,b,c): (a * b) + c *)
Expand All @@ -188,7 +189,7 @@ let neutral_elem = function
| Min -> Float.infinity
| And -> 1.
| Or -> 0.
| Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr *) -> 0.
| Arg2 | Arg1 | Mod | Cmplt | Cmpeq (* | Shl | Shr *) -> 0.

let interpret_binop op v1 v2 =
let open Float in
Expand All @@ -205,7 +206,7 @@ let interpret_binop op v1 v2 =
| Min -> min v1 v2
| Mod -> v1 % v2
| Cmplt -> if v1 < v2 then 1. else 0.
| Cmpne -> if v1 <> v2 then 1. else 0.
| Cmpeq -> if v1 = v2 then 1. else 0.
(* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
(* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
Expand All @@ -231,6 +232,7 @@ let interpret_unop op v =
| Recip_sqrt -> 1. / sqrt v
| Neg -> ~-.v
| Tanh_approx -> tanh v
| Not -> if v = 0. then 1. else 0.

let interpret_ternop op v1 v2 v3 =
let open Float in
Expand All @@ -251,7 +253,7 @@ let binop_cd_syntax = function
| ToPowOf -> "**"
| Relu_gate -> "-?/"
| Cmplt -> "<"
| Cmpne -> "<>"
| Cmpeq -> "="
| Or -> "||"
| And -> "&&"
| Mod -> "%"
Expand All @@ -272,7 +274,7 @@ let binop_cd_fallback_syntax = function
| ToPowOf -> "pow"
| Relu_gate -> "relu_gate"
| Cmplt -> "lt"
| Cmpne -> "le"
| Cmpeq -> "eq"
| Or -> "or_"
| And -> "and_"
| Mod -> "mod_"
Expand Down Expand Up @@ -302,7 +304,7 @@ let binop_c_syntax prec v =
| Min, _ -> ("fminf(", ",", ")")
| Mod, _ -> ("(", " %", ")")
| Cmplt, _ -> ("(", " <", ")")
| Cmpne, _ -> ("(", " !=", ")")
| Cmpeq, _ -> ("(", " ==", ")")
(* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
(* | Shl, _ -> ("((", ") * exp2(", "))") *)
(* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
Expand All @@ -311,7 +313,7 @@ let binop_c_syntax prec v =
| And, _ -> ("(", " &&", ")")

let is_assign_op = function
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne -> false
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq -> false
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true

let assign_op_cd_syntax ~initialize_neutral = function
Expand All @@ -336,7 +338,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
| Min -> "=^^"
| Or -> "=||"
| And -> "=&&"
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne ->
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq ->
invalid_arg "Ops.assign_op_cd_syntax: not an assignment op"

(** Note: currently we do not support unary prefix symbols. *)
Expand All @@ -355,6 +357,7 @@ let unop_cd_syntax = function
| Recip_sqrt -> "recip_sqrt"
| Neg -> "neg"
| Tanh_approx -> "tanh"
| Not -> "not"

let unop_c_syntax prec op =
let fmax () =
Expand Down Expand Up @@ -400,6 +403,7 @@ let unop_c_syntax prec op =
| Tanh_approx, Byte_prec _ ->
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
| Tanh_approx, _ -> ("tanhf(", ")")
| Not, _ -> ("(", " == 0.0 ? 1.0 : 0.0)")

(** In the %cd syntax, we use uncurried notation for ternary ops. *)
let ternop_cd_syntax = function Where -> "where" | FMA -> "fma"
Expand Down
2 changes: 1 addition & 1 deletion lib/ppx_cd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ let translate (expr : expression) : result =
@@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected a binary operator, one of: %s"
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -/> (Arg2), \
< (Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
< (Cmplt), = (Cmpeq), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
in
let ternary_op tern_op =
loc
Expand Down
5 changes: 3 additions & 2 deletions lib/ppx_shared.ml
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ let binary_ops =
("relu_gate", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate]));
("<", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
("lt", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
("<>", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
("ne", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne]));
("=", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
("eq", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
("||", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
("or_", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or]));
("&&", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.And]));
Expand Down Expand Up @@ -182,6 +182,7 @@ let unary_ops =
("recip_sqrt", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Recip_sqrt]));
("neg", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Neg]));
("tanh", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Tanh_approx]));
("not", fun loc -> ([%expr Shape.Pointwise_un], [%expr Arrayjit.Ops.Not]));
]

(** Ternary primitive ops. *)
Expand Down
4 changes: 2 additions & 2 deletions lib/syntax_extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ The binary primitive operations:
| `pow` | `**` | pointwise | `ToPowOf` | `=**`, `=:**` |
| `relu_gate` | `-?/` | pointwise | `Relu_gate` | `=?/`, `=:?/` |
| `lt` | `<` | pointwise | `Cmplt` | none |
| `ne` | `<>` | pointwise | `Cmpne` | none |
| `eq` | `<>` | pointwise | `Cmpeq` | none |
| `or_` | `\|\|` | pointwise | `Or` | `=\|\|`, `=:\|\|` |
| `and_` | `&&` | pointwise | `And` | `=&&`, `=:&&` |
| `mod_` | `%` | pointwise | `Mod` | `=%`, `=:%` |
| `mod_` | `%` | pointwise | `Mod` | none |
| `max` | `@^` | pointwise | `Max` | `=@^`, `=:@^` |
| `min` | `^^` | pointwise | `Min` | `=^^`, `=:^^` |

Expand Down

0 comments on commit f279856

Please sign in to comment.