Skip to content

Commit

Permalink
Reformat to a smaller margin
Browse files Browse the repository at this point in the history
Motivation: 100 chars seems sufficiently small for vertical use of wide-ratio monitors.
  • Loading branch information
lukstafi committed Jul 3, 2024
1 parent 8833052 commit 51c903b
Show file tree
Hide file tree
Showing 45 changed files with 1,970 additions and 1,088 deletions.
2 changes: 1 addition & 1 deletion .ocamlformat
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
profile = default
margin = 110
margin = 100
parse-docstrings = true
wrap-comments = true
59 changes: 41 additions & 18 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ let get_name_exn asgns =
let punct_or_sp = Str.regexp "[-@*/:.;, ]" in
let punct_and_sp = Str.regexp {|[-@*/:.;,]\( |$\)|} in
let rec loop = function
| Block_comment (s, _) -> Str.global_replace punct_and_sp "" s |> Str.global_replace punct_or_sp "_"
| Block_comment (s, _) ->
Str.global_replace punct_and_sp "" s |> Str.global_replace punct_or_sp "_"
| Seq (t1, t2) ->
let n1 = loop t1 and n2 = loop t2 in
let prefix = String.common_prefix2_length n1 n2 in
Expand All @@ -63,15 +64,19 @@ let get_name_exn asgns =
let recurrent_nodes asgns =
let open Utils.Set_O in
let empty = Set.empty (module Tn) in
let single = function Node tn -> Set.singleton (module Tn) tn | Merge_buffer _ -> Set.empty (module Tn) in
let single = function
| Node tn -> Set.singleton (module Tn) tn
| Merge_buffer _ -> Set.empty (module Tn)
in
let maybe have lhs = if have then Set.singleton (module Tn) lhs else empty in
let rec loop = function
| Noop -> empty
| Seq (t1, t2) -> loop t1 + (loop t2 - assigned t1)
| Block_comment (_, t) -> loop t
| Accum_binop { initialize_neutral; lhs; rhs1; rhs2; _ } ->
maybe (not initialize_neutral) lhs + single rhs1 + single rhs2
| Accum_unop { initialize_neutral; lhs; rhs; _ } -> maybe (not initialize_neutral) lhs + single rhs
| Accum_unop { initialize_neutral; lhs; rhs; _ } ->
maybe (not initialize_neutral) lhs + single rhs
| Fetch _ -> empty
and assigned = function
| Noop -> Set.empty (module Tn)
Expand Down Expand Up @@ -101,7 +106,8 @@ let%debug_sexp to_low_level code =
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
match buffer with
| Node tn -> Low_level.Get (tn, idcs)
| Merge_buffer tn -> Low_level.Get_global (Ops.Merge_buffer { source_node_id = tn.Tn.id }, Some idcs)
| Merge_buffer tn ->
Low_level.Get_global (Ops.Merge_buffer { source_node_id = tn.Tn.id }, Some idcs)
in
let set tn idcs llv =
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
Expand All @@ -121,13 +127,16 @@ let%debug_sexp to_low_level code =
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
let projections = Lazy.force projections in
let lhs_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_lhs
in
let rhs1_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(0)
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(0)
in
let rhs2_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(1)
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(1)
in
let is_assignment = initialize_neutral && Indexing.is_bijective projections in
let basecase rev_iters =
Expand Down Expand Up @@ -170,10 +179,12 @@ let%debug_sexp to_low_level code =
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
let projections = Lazy.force projections in
let lhs_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_lhs
in
let rhs_idx =
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(0)
derive_index ~product_syms:projections.product_iterators
~projection:projections.project_rhs.(0)
in
let is_assignment = initialize_neutral && Indexing.is_bijective projections in
let basecase rev_iters =
Expand Down Expand Up @@ -242,9 +253,13 @@ let get_ident_within_code ?no_dots c =
let nograd_idents = Hashtbl.create (module String) in
let grad_idents = Hashtbl.create (module String) in
let visit tn =
let idents = if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents in
let idents =
if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents
in
Option.iter (Tn.ident_label tn)
~f:(Hashtbl.update idents ~f:(fun old -> Set.add (Option.value ~default:Utils.no_ints old) tn.id))
~f:
(Hashtbl.update idents ~f:(fun old ->
Set.add (Option.value ~default:Utils.no_ints old) tn.id))
in
let tn = function Node tn -> tn | Merge_buffer tn -> tn in
let rec loop (c : t) =
Expand All @@ -264,7 +279,9 @@ let get_ident_within_code ?no_dots c =
let repeating_nograd_idents =
Hashtbl.filter nograd_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1)
in
let repeating_grad_idents = Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) in
let repeating_grad_idents =
Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1)
in
Tn.styled_ident ~repeating_nograd_idents ~repeating_grad_idents ident_style

let fprint_hum ?name ?static_indices () ppf c =
Expand All @@ -278,7 +295,8 @@ let fprint_hum ?name ?static_indices () ppf c =
| Imported (Merge_buffer { source_node_id }) ->
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
fprintf ppf "merge %s" (ident tn)
| Imported (Ops.External_unsafe { ptr; prec; dims = _ }) -> fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec
| Imported (Ops.External_unsafe { ptr; prec; dims = _ }) ->
fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec
| Slice { batch_idx; sliced } ->
fprintf ppf "%s @@| %s" (ident sliced) (Indexing.symbol_ident batch_idx.static_symbol)
| Embed_symbol { static_symbol; static_range = _ } ->
Expand All @@ -295,24 +313,29 @@ let fprint_hum ?name ?static_indices () ppf c =
loop c
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
let proj_spec =
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec else "<not-in-yet>"
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
else "<not-in-yet>"
in
fprintf ppf "%s %s %s %s %s%s;@ " (ident lhs)
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
(buffer_ident rhs1) (Ops.binop_cd_syntax op) (buffer_ident rhs2)
(if (not (String.equal proj_spec ".")) || List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op then
" ~logic:\"" ^ proj_spec ^ "\""
(if
(not (String.equal proj_spec "."))
|| List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
then " ~logic:\"" ^ proj_spec ^ "\""
else "")
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
let proj_spec =
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec else "<not-in-yet>"
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
else "<not-in-yet>"
in
fprintf ppf "%s %s %s%s%s;@ " (ident lhs)
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
(if not @@ Ops.equal_unop op Ops.Identity then Ops.unop_cd_syntax op ^ " " else "")
(buffer_ident rhs)
(if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "")
| Fetch { array; fetch_op; dims = _ } -> fprintf ppf "%s := %a;@ " (ident array) out_fetch_op fetch_op
| Fetch { array; fetch_op; dims = _ } ->
fprintf ppf "%s := %a;@ " (ident array) out_fetch_op fetch_op
in
fprintf ppf "@,@[<v 2>";
Low_level.fprint_function_header ?name ?static_indices () ppf;
Expand Down
Loading

0 comments on commit 51c903b

Please sign in to comment.