diff --git a/.ocamlformat b/.ocamlformat index 193f136f..db49a99b 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,4 +1,4 @@ profile = default -margin = 110 +margin = 100 parse-docstrings = true wrap-comments = true \ No newline at end of file diff --git a/arrayjit/lib/assignments.ml b/arrayjit/lib/assignments.ml index 2827e15a..75081674 100644 --- a/arrayjit/lib/assignments.ml +++ b/arrayjit/lib/assignments.ml @@ -48,7 +48,8 @@ let get_name_exn asgns = let punct_or_sp = Str.regexp "[-@*/:.;, ]" in let punct_and_sp = Str.regexp {|[-@*/:.;,]\( |$\)|} in let rec loop = function - | Block_comment (s, _) -> Str.global_replace punct_and_sp "" s |> Str.global_replace punct_or_sp "_" + | Block_comment (s, _) -> + Str.global_replace punct_and_sp "" s |> Str.global_replace punct_or_sp "_" | Seq (t1, t2) -> let n1 = loop t1 and n2 = loop t2 in let prefix = String.common_prefix2_length n1 n2 in @@ -63,7 +64,10 @@ let get_name_exn asgns = let recurrent_nodes asgns = let open Utils.Set_O in let empty = Set.empty (module Tn) in - let single = function Node tn -> Set.singleton (module Tn) tn | Merge_buffer _ -> Set.empty (module Tn) in + let single = function + | Node tn -> Set.singleton (module Tn) tn + | Merge_buffer _ -> Set.empty (module Tn) + in let maybe have lhs = if have then Set.singleton (module Tn) lhs else empty in let rec loop = function | Noop -> empty @@ -71,7 +75,8 @@ let recurrent_nodes asgns = | Block_comment (_, t) -> loop t | Accum_binop { initialize_neutral; lhs; rhs1; rhs2; _ } -> maybe (not initialize_neutral) lhs + single rhs1 + single rhs2 - | Accum_unop { initialize_neutral; lhs; rhs; _ } -> maybe (not initialize_neutral) lhs + single rhs + | Accum_unop { initialize_neutral; lhs; rhs; _ } -> + maybe (not initialize_neutral) lhs + single rhs | Fetch _ -> empty and assigned = function | Noop -> Set.empty (module Tn) @@ -101,7 +106,8 @@ let%debug_sexp to_low_level code = assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)); match buffer with | Node tn -> Low_level.Get (tn, idcs) - | Merge_buffer tn -> Low_level.Get_global (Ops.Merge_buffer { source_node_id = tn.Tn.id }, Some idcs) + | Merge_buffer tn -> + Low_level.Get_global (Ops.Merge_buffer { source_node_id = tn.Tn.id }, Some idcs) in let set tn idcs llv = if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then @@ -121,13 +127,16 @@ let%debug_sexp to_low_level code = | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } -> let projections = Lazy.force projections in let lhs_idx = - derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs + derive_index ~product_syms:projections.product_iterators + ~projection:projections.project_lhs in let rhs1_idx = - derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(0) + derive_index ~product_syms:projections.product_iterators + ~projection:projections.project_rhs.(0) in let rhs2_idx = - derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(1) + derive_index ~product_syms:projections.product_iterators + ~projection:projections.project_rhs.(1) in let is_assignment = initialize_neutral && Indexing.is_bijective projections in let basecase rev_iters = @@ -170,10 +179,12 @@ let%debug_sexp to_low_level code = | Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } -> let projections = Lazy.force projections in let lhs_idx = - derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs + derive_index ~product_syms:projections.product_iterators + ~projection:projections.project_lhs in let rhs_idx = - derive_index ~product_syms:projections.product_iterators ~projection:projections.project_rhs.(0) + derive_index ~product_syms:projections.product_iterators + ~projection:projections.project_rhs.(0) in let is_assignment = initialize_neutral && Indexing.is_bijective projections in let basecase rev_iters = @@ -242,9 +253,13 @@ let get_ident_within_code ?no_dots c = let nograd_idents = Hashtbl.create (module String) in let grad_idents = Hashtbl.create (module String) in let visit tn = - let idents = if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents in + let idents = + if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents + in Option.iter (Tn.ident_label tn) - ~f:(Hashtbl.update idents ~f:(fun old -> Set.add (Option.value ~default:Utils.no_ints old) tn.id)) + ~f: + (Hashtbl.update idents ~f:(fun old -> + Set.add (Option.value ~default:Utils.no_ints old) tn.id)) in let tn = function Node tn -> tn | Merge_buffer tn -> tn in let rec loop (c : t) = @@ -264,7 +279,9 @@ let get_ident_within_code ?no_dots c = let repeating_nograd_idents = Hashtbl.filter nograd_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) in - let repeating_grad_idents = Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) in + let repeating_grad_idents = + Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) + in Tn.styled_ident ~repeating_nograd_idents ~repeating_grad_idents ident_style let fprint_hum ?name ?static_indices () ppf c = @@ -278,7 +295,8 @@ let fprint_hum ?name ?static_indices () ppf c = | Imported (Merge_buffer { source_node_id }) -> let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in fprintf ppf "merge %s" (ident tn) - | Imported (Ops.External_unsafe { ptr; prec; dims = _ }) -> fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec + | Imported (Ops.External_unsafe { ptr; prec; dims = _ }) -> + fprintf ppf "%s" @@ Ops.ptr_to_string ptr prec | Slice { batch_idx; sliced } -> fprintf ppf "%s @@| %s" (ident sliced) (Indexing.symbol_ident batch_idx.static_symbol) | Embed_symbol { static_symbol; static_range = _ } -> @@ -295,24 +313,29 @@ let fprint_hum ?name ?static_indices () ppf c = loop c | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } -> let proj_spec = - if Lazy.is_val projections then (Lazy.force projections).debug_info.spec else "" + if Lazy.is_val projections then (Lazy.force projections).debug_info.spec + else "" in fprintf ppf "%s %s %s %s %s%s;@ " (ident lhs) (Ops.assign_op_cd_syntax ~initialize_neutral accum) (buffer_ident rhs1) (Ops.binop_cd_syntax op) (buffer_ident rhs2) - (if (not (String.equal proj_spec ".")) || List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op then - " ~logic:\"" ^ proj_spec ^ "\"" + (if + (not (String.equal proj_spec ".")) + || List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op + then " ~logic:\"" ^ proj_spec ^ "\"" else "") | Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } -> let proj_spec = - if Lazy.is_val projections then (Lazy.force projections).debug_info.spec else "" + if Lazy.is_val projections then (Lazy.force projections).debug_info.spec + else "" in fprintf ppf "%s %s %s%s%s;@ " (ident lhs) (Ops.assign_op_cd_syntax ~initialize_neutral accum) (if not @@ Ops.equal_unop op Ops.Identity then Ops.unop_cd_syntax op ^ " " else "") (buffer_ident rhs) (if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "") - | Fetch { array; fetch_op; dims = _ } -> fprintf ppf "%s := %a;@ " (ident array) out_fetch_op fetch_op + | Fetch { array; fetch_op; dims = _ } -> + fprintf ppf "%s := %a;@ " (ident array) out_fetch_op fetch_op in fprintf ppf "@,@["; Low_level.fprint_function_header ?name ?static_indices () ppf; diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 528b81fe..1d064d2a 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -28,8 +28,8 @@ module type No_device_backend = sig val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.t -> code (** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a - device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether until - [link] is called, to benefit from more optimizations. *) + device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether + until [link] is called, to benefit from more optimizations. *) val compile_batch : ?shared:bool -> @@ -38,27 +38,32 @@ module type No_device_backend = sig Indexing.unit_bindings -> Assignments.t array -> code_batch - (** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the compile time - and debugging convenience by generating fewer files -- ideally does not affect execution, but there can - be backend-specific differences. Only array entries for which [occupancy] returns true are included. *) + (** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the + compile time and debugging convenience by generating fewer files -- ideally does not affect + execution, but there can be backend-specific differences. Only array entries for which + [occupancy] returns true are included. *) val link : merge_buffer:buffer_ptr option ref -> context -> code -> routine (** Returns the routine for the code's procedure, in a new context derived from the given context. *) val link_batch : merge_buffer:buffer_ptr option ref -> context -> code_batch -> context * routine option array - (** Returns the routines for the procedures included in the code batch. The returned context is downstream - of all the returned routines. *) + (** Returns the routines for the procedures included in the code batch. The returned context is + downstream of all the returned routines. *) val unsafe_cleanup : ?unsafe_shutdown:bool -> unit -> unit - (** Cleans up all work on a backend. If [~unsafe_shutdown:true], releases resources, potentially making the - backend unusable. *) + (** Cleans up all work on a backend. If [~unsafe_shutdown:true], releases resources, potentially + making the backend unusable. *) val to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit - val host_to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit - val buffer_to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit + val host_to_buffer : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit + + val buffer_to_host : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit + val get_buffer : Tnode.t -> context -> buffer_ptr option end @@ -69,18 +74,18 @@ module type Backend = sig (** Returns the routine for the code's procedure, in a new context derived from the given context. *) val link_batch : context -> code_batch -> context * routine option array - (** Returns the routines for the procedures included in the code batch. The returned context is downstream - of all the returned routines. *) + (** Returns the routines for the procedures included in the code batch. The returned context is + downstream of all the returned routines. *) val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool - (** If the array is both hosted and in-context, schedules a copy from host to context and returns true, - otherwise returns false. NOTE: when run for a device, it's the caller's responsibility to synchronize - the device before the host's data is overwritten. *) + (** If the array is both hosted and in-context, schedules a copy from host to context and returns + true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility + to synchronize the device before the host's data is overwritten. *) val to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool - (** If the array is both hosted and in-context, schedules a copy from context to host and returns true, - otherwise returns false. NOTE: when run for a device, it's the caller's responsibility to synchronize - the device before the host's data is read. *) + (** If the array is both hosted and in-context, schedules a copy from context to host and returns + true, otherwise returns false. NOTE: when run for a device, it's the caller's responsibility + to synchronize the device before the host's data is read. *) val device_to_device : ?rt:(module Minidebug_runtime.Debug_runtime) -> @@ -92,21 +97,24 @@ module type Backend = sig (** If the node is absent from the [src] context and either it is present in the [dst] context or [~into_merge_buffer] is different from [No]: raises an error. - If [~into_merge_buffer:No]: If the node is present in the [dst] context, schedules a copy of the tensor - node from the device of [src] to the device of [dst] and returns true, otherwise returns false. + If [~into_merge_buffer:No]: If the node is present in the [dst] context, schedules a copy of + the tensor node from the device of [src] to the device of [dst] and returns true, otherwise + returns false. If [~into_merge_buffer] is different from [No]: schedules the following task and returns true. The merge-buffer task sets on [dst] the merge buffer source to the given node. If - [~into_merge_buffer:Streaming], remembers the buffer pointer of the source node to use for streaming, - without blocking. If [~into_merge_buffer:Copy], copies from [src] to the merge buffer of [dst]'s device. + [~into_merge_buffer:Streaming], remembers the buffer pointer of the source node to use for + streaming, without blocking. If [~into_merge_buffer:Copy], copies from [src] to the merge + buffer of [dst]'s device. - If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge buffer code, - the [device_to_device] call should fail immediately if there's a mismatch with [~into_merge_buffer]. + If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge + buffer code, the [device_to_device] call should fail immediately if there's a mismatch with + [~into_merge_buffer]. - NOTE: it's the caller's responsibility to synchronize the [src] device, if needed, {i before} calling - [device_to_device], and if [~into_merge_buffer:Streaming], the [dst] device {i afterward}, before any - computations on the [src] device overwrite the node. *) + NOTE: it's the caller's responsibility to synchronize the [src] device, if needed, {i before} + calling [device_to_device], and if [~into_merge_buffer:Streaming], the [dst] device + {i afterward}, before any computations on the [src] device overwrite the node. *) type physical_device type device @@ -201,7 +209,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let spinup_device ~(ordinal : int) : device = Int.incr global_run_no; let init_pos = - Utils.Cons { hd = Tnode.{ description = "root of task queue"; work = (fun _rt () -> ()) }; tl = Empty } + Utils.Cons + { hd = Tnode.{ description = "root of task queue"; work = (fun _rt () -> ()) }; tl = Empty } in let state = { @@ -256,7 +265,11 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let%diagn_sexp make_work device task = let%diagn_rt_sexp work () = schedule_task device task in Tnode. - { description = "schedules {" ^ task.description ^ "} on device " ^ Int.to_string device.ordinal; work } + { + description = + "schedules {" ^ task.description ^ "} on device " ^ Int.to_string device.ordinal; + work; + } type context = { device : device; ctx : Backend.context; expected_merge_node : Tnode.t option } [@@deriving sexp_of] @@ -286,7 +299,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let task = Backend.link ~merge_buffer:device.merge_buffer_ptr ctx code in { task with - context = { ctx = task.context; device; expected_merge_node = Backend.expected_merge_node code }; + context = + { ctx = task.context; device; expected_merge_node = Backend.expected_merge_node code }; schedule = make_work device task.schedule; } @@ -304,7 +318,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let from_host ?rt (context : context) (tn : Tnode.t) = if Option.is_some rt then - raise @@ Utils.User_error "Multicore_backend.from_host: backend cannot be nested in another runtime"; + raise + @@ Utils.User_error "Multicore_backend.from_host: backend cannot be nested in another runtime"; Option.value ~default:false @@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr -> match tn.Tnode.array with @@ -328,7 +343,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct Tnode. { description = - "from_host " ^ Tnode.get_debug_name tn ^ " dst " ^ Int.to_string context.device.ordinal; + "from_host " ^ Tnode.get_debug_name tn ^ " dst " + ^ Int.to_string context.device.ordinal; work; }; true @@ -341,7 +357,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let to_host ?rt (context : context) (tn : Tnode.t) = if Option.is_some rt then - raise @@ Utils.User_error "Multicore_backend.to_host: backend cannot be nested in another runtime"; + raise + @@ Utils.User_error "Multicore_backend.to_host: backend cannot be nested in another runtime"; Option.value ~default:false @@ Option.map (Backend.get_buffer tn context.ctx) ~f:(fun c_arr -> match tn.Tnode.array with @@ -365,7 +382,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct Tnode. { description = - "from_host " ^ Tnode.get_debug_name tn ^ " dst " ^ Int.to_string context.device.ordinal; + "from_host " ^ Tnode.get_debug_name tn ^ " dst " + ^ Int.to_string context.device.ordinal; work; }; true @@ -379,7 +397,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct let device_to_device ?rt tn ~into_merge_buffer ~dst ~src = if Option.is_some rt then raise - @@ Utils.User_error "Multicore_backend.device_to_device: backend cannot be nested in another runtime"; + @@ Utils.User_error + "Multicore_backend.device_to_device: backend cannot be nested in another runtime"; let dev = dst.device in if (not (equal_merge_buffer_use into_merge_buffer No)) @@ -402,10 +421,14 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct @@ Option.value_exn ~here:[%here] ~message:(Tnode.get_debug_name tn) @@ Lazy.force tn.array in - let allocated_capacity = Option.value ~default:0 @@ Option.map dev.allocated_buffer ~f:snd in + let allocated_capacity = + Option.value ~default:0 @@ Option.map dev.allocated_buffer ~f:snd + in if allocated_capacity < size_in_bytes then dev.allocated_buffer <- - Some (Backend.alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes (), size_in_bytes); + Some + ( Backend.alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes (), + size_in_bytes ); dev.merge_buffer_ptr := Option.map ~f:fst dev.allocated_buffer; Backend.to_buffer ~rt tn ~dst:(Option.value_exn ~here:[%here] !(dev.merge_buffer_ptr)) @@ -415,8 +438,8 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct Tnode. { description = - "device_to_device " ^ Tnode.get_debug_name tn ^ " dst " ^ Int.to_string dev.ordinal ^ " src " - ^ Int.to_string src.device.ordinal; + "device_to_device " ^ Tnode.get_debug_name tn ^ " dst " ^ Int.to_string dev.ordinal + ^ " src " ^ Int.to_string src.device.ordinal; work; } in @@ -466,8 +489,8 @@ let lower_assignments ?name bindings asgns = let ll_source = Utils.get_debug_formatter ~fname:(name ^ ".ll") in let cd_source = Utils.get_debug_formatter ~fname:(name ^ ".cd") in ( name, - Assignments.lower_proc ~unoptim_ll_source ~ll_source ~cd_source ~name (Indexing.bound_symbols bindings) - asgns ) + Assignments.lower_proc ~unoptim_ll_source ~ll_source ~cd_source ~name + (Indexing.bound_symbols bindings) asgns ) let lower_batch_assignments ?names ?occupancy bindings asgns_l = let names = @@ -485,7 +508,9 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l = let asgns = asgns_l.(src_n) in if occupancy ~name ~src_n then ( Some name, - Some (Assignments.lower_proc ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) ) + Some + (Assignments.lower_proc ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) + ) else (None, None)) module type Simple_backend = sig @@ -530,15 +555,22 @@ module type Simple_backend = sig val to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit - val host_to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit - val buffer_to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit + val host_to_buffer : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit + + val buffer_to_host : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit end module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend = struct include Backend type code = - | Postponed of { lowered : Low_level.optimized; bindings : Indexing.unit_bindings; name : string } + | Postponed of { + lowered : Low_level.optimized; + bindings : Indexing.unit_bindings; + name : string; + } | Compiled of Backend.procedure [@@deriving sexp_of] @@ -565,7 +597,8 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend = let expected_merge_nodes : code_batch -> _ = function | Postponed { lowereds; _ } -> - Array.map lowereds ~f:(fun lowered -> Option.(join @@ map lowered ~f:(fun optim -> optim.merge_node))) + Array.map lowereds ~f:(fun lowered -> + Option.(join @@ map lowered ~f:(fun optim -> optim.merge_node))) | Compiled (_, procs) -> Array.map ~f:(fun proc -> Option.(join @@ map proc ~f:Backend.expected_merge_node)) procs @@ -583,7 +616,9 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend = let context, bindings, schedule, name = match code with | Postponed { lowered; bindings; name } -> - let proc = Backend.compile ~name ~opt_ctx_arrays:(Some (ctx_arrays old_context)) bindings lowered in + let proc = + Backend.compile ~name ~opt_ctx_arrays:(Some (ctx_arrays old_context)) bindings lowered + in link_compiled ~merge_buffer old_context proc | Compiled code -> link_compiled ~merge_buffer old_context code in @@ -593,7 +628,9 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend = let _opt_ctx_arrays, procs = match code_batch with | Postponed { lowereds; bindings; names } -> - Backend.compile_batch ~names ~opt_ctx_arrays:(Some (ctx_arrays old_context)) bindings lowereds + Backend.compile_batch ~names + ~opt_ctx_arrays:(Some (ctx_arrays old_context)) + bindings lowereds | Compiled procs -> procs in Array.fold_map procs ~init:old_context ~f:(fun context -> function @@ -605,7 +642,9 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend = let to_buffer ?rt tn ~dst ~src = Backend.to_buffer ?rt tn ~dst ~src let host_to_buffer = Backend.host_to_buffer let buffer_to_host = Backend.buffer_to_host - let get_buffer tn context = Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr + + let get_buffer tn context = + Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr end module C_device : No_device_backend = Simple_no_device_backend (( diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index d2f75e8a..0f1ffdc6 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -16,7 +16,8 @@ let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backe type mem_properties = | Local_only (** The array is only needed for a local computation, is allocated on the stack. *) - | From_context (** The array has a copy allocated per-cpu-device, may or may not exist on the host. *) + | From_context + (** The array has a copy allocated per-cpu-device, may or may not exist on the host. *) | Constant_from_host (** The array is read directly from the host. *) [@@deriving sexp, equal, compare, variants] @@ -72,8 +73,8 @@ type tn_info = { mutable ptr : string; (** Pointer to the first value of the associated array. - if [mem = Constant_from_host], the pointer to the first element of the hosted [Ndarray], - - if [mem = From_context], either a pointer to [Ndarray] from [context.arrays] when [~shared:false], - or the function parameter when [~shared:true], + - if [mem = From_context], either a pointer to [Ndarray] from [context.arrays] when + [~shared:false], or the function parameter when [~shared:true], - if [mem = Local_only], the address of the on-the-stack array. *) mem : mem_properties; dims : int array; @@ -92,7 +93,11 @@ type info_nodes = { } [@@deriving sexp_of] -type param_source = Log_file_name | Merge_buffer | Param_ptr of Tn.t | Static_idx of Indexing.static_symbol +type param_source = + | Log_file_name + | Merge_buffer + | Param_ptr of Tn.t + | Static_idx of Indexing.static_symbol [@@deriving sexp_of] (* open Ctypes *) @@ -117,13 +122,16 @@ type ctx_nodes = Ctx_arrays of ctx_arrays ref | Param_ptrs of (string * param_so (* https://github.com/yallop/ocaml-ctypes/blob/master/src/ctypes-foreign/dl.mli https://github.com/ahrefs/ocannl/blob/1eb5209772b759f00a0cb8a39e51c4ddae78aee6/lib/exec_as_OCaml.ml *) -let pp_zero_out ppf node = Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " node.ptr node.size_in_bytes +let pp_zero_out ppf node = + Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " node.ptr node.size_in_bytes let get_c_ptr prec nd = let f arr = Ops.ptr_to_string (Ctypes.bigarray_start Ctypes_static.Genarray arr) prec in Ndarray.(map { f } nd) -let is_builtin_op = function Ops.Add | Sub | Mul | Div -> true | ToPowOf | Relu_gate | Arg2 | Arg1 -> false +let is_builtin_op = function + | Ops.Add | Sub | Mul | Div -> true + | ToPowOf | Relu_gate | Arg2 | Arg1 -> false let node_debug_name node = (* FIXME: node.ptr is not the mem address? *) @@ -160,8 +168,8 @@ let array_offset_to_string (idcs, dims) = Stdlib.Format.pp_print_flush ppf (); Buffer.contents b -(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx + - (offset * dim)) *) +(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim + -> idx + (offset * dim)) *) let%debug_sexp prepare_node ~(traced_store : Low_level.traced_store) info ctx_nodes tn = Hash_set.add info.used_tensors tn; @@ -200,7 +208,8 @@ let%debug_sexp prepare_node ~(traced_store : Low_level.traced_store) info ctx_no | From_context, Param_ptrs ptrs -> ptrs := (name, Param_ptr tn) :: !ptrs; ident - | Constant_from_host, _ -> get_c_ptr prec @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array + | Constant_from_host, _ -> + get_c_ptr prec @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array | Local_only, _ -> ident in let backend_info = sexp_of_mem_properties mem in @@ -228,12 +237,13 @@ let compile_main ~traced_store info ppf llc : unit = match c with | Low_level.Noop -> () | Seq (c1, c2) -> - (* Note: no separator. Filter out some entries known to not generate code to avoid whitespace. *) + (* Note: no separator. Filter out some entries known to not generate code to avoid + whitespace. *) fprintf ppf "@[%a@]" (pp_print_list pp_ll) (List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true)) | For_loop { index = i; from_; to_; body; trace_it = _ } -> - fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ %a@;<1 -2>}@]@," pp_index i from_ pp_index i - to_ pp_index i pp_ll body + fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ %a@;<1 -2>}@]@," pp_index i from_ + pp_index i to_ pp_index i pp_ll body | Zero_out tn -> let node = Hashtbl.find_exn info.nodes tn in let traced = Low_level.(get_node traced_store tn) in @@ -265,7 +275,8 @@ let compile_main ~traced_store info ppf llc : unit = fprintf ppf {|@[<7>fprintf(log_file,@ @["%s[%%u] = %%f = %s\n",@]@ %a,@ new_set_v%a);@]@ fflush(log_file);@ |} ident v_code pp_array_offset offset pp_args v_idcs; - fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@;<1 -2>}@]@ " ident pp_array_offset (idcs, node.dims)) + fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@;<1 -2>}@]@ " ident pp_array_offset + (idcs, node.dims)) else (* No idea why adding any cut hint at the end of the assign line breaks formatting! *) fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " ident pp_array_offset (idcs, node.dims) loop_f llv; @@ -288,8 +299,8 @@ let compile_main ~traced_store info ppf llc : unit = match vcomp with | Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } -> let num_typ = Ops.cuda_typ_of_prec prec in - (* Arrays are initialized to 0 by default. However, there is typically an explicit initialization for - virtual nodes. *) + (* Arrays are initialized to 0 by default. However, there is typically an explicit + initialization for virtual nodes. *) fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i; pp_ll ppf body; pp_print_space ppf (); @@ -312,7 +323,8 @@ let compile_main ~traced_store info ppf llc : unit = fprintf ppf "v%d" id.scope_id | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) -> let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in - fprintf ppf "@[<2>((%s*)merge_buffer)[%a@;<0 -2>]@]" (Ops.cuda_typ_of_prec prec) pp_array_offset + fprintf ppf "@[<2>((%s*)merge_buffer)[%a@;<0 -2>]@]" (Ops.cuda_typ_of_prec prec) + pp_array_offset (idcs, Lazy.force tn.dims) | Get_global _ -> failwith "Cc_backend: Get_global / FFI NOT IMPLEMENTED YET" | Get (tn, idcs) -> @@ -322,7 +334,8 @@ let compile_main ~traced_store info ppf llc : unit = fprintf ppf "@[<2>%s[%a@;<0 -2>]@]" ident pp_array_offset (idcs, node.dims) | Constant c -> fprintf ppf "(%f)" c | Embed_index idx -> - if not @@ List.exists ~f:(String.equal num_typ) [ "int"; "size_t" ] then fprintf ppf "(%s)" num_typ; + if not @@ List.exists ~f:(String.equal num_typ) [ "int"; "size_t" ] then + fprintf ppf "(%s)" num_typ; pp_index_axis ppf idx | Binop (Arg1, v1, _v2) -> loop ppf v1 | Binop (Arg2, _v1, v2) -> loop ppf v2 @@ -338,7 +351,8 @@ let compile_main ~traced_store info ppf llc : unit = let loop = debug_float prec in match value with | Local_scope { id; _ } -> - (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *) + (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug + logs. *) loop @@ Get_local id | Get_local id -> let get_typ = Ops.cuda_typ_of_prec id.tn.prec in @@ -350,7 +364,9 @@ let compile_main ~traced_store info ppf llc : unit = | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) -> let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in let node = get_node tn in - let v = sprintf "@[<2>merge_buffer[%s@;<0 -2>]@]" (array_offset_to_string (idcs, node.dims)) in + let v = + sprintf "@[<2>merge_buffer[%s@;<0 -2>]@]" (array_offset_to_string (idcs, node.dims)) + in ("merge_buffer[%u]{=%f}", [ `Accessor (idcs, node.dims); `Value v ]) | Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET" | Get (tn, idcs) -> @@ -404,14 +420,17 @@ let%track_sexp compile_proc ~name info ppf idx_params Low_level.{ traced_store; [%log "array-used:", (tn : Tn.t), Tn.label tn, (node.mem : mem_properties)]; match node.mem with | Local_only -> None - | From_context -> Some (Ops.cuda_typ_of_prec node.prec ^ " *" ^ info.get_ident tn, Param_ptr tn) + | From_context -> + Some (Ops.cuda_typ_of_prec node.prec ^ " *" ^ info.get_ident tn, Param_ptr tn) | Constant_from_host -> None) in let idx_params = - List.map idx_params ~f:(fun s -> ("int " ^ Indexing.symbol_ident s.Indexing.static_symbol, Static_idx s)) + List.map idx_params ~f:(fun s -> + ("int " ^ Indexing.symbol_ident s.Indexing.static_symbol, Static_idx s)) in let log_file = - if Utils.settings.debug_log_from_routines then [ ("const char* log_file_name", Log_file_name) ] else [] + if Utils.settings.debug_log_from_routines then [ ("const char* log_file_name", Log_file_name) ] + else [] in let merge_param = Option.( @@ -423,13 +442,15 @@ let%track_sexp compile_proc ~name info ppf idx_params Low_level.{ traced_store; fprintf ppf "@[@[void %s(@,@[%a@]@;<0 -4>)@] {@ " name (pp_print_list ~pp_sep:pp_comma pp_print_string) @@ List.map ~f:fst params; - if Utils.settings.debug_log_from_routines then fprintf ppf {|FILE* log_file = fopen(log_file_name, "w");@ |}; + if Utils.settings.debug_log_from_routines then + fprintf ppf {|FILE* log_file = fopen(log_file_name, "w");@ |}; fprintf ppf "/* Local declarations and initialization. */@ "; List.iter arrays ~f:(fun tn -> let node = Hashtbl.find_exn info.nodes tn in match node.mem with | Local_only -> - fprintf ppf "%s %s[%d]%s;@ " (Ops.cuda_typ_of_prec node.prec) (info.get_ident tn) node.size_in_elems + fprintf ppf "%s %s[%d]%s;@ " (Ops.cuda_typ_of_prec node.prec) (info.get_ident tn) + node.size_in_elems (if (Hashtbl.find_exn traced_store tn).zero_initialized then " = {0}" else "") | From_context when node.zero_initialized -> pp_zero_out ppf node | _ -> ()); @@ -482,7 +503,9 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ } in let ctx_nodes = - match opt_ctx_arrays with Some ctx_arrays -> Ctx_arrays (ref ctx_arrays) | None -> Param_ptrs (ref []) + match opt_ctx_arrays with + | Some ctx_arrays -> Ctx_arrays (ref ctx_arrays) + | None -> Param_ptrs (ref []) in prepare_nodes info ctx_nodes lowered; let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in @@ -503,10 +526,13 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ () done; let result = Dl.dlopen ~filename:libname ~flags:[ RTLD_NOW; RTLD_DEEPBIND ] in - let opt_ctx_arrays = match ctx_nodes with Ctx_arrays ctx_arrays -> Some !ctx_arrays | _ -> None in + let opt_ctx_arrays = + match ctx_nodes with Ctx_arrays ctx_arrays -> Some !ctx_arrays | _ -> None + in { info; result; params; bindings; name; opt_ctx_arrays; expected_merge_node = lowered.merge_node } -let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings (lowereds : Low_level.optimized option array) = +let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings + (lowereds : Low_level.optimized option array) = let get_ident = Low_level.get_ident_within_code ~no_dots:true @@ Array.filter_map lowereds ~f:(Option.map ~f:(fun Low_level.{ llc; _ } -> llc)) @@ -529,7 +555,9 @@ let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings (lowereds : Low_lev in let ctx_nodes = Array.map lowereds ~f:(fun _ -> - match opt_ctx_arrays with Some _ -> Ctx_arrays global_ctx_arrays | None -> Param_ptrs (ref [])) + match opt_ctx_arrays with + | Some _ -> Ctx_arrays global_ctx_arrays + | None -> Param_ptrs (ref [])) in Array.iteri ctx_nodes ~f:(fun i ctx_nodes -> Option.iter infos.(i) ~f:(fun info -> @@ -611,18 +639,20 @@ let%track_sexp link_compiled ~merge_buffer (old_context : context) (code : proce | Bind _, [] -> invalid_arg "Cc_backend.link: too few static index params" | Bind (_, bs), Static_idx _ :: ps -> Param_idx (ref 0, link bs ps Ctypes.(int @-> cs)) | Empty, Static_idx _ :: _ -> invalid_arg "Cc_backend.link: too many static index params" - | bs, Log_file_name :: ps -> Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) + | bs, Log_file_name :: ps -> + Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) | bs, Merge_buffer :: ps -> Param_2f (Ndarray.get_voidptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs)) | bs, Param_ptr tn :: ps -> let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in - (* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr = Ndarray.(map { f } - nd) in *) + (* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr = + Ndarray.(map { f } nd) in *) let c_ptr = Ndarray.get_voidptr nd in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) in - (* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed in the - wrong order but that's OK because [link] only uses them to check the number of indices. *) + (* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed + in the wrong order but that's OK because [link] only uses them to check the number of + indices. *) let params = List.rev_map code.params ~f:(fun (_, p) -> p) in link code.bindings params Ctypes.(void @-> returning void)] in diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 77072bed..1e7635cc 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -9,7 +9,8 @@ module Debug_runtime = Utils.Debug_runtime type mem_properties = | Local_only (** The array is only needed for a single computation and is allocated locally (or spilled). *) - | Global (** Could not perform optimizations: the array is computed directly in the global memory. *) + | Global + (** Could not perform optimizations: the array is computed directly in the global memory. *) [@@deriving sexp, equal, compare, variants] type tn_info = { @@ -22,8 +23,8 @@ type tn_info = { size_in_bytes : int; size_in_elems : int; num_typ : string; - (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double] - (precision [Double]). *) + (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), + [double] (precision [Double]). *) zero_initialized : bool; } [@@deriving sexp_of] @@ -36,7 +37,11 @@ type physical_device = { } [@@deriving sexp_of] -and device = { physical : physical_device; stream : (Cudajit.stream[@sexp.opaque]); subordinal : int } +and device = { + physical : physical_device; + stream : (Cudajit.stream[@sexp.opaque]); + subordinal : int; +} and context = { label : string; @@ -125,7 +130,9 @@ let get_ctx_device { device; _ } = device let get_physical_device { physical; _ } = physical let to_ordinal { ordinal; _ } = ordinal let to_subordinal { subordinal; _ } = subordinal -let get_name device = Int.to_string (to_ordinal device.physical) ^ "_" ^ Int.to_string (to_subordinal device) + +let get_name device = + Int.to_string (to_ordinal device.physical) ^ "_" ^ Int.to_string (to_subordinal device) let set_ctx ctx = let cur_ctx = Cudajit.ctx_get_current () in @@ -158,14 +165,16 @@ let finalize ctx = if phys_equal f_ctx ctx then f ~output)) ~finally:(fun () -> ctx.device.physical.postprocess_queue <- - List.filter ctx.device.physical.postprocess_queue ~f:(fun (f_ctx, _) -> phys_equal f_ctx ctx)); + List.filter ctx.device.physical.postprocess_queue ~f:(fun (f_ctx, _) -> + phys_equal f_ctx ctx)); Map.iter ctx.global_arrays ~f:(fun ptr -> Cudajit.mem_free ptr)) let unsafe_cleanup ?unsafe_shutdown:_ () = let len = Core.Weak.length !devices in (* TODO: maybe better to do device_primary_ctx_reset if [unsafe_shutdown=false]. *) for i = 0 to len - 1 do - Option.iter (Core.Weak.get !devices i) ~f:(fun device -> Cudajit.device_primary_ctx_release device.dev) + Option.iter (Core.Weak.get !devices i) ~f:(fun device -> + Cudajit.device_primary_ctx_release device.dev) done; Core.Weak.fill !devices 0 len None @@ -180,7 +189,8 @@ let await device = let is_idle _device = failwith "NOT IMPLEMENTED YET" -let%diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context) tn = +let%diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) (ctx : context) tn + = match (Map.find ctx.all_arrays tn, Map.find ctx.global_arrays tn) with | Some { tn = { Tn.array = (lazy (Some hosted)); _ }; _ }, Some dst -> set_ctx ctx.ctx; @@ -188,7 +198,8 @@ let%diagn_sexp from_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) let f src = Cudajit.memcpy_H_to_D ~dst ~src () in Ndarray.map { f } hosted; (if Utils.settings.with_debug_level > 0 then - let module Debug_runtime = (val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime))) + let module Debug_runtime = + (val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime))) in [%log "copied", Tn.label tn, Tn.name tn, "from host"]); true @@ -202,7 +213,8 @@ let%diagn_sexp to_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) ( let f dst = Cudajit.memcpy_D_to_H ~dst ~src () in Ndarray.map { f } hosted; if Utils.settings.with_debug_level > 0 then ( - let module Debug_runtime = (val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime))) + let module Debug_runtime = + (val Option.value_or_thunk rt ~default:(fun () -> (module Debug_runtime))) in [%log "copied", Tn.label tn, Tn.name tn, "to host"]; if Utils.settings.with_debug_level > 1 then @@ -212,8 +224,8 @@ let%diagn_sexp to_host ?(rt : (module Minidebug_runtime.Debug_runtime) option) ( true | _ -> false -let%diagn_sexp device_to_device ?(rt : (module Minidebug_runtime.Debug_runtime) option) tn ~into_merge_buffer - ~dst ~src = +let%diagn_sexp device_to_device ?(rt : (module Minidebug_runtime.Debug_runtime) option) tn + ~into_merge_buffer ~dst ~src = Option.value ~default:false @@ Option.map (Map.find src.global_arrays tn) ~f:(fun s_arr -> Option.value ~default:false @@ -269,7 +281,10 @@ let array_offset_to_string (idcs, dims) = Buffer.contents b let get_run_ptr array = - match (array.global, array.local) with _, Some lv -> lv | Some rv, _ -> rv | None, None -> assert false + match (array.global, array.local) with + | _, Some lv -> lv + | Some rv, _ -> rv + | None, None -> assert false let get_run_ptr_debug array = match (array.global, array.local) with @@ -277,8 +292,8 @@ let get_run_ptr_debug array = | Some rv, _ -> "global_" ^ rv | None, None -> assert false -(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim -> idx + - (offset * dim)) *) +(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim + -> idx + (offset * dim)) *) let%debug_sexp prepare_node traced_store info tn = Hash_set.add info.used_tensors tn; @@ -325,15 +340,16 @@ let compile_main traced_store info ppf llc : unit = match c with | Low_level.Noop -> () | Seq (c1, c2) -> - (* Note: no separator. Filter out some entries known to not generate code to avoid whitespace. *) + (* Note: no separator. Filter out some entries known to not generate code to avoid + whitespace. *) fprintf ppf "@[%a@]" (pp_print_list pp_ll) (List.filter [ c1; c2 ] ~f:(function | Noop -> false | Zero_out ptr -> not Low_level.(get_node traced_store ptr).zero_initialized | _ -> true)) | For_loop { index = i; from_; to_; body; trace_it = _ } -> - fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ %a@]@ }@," pp_index i from_ pp_index i to_ - pp_index i pp_ll body + fprintf ppf "@[<2>for (int@ %a = %d;@ %a <= %d;@ ++%a) {@ %a@]@ }@," pp_index i from_ + pp_index i to_ pp_index i pp_ll body | Zero_out tn -> if Hash_set.mem visited tn then pp_ll ppf @@ -364,25 +380,28 @@ let compile_main traced_store info ppf llc : unit = let run_ptr_debug = get_run_ptr_debug node in let run_ptr = get_run_ptr node in let offset = (idcs, node.dims) in - let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" in + let debug_line = + "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\\n" + in fprintf ppf - "@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", log_id);@ \ - printf(@[\"%%d: %s[%%u] = %%f = %s\\n\"@], log_id,@ %a,@ %s[%a]%a);@ @]}" - debug_line run_ptr_debug v_code pp_array_offset offset run_ptr pp_array_offset offset pp_args - v_idcs; - fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@ " (get_run_ptr node) pp_array_offset (idcs, node.dims)) + "@ @[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(\"%%d: %s\", \ + log_id);@ printf(@[\"%%d: %s[%%u] = %%f = %s\\n\"@], log_id,@ %a,@ %s[%a]%a);@ @]}" + debug_line run_ptr_debug v_code pp_array_offset offset run_ptr pp_array_offset offset + pp_args v_idcs; + fprintf ppf "@[<2>%s[@,%a] =@ new_set_v;@]@ " (get_run_ptr node) pp_array_offset + (idcs, node.dims)) else (* No idea why adding any cut hint at the end of the assign line breaks formatting! *) - fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr node) pp_array_offset (idcs, node.dims) loop_f - llv; + fprintf ppf "@[<2>%s[@,%a] =@ %a;@]@ " (get_run_ptr node) pp_array_offset + (idcs, node.dims) loop_f llv; for _ = 1 to num_closing_braces do fprintf ppf "@]@ }@," done | Comment message -> if Utils.settings.debug_log_from_routines then fprintf ppf - "@[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(@[\"%%d: COMMENT: %s\\n\", \ - log_id@]);@ @]}" + "@[<2>if @[<2>(threadIdx.x == 0 && blockIdx.x == 0@]) {@ printf(@[\"%%d: COMMENT: \ + %s\\n\", log_id@]);@ @]}" (String.substr_replace_all ~pattern:"%" ~with_:"%%" message) else fprintf ppf "/* %s */@ " message | Staged_compilation callback -> callback () @@ -397,8 +416,8 @@ let compile_main traced_store info ppf llc : unit = match vcomp with | Local_scope { id = { scope_id = i; tn = { prec; _ } }; body; orig_indices = _ } -> let num_typ = Ops.cuda_typ_of_prec prec in - (* Arrays are initialized to 0 by default. However, there is typically an explicit initialization for - virtual nodes. *) + (* Arrays are initialized to 0 by default. However, there is typically an explicit + initialization for virtual nodes. *) fprintf ppf "@[<2>{@ %s v%d = 0;@ " num_typ i; pp_ll ppf body; pp_print_space ppf (); @@ -428,7 +447,8 @@ let compile_main traced_store info ppf llc : unit = fprintf ppf "@[<2>%s[%a@]]" (get_run_ptr node) pp_array_offset (idcs, node.dims) | Constant c -> fprintf ppf "(%f)" c | Embed_index idx -> - if not @@ List.exists ~f:(String.equal num_typ) [ "int"; "size_t" ] then fprintf ppf "(%s)" num_typ; + if not @@ List.exists ~f:(String.equal num_typ) [ "int"; "size_t" ] then + fprintf ppf "(%s)" num_typ; pp_index_axis ppf idx | Binop (Arg1, v1, _v2) -> loop ppf v1 | Binop (Arg2, _v1, v2) -> loop ppf v2 @@ -443,7 +463,8 @@ let compile_main traced_store info ppf llc : unit = let loop = debug_float ~num_typ prec in match value with | Local_scope { id; _ } -> - (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *) + (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug + logs. *) loop @@ Get_local id | Get_local id -> let get_typ = Ops.cuda_typ_of_prec id.tn.prec in @@ -454,12 +475,17 @@ let compile_main traced_store info ppf llc : unit = (v ^ "{=%f}", [ `Value v ]) | Get_global (Merge_buffer { source_node_id }, Some idcs) -> let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in - let v = sprintf "@[<2>merge_buffer[%s@]]" (array_offset_to_string (idcs, Lazy.force tn.dims)) in - ("merge " ^ Tn.get_debug_name tn ^ "[%u]{=%f}", [ `Accessor (idcs, Lazy.force tn.dims); `Value v ]) + let v = + sprintf "@[<2>merge_buffer[%s@]]" (array_offset_to_string (idcs, Lazy.force tn.dims)) + in + ( "merge " ^ Tn.get_debug_name tn ^ "[%u]{=%f}", + [ `Accessor (idcs, Lazy.force tn.dims); `Value v ] ) | Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET" | Get (tn, idcs) -> let node = get_node tn in - let v = sprintf "@[<2>%s[%s@]]" (get_run_ptr node) (array_offset_to_string (idcs, node.dims)) in + let v = + sprintf "@[<2>%s[%s@]]" (get_run_ptr node) (array_offset_to_string (idcs, node.dims)) + in (get_run_ptr_debug node ^ "[%u]{=%f}", [ `Accessor (idcs, node.dims); `Value v ]) | Constant c -> (Float.to_string c, []) | Embed_index (Fixed_idx i) -> (Int.to_string i, []) @@ -521,9 +547,12 @@ type code_batch = { } [@@deriving sexp_of] -let%track_sexp compile_proc ~name ~get_ident ppf idx_params Low_level.{ traced_store; llc; merge_node } = +let%track_sexp compile_proc ~name ~get_ident ppf idx_params + Low_level.{ traced_store; llc; merge_node } = let open Stdlib.Format in - let info = { nodes = Hashtbl.create (module Tn); used_tensors = Hash_set.create (module Tn); get_ident } in + let info = + { nodes = Hashtbl.create (module Tn); used_tensors = Hash_set.create (module Tn); get_ident } + in prepare_nodes traced_store info llc; let arrays = Hash_set.to_list info.used_tensors in let params = @@ -536,15 +565,19 @@ let%track_sexp compile_proc ~name ~get_ident ppf idx_params Low_level.{ traced_s | Global -> Option.map node.global ~f:(fun n -> node.num_typ ^ " *" ^ n)) in let idx_params = - List.map idx_params ~f:(fun { Indexing.static_symbol; _ } -> "int " ^ Indexing.symbol_ident static_symbol) + List.map idx_params ~f:(fun { Indexing.static_symbol; _ } -> + "int " ^ Indexing.symbol_ident static_symbol) in let merge_buffer_param = - Option.to_list merge_node |> List.map ~f:(fun tn -> Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer") + Option.to_list merge_node + |> List.map ~f:(fun tn -> Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer") in let log_id = if Utils.settings.debug_log_from_routines then [ "int log_id" ] else [] in - fprintf ppf "extern \"C\" __global__ void %s(%a) {@," name (pp_print_list ~pp_sep:pp_comma pp_print_string) + fprintf ppf "extern \"C\" __global__ void %s(%a) {@," name + (pp_print_list ~pp_sep:pp_comma pp_print_string) @@ log_id @ merge_buffer_param @ idx_params @ params; - fprintf ppf "/* FIXME: single-threaded for now. */@,if (threadIdx.x != 0 || blockIdx.x != 0) { return; }@ "; + fprintf ppf + "/* FIXME: single-threaded for now. */@,if (threadIdx.x != 0 || blockIdx.x != 0) { return; }@ "; (* TODO: The following link seems to claim it's better to expand into loops. https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *) fprintf ppf "/* Thread-local declarations and initialization. */@,"; @@ -591,7 +624,8 @@ let%diagn_sexp cuda_to_ptx ~name cu_src = type buffer_ptr = Cudajit.deviceptr -let sexp_of_buffer_ptr (Cudajit.Deviceptr ptr : buffer_ptr) = Sexp.Atom (Unsigned.UInt64.to_hexstring ptr) +let sexp_of_buffer_ptr (Cudajit.Deviceptr ptr : buffer_ptr) = + Sexp.Atom (Unsigned.UInt64.to_hexstring ptr) let alloc_buffer ?old_buffer ~size_in_bytes () = match old_buffer with @@ -620,7 +654,9 @@ let%diagn_sexp link_proc (old_context : context) ~name info ptx = let compile ?name bindings ({ Low_level.llc; _ } as lowered) = let get_ident = Low_level.get_ident_within_code ~no_dots:true [| llc |] in - let name : string = Option.value_or_thunk name ~default:(fun () -> Low_level.extract_block_name [ llc ]) in + let name : string = + Option.value_or_thunk name ~default:(fun () -> Low_level.extract_block_name [ llc ]) + in let idx_params = Indexing.bound_symbols bindings in let b = Buffer.create 4096 in let ppf = Stdlib.Format.formatter_of_buffer b in @@ -644,7 +680,8 @@ let compile_batch ~names bindings lowereds = in let name : string = String.( - strip ~drop:(equal_char '_') @@ common_prefix (Array.to_list names |> List.concat_map ~f:Option.to_list)) + strip ~drop:(equal_char '_') + @@ common_prefix (Array.to_list names |> List.concat_map ~f:Option.to_list)) in let ptx = cuda_to_ptx ~name @@ Buffer.contents b in { ptx; infos; bindings; names } @@ -674,18 +711,20 @@ let link old_context (code : code) = raise @@ Utils.User_error [%string - "Exec_as_cuda: static index %{Indexing.symbol_ident static_symbol} is negative: %{!i#Int}"]; + "Exec_as_cuda: static index %{Indexing.symbol_ident static_symbol} is negative: \ + %{!i#Int}"]; Option.iter static_range ~f:(fun upto -> if !i >= upto then raise @@ Utils.User_error [%string - "Exec_as_cuda: static index %{Indexing.symbol_ident static_symbol} is too big: \ - %{upto#Int}"]); + "Exec_as_cuda: static index %{Indexing.symbol_ident static_symbol} is too \ + big: %{upto#Int}"]); Cu.Int !i) in let args = - (* TODO: should we prohibit or warn about Local_only tensors that are in old_context.global_arrays? *) + (* TODO: should we prohibit or warn about Local_only tensors that are in + old_context.global_arrays? *) let arrays = Hash_set.to_list code.info.used_tensors in List.filter_map arrays ~f:(fun tn -> let node = Hashtbl.find_exn code.info.nodes tn in @@ -698,9 +737,11 @@ let link old_context (code : code) = Map.iteri global_arrays ~f:(fun ~key ~data:ptr -> if Hash_set.mem code.info.used_tensors key then let node = Map.find_exn all_arrays key in - if node.zero_initialized then Cu.memset_d8 ptr Unsigned.UChar.zero ~length:node.size_in_bytes); + if node.zero_initialized then + Cu.memset_d8 ptr Unsigned.UChar.zero ~length:node.size_in_bytes); [%log "launching the kernel"]; - (* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE 4096; *) + (* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE + 4096; *) Cu.launch_kernel func ~grid_dim_x:1 ~block_dim_x:1 ~shared_mem_bytes:0 Cu.no_stream @@ log_arg @ idx_args @ args; [%log "kernel launched"]; diff --git a/arrayjit/lib/cuda_backend.missing.ml b/arrayjit/lib/cuda_backend.missing.ml index 9c55e96a..65ce7dd0 100644 --- a/arrayjit/lib/cuda_backend.missing.ml +++ b/arrayjit/lib/cuda_backend.missing.ml @@ -16,7 +16,9 @@ let compile_batch ~names:_ (bindings : Indexing.unit_bindings) optimized : code_ let link (Unimplemented_ctx : context) (code : code) = let lowered_bindings = List.map ~f:(fun s -> (s, ref 0)) @@ Indexing.bound_symbols code in - let task = Tnode.{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) } in + let task = + Tnode.{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) } + in ((Unimplemented_ctx : context), lowered_bindings, task) let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) = @@ -26,7 +28,9 @@ let link_batch (Unimplemented_ctx : context) (code_batch : code_batch) = in let task = Array.map code_batch ~f:(fun _ -> - Some Tnode.{ description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) }) + Some + Tnode. + { description = "CUDA missing: install cudajit"; work = (fun _debug_runtime () -> ()) }) in ((Unimplemented_ctx : context), lowered_bindings, task) diff --git a/arrayjit/lib/cuda_backend.mli b/arrayjit/lib/cuda_backend.mli index 634bfb0f..bb6d8ae5 100644 --- a/arrayjit/lib/cuda_backend.mli +++ b/arrayjit/lib/cuda_backend.mli @@ -9,10 +9,16 @@ val sexp_of_context : context -> Sexplib.Sexp.t val compile : ?name:string -> Indexing.unit_bindings -> Low_level.optimized -> code val compile_batch : - names:string option array -> Indexing.unit_bindings -> Low_level.optimized option array -> code_batch + names:string option array -> + Indexing.unit_bindings -> + Low_level.optimized option array -> + code_batch val link : context -> code -> context * Indexing.lowered_bindings * Tnode.task -val link_batch : context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array + +val link_batch : + context -> code_batch -> context * Indexing.lowered_bindings * Tnode.task option array + val unsafe_cleanup : ?unsafe_shutdown:bool -> unit -> unit val from_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> context -> Tnode.t -> bool @@ -35,8 +41,12 @@ type buffer_ptr [@@deriving sexp_of] val to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Tnode.t -> dst:buffer_ptr -> src:context -> unit -val host_to_buffer : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit -val buffer_to_host : ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit +val host_to_buffer : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> dst:buffer_ptr -> unit + +val buffer_to_host : + ?rt:(module Minidebug_runtime.Debug_runtime) -> Ndarray.t -> src:buffer_ptr -> unit + val get_buffer : Tnode.t -> context -> buffer_ptr option type physical_device diff --git a/arrayjit/lib/gcc_backend.gccjit.ml b/arrayjit/lib/gcc_backend.gccjit.ml index 752e98c9..eac92682 100644 --- a/arrayjit/lib/gcc_backend.gccjit.ml +++ b/arrayjit/lib/gcc_backend.gccjit.ml @@ -16,7 +16,8 @@ type config = [ `Physical_devices_only | `For_parallel_copying | `Most_parallel_ type mem_properties = | Local_only (** The array is only needed for a local computation, is allocated on the stack. *) - | From_context (** The array has a copy allocated per-cpu-device, may or may not exist on the host. *) + | From_context + (** The array has a copy allocated per-cpu-device, may or may not exist on the host. *) | Constant_from_host (** The array is read directly from the host. *) [@@deriving sexp, equal, compare, variants] @@ -37,7 +38,11 @@ let buffer_ptr ctx_array = Ndarray.get_voidptr ctx_array let buffer_ptr ctx_array = ctx_array -type context = { label : string; arrays : ctx_arrays; result : (Gccjit.result option[@sexp.opaque]) } +type context = { + label : string; + arrays : ctx_arrays; + result : (Gccjit.result option[@sexp.opaque]); +} [@@deriving sexp_of] let ctx_arrays context = context.arrays @@ -79,15 +84,15 @@ type tn_info = { mutable ptr : (Gccjit.rvalue[@sexp.opaque]) Lazy.t; (** Pointer to the first value of the associated array. - if [mem = Constant_from_host], the pointer to the first element of the hosted [Ndarray], - - if [mem = From_context], either a pointer to [Ndarray] from [context.arrays] when [~shared:false], - or the function parameter when [~shared:true], + - if [mem = From_context], either a pointer to [Ndarray] from [context.arrays] when + [~shared:false], or the function parameter when [~shared:true], - if [mem = Local_only], the address of the on-the-stack array. *) mem : mem_properties; dims : int array; size_in_bytes : int; num_typ : (Gccjit.type_[@sexp.opaque]); - (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), [double] - (precision [Double]). *) + (** The type of the stored values: [short] (precision [Half]), [float] (precision [Single]), + [double] (precision [Double]). *) prec : Ops.prec; zero_initialized : bool; } @@ -104,7 +109,11 @@ type info_nodes = { } [@@deriving sexp_of] -type param_source = Log_file_name | Merge_buffer | Param_ptr of Tn.t | Static_idx of Indexing.static_symbol +type param_source = + | Log_file_name + | Merge_buffer + | Param_ptr of Tn.t + | Static_idx of Indexing.static_symbol [@@deriving sexp_of] type procedure = { @@ -160,7 +169,8 @@ let zero_out ctx block node = let get_c_ptr ctx num_typ ba = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) @@ Ctypes.bigarray_start Ctypes_static.Genarray ba) -let prepare_node ~debug_log_zero_out ~get_ident ctx nodes traced_store ctx_nodes initializations (tn : Tn.t) = +let prepare_node ~debug_log_zero_out ~get_ident ctx nodes traced_store ctx_nodes initializations + (tn : Tn.t) = let open Gccjit in Hashtbl.update nodes tn ~f:(function | Some old -> old @@ -244,7 +254,9 @@ let prec_to_kind prec = | Single_prec _ -> Type.Float | Double_prec _ -> Type.Double -let is_builtin_op = function Ops.Add | Sub | Mul | Div -> true | ToPowOf | Relu_gate | Arg2 | Arg1 -> false +let is_builtin_op = function + | Ops.Add | Sub | Mul | Div -> true + | ToPowOf | Relu_gate | Arg2 | Arg1 -> false let builtin_op = function | Ops.Add -> Gccjit.Plus @@ -271,9 +283,14 @@ let debug_log_zero_out ctx log_functions get_ident block node = Block.eval block @@ RValue.call ctx pf @@ lf :: RValue.string_literal ctx - [%string {|memset_zero(%{node_debug_name get_ident node}) where before first element = %g + [%string + {|memset_zero(%{node_debug_name get_ident node}) where before first element = %g |}] - :: [ to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) @@ RValue.zero ctx c_index ]; + :: [ + to_d @@ RValue.lvalue + @@ LValue.access_array (Lazy.force node.ptr) + @@ RValue.zero ctx c_index; + ]; Block.eval block @@ RValue.call ctx ff [ lf ] | _ -> () @@ -288,8 +305,8 @@ let debug_log_index ctx log_functions = Block.eval block @@ RValue.call ctx ff [ lf ] | _ -> fun _block _i _index -> () -let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; _ } func initial_block - (body : Low_level.t) = +let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; _ } func + initial_block (body : Low_level.t) = let open Gccjit in let c_int = Type.get ctx Type.Int in let c_index = c_int in @@ -300,8 +317,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; | Indexing.Fixed_idx i -> RValue.int ctx c_index i | Iterator s -> Map.find_exn env s) with e -> - Stdlib.Format.eprintf "exec_as_gccjit: missing index from@ %a@ among environment keys:@ %a\n%!" - Sexp.pp_hum + Stdlib.Format.eprintf + "exec_as_gccjit: missing index from@ %a@ among environment keys:@ %a\n%!" Sexp.pp_hum ([%sexp_of: Indexing.axis_index array] indices) Sexp.pp_hum ([%sexp_of: Indexing.symbol list] @@ Map.keys env); @@ -311,7 +328,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let c_double = Type.get ctx Type.Double in let cast_bool num_typ v = RValue.cast ctx (RValue.cast ctx v c_int) num_typ in let to_d v = RValue.cast ctx v c_double in - (* Source of unique identifiers for local scope ids, which can be non-unique globally due to inlining. *) + (* Source of unique identifiers for local scope ids, which can be non-unique globally due to + inlining. *) let uid = ref 0 in let get_uid () = let id = @@ -336,7 +354,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let base = RValue.cast ctx v1 c_float in let expon = RValue.cast ctx v2 c_float in RValue.cast ctx (RValue.call ctx (Function.builtin ctx "powf") [ base; expon ]) num_typ - | ToPowOf, Byte_prec _ -> raise @@ Utils.User_error "gccjit_backend: Byte_prec does not support ToPowOf" + | ToPowOf, Byte_prec _ -> + raise @@ Utils.User_error "gccjit_backend: Byte_prec does not support ToPowOf" | Relu_gate, _ -> let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1 in RValue.binary_op ctx Mult num_typ cmp @@ v2 @@ -356,7 +375,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let loop = debug_float ~env prec in match value with | Low_level.Local_scope { id; _ } -> - (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug logs. *) + (* Not printing the inlined definition: (1) code complexity; (2) don't overload the debug + logs. *) loop @@ Get_local id | Get_local id -> let lvalue = Map.find_exn !debug_locals id in @@ -378,7 +398,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let offset = jit_array_offset ctx ~idcs ~dims:(Lazy.force tn.dims) in let v = to_d @@ RValue.lvalue @@ LValue.access_array ptr offset in ("merge " ^ get_ident tn ^ "[%d]{=%g}", [ offset; v ]) - | Get_global (C_function _, Some _) -> failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET" + | Get_global (C_function _, Some _) -> + failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET" | Get (tn, idcs) -> let node = get_node tn in let idcs = lookup env idcs in @@ -406,7 +427,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let v_format, v_fillers = debug_float ~env node.prec v_code in let offset = jit_array_offset ctx ~idcs ~dims:node.dims in let debug_line = "# " ^ String.substr_replace_all debug ~pattern:"\n" ~with_:"$" ^ "\n" in - Block.eval !current_block @@ RValue.call ctx pf @@ [ lf; RValue.string_literal ctx debug_line ]; + Block.eval !current_block @@ RValue.call ctx pf + @@ [ lf; RValue.string_literal ctx debug_line ]; Block.eval !current_block @@ RValue.call ctx pf @@ lf :: RValue.string_literal ctx @@ -471,8 +493,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; (* Scope ids can be non-unique due to inlining. *) let v_name = Int.("v" ^ to_string i ^ "_" ^ get_uid ()) in let lvalue = Function.local func typ v_name in - (* Arrays are initialized to 0 by default. However, there is typically an explicit initialization for - virtual nodes. *) + (* Arrays are initialized to 0 by default. However, there is typically an explicit + initialization for virtual nodes. *) Block.assign !current_block lvalue @@ RValue.zero ctx typ; let old_locals = !locals in locals := Map.update !locals id ~f:(fun _ -> lvalue); @@ -509,7 +531,8 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; let local_typ = gcc_typ_of_prec tn.prec in let num_typ = Type.get ctx local_typ in if not @@ Ops.equal_prec prec tn.prec then RValue.cast ctx rvalue num_typ else rvalue - | Get_global (C_function _, Some _) -> failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET" + | Get_global (C_function _, Some _) -> + failwith "gccjit_backend: FFI with parameters NOT IMPLEMENTED YET" | Get (tn, idcs) -> Hash_set.add visited tn; let node = get_node tn in @@ -535,7 +558,9 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; | Unop (Identity, c) -> loop c | Unop (Relu, c) -> (* FIXME: don't recompute c *) - let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) @@ loop c in + let cmp = + cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) @@ loop c + in RValue.binary_op ctx Mult num_typ cmp @@ loop c | Constant v -> RValue.double ctx num_typ v and loop_for_loop ~toplevel ~env key ~from_ ~to_ body = @@ -599,7 +624,8 @@ let%track_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident let fkind = Function.Exported in let c_str = Type.(get ctx Const_char_ptr) in let log_file_name = - if Utils.settings.debug_log_from_routines then Some (Param.create ctx c_str "log_file_name", Log_file_name) + if Utils.settings.debug_log_from_routines then + Some (Param.create ctx c_str "log_file_name", Log_file_name) else None in let symbols = Indexing.bound_symbols bindings in @@ -619,7 +645,9 @@ let%track_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident ref (Option.to_list log_file_name @ merge_param @ static_indices) in let ctx_nodes : ctx_nodes = - match opt_ctx_arrays with None -> Param_ptrs params | Some ctx_arrays -> Ctx_arrays (ref ctx_arrays) + match opt_ctx_arrays with + | None -> Param_ptrs params + | Some ctx_arrays -> Ctx_arrays (ref ctx_arrays) in let initializations = ref [] in let nodes = Hashtbl.create (module Tn) in @@ -654,19 +682,25 @@ let%track_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident log_functions_ref := Some (log_file, fprintf, fflush)); let log_functions = Lazy.force log_functions in let debug_log_index = debug_log_index ctx log_functions in - Map.iteri env ~f:(fun ~key:sym ~data:idx -> debug_log_index init_block (Indexing.symbol_ident sym) idx); + Map.iteri env ~f:(fun ~key:sym ~data:idx -> + debug_log_index init_block (Indexing.symbol_ident sym) idx); (* Do initializations in the order they were scheduled. *) List.iter (List.rev !initializations) ~f:(fun init -> init init_block func); let main_block = Block.create ~name func in (* let merge_node = Option.map merge_node ~f:(fun _tn -> Function.param) in *) - let ctx_info : info_nodes = { ctx; traced_store; init_block; func; nodes; get_ident; merge_node = None } in + let ctx_info : info_nodes = + { ctx; traced_store; init_block; func; nodes; get_ident; merge_node = None } + in let after_proc = compile_main ~name ~log_functions ~env ctx_info func main_block proc in (match log_functions with | Some (lf, _, _) -> (* FIXME: should be Imported? *) let file_ptr = Type.(get ctx File_ptr) in let fclose = - Function.create ctx Imported Type.(get ctx Type.Void_ptr) "fclose" [ Param.create ctx file_ptr "f" ] + Function.create ctx Imported + Type.(get ctx Type.Void_ptr) + "fclose" + [ Param.create ctx file_ptr "f" ] in Block.eval after_proc @@ RValue.call ctx fclose [ lf ] | None -> ()); @@ -688,9 +722,11 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ let ctx = Context.create_child @@ Option.value_exn ~here:[%here] !root_ctx in Context.set_option ctx Context.Optimization_level (optimization_level ()); (* if Utils.settings.with_debug && Utils.settings.output_debug_files_in_run_directory then ( - Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx Context.Dump_everything - true); *) - let info, opt_ctx_arrays, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in + Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx + Context.Dump_everything true); *) + let info, opt_ctx_arrays, params = + compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered + in (if Utils.settings.output_debug_files_in_run_directory then let f_name = name ^ "-gccjit-debug.c" in Context.dump_to_file ctx ~update_locs:true f_name); @@ -717,8 +753,8 @@ let%track_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind let ctx = Context.create_child @@ Option.value_exn ~here:[%here] !root_ctx in Context.set_option ctx Context.Optimization_level (optimization_level ()); (* if Utils.settings.with_debug && Utils.settings.output_debug_files_in_run_directory then ( - Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx Context.Dump_everything - true); *) + Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx + Context.Dump_everything true); *) let opt_ctx_arrays, funcs = Array.fold_mapi lowereds ~init:opt_ctx_arrays ~f:(fun i opt_ctx_arrays lowered -> match (names.(i), lowered) with @@ -749,7 +785,8 @@ let%track_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind name; opt_ctx_arrays; params = List.map ~f:snd params; - expected_merge_node = Option.(join @@ map lowereds.(i) ~f:(fun optim -> optim.merge_node)); + expected_merge_node = + Option.(join @@ map lowereds.(i) ~f:(fun optim -> optim.merge_node)); })) ) let alloc_buffer ?old_buffer ~size_in_bytes () = @@ -794,19 +831,22 @@ let%track_sexp link_compiled ~merge_buffer (old_context : context) (code : proce | Empty, [] -> Indexing.Result (Gccjit.Result.code code.result name cs) | Bind _, [] -> invalid_arg "Gccjit_backend.link: too few static index params" | Bind (_, bs), Static_idx _ :: ps -> Param_idx (ref 0, link bs ps Ctypes.(int @-> cs)) - | Empty, Static_idx _ :: _ -> invalid_arg "Gccjit_backend.link: too many static index params" - | bs, Log_file_name :: ps -> Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) + | Empty, Static_idx _ :: _ -> + invalid_arg "Gccjit_backend.link: too many static index params" + | bs, Log_file_name :: ps -> + Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) | bs, Param_ptr tn :: ps -> let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in - (* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr = Ndarray.(map { f } - nd) in *) + (* let f ba = Ctypes.bigarray_start Ctypes_static.Genarray ba in let c_ptr = + Ndarray.(map { f } nd) in *) let c_ptr = Ndarray.get_voidptr nd in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) | bs, Merge_buffer :: ps -> Param_2f (Ndarray.get_voidptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs)) in - (* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed in the - wrong order but that's OK because [link] only uses them to check the number of indices. *) + (* Folding by [link] above reverses the input order. Important: [code.bindings] are traversed + in the wrong order but that's OK because [link] only uses them to check the number of + indices. *) link code.bindings (List.rev code.params) Ctypes.(void @-> returning void)] in let%diagn_rt_sexp work () : unit = diff --git a/arrayjit/lib/indexing.ml b/arrayjit/lib/indexing.ml index ec212a74..2bbf5020 100644 --- a/arrayjit/lib/indexing.ml +++ b/arrayjit/lib/indexing.ml @@ -85,12 +85,13 @@ let get_static_symbol ?static_range bindings = let s = { static_symbol = get_symbol (); static_range } in (s, Bind (s, bindings)) -(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. Outputs - ["-"] for empty dimensions. *) +(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. + Outputs ["-"] for empty dimensions. *) let dims_to_string ?(with_axis_numbers = false) dims = if Array.is_empty dims then "-" else if with_axis_numbers then - String.concat_array ~sep:" x " @@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s) + String.concat_array ~sep:" x " + @@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s) else String.concat_array ~sep:"x" @@ Array.map dims ~f:Int.to_string type axis_index = @@ -120,11 +121,11 @@ type projections = { rhs_dims : int array array; (** The dimensions of the RHS arrays, needed for deriving projections from other projections. *) product_iterators : symbol array; - (** The product space iterators (concatentation of the relevant batch, output, input axes) for iterating - over the [product_space] axes, where same axes are at same array indices. *) + (** The product space iterators (concatentation of the relevant batch, output, input axes) for + iterating over the [product_space] axes, where same axes are at same array indices. *) project_lhs : axis_index array; - (** A projection that takes an [product_space]-bound index and produces an index into the result of an - operation. *) + (** A projection that takes an [product_space]-bound index and produces an index into the + result of an operation. *) project_rhs : axis_index array array; (** [project_rhs.(i)] Produces an index into the [i+1]th argument of an operation. *) debug_info : (projections_debug[@sexp.ignore] [@compare.ignore] [@equal.ignore]); @@ -156,12 +157,15 @@ let identity_projections ~debug_info ~lhs_dims = product_iterators; project_lhs; project_rhs = [| project_lhs |]; - debug_info = { debug_info with trace = ("indentity_projections", unique_debug_id ()) :: debug_info.trace }; + debug_info = + { debug_info with trace = ("indentity_projections", unique_debug_id ()) :: debug_info.trace }; } let derive_index ~product_syms ~(projection : axis_index array) = let sym_to_i = - Array.mapi product_syms ~f:(fun i s -> (s, i)) |> Array.to_list |> Map.of_alist_exn (module Symbol) + Array.mapi product_syms ~f:(fun i s -> (s, i)) + |> Array.to_list + |> Map.of_alist_exn (module Symbol) in let positions = Array.map projection ~f:(function diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index af617bcd..3ec2c917 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -21,7 +21,8 @@ module Scope_id = struct end) end -type scope_id = Scope_id.t = { tn : Tn.t; scope_id : int } [@@deriving sexp_of, equal, hash, compare] +type scope_id = Scope_id.t = { tn : Tn.t; scope_id : int } +[@@deriving sexp_of, equal, hash, compare] (** *** Low-level representation. *) @@ -54,9 +55,13 @@ and float_t = | Embed_index of Indexing.axis_index [@@deriving sexp_of, equal, compare] -let binop ~op ~rhs1 ~rhs2 = match op with Ops.Arg1 -> rhs1 | Arg2 -> rhs2 | _ -> Binop (op, rhs1, rhs2) +let binop ~op ~rhs1 ~rhs2 = + match op with Ops.Arg1 -> rhs1 | Arg2 -> rhs2 | _ -> Binop (op, rhs1, rhs2) + let unop ~op ~rhs = match op with Ops.Identity -> rhs | _ -> Unop (op, rhs) -let rec flat_lines ts = List.concat_map ts ~f:(function Seq (t1, t2) -> flat_lines [ t1; t2 ] | t -> [ t ]) + +let rec flat_lines ts = + List.concat_map ts ~f:(function Seq (t1, t2) -> flat_lines [ t1; t2 ] | t -> [ t ]) let rec unflat_lines = function | [] -> Noop @@ -68,7 +73,8 @@ let comment_to_name = let nonliteral = Str.regexp {|[^a-zA-Z0-9_]|} in Str.global_replace nonliteral "_" -let extract_block_name llc = match flat_lines llc with Comment s :: _ -> comment_to_name s | _ -> "" +let extract_block_name llc = + match flat_lines llc with Comment s :: _ -> comment_to_name s | _ -> "" (** *** Optimization *** *) @@ -80,7 +86,9 @@ type virtualize_settings = { } let virtualize_settings = - let max_visits = Int.of_string @@ Utils.get_global_arg ~arg_name:"virtualize_max_visits" ~default:"1" in + let max_visits = + Int.of_string @@ Utils.get_global_arg ~arg_name:"virtualize_max_visits" ~default:"1" + in let max_tracing_dim = Int.of_string @@ Utils.get_global_arg ~arg_name:"virtualize_max_tracing_dim" ~default:"5" in @@ -94,28 +102,32 @@ let virtualize_settings = type visits = | Visits of int - | Recurrent (** A [Recurrent] visit is when there is an access prior to any assignment in an update. *) + | Recurrent + (** A [Recurrent] visit is when there is an access prior to any assignment in an update. *) [@@deriving sexp, equal, variants] type traced_array = { tn : Tn.t; mutable computations : (Indexing.axis_index array option * t) list; - (** The computations (of the tensor node) are retrieved for optimization just as they are populated, so - that the inlined code corresponds precisely to the changes to the arrays that would happen up till - that point. Within the code blocks paired with an index tuple, all assignments and accesses must - happen via the index tuple; if this is not the case for some assignment, the node cannot be virtual. - Currently, we only allow for-loop symbols in assignment indices of virtual nodes. *) + (** The computations (of the tensor node) are retrieved for optimization just as they are + populated, so that the inlined code corresponds precisely to the changes to the arrays + that would happen up till that point. Within the code blocks paired with an index tuple, + all assignments and accesses must happen via the index tuple; if this is not the case for + some assignment, the node cannot be virtual. Currently, we only allow for-loop symbols in + assignment indices of virtual nodes. *) assignments : int array Hash_set.t; accesses : (int array, visits) Hashtbl.t; - (** For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which is safe. *) + (** For dynamic indexes, we take a value of 0. This leads to an overestimate of visits, which + is safe. *) mutable zero_initialized : bool; mutable zeroed_out : bool; - mutable read_before_write : bool; (** The node is read before it is written (i.e. it is recurrent). *) + mutable read_before_write : bool; + (** The node is read before it is written (i.e. it is recurrent). *) mutable read_only : bool; mutable is_scalar_constexpr : bool; - (** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned before - accessed, is assigned at most once, and from an expression involving only constants or tensor nodes - that were at the time is_scalar_constexpr. *) + (** True only if the tensor node has all axes of dimension 1, is either zeroed-out or assigned + before accessed, is assigned at most once, and from an expression involving only constants + or tensor nodes that were at the time is_scalar_constexpr. *) } [@@deriving sexp_of] @@ -151,7 +163,11 @@ let partition_tf_with_comment cs ~f = let visit ~is_assigned old = if not is_assigned then Recurrent - else match old with None -> Visits 1 | Some (Visits i) -> Visits (i + 1) | Some Recurrent -> Recurrent + else + match old with + | None -> Visits 1 + | Some (Visits i) -> Visits (i + 1) + | Some Recurrent -> Recurrent let is_constexpr_comp traced_store llv = let rec loop llv = @@ -201,8 +217,10 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc = | Set { tn; idcs; llv; debug = _ } -> loop_float env llv; let traced : traced_array = get_node traced_store tn in - if Hash_set.is_empty traced.assignments && Hashtbl.is_empty traced.accesses && is_scalar_dims tn then - traced.is_scalar_constexpr <- is_constexpr_comp traced_store llv + if + Hash_set.is_empty traced.assignments + && Hashtbl.is_empty traced.accesses && is_scalar_dims tn + then traced.is_scalar_constexpr <- is_constexpr_comp traced_store llv (* Note: this prevents detection if the same constant is assigned inside a loop. *) else if not @@ Hash_set.is_empty traced.assignments then traced.is_scalar_constexpr <- false; Hash_set.add traced.assignments (lookup env idcs); @@ -245,9 +263,9 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc = then Tn.update_memory_mode tn Virtual 40; if Option.is_none tn.memory_mode && Hashtbl.exists traced.accesses ~f:is_too_many then Tn.update_memory_mode tn Never_virtual 1 - (* The tensor node is read-only/recurrent for this computation, but maybe computed by another one. - However, if the memory mode is unspecified, we assume this will be the first computation involving - the tensor node. *); + (* The tensor node is read-only/recurrent for this computation, but maybe computed by + another one. However, if the memory mode is unspecified, we assume this will be the first + computation involving the tensor node. *); if (not traced.zeroed_out) && Hash_set.is_empty traced.assignments then ( traced.read_only <- true; if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Constant) 37 @@ -260,7 +278,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc = let%diagn_sexp check_and_store_virtual traced static_indices top_llc = let exception Non_virtual of int in let static_indices = - Set.of_list (module Indexing.Symbol) @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices + Set.of_list (module Indexing.Symbol) + @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices in let at_idcs = ref None in let has_setter = ref false in @@ -268,7 +287,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc = let check_idcs indices = (match !at_idcs with | None -> at_idcs := Some indices - | Some at -> if not @@ [%equal: Indexing.axis_index array] at indices then raise @@ Non_virtual 4); + | Some at -> + if not @@ [%equal: Indexing.axis_index array] at indices then raise @@ Non_virtual 4); (* TODO(#133): For non-recursive accesses, non-linearity is not supported yet. *) let syms = Set.of_array (module Indexing.Symbol) @@ -276,7 +296,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc = ~f: Indexing.( function - | Fixed_idx _ -> None | Iterator s -> Option.some_if (not @@ Set.mem static_indices s) s) + | Fixed_idx _ -> None + | Iterator s -> Option.some_if (not @@ Set.mem static_indices s) s) in let num_syms = Array.count indices ~f:(function Iterator s -> not @@ Set.mem static_indices s | _ -> false) @@ -326,7 +347,10 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc = | Iterator s when not (Set.mem static_indices s) -> if not @@ Set.mem env_dom s then ( if Utils.settings.with_debug_level > 1 then - [%log "Inlining candidate has an escaping variable", (s : Indexing.symbol), (top_llc : t)]; + [%log + "Inlining candidate has an escaping variable", + (s : Indexing.symbol), + (top_llc : t)]; raise @@ Non_virtual 9) | _ -> ()) | Local_scope { body; _ } -> loop_proc ~env_dom body @@ -336,7 +360,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc = | Embed_index (Iterator s) -> if not @@ Set.mem env_dom s then ( if Utils.settings.with_debug_level > 1 && not (Set.mem static_indices s) then - [%log "Inlining candidate has an escaping variable", (s : Indexing.symbol), (top_llc : t)]; + [%log + "Inlining candidate has an escaping variable", (s : Indexing.symbol), (top_llc : t)]; raise @@ Non_virtual 10) | Binop (_, llv1, llv2) -> loop_float ~env_dom llv1; @@ -353,7 +378,8 @@ let%diagn_sexp check_and_store_virtual traced static_indices top_llc = let inline_computation ~id traced static_indices call_args = let exception Non_virtual of int in let static_indices = - Set.of_list (module Indexing.Symbol) @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices + Set.of_list (module Indexing.Symbol) + @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices in let make_subst i lhs_ind = let rhs_ind = call_args.(i) in @@ -372,7 +398,10 @@ let inline_computation ~id traced static_indices call_args = @@ Array.to_list @@ Array.filter_mapi def_args ~f:make_subst in - let subst env = function Indexing.Iterator s when Map.mem env s -> Map.find_exn env s | idx -> idx in + let subst env = function + | Indexing.Iterator s when Map.mem env s -> Map.find_exn env s + | idx -> idx + in let rec loop env llc : t option = match llc with | Noop -> None @@ -447,7 +476,8 @@ let virtual_llc traced_store reverse_node_map static_indices (llc : t) : t = | Some tn when not @@ Set.mem process_for tn -> let node : traced_array = get_node traced_store tn in let result = loop_proc ~process_for:(Set.add process_for tn) llc in - if not @@ Tn.known_non_virtual node.tn then check_and_store_virtual node static_indices result; + if not @@ Tn.known_non_virtual node.tn then + check_and_store_virtual node static_indices result; result | _ -> For_loop { for_config with body = loop body }) | Zero_out tn -> @@ -469,7 +499,8 @@ let virtual_llc traced_store reverse_node_map static_indices (llc : t) : t = match llv with | Constant _ -> llv | Get (tn, _) when Set.mem process_for tn -> - (* [Get_local] will replace this [Get] during [inline_computation] if [tn] remains virtual. *) + (* [Get_local] will replace this [Get] during [inline_computation] if [tn] remains + virtual. *) llv | Get (tn, indices) -> let traced = get_node traced_store tn in @@ -480,11 +511,13 @@ let virtual_llc traced_store reverse_node_map static_indices (llc : t) : t = @@ Option.map (inline_computation ~id traced static_indices indices) ~f:(fun body -> Local_scope { id; body; orig_indices = indices }) | Local_scope opts -> - Local_scope { opts with body = loop_proc ~process_for:(Set.add process_for opts.id.tn) opts.body } + Local_scope + { opts with body = loop_proc ~process_for:(Set.add process_for opts.id.tn) opts.body } | Get_local _ -> llv | Get_global _ -> llv | Embed_index _ -> llv - | Binop (op, llv1, llv2) -> Binop (op, loop_float ~process_for llv1, loop_float ~process_for llv2) + | Binop (op, llv1, llv2) -> + Binop (op, loop_float ~process_for llv1, loop_float ~process_for llv2) | Unop (op, llv) -> Unop (op, loop_float ~process_for llv) in loop_proc ~process_for:(Set.empty (module Tnode)) llc @@ -524,7 +557,8 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t = Tn.update_memory_mode tn Virtual 152; None) else ( - assert (Array.for_all idcs ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true)); + assert ( + Array.for_all idcs ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true)); Some (Set { tn; idcs; llv = loop_float ~balanced ~env_dom llv; debug = "" })) | Set_local (id, llv) -> assert (not @@ Tn.known_non_virtual id.tn); @@ -539,11 +573,14 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t = | Get (a, indices) -> (* DEBUG: *) Tn.update_memory_mode a Never_virtual 17; - assert (Array.for_all indices ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true)); + assert ( + Array.for_all indices ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true)); llv | Local_scope { id; body; orig_indices } -> assert ( - Array.for_all orig_indices ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true)); + Array.for_all orig_indices ~f:(function + | Indexing.Iterator s -> Set.mem env_dom s + | _ -> true)); if Tn.known_non_virtual id.tn then Get (id.tn, orig_indices) else let body = Option.value_exn ~here:[%here] @@ loop_proc ~balanced ~env_dom body in @@ -562,7 +599,8 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t = | Unop (op, llv) -> Unop (op, loop llv) in let static_indices = - Set.of_list (module Indexing.Symbol) @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices + Set.of_list (module Indexing.Symbol) + @@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices in Option.value_exn ~here:[%here] @@ loop_proc ~balanced:false ~env_dom:static_indices llc @@ -643,12 +681,15 @@ let simplify_llc llc = | Binop (Arg1, llv1, _) -> loop_float llv1 | Binop (Arg2, _, llv2) -> loop_float llv2 | Binop (op, Constant c1, Constant c2) -> Constant (Ops.interpret_binop op c1 c2) - | Binop (Add, llv, Constant 0.) | Binop (Sub, llv, Constant 0.) | Binop (Add, Constant 0., llv) -> + | Binop (Add, llv, Constant 0.) | Binop (Sub, llv, Constant 0.) | Binop (Add, Constant 0., llv) + -> loop_float llv | Binop (Sub, Constant 0., llv) -> loop_float @@ Binop (Mul, Constant (-1.), llv) - | Binop (Mul, llv, Constant 1.) | Binop (Div, llv, Constant 1.) | Binop (Mul, Constant 1., llv) -> + | Binop (Mul, llv, Constant 1.) | Binop (Div, llv, Constant 1.) | Binop (Mul, Constant 1., llv) + -> loop_float llv - | Binop (Mul, _, Constant 0.) | Binop (Div, Constant 0., _) | Binop (Mul, Constant 0., _) -> Constant 0. + | Binop (Mul, _, Constant 0.) | Binop (Div, Constant 0., _) | Binop (Mul, Constant 0., _) -> + Constant 0. | Binop (Add, (Binop (Add, Constant c2, llv) | Binop (Add, llv, Constant c2)), Constant c1) | Binop (Add, Constant c1, (Binop (Add, Constant c2, llv) | Binop (Add, llv, Constant c2))) -> loop_float @@ Binop (Add, Constant (c1 +. c2), llv) @@ -658,8 +699,10 @@ let simplify_llc llc = loop_float @@ Binop (Sub, Constant (c1 -. c2), llv) | Binop (Add, llv1, Binop (Sub, llv2, llv3)) | Binop (Add, Binop (Sub, llv2, llv3), llv1) -> loop_float @@ Binop (Sub, Binop (Add, llv1, llv2), llv3) - | Binop (Sub, llv1, Binop (Sub, llv2, llv3)) -> loop_float @@ Binop (Sub, Binop (Add, llv1, llv3), llv2) - | Binop (Sub, Binop (Sub, llv1, llv2), llv3) -> loop_float @@ Binop (Sub, llv1, Binop (Add, llv2, llv3)) + | Binop (Sub, llv1, Binop (Sub, llv2, llv3)) -> + loop_float @@ Binop (Sub, Binop (Add, llv1, llv3), llv2) + | Binop (Sub, Binop (Sub, llv1, llv2), llv3) -> + loop_float @@ Binop (Sub, llv1, Binop (Add, llv2, llv3)) | Binop (Mul, (Binop (Mul, Constant c2, llv) | Binop (Mul, llv, Constant c2)), Constant c1) | Binop (Mul, Constant c1, (Binop (Mul, Constant c2, llv) | Binop (Mul, llv, Constant c2))) -> loop_float @@ Binop (Mul, Constant (c1 *. c2), llv) @@ -670,8 +713,10 @@ let simplify_llc llc = loop_float @@ Binop (Div, Constant (c1 /. c2), llv) | Binop (Mul, llv1, Binop (Div, llv2, llv3)) | Binop (Mul, Binop (Div, llv2, llv3), llv1) -> loop_float @@ Binop (Div, Binop (Mul, llv1, llv2), llv3) - | Binop (Div, llv1, Binop (Div, llv2, llv3)) -> loop_float @@ Binop (Div, Binop (Mul, llv1, llv3), llv2) - | Binop (Div, Binop (Div, llv1, llv2), llv3) -> loop_float @@ Binop (Div, llv1, Binop (Mul, llv2, llv3)) + | Binop (Div, llv1, Binop (Div, llv2, llv3)) -> + loop_float @@ Binop (Div, Binop (Mul, llv1, llv3), llv2) + | Binop (Div, Binop (Div, llv1, llv2), llv3) -> + loop_float @@ Binop (Div, llv1, Binop (Mul, llv2, llv3)) | Binop (ToPowOf, llv1, llv2) -> ( let v1 : float_t = loop_float llv1 in let v2 : float_t = loop_float llv2 in @@ -679,7 +724,8 @@ let simplify_llc llc = if not !optimize_integer_pow then result else match v2 with - | Constant c when Float.is_integer c -> loop_float @@ unroll_pow ~base:v1 ~exp:(Float.to_int c) + | Constant c when Float.is_integer c -> + loop_float @@ unroll_pow ~base:v1 ~exp:(Float.to_int c) | _ -> result) | Binop (op, llv1, llv2) -> let v1 = loop_float llv1 in @@ -696,7 +742,9 @@ let simplify_llc llc = loop_proc llc type traced_store = (Tn.t, traced_array) Base.Hashtbl.t [@@deriving sexp_of] -type optimized = { traced_store : traced_store; llc : t; merge_node : Tn.t option } [@@deriving sexp_of] + +type optimized = { traced_store : traced_store; llc : t; merge_node : Tn.t option } +[@@deriving sexp_of] let%debug_sexp optimize_proc static_indices llc = let traced_store = Hashtbl.create (module Tnode) in @@ -704,7 +752,8 @@ let%debug_sexp optimize_proc static_indices llc = let reverse_node_map = Hashtbl.create (module Indexing.Symbol) in [%log "tracing"]; let merge_node_id = ref None in - visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits:virtualize_settings.max_visits llc; + visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits:virtualize_settings.max_visits + llc; [%log "optimizing"]; let llc = simplify_llc @@ -730,13 +779,16 @@ let pp_index ppf idx = | Indexing.Iterator sym -> pp_symbol ppf sym | Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i -let pp_indices ppf idcs = Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_index ppf @@ Array.to_list idcs +let pp_indices ppf idcs = + Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_index ppf @@ Array.to_list idcs let fprint_function_header ?name ?static_indices () ppf = let open Stdlib.Format in match (name, static_indices) with | Some name, Some static_indices -> - fprintf ppf "%s @[(%a)@]:@ " name (pp_print_list ~pp_sep:pp_comma pp_static_symbol) static_indices + fprintf ppf "%s @[(%a)@]:@ " name + (pp_print_list ~pp_sep:pp_comma pp_static_symbol) + static_indices | Some name, None -> fprintf ppf "%s:@ " name | _ -> () @@ -745,9 +797,13 @@ let get_ident_within_code ?no_dots llcs = let nograd_idents = Hashtbl.create (module String) in let grad_idents = Hashtbl.create (module String) in let visit tn = - let idents = if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents in + let idents = + if List.mem ~equal:String.equal tn.Tn.label "grad" then grad_idents else nograd_idents + in Option.iter (Tn.ident_label tn) - ~f:(Hashtbl.update idents ~f:(fun old -> Set.add (Option.value ~default:Utils.no_ints old) tn.id)) + ~f: + (Hashtbl.update idents ~f:(fun old -> + Set.add (Option.value ~default:Utils.no_ints old) tn.id)) in let rec loop (c : t) = match c with @@ -781,7 +837,9 @@ let get_ident_within_code ?no_dots llcs = let repeating_nograd_idents = Hashtbl.filter nograd_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) in - let repeating_grad_idents = Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) in + let repeating_grad_idents = + Hashtbl.filter grad_idents ~f:(fun ids -> List.length (Set.to_list ids) > 1) + in Tn.styled_ident ~repeating_nograd_idents ~repeating_grad_idents ident_style let fprint_hum ?name ?static_indices () ppf llc = @@ -801,8 +859,10 @@ let fprint_hum ?name ?static_indices () ppf llc = | Zero_out tn -> fprintf ppf "zero_out %a;" pp_ident tn | Set p -> p.debug <- - asprintf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs (pp_float p.tn.prec) p.llv; - fprintf ppf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs (pp_float p.tn.prec) p.llv + asprintf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs (pp_float p.tn.prec) + p.llv; + fprintf ppf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs (pp_float p.tn.prec) + p.llv | Comment message -> fprintf ppf "/* %s */" message | Staged_compilation _ -> fprintf ppf "STAGED_COMPILATION_CALLBACK()" | Set_local (id, llv) -> fprintf ppf "@[<2>%a :=@ %a;@]" pp_local id (pp_float id.tn.prec) llv @@ -850,7 +910,8 @@ let%debug_sexp optimize_proc ~unoptim_ll_source ~ll_source ~(name : string) let loop_over_dims dims ~body = let rec for_loop rev_idcs : _ -> t = function | [] -> body @@ Array.of_list_rev rev_idcs - | d :: product when not @@ Indexing.iterated d -> for_loop (Indexing.Fixed_idx 0 :: rev_idcs) product + | d :: product when not @@ Indexing.iterated d -> + for_loop (Indexing.Fixed_idx 0 :: rev_idcs) product | d :: product -> let index = Indexing.get_symbol () in For_loop diff --git a/arrayjit/lib/ndarray.ml b/arrayjit/lib/ndarray.ml index 12a2898e..b96ff38b 100644 --- a/arrayjit/lib/ndarray.ml +++ b/arrayjit/lib/ndarray.ml @@ -16,7 +16,9 @@ type half_nd = (float, Ops.float16_elt) bigarray type single_nd = (float, Ops.float32_elt) bigarray type double_nd = (float, Ops.float64_elt) bigarray -let sexp_of_address_of arr = Sexp.Atom ("@" ^ Nativeint.Hex.to_string @@ Ctypes_bigarray.unsafe_address arr) +let sexp_of_address_of arr = + Sexp.Atom ("@" ^ Nativeint.Hex.to_string @@ Ctypes_bigarray.unsafe_address arr) + let sexp_of_byte_nd (arr : byte_nd) = sexp_of_address_of arr let sexp_of_half_nd (arr : half_nd) = sexp_of_address_of arr let sexp_of_single_nd (arr : single_nd) = sexp_of_address_of arr @@ -25,7 +27,8 @@ let sexp_of_double_nd (arr : double_nd) = sexp_of_address_of arr type t = Byte_nd of byte_nd | Half_nd of half_nd | Single_nd of single_nd | Double_nd of double_nd [@@deriving sexp_of] -let as_array (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) (arr : (ocaml, elt_t) bigarray) = +let as_array (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) + (arr : (ocaml, elt_t) bigarray) = match prec with | Byte -> Byte_nd arr | Half -> Half_nd arr @@ -75,8 +78,8 @@ let init_bigarray_of_prec (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precisio let indices_to_offset ~dims ~idcs = Array.fold2_exn dims idcs ~init:0 ~f:(fun accu dim idx -> (accu * dim) + idx) -let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~dims (init_op : Ops.init_op) : - (ocaml, elt_t) bigarray = +let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~dims + (init_op : Ops.init_op) : (ocaml, elt_t) bigarray = Option.iter Utils.settings.fixed_state_for_init ~f:(fun seed -> Rand.Lib.init seed); let constant_fill_f f values strict = let len = Array.length values in @@ -86,28 +89,35 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di raise @@ Utils.User_error [%string - "Ndarray.create_bigarray: Constant_fill: invalid data size %{len#Int}, expected %{size#Int}"]; + "Ndarray.create_bigarray: Constant_fill: invalid data size %{len#Int}, expected \ + %{size#Int}"]; init_bigarray_of_prec prec dims ~f:(fun idcs -> f values.(indices_to_offset ~dims ~idcs))) - else init_bigarray_of_prec prec dims ~f:(fun idcs -> f values.(indices_to_offset ~dims ~idcs % len)) + else + init_bigarray_of_prec prec dims ~f:(fun idcs -> + f values.(indices_to_offset ~dims ~idcs % len)) in let constant_fill_float values strict = constant_fill_f Fn.id values strict in match (prec, init_op) with | Ops.Half, Constant_fill { values; strict } -> constant_fill_float values strict | Ops.Half, Range_over_offsets -> init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs) - | Ops.Half, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) + | Ops.Half, Standard_uniform -> + init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) | Ops.Single, Constant_fill { values; strict } -> constant_fill_float values strict | Ops.Single, Range_over_offsets -> init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs) - | Ops.Single, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) + | Ops.Single, Standard_uniform -> + init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) | Ops.Double, Constant_fill { values; strict } -> constant_fill_float values strict | Ops.Double, Range_over_offsets -> init_bigarray_of_prec prec dims ~f:(fun idcs -> Float.of_int @@ indices_to_offset ~dims ~idcs) - | Ops.Double, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) + | Ops.Double, Standard_uniform -> + init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.float_range 0.0 1.0) | Ops.Byte, Constant_fill { values; strict } -> constant_fill_f (Fn.compose Char.of_int_exn Int.of_float) values strict | Ops.Byte, Range_over_offsets -> - init_bigarray_of_prec prec dims ~f:(fun idcs -> Char.of_int_exn @@ indices_to_offset ~dims ~idcs) + init_bigarray_of_prec prec dims ~f:(fun idcs -> + Char.of_int_exn @@ indices_to_offset ~dims ~idcs) | Ops.Byte, Standard_uniform -> init_bigarray_of_prec prec dims ~f:(fun _ -> Rand.Lib.char ()) | _, File_mapped (filename, stored_prec) -> (* See: https://github.com/janestreet/torch/blob/master/src/torch/dataset_helper.ml#L3 *) @@ -115,8 +125,8 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di raise @@ Utils.User_error [%string - "Ndarray.create_bigarray: File_mapped: precision mismatch %{Ops.prec_string stored_prec} vs \ - %{Ops.precision_to_string prec}"]; + "Ndarray.create_bigarray: File_mapped: precision mismatch %{Ops.prec_string \ + stored_prec} vs %{Ops.precision_to_string prec}"]; let fd = Unix.openfile filename [ Unix.O_RDONLY ] 0o640 in let len = Unix.lseek fd 0 Unix.SEEK_END in ignore (Unix.lseek fd 0 Unix.SEEK_SET : int); @@ -126,10 +136,11 @@ let create_bigarray (type ocaml elt_t) (prec : (ocaml, elt_t) Ops.precision) ~di raise @@ Utils.User_error [%string - "Ndarray.create_bigarray: File_mapped: invalid file bytes %{len#Int}, expected %{size * \ - Ops.prec_in_bytes stored_prec#Int}"]); + "Ndarray.create_bigarray: File_mapped: invalid file bytes %{len#Int}, expected \ + %{size * Ops.prec_in_bytes stored_prec#Int}"]); let ba = - Unix.map_file fd (precision_to_bigarray_kind prec) Bigarray.c_layout false dims ~pos:(Int64.of_int 0) + Unix.map_file fd (precision_to_bigarray_kind prec) Bigarray.c_layout false dims + ~pos:(Int64.of_int 0) in Unix.close fd; ba @@ -138,7 +149,8 @@ let create_array prec ~dims init_op = let f prec = as_array prec @@ create_bigarray prec ~dims init_op in Ops.map_prec { f } prec -let empty_array prec = create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false }) +let empty_array prec = + create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false }) (** {2 *** Accessing ***} *) @@ -207,7 +219,8 @@ let set_bigarray arr ~f = let len = Array.length dims in cloop (Array.create ~len 0) f 0 -let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precision) (arr : (o, b) bigarray) = +let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precision) + (arr : (o, b) bigarray) = let dims = A.dims arr in let constant_set_f f values strict = let len = Array.length values in @@ -217,7 +230,8 @@ let reset_bigarray (init_op : Ops.init_op) (type o b) (prec : (o, b) Ops.precisi raise @@ Utils.User_error [%string - "Ndarray.reset_bigarray: Constant_fill: invalid data size %{len#Int}, expected %{size#Int}"]; + "Ndarray.reset_bigarray: Constant_fill: invalid data size %{len#Int}, expected \ + %{size#Int}"]; set_bigarray arr ~f:(fun idcs -> f values.(indices_to_offset ~dims ~idcs))) else set_bigarray arr ~f:(fun idcs -> f values.(indices_to_offset ~dims ~idcs % len)) in @@ -263,14 +277,15 @@ let fold_bigarray arr ~init ~f = let fold_as_float ~init ~f arr = match arr with - | Byte_nd arr -> fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ Float.of_int @@ Char.to_int c) arr + | Byte_nd arr -> + fold_bigarray ~init ~f:(fun accu idx c -> f accu idx @@ Float.of_int @@ Char.to_int c) arr | Half_nd arr -> fold_bigarray ~init ~f arr | Single_nd arr -> fold_bigarray ~init ~f arr | Double_nd arr -> fold_bigarray ~init ~f arr let size_in_bytes v = - (* Cheating here because 1 number Bigarray is same size as empty Bigarray: it's more informative to report - the cases differently. *) + (* Cheating here because 1 number Bigarray is same size as empty Bigarray: it's more informative + to report the cases differently. *) let f arr = if Array.is_empty @@ A.dims arr then 0 else A.size_in_bytes arr in map { f } v @@ -350,38 +365,45 @@ let retrieve_flat_values arr = (** {2 *** Printing ***} *) -(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. Outputs - ["-"] for empty dimensions. *) +(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2. + Outputs ["-"] for empty dimensions. *) let int_dims_to_string ?(with_axis_numbers = false) dims = if Array.is_empty dims then "-" else if with_axis_numbers then - String.concat_array ~sep:" x " @@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s) + String.concat_array ~sep:" x " + @@ Array.mapi dims ~f:(fun d s -> Int.to_string d ^ ":" ^ Int.to_string s) else String.concat_array ~sep:"x" @@ Array.map dims ~f:Int.to_string let concise_float ~prec v = Printf.sprintf "%.*e" prec v - |> (* The C99 standard requires at least two digits for the exponent, but the leading zero is a waste of - space. *) + |> (* The C99 standard requires at least two digits for the exponent, but the leading zero is a + waste of space. *) String.substr_replace_first ~pattern:"e+0" ~with_:"e+" |> String.substr_replace_first ~pattern:"e-0" ~with_:"e-" -(** Prints 0-based [indices] entries out of [arr], where a number between [-5] and [-1] in an axis means to - print out the axis, and a non-negative number means to print out only the indexed dimension of the axis. - Prints up to [entries_per_axis] or [entries_per_axis+1] entries per axis, possibly with ellipsis in the - middle. [labels] provides the axis labels for all axes (use [""] or ["_"] for no label). The last label - corresponds to axis [-1] etc. The printed out axes are arranged as: +(** Prints 0-based [indices] entries out of [arr], where a number between [-5] and [-1] in an axis + means to print out the axis, and a non-negative number means to print out only the indexed + dimension of the axis. Prints up to [entries_per_axis] or [entries_per_axis+1] entries per axis, + possibly with ellipsis in the middle. [labels] provides the axis labels for all axes (use [""] + or ["_"] for no label). The last label corresponds to axis [-1] etc. The printed out axes are + arranged as: - [-1]: a horizontal segment in an inner rectangle (i.e. column numbers of the inner rectangle), - [-2]: a sequence of segments in a line of text (i.e. column numbers of an outer rectangle), - [-3]: a vertical segment in an inner rectangle (i.e. row numbers of the inner rectangle), - [-4]: a vertical sequence of segments (i.e. column numbers of an outer rectangle), - [-5]: a sequence of screens of text (i.e. stack numbers of outer rectangles). *) -let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(labels = [||]) ~indices arr = +let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(labels = [||]) ~indices + arr = let module B = PrintBox in let dims = dims arr in let has_nan = fold_as_float ~init:false ~f:(fun has_nan _ v -> has_nan || Float.is_nan v) arr in - let has_inf = fold_as_float ~init:false ~f:(fun has_inf _ v -> has_inf || Float.(v = infinity)) arr in + let has_inf = + fold_as_float ~init:false ~f:(fun has_inf _ v -> has_inf || Float.(v = infinity)) arr + in let has_neg_inf = - fold_as_float ~init:false ~f:(fun has_neg_inf _ v -> has_neg_inf || Float.(v = neg_infinity)) arr + fold_as_float ~init:false + ~f:(fun has_neg_inf _ v -> has_neg_inf || Float.(v = neg_infinity)) + arr in let header = prefix @@ -393,11 +415,16 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label if Array.is_empty dims then B.vlist ~bars:false [ B.text header; B.line "" ] else let indices = Array.copy indices in - let entries_per_axis = if entries_per_axis % 2 = 0 then entries_per_axis + 1 else entries_per_axis in - let var_indices = Array.filter_mapi indices ~f:(fun i d -> if d <= -1 then Some (5 + d, i) else None) in + let entries_per_axis = + if entries_per_axis % 2 = 0 then entries_per_axis + 1 else entries_per_axis + in + let var_indices = + Array.filter_mapi indices ~f:(fun i d -> if d <= -1 then Some (5 + d, i) else None) + in let extra_indices = [| (0, -1); (1, -1); (2, -1); (3, -1); (4, -1) |] - |> Array.filter ~f:(Fn.non @@ Array.mem var_indices ~equal:(fun (a, _) (b, _) -> Int.equal a b)) + |> Array.filter + ~f:(Fn.non @@ Array.mem var_indices ~equal:(fun (a, _) (b, _) -> Int.equal a b)) in let var_indices = Array.append extra_indices var_indices in Array.sort ~compare:(fun (a, _) (b, _) -> Int.compare a b) var_indices; @@ -443,19 +470,22 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label B.hpad 1 @@ B.line @@ if is_ellipsis () then "..." - else concise_float ~prec:Utils.settings.print_decimals_precision (get_as_float arr indices) + else + concise_float ~prec:Utils.settings.print_decimals_precision (get_as_float arr indices) with Invalid_argument _ -> raise @@ Utils.User_error [%string - "Invalid indices: %{int_dims_to_string indices} into array: %{(int_dims_to_string dims)}"]) + "Invalid indices: %{int_dims_to_string indices} into array: \ + %{(int_dims_to_string dims)}"]) in let tag ?pos label ind = if ind = -1 then "" else match pos with | Some pos when elide_for pos ~ind -> "~~~~~" - | Some pos when pos >= 0 -> Int.to_string (expand pos ~ind) ^ " @ " ^ label ^ Int.to_string ind + | Some pos when pos >= 0 -> + Int.to_string (expand pos ~ind) ^ " @ " ^ label ^ Int.to_string ind | _ -> "axis " ^ label ^ Int.to_string ind in let nlines = if brief then size1 else size1 + 1 in @@ -476,7 +506,8 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label else let nline = if brief then line else line - 1 in let ncol = if brief then col else col - 1 in - if elide_for ncol ~ind:ind2 || elide_for nline ~ind:ind1 then B.hpad 1 @@ B.line "..." + if elide_for ncol ~ind:ind2 || elide_for nline ~ind:ind1 then + B.hpad 1 @@ B.line "..." else inner_grid v nline ncol) in let screens = @@ -497,7 +528,9 @@ let pp_array_inline fmt ~num_batch_axes ~num_output_axes ~num_input_axes ?axes_s (match axes_spec with None -> () | Some spec -> fprintf fmt "\"%s\" " spec); let rec loop axis = let sep = - if axis < num_batch_axes then ";" else if axis < num_batch_axes + num_output_axes then ";" else "," + if axis < num_batch_axes then ";" + else if axis < num_batch_axes + num_output_axes then ";" + else "," in let open_delim = if axis < num_batch_axes then "[|" diff --git a/arrayjit/lib/ops.ml b/arrayjit/lib/ops.ml index 242866a8..10d06dda 100644 --- a/arrayjit/lib/ops.ml +++ b/arrayjit/lib/ops.ml @@ -94,7 +94,8 @@ let pack_prec (type ocaml elt_t) (prec : (ocaml, elt_t) precision) = type 'r map_prec = { f : 'ocaml 'elt_t. ('ocaml, 'elt_t) precision -> 'r } let map_prec ?default { f } = function - | Void_prec -> Option.value_or_thunk default ~default:(fun () -> invalid_arg "map_prec: Void_prec") + | Void_prec -> + Option.value_or_thunk default ~default:(fun () -> invalid_arg "map_prec: Void_prec") | Byte_prec Byte -> f Byte | Half_prec (Half | Single) -> f Half | Single_prec (Single | Half) -> f Single @@ -111,22 +112,26 @@ let cuda_typ_of_prec = function (** {2 *** Operations ***} *) -(** Initializes or resets a array by filling in the corresponding numbers, at the appropriate precision. *) +(** Initializes or resets a array by filling in the corresponding numbers, at the appropriate + precision. *) type init_op = | Constant_fill of { values : float array; strict : bool } - (** Fills in the numbers where the rightmost axis is contiguous. If [strict=true], loops over the - provided values. *) + (** Fills in the numbers where the rightmost axis is contiguous. If [strict=true], loops over + the provided values. *) | Range_over_offsets - (** Fills in the offset number of each cell (i.e. how many cells away it is from the beginning). *) + (** Fills in the offset number of each cell (i.e. how many cells away it is from the + beginning). *) | Standard_uniform (** Draws the values from U(0,1). *) | File_mapped of string * prec (** Reads the data using [Unix.openfile] and [Unix.map_file]. *) [@@deriving equal, sexp] -type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1 [@@deriving sexp, compare, equal] +type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1 +[@@deriving sexp, compare, equal] + type unop = Identity | Relu [@@deriving sexp, compare, equal] -(** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation does not - have a neutral element. *) +(** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation + does not have a neutral element. *) let neutral_elem = function | Add | Sub -> 0. | Mul | Div -> 1. @@ -227,7 +232,7 @@ type global_identifier = dims : int array Lazy.t; } | Merge_buffer of { source_node_id : int } - (** Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by merge - operations. The merge buffer is associated with the source node of the device's most recent - [device_to_device ~into_merge_buffer:true] operation. *) + (** Each device has at most one merge buffer, which is re-used, and re-allocated as needed, by + merge operations. The merge buffer is associated with the source node of the device's most + recent [device_to_device ~into_merge_buffer:true] operation. *) [@@deriving sexp_of, equal, compare] diff --git a/arrayjit/lib/ppx_helper.ml b/arrayjit/lib/ppx_helper.ml index d6fd52d0..2ca0cdc7 100644 --- a/arrayjit/lib/ppx_helper.ml +++ b/arrayjit/lib/ppx_helper.ml @@ -36,7 +36,8 @@ let ndarray_constant expr = if depth >= Array.length dims_spec then match expr with | { pexp_desc = Pexp_constant (Pconst_float _); _ } -> expr :: accu - | { pexp_desc = Pexp_constant (Pconst_integer _); _ } -> [%expr Float.of_int [%e expr]] :: accu + | { pexp_desc = Pexp_constant (Pconst_integer _); _ } -> + [%expr Float.of_int [%e expr]] :: accu | { pexp_desc = Pexp_tuple _; pexp_loc = loc; _ } -> (pexp_extension ~loc @@ Location.error_extensionf ~loc @@ -61,7 +62,8 @@ let ndarray_constant expr = List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc - @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" + @@ Location.error_extensionf ~loc + "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Input_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) @@ -71,7 +73,8 @@ let ndarray_constant expr = List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc - @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" + @@ Location.error_extensionf ~loc + "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Batch_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) @@ -82,7 +85,8 @@ let ndarray_constant expr = List.fold_left exps ~init:accu ~f:(loop_values @@ (depth + 1)) | dim_spec -> (pexp_extension ~loc - @@ Location.error_extensionf ~loc "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" + @@ Location.error_extensionf ~loc + "Arrayjit: ndarray literal axis mismatch, got %s, expected %s" (dim_spec_to_string @@ `Output_dims (List.length exps)) (dim_spec_to_string dim_spec)) :: accu) @@ -95,7 +99,8 @@ let ndarray_constant expr = let result = loop_values 0 [] expr in let values = { expr with pexp_desc = Pexp_array (List.rev result) } in let batch_dims, output_dims, input_dims = - Array.fold dims_spec ~init:([], [], []) ~f:(fun (batch_dims, output_dims, input_dims) -> function + Array.fold dims_spec ~init:([], [], []) ~f:(fun (batch_dims, output_dims, input_dims) -> + function | `Input_dims dim -> (batch_dims, output_dims, eint ~loc dim :: input_dims) | `Output_dims dim -> (batch_dims, eint ~loc dim :: output_dims, input_dims) | `Batch_dims dim -> (eint ~loc dim :: batch_dims, output_dims, input_dims)) diff --git a/arrayjit/lib/rand.ml b/arrayjit/lib/rand.ml index 7c5ffbc2..322a761f 100644 --- a/arrayjit/lib/rand.ml +++ b/arrayjit/lib/rand.ml @@ -35,7 +35,7 @@ let random_lib = | "for_tests" -> (module Random_for_tests : Random) | _ -> invalid_arg - @@ "Rand.random_lib: invalid setting of the global argument randomness_lib, expected one of: stdlib, \ - for_tests; found: " ^ random_config + @@ "Rand.random_lib: invalid setting of the global argument randomness_lib, expected one of: \ + stdlib, for_tests; found: " ^ random_config module Lib = (val random_lib) diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index ca374e7c..582a1936 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -6,7 +6,10 @@ module Debug_runtime = Utils.Debug_runtime [%%global_debug_log_level Nothing] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] -type task = { description : string; work : (module Minidebug_runtime.Debug_runtime) -> unit -> unit } +type task = { + description : string; + work : (module Minidebug_runtime.Debug_runtime) -> unit -> unit; +} [@@deriving sexp_of] let run debug_runtime task = @@ -30,19 +33,19 @@ type memory_mode = | Virtual (** The tensor node's computations are inlined on a per-scalar basis. *) | Never_virtual (** One of: [Local], [On_device], [Hosted]. *) | Local - (** The full tensor node is cached for the duration of a computation but not persisted across calls to - compiled functions. It is not available for merging across devices. *) + (** The full tensor node is cached for the duration of a computation but not persisted across + calls to compiled functions. It is not available for merging across devices. *) | Device_only (** One of: [Local], [On_device]. *) | On_device - (** The tensor node is stored on the devices that compute with it and persisted across function calls. - It is available for merging across devices (for devices that support merging / P2P), but not - (directly) for visualization or storing to disk. *) + (** The tensor node is stored on the devices that compute with it and persisted across + function calls. It is available for merging across devices (for devices that support + merging / P2P), but not (directly) for visualization or storing to disk. *) | Materialized (** One of: [On_device], [Hosted]. *) | Hosted of memory_type - (** The tensor node is stored in a globally addressable memory, in addition to on devices where it is - computed with (or as part of one of them, if "hosting on device", or only on the host and not on - devices, for some backends). It is available for all operations, and visible to OCaml programs as an - {!Ndarray} (the optional [array] of {!t}). *) + (** The tensor node is stored in a globally addressable memory, in addition to on devices + where it is computed with (or as part of one of them, if "hosting on device", or only on + the host and not on devices, for some backends). It is available for all operations, and + visible to OCaml programs as an {!Ndarray} (the optional [array] of {!t}). *) [@@deriving sexp, compare, equal] type t = { @@ -87,23 +90,30 @@ let is_materialized_force tn provenance = | Some ((On_device | Hosted _ | Materialized), _) -> true | Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false -let known_not_materialized tn = match tn.memory_mode with Some ((Virtual | Local), _) -> true | _ -> false +let known_not_materialized tn = + match tn.memory_mode with Some ((Virtual | Local), _) -> true | _ -> false let known_constant tn = - match tn.memory_mode with Some ((Effectively_constant | Hosted Constant), _) -> true | _ -> false + match tn.memory_mode with + | Some ((Effectively_constant | Hosted Constant), _) -> true + | _ -> false let known_non_virtual tn = match tn.memory_mode with None | Some ((Virtual | Effectively_constant), _) -> false | _ -> true let known_not_param tn = match tn.memory_mode with - | Some ((Virtual | Local | Effectively_constant | Device_only | On_device | Hosted (Constant | Volatile)), _) - -> + | Some + ( ( Virtual | Local | Effectively_constant | Device_only | On_device + | Hosted (Constant | Volatile) ), + _ ) -> true | _ -> false let mode_is_unspecified tn = - match tn.memory_mode with None | Some ((Never_virtual | Effectively_constant), _) -> true | _ -> false + match tn.memory_mode with + | None | Some ((Never_virtual | Effectively_constant), _) -> true + | _ -> false let update_memory_mode tn mode provenance = match (tn.memory_mode, mode) with @@ -113,8 +123,8 @@ let update_memory_mode tn mode provenance = raise @@ Utils.User_error [%string - "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{name tn}: cannot be \ - virtual"] + "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{name tn}: \ + cannot be virtual"] | Some ((Virtual | Hosted Constant), _), Effectively_constant -> () | Some ((Never_virtual | Materialized), _), Effectively_constant | Some (Effectively_constant, _), (Never_virtual | Materialized | Hosted Constant) -> @@ -128,8 +138,8 @@ let update_memory_mode tn mode provenance = raise @@ Utils.User_error [%string - "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{name tn} is already \ - virtual"] + "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{name tn} is \ + already virtual"] | Some (_, _), Never_virtual -> () | Some (Device_only, _), (Local | On_device) -> tn.memory_mode <- Some (mode, provenance) | Some (Materialized, _), (On_device | Hosted _) -> tn.memory_mode <- Some (mode, provenance) @@ -140,7 +150,8 @@ let update_memory_mode tn mode provenance = | Some (_, prov2), _ -> invalid_arg [%string - "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for %{name tn}"] + "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \ + %{name tn}"] include Comparator.Make (struct type nonrec t = t @@ -171,13 +182,17 @@ let dims_to_string ?(with_axis_numbers = false) arr = let is_alphanum_ = String.for_all ~f:(fun c -> Char.equal c '_' || Char.is_alphanum c) let ident_label tn = - let components = List.filter tn.label ~f:(fun i -> is_alphanum_ i && not (String.equal i "grad")) in + let components = + List.filter tn.label ~f:(fun i -> is_alphanum_ i && not (String.equal i "grad")) + in if List.is_empty components then None else Some (String.concat ~sep:"_" components) let debug_name ~id ~label = let n = "n" ^ Int.to_string id in let ident_label = - let components = List.filter label ~f:(fun i -> is_alphanum_ i && not (String.equal i "grad")) in + let components = + List.filter label ~f:(fun i -> is_alphanum_ i && not (String.equal i "grad")) + in if List.is_empty components then None else Some (String.concat ~sep:"_" components) in let is_grad = List.mem ~equal:String.equal label "grad" in @@ -208,7 +223,8 @@ let styled_ident ~repeating_nograd_idents ~repeating_grad_idents style arr = in match ident_label arr with | Some ident -> - if Hashtbl.mem (if is_grad then repeating_grad_idents else repeating_nograd_idents) ident then + if Hashtbl.mem (if is_grad then repeating_grad_idents else repeating_nograd_idents) ident + then if is_grad then [%string "n%{arr.id - 1#Int}_%{ident}%{opt_grad}"] else [%string "n%{arr.id#Int}_%{ident}"] else [%string "%{ident}%{opt_grad}"] @@ -220,7 +236,8 @@ let get_style ?(arg_name = "ll_ident_style") ?(no_dots = false) () = | "heuristic" -> `Heuristic_ocannl (if no_dots then `Under_grad else `Dot_grad) | "name_and_label" -> `Name_and_label | "name_only" -> `Name_only - | _ -> invalid_arg @@ "Wrong " ^ arg_name ^ ", must be one of: heuristic, name_and_label, name_only" + | _ -> + invalid_arg @@ "Wrong " ^ arg_name ^ ", must be one of: heuristic, name_and_label, name_only" let header arr = let mem_size = @@ -248,7 +265,9 @@ let registry = Registry.create 16 let create prec ~id ~label ~dims init_op = let rec array = - lazy (if is_hosted_force tn 30 then Some (Nd.create_array prec ~dims:(Lazy.force dims) init_op) else None) + lazy + (if is_hosted_force tn 30 then Some (Nd.create_array prec ~dims:(Lazy.force dims) init_op) + else None) and tn = { array; prec; id; label; memory_mode = None; backend_info = Sexp.List []; dims } in Registry.add registry tn; tn diff --git a/arrayjit/lib/utils.ml b/arrayjit/lib/utils.ml index b67d5890..101fd40a 100644 --- a/arrayjit/lib/utils.ml +++ b/arrayjit/lib/utils.ml @@ -6,7 +6,8 @@ module Set_O = struct let ( & ) = Set.inter let ( -* ) s1 s2 = - Set.of_sequence (Set.comparator_s s1) @@ Sequence.map ~f:Either.value @@ Set.symmetric_diff s1 s2 + Set.of_sequence (Set.comparator_s s1) + @@ Sequence.map ~f:Either.value @@ Set.symmetric_diff s1 s2 end let no_ints = Set.empty (module Int) @@ -17,7 +18,9 @@ let map_merge m1 m2 ~f = match m with `Right v | `Left v -> Some v | `Both (v1, v2) -> Some (f v1 v2)) let mref_add mref ~key ~data ~or_ = - match Map.add !mref ~key ~data with `Ok m -> mref := m | `Duplicate -> or_ (Map.find_exn !mref key) + match Map.add !mref ~key ~data with + | `Ok m -> mref := m + | `Duplicate -> or_ (Map.find_exn !mref key) let mref_add_missing mref key ~f = if Map.mem !mref key then () else mref := Map.add_exn !mref ~key ~data:(f ()) @@ -28,7 +31,8 @@ type settings = { mutable output_debug_files_in_run_directory : bool; mutable with_debug_level : int; mutable fixed_state_for_init : int option; - mutable print_decimals_precision : int; (** When rendering arrays etc., outputs this many decimal digits. *) + mutable print_decimals_precision : int; + (** When rendering arrays etc., outputs this many decimal digits. *) } [@@deriving sexp] @@ -41,7 +45,7 @@ let settings = fixed_state_for_init = None; print_decimals_precision = 2; } - + let accessed_global_args = Hash_set.create (module String) let read_cmdline_or_env_var n = @@ -64,7 +68,8 @@ let read_cmdline_or_env_var n = List.find_map env_variants ~f:(fun env_n -> Option.( join - @@ map (Core.Sys.getenv env_n) ~f:(fun v -> if String.is_empty v then None else Some (env_n, v)))) + @@ map (Core.Sys.getenv env_n) ~f:(fun v -> + if String.is_empty v then None else Some (env_n, v)))) with | None | Some (_, "") -> None | Some (p, arg) -> @@ -93,9 +98,12 @@ let config_file_args = | [] -> None | key :: [ v ] -> let key = - String.(lowercase @@ strip ~drop:(fun c -> equal_char '-' c || equal_char ' ' c) key) + String.( + lowercase @@ strip ~drop:(fun c -> equal_char '-' c || equal_char ' ' c) key) + in + let key = + if String.is_prefix key ~prefix:"ocannl" then String.drop_prefix key 6 else key in - let key = if String.is_prefix key ~prefix:"ocannl" then String.drop_prefix key 6 else key in Some (String.strip ~drop:(equal_char '_') key, v) | _ -> failwith @@ "OCANNL: invalid syntax in the config file " ^ fname @@ -105,8 +113,8 @@ let config_file_args = Stdio.printf "\nWelcome to OCANNL! Configuration defaults file is disabled.\n%!"; Hashtbl.create (module String) -(** Retrieves [arg_name] argument from the command line or from an environment variable, returns [default] if - none found. *) +(** Retrieves [arg_name] argument from the command line or from an environment variable, returns + [default] if none found. *) let get_global_arg ~default ~arg_name:n = let with_debug = settings.with_debug_level > 0 && not (Hash_set.mem accessed_global_args n) in if with_debug then @@ -125,13 +133,15 @@ let get_global_arg ~default ~arg_name:n = result let () = - settings.with_debug_level <- Int.of_string @@ get_global_arg ~arg_name:"with_debug_level" ~default:"0"; + settings.with_debug_level <- + Int.of_string @@ get_global_arg ~arg_name:"with_debug_level" ~default:"0"; settings.debug_log_from_routines <- Bool.of_string @@ get_global_arg ~arg_name:"debug_log_from_routines" ~default:"false"; settings.debug_memory_locations <- Bool.of_string @@ get_global_arg ~arg_name:"debug_memory_locations" ~default:"false"; settings.output_debug_files_in_run_directory <- - Bool.of_string @@ get_global_arg ~arg_name:"output_debug_files_in_run_directory" ~default:"false"; + Bool.of_string + @@ get_global_arg ~arg_name:"output_debug_files_in_run_directory" ~default:"false"; settings.fixed_state_for_init <- (let seed = get_global_arg ~arg_name:"fixed_state_for_init" ~default:"" in if String.is_empty seed then None else Some (Int.of_string seed)); @@ -159,8 +169,8 @@ let get_debug name = | "nanoseconds" -> Nanoseconds | s -> invalid_arg - @@ "ocannl_elapsed_times setting should be not_reported, seconds or milliseconds, microseconds or \ - nanoseconds; found: " ^ s + @@ "ocannl_elapsed_times setting should be not_reported, seconds or milliseconds, \ + microseconds or nanoseconds; found: " ^ s in let location_format = match String.lowercase @@ get_global_arg ~default:"beg_pos" ~arg_name:"location_format" with @@ -170,26 +180,34 @@ let get_debug name = | "beg_pos" -> Beg_pos | "range_line" -> Range_line | "range_pos" -> Range_pos - | s -> invalid_arg @@ "ocannl_location_format setting should be none, clock or elapsed; found: " ^ s + | s -> + invalid_arg @@ "ocannl_location_format setting should be none, clock or elapsed; found: " + ^ s in let flushing, backend = - match String.lowercase @@ String.strip @@ get_global_arg ~default:"html" ~arg_name:"debug_backend" with + match + String.lowercase @@ String.strip @@ get_global_arg ~default:"html" ~arg_name:"debug_backend" + with | "text" -> (false, `Text) | "html" -> (false, `Html Minidebug_runtime.default_html_config) | "markdown" -> (false, `Markdown Minidebug_runtime.default_md_config) | "flushing" -> (true, `Text) | s -> - invalid_arg @@ "ocannl_debug_backend setting should be text, html, markdown or flushing; found: " ^ s + invalid_arg + @@ "ocannl_debug_backend setting should be text, html, markdown or flushing; found: " ^ s in let hyperlink = get_global_arg ~default:"./" ~arg_name:"hyperlink_prefix" in - let print_entry_ids = Bool.of_string @@ get_global_arg ~default:"false" ~arg_name:"logs_print_entry_ids" in + let print_entry_ids = + Bool.of_string @@ get_global_arg ~default:"false" ~arg_name:"logs_print_entry_ids" + in let verbose_entry_ids = Bool.of_string @@ get_global_arg ~default:"false" ~arg_name:"logs_verbose_entry_ids" in let filename = if String.is_empty name then "debug" else "debug-" ^ name in let log_level = match - String.lowercase @@ String.strip @@ get_global_arg ~default:"nonempty_entries" ~arg_name:"log_level" + String.lowercase @@ String.strip + @@ get_global_arg ~default:"nonempty_entries" ~arg_name:"log_level" with | "nothing" -> Minidebug_runtime.Nothing | "prefixed_error" -> Prefixed [| "ERROR" |] @@ -200,8 +218,9 @@ let get_debug name = | "everything" -> Everything | s -> invalid_arg - @@ "ocannl_log_level setting should be one of: nothing, prefixed_error, prefixed_warn_error, \ - prefixed_info_warn_error, explicit_logs, nonempty_entries, everything; found: " ^ s + @@ "ocannl_log_level setting should be one of: nothing, prefixed_error, \ + prefixed_warn_error, prefixed_info_warn_error, explicit_logs, nonempty_entries, \ + everything; found: " ^ s in let toc_entry_minimal_depth = let arg = get_global_arg ~default:"" ~arg_name:"toc_entry_minimal_depth" in @@ -222,7 +241,8 @@ let get_debug name = | "us" -> Mtime.Span.us | "ms" -> Mtime.Span.ms | _ -> - invalid_arg @@ "ocannl_toc_entry_minimal_span setting should end with one of: ns, us, ms; found: " + invalid_arg + @@ "ocannl_toc_entry_minimal_span setting should end with one of: ns, us, ms; found: " ^ period in [ Minidebug_runtime.Minimal_span Mtime.Span.(Int.of_string arg * period) ] @@ -231,13 +251,14 @@ let get_debug name = Minidebug_runtime.And (toc_entry_minimal_depth @ toc_entry_minimal_size @ toc_entry_minimal_span) in if flushing then - Minidebug_runtime.debug_flushing ~filename ~time_tagged ~elapsed_times ~print_entry_ids ~verbose_entry_ids - ~global_prefix:name ~for_append:false (* ~log_level *) () + Minidebug_runtime.debug_flushing ~filename ~time_tagged ~elapsed_times ~print_entry_ids + ~verbose_entry_ids ~global_prefix:name ~for_append:false (* ~log_level *) () else Minidebug_runtime.forget_printbox @@ Minidebug_runtime.debug_file ~time_tagged ~elapsed_times ~location_format ~print_entry_ids - ~verbose_entry_ids ~global_prefix:name ~toc_flame_graph:true ~flame_graph_separation:50 ~toc_entry - ~for_append:false ~max_inline_sexp_length:120 ~hyperlink ~toc_specific_hyperlink:"" + ~verbose_entry_ids ~global_prefix:name ~toc_flame_graph:true ~flame_graph_separation:50 + ~toc_entry ~for_append:false ~max_inline_sexp_length:120 ~hyperlink + ~toc_specific_hyperlink:"" ~highlight_terms:Re.(alt []) ~exclude_on_path:Re.(str "env") ~values_first_mode:true ~backend ~log_level ?snapshot_every_sec filename @@ -252,7 +273,8 @@ module Debug_runtime = (val get_debug "") let rec union_find ~equal map ~key ~rank = match Map.find map key with | None -> (key, rank) - | Some data -> if equal key data then (key, rank) else union_find ~equal map ~key:data ~rank:(rank + 1) + | Some data -> + if equal key data then (key, rank) else union_find ~equal map ~key:data ~rank:(rank + 1) let union_add ~equal map k1 k2 = if equal k1 k2 then map @@ -283,8 +305,9 @@ let sorted_diff ~compare l1 l2 = in (loop [] l1 l2 [@nontail]) -(** [parallel_merge merge num_devices] progressively invokes the pairwise [merge] callback, converging on the - 0th position, with [from] ranging from [1] to [num_devices - 1], and [to_ < from]. *) +(** [parallel_merge merge num_devices] progressively invokes the pairwise [merge] callback, + converging on the 0th position, with [from] ranging from [1] to [num_devices - 1], and + [to_ < from]. *) let%track_sexp parallel_merge merge (num_devices : int) = let rec loop (upper : int) : unit = let is_even = (upper + 1) % 2 = 0 in @@ -307,8 +330,8 @@ let ( !@ ) = Atomic.get type waiter = { await : keep_waiting:(unit -> bool) -> unit -> bool; - (** Returns [true] if the waiter was not already waiting (in another thread) and waiting was needed - ([keep_waiting] always returned true). *) + (** Returns [true] if the waiter was not already waiting (in another thread) and waiting was + needed ([keep_waiting] always returned true). *) release_if_waiting : unit -> bool; (** Returns [true] if the waiter both was waiting and was not already released. *) is_waiting : unit -> bool; @@ -318,7 +341,8 @@ type waiter = { let waiter ~name:_ () = let is_open = Atomic.make true in - (* TODO: since OCaml 5.2, use [make_contended] for at least [is_released] and maybe [is_waiting]. *) + (* TODO: since OCaml 5.2, use [make_contended] for at least [is_released] and maybe + [is_waiting]. *) let is_released = Atomic.make false in let is_waiting = Atomic.make false in let pipe_inp, pipe_out = Unix.pipe ~cloexec:true () in @@ -414,7 +438,8 @@ let%diagn_rt_sexp log_trace_tree logs = loop more | source :: trace :: more when String.is_prefix source ~prefix:"# " -> (let source = - String.concat ~sep:"\n" @@ String.split ~on:'$' @@ String.chop_prefix_exn ~prefix:"# " source + String.concat ~sep:"\n" @@ String.split ~on:'$' + @@ String.chop_prefix_exn ~prefix:"# " source in match split_with_seps header_sep trace with | [] | [ "" ] -> [%log source] @@ -447,14 +472,17 @@ let insert ~next = function cons.tl <- Cons { hd = next; tl = cons.tl }; cons.tl -let tl_exn = function Empty -> raise @@ Not_found_s (Sexp.Atom "mutable_list.tl_exn") | Cons { tl; _ } -> tl +let tl_exn = function + | Empty -> raise @@ Not_found_s (Sexp.Atom "mutable_list.tl_exn") + | Cons { tl; _ } -> tl type pp_file = { f_name : string; ppf : Stdlib.Format.formatter; finalize : unit -> unit } let pp_file ~base_name ~extension = let column_width = 110 in let f_name = - if settings.output_debug_files_in_run_directory then Filename_base.concat "./" base_name ^ extension + if settings.output_debug_files_in_run_directory then + Filename_base.concat "./" base_name ^ extension else Stdlib.Filename.temp_file (base_name ^ "_") extension in (* (try Stdlib.Sys.remove f_name with _ -> ()); *) diff --git a/bin/compilation_speed.ml b/bin/compilation_speed.ml index f556e628..dff501d6 100644 --- a/bin/compilation_speed.ml +++ b/bin/compilation_speed.ml @@ -25,7 +25,8 @@ let benchmark_overhead backend () = let device = new_virtual_device @@ get_device ~ordinal:0 in let ctx = init device in let update_f = Train.grad_update f in - (* Initialize the context with a mock update of x to ensure that it is not optimized as a constant. *) + (* Initialize the context with a mock update of x to ensure that it is not optimized as a + constant. *) let%cd mock_update_x = x =: 42 in let init_assign_x = link ctx @@ compile ~name:"init_assign_x" IDX.empty mock_update_x in let f_routine = link init_assign_x.context @@ compile IDX.empty update_f.fwd_bprop in @@ -61,7 +62,9 @@ let benchmark_overhead backend () = (* FIXME: global mem consumption *) mem_in_bytes = 0; result_label = "x, f(x)"; - result = [%sexp_of: (float * float) list] @@ [ (xs.(0), ys.(0)); (xs.(n_data / 2), ys.(n_data / 2)) ]; + result = + [%sexp_of: (float * float) list] + @@ [ (xs.(0), ys.(0)); (xs.(n_data / 2), ys.(n_data / 2)) ]; } in PrintBox_text.output Stdio.stdout plot_box; @@ -76,4 +79,5 @@ let benchmarks = ] let () = - List.map benchmarks ~f:(fun bench -> bench ()) |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout + List.map benchmarks ~f:(fun bench -> bench ()) + |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout diff --git a/bin/einsum_trivia.ml b/bin/einsum_trivia.ml index 63ce6492..5250517f 100644 --- a/bin/einsum_trivia.ml +++ b/bin/einsum_trivia.ml @@ -33,7 +33,9 @@ let () = let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in let%op ho = hey ++ "b|i->o => o|b->i" in Train.forward_and_forget backend ctx ho; - let hey2 = TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () in + let hey2 = + TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () + in let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in Train.forward_and_forget backend ctx ho2; Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho2 diff --git a/bin/micrograd_basic.ml b/bin/micrograd_basic.ml index 205419d6..a274d3e0 100644 --- a/bin/micrograd_basic.ml +++ b/bin/micrograd_basic.ml @@ -19,10 +19,11 @@ let%diagn_sexp () = let%op c = "a" [ -4 ] + "b" [ 2 ] in let%op d = c + c + 1 in (* let%op c = c + 1 + c + ~-a in *) - (* Uncomment just the first "fully on host" line to see which arrays can be virtual, and just the second - line to see the intermediate computation values. *) + (* Uncomment just the first "fully on host" line to see which arrays can be virtual, and just the + second line to see the intermediate computation values. *) Train.every_non_literal_on_host d; - (* List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff ]; *) + (* List.iter ~f:(function Some diff -> Train.set_hosted diff.grad | None -> ()) [ a.diff; b.diff + ]; *) let update = Train.grad_update d in let routine = Backend.(link ctx @@ compile IDX.empty update.fwd_bprop) in Train.sync_run (module Backend) routine d; diff --git a/bin/micrograd_demo.ml b/bin/micrograd_demo.ml index 8d24b933..de388d20 100644 --- a/bin/micrograd_demo.ml +++ b/bin/micrograd_demo.ml @@ -57,15 +57,20 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = let log_losses = ref [] in let learning_rates = ref [] in let%op margin_loss = ?/(1 - (moons_class *. mlp moons_input)) in - (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update computation. *) + (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update + computation. *) let scalar_loss, weight_decay = if use_builtin_weight_decay then let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in (scalar_loss, 0.0002) else let%op ssq w = (w **. 2) ++ "...|...->... => 0" in - let reg_loss = List.map ~f:ssq [ w1; w2; w3; b1; b2; b3 ] |> List.reduce_exn ~f:TDSL.O.( + ) in - let%op scalar_loss = ((margin_loss ++ "...|... => 0") /. !..batch_size) + (0.0001 *. reg_loss) in + let reg_loss = + List.map ~f:ssq [ w1; w2; w3; b1; b2; b3 ] |> List.reduce_exn ~f:TDSL.O.( + ) + in + let%op scalar_loss = + ((margin_loss ++ "...|... => 0") /. !..batch_size) + (0.0001 *. reg_loss) + in (scalar_loss, 0.0) in (* So that we can inspect them. *) @@ -80,9 +85,10 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = let routine = Backend.(link ctx @@ compile bindings (Seq (update.fwd_bprop, sgd))) in Train.all_host_to_device (module Backend) routine.context scalar_loss; Train.all_host_to_device (module Backend) routine.context learning_rate; - (* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true ~with_grad:false - ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate **********"; Tensor.print_tree - ~with_id:true ~with_grad:false ~depth:9 learning_rate; Stdio.printf "\n********\n%!"; *) + (* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true + ~with_grad:false ~depth:9 scalar_loss; Stdio.print_endline "\n******** learning_rate + **********"; Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 learning_rate; + Stdio.printf "\n********\n%!"; *) let open Operation.At in let epoch_loss = ref 0. in let step_ref = IDX.find_exn routine.bindings step_n in @@ -115,14 +121,16 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = Train.set_on_host Volatile mlp_result.value; (* By using jitted.context here, we don't need to copy the parameters back to the host. *) let result_routine = - Backend.(link routine.context @@ compile IDX.empty @@ Block_comment ("moons infer", mlp_result.forward)) + Backend.( + link routine.context @@ compile IDX.empty @@ Block_comment ("moons infer", mlp_result.forward)) in Stdio.print_endline "\n******** mlp_result **********"; Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 mlp_result; Stdio.printf "\n********\n%!"; let callback (x, y) = Tensor.set_values point [| x; y |]; - (* For the gccjit backend, point is only on host, not on device. For cuda, this will be needed. *) + (* For the gccjit backend, point is only on host, not on device. For cuda, this will be + needed. *) assert (Backend.from_host result_routine.context point.value); Train.run result_routine; assert (Backend.to_host result_routine.context mlp_result.value); diff --git a/bin/moons_benchmark.ml b/bin/moons_benchmark.ml index 877ea713..bc818ba6 100644 --- a/bin/moons_benchmark.ml +++ b/bin/moons_benchmark.ml @@ -13,8 +13,8 @@ module Debug_runtime = Utils.Debug_runtime [%%global_debug_log_level Nothing] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] -let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~backend_name ~value_prec - ~grad_prec () = +let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~backend_name + ~value_prec ~grad_prec () = [%track_sexp let _debug : string = "started" in (fun (started : unit) -> started) ()]; @@ -22,8 +22,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b let bench_title = [%string "seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_devices#Int}, batch \ - %{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad prec \ - %{Ops.prec_string grad_prec}"] + %{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad \ + prec %{Ops.prec_string grad_prec}"] in Stdio.printf "\n*** %s ***\n%!" bench_title; CDSL.virtualize_settings.enable_device_only <- on_device; @@ -57,9 +57,9 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b let moons_classes ~b = TDSL.init_const ~l:"moons_classes" ~b ~o:[ 1 ] moons_classes in let init_time = Time_now.nanoseconds_since_unix_epoch () in - (* * let%op mlp x = "b6" 1 + "w6" * ?/("b4" hid_4_5 + "w4" * ?/("b2" hid_2_3 + ("w2" * ?/("b1" 16 + ("w1" * - x))) + "b3" hid_2_3 + ("w3" * ?/(b2 + (w2 * ?/(b1 + (w1 * x)))))) + ("b5" hid_4_5 + ("w5" * ?/(b4 + (w4 * - ?/(b3 + (w3 * ?/(b2 + (w2 * ?/(b1 + (w1 * x))))))))))) in * *) + (* * let%op mlp x = "b6" 1 + "w6" * ?/("b4" hid_4_5 + "w4" * ?/("b2" hid_2_3 + ("w2" * ?/("b1" 16 + + ("w1" * x))) + "b3" hid_2_3 + ("w3" * ?/(b2 + (w2 * ?/(b1 + (w1 * x)))))) + ("b5" hid_4_5 + + ("w5" * ?/(b4 + (w4 * ?/(b3 + (w3 * ?/(b2 + (w2 * ?/(b1 + (w1 * x))))))))))) in * *) let%op mlp x = "b3" 1 + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in let start_time = ref None in @@ -70,12 +70,13 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b in (* Tn.print_accessible_headers (); *) let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss = - Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate epoch_loss + Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate + epoch_loss in let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates = - Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices:num_devices ~data_len:len ~epochs - ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay ~per_batch_callback - ~per_epoch_callback backend () + Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_devices:num_devices ~data_len:len + ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay + ~per_batch_callback ~per_epoch_callback backend () in let points = Tensor.value_2d_points ~xdim:0 ~ydim:1 inputs in let classes = Tensor.value_1d_points ~xdim:0 outputs in @@ -102,7 +103,8 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b [ Line_plot { - points = Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); + points = + Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); pixel = "-"; }; ] @@ -159,17 +161,18 @@ let benchmarks = List.concat_map [ 0; 1 (* ; 2; 3; 4 *) ] ~f:(fun seed -> List.concat_map [ "gccjit" (* *; "cuda" *) ] ~f:(fun backend_name -> [ - classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices ~batch_size - ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single; + classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices + ~batch_size ~backend_name ~value_prec:CDSL.single ~grad_prec:CDSL.single; ]))))) -(* let time_of = function PrintBox_utils.Benchmark { time_in_sec; _ } -> time_in_sec let nth_best nth bench = - let results = List.init 5 ~f:(fun seed -> bench ~seed ()) in let sorted = List.sort results ~compare:(fun - r1 r2 -> Float.compare (time_of r1) (time_of r2)) in List.nth_exn sorted (nth - 1) *) +(* let time_of = function PrintBox_utils.Benchmark { time_in_sec; _ } -> time_in_sec let nth_best + nth bench = let results = List.init 5 ~f:(fun seed -> bench ~seed ()) in let sorted = List.sort + results ~compare:(fun r1 r2 -> Float.compare (time_of r1) (time_of r2)) in List.nth_exn sorted + (nth - 1) *) let fixed_seed_search seed = - classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_devices:1 ~batch_size:20 ~backend_name:"cuda" - ~value_prec:CDSL.single ~grad_prec:CDSL.single () + classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_devices:1 ~batch_size:20 + ~backend_name:"cuda" ~value_prec:CDSL.single ~grad_prec:CDSL.single () let _suspended () = List.init 20 ~f:fixed_seed_search |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout @@ -178,6 +181,7 @@ let _suspended () = Stdio.stdout *) let benchmark () = - List.map benchmarks ~f:(fun bench -> bench ()) |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout + List.map benchmarks ~f:(fun bench -> bench ()) + |> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout let _suspended () = benchmark () diff --git a/bin/moons_demo_parallel.ml b/bin/moons_demo_parallel.ml index 61354e9b..473a75ca 100644 --- a/bin/moons_demo_parallel.ml +++ b/bin/moons_demo_parallel.ml @@ -39,23 +39,25 @@ let experiment ~seed () = let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in (* let%op mlp x = "b" + ("w" * x) in *) let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in - (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update computation. *) + (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update + computation. *) let weight_decay = 0.0002 in (* So that we can inspect them. *) let backend = Train.fresh_backend () in let per_batch_callback ~at_batch ~at_step ~learning_rate ~batch_loss ~epoch_loss = if (at_batch + 1) % 20 = 0 then - Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step learning_rate - batch_loss epoch_loss + Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step + learning_rate batch_loss epoch_loss in (* Tn.print_accessible_headers (); *) let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss = - Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate epoch_loss + Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate + epoch_loss in let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates = - Train.example_train_loop ~seed ~batch_size ~max_num_devices:(batch_size / 2) ~init_lr ~data_len:len - ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay ~per_batch_callback - ~per_epoch_callback backend () + Train.example_train_loop ~seed ~batch_size ~max_num_devices:(batch_size / 2) ~init_lr + ~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn + ~weight_decay ~per_batch_callback ~per_epoch_callback backend () in let points = Tensor.value_2d_points ~xdim:0 ~ydim:1 inputs in let classes = Tensor.value_1d_points ~xdim:0 outputs in @@ -96,7 +98,8 @@ let experiment ~seed () = [ Line_plot { - points = Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); + points = + Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); pixel = "-"; }; ] diff --git a/bin/zero2hero_1of7.ml b/bin/zero2hero_1of7.ml index c4d5935e..6211f0b0 100644 --- a/bin/zero2hero_1of7.ml +++ b/bin/zero2hero_1of7.ml @@ -35,7 +35,10 @@ let _suspended () = let%op f5 = f 5 in let module Backend = (val Train.fresh_backend ()) in Train.every_non_literal_on_host f5; - Train.forward_and_forget (module Backend) Backend.(init @@ new_virtual_device @@ get_device ~ordinal:0) f5; + Train.forward_and_forget + (module Backend) + Backend.(init @@ new_virtual_device @@ get_device ~ordinal:0) + f5; Stdio.printf "\n%!"; Tensor.print_tree ~with_grad:false ~depth:9 f5; Stdio.printf "\n%!" @@ -167,15 +170,19 @@ let _suspended () = which are specified in the tensor in the brackets.|}; Tensor.print_tree ~with_grad:true ~depth:9 l; let%op learning_rate = 0.1 in - let routine = link routine.context @@ compile IDX.empty @@ Train.sgd_update ~learning_rate update in + let routine = + link routine.context @@ compile IDX.empty @@ Train.sgd_update ~learning_rate update + in (* learning_rate is virtual so this will not print anything. *) - Tensor.iter_embedded_arrays learning_rate ~f:(fun a -> ignore (from_host routine.context a : bool)); + Tensor.iter_embedded_arrays learning_rate ~f:(fun a -> + ignore (from_host routine.context a : bool)); Stdio.print_endline {| Due to how the gccjit backend works, since the parameters were constant in the grad_update computation, they did not exist on the device before. Now they do. This would not be needed on the cuda backend.|}; - List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a -> assert (from_host routine.context a)); + List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a -> + assert (from_host routine.context a)); Train.run routine; Tensor.iter_embedded_arrays l ~f:(fun a -> ignore (to_host routine.context a : bool)); await device; diff --git a/lib/PrintBox_utils.ml b/lib/PrintBox_utils.ml index 88300296..f3456b4e 100644 --- a/lib/PrintBox_utils.ml +++ b/lib/PrintBox_utils.ml @@ -25,7 +25,9 @@ let rec boxify (depth : int) (b : dag) : dag = match b with | b when depth <= 0 -> b | `Tree (n, bs) when depth > 0 -> - `Vlist (false, [ `Align (`Center, `Bottom, n); `Hlist (true, List.map ~f:(boxify @@ (depth - 1)) bs) ]) + `Vlist + ( false, + [ `Align (`Center, `Bottom, n); `Hlist (true, List.map ~f:(boxify @@ (depth - 1)) bs) ] ) | `Hlist (bars, bs) -> `Hlist (bars, List.map ~f:(boxify @@ (depth - 1)) bs) | `Vlist (bars, bs) -> `Vlist (bars, List.map ~f:(boxify @@ (depth - 1)) bs) | `Pad b -> `Pad (boxify depth b) @@ -45,7 +47,8 @@ let dag_to_box (b : dag) = | `Hlist (_, bs) -> Set.union_list s @@ List.map ~f:reused bs | `Vlist (_, bs) -> Set.union_list s @@ List.map ~f:reused bs | `Table bss -> - Set.union_list s @@ Array.to_list @@ Array.concat_map bss ~f:(fun bs -> Array.map ~f:reused bs) + Set.union_list s @@ Array.to_list + @@ Array.concat_map bss ~f:(fun bs -> Array.map ~f:reused bs) in let reused = reused b in let open PrintBox in @@ -74,14 +77,18 @@ type plot_spec = | Scatterplot of { points : (float * float) array; pixel : string } | Line_plot of { points : float array; pixel : string } | Boundary_map of { callback : float * float -> bool; pixel_true : string; pixel_false : string } - | Line_plot_adaptive of { callback : float -> float; mutable cache : float Map.M(Float).t; pixel : string } + | Line_plot_adaptive of { + callback : float -> float; + mutable cache : float Map.M(Float).t; + pixel : string; + } [@@deriving sexp_of] let plot_canvas ?canvas ?(size : (int * int) option) (specs : plot_spec list) : float * float * float * float * _ = let open Float in - (* Unfortunately "x" and "y" of a "matrix" are opposite to how we want them displayed -- the first dimension - (i.e. "x") as the horizontal axis. *) + (* Unfortunately "x" and "y" of a "matrix" are opposite to how we want them displayed -- the first + dimension (i.e. "x") as the horizontal axis. *) let (dimx, dimy, canvas) : int * int * _ = match (canvas, size) with | None, None -> invalid_arg "PrintBox_utils.plot: provide ~canvas or ~size" @@ -128,13 +135,16 @@ let plot_canvas ?canvas ?(size : (int * int) option) (specs : plot_spec list) : if Array.is_empty all_x_points then of_int Int.(Array.length all_y_points - 1) else Array.reduce_exn all_x_points ~f:max in - let maxy = if Array.is_empty all_y_points then maxx - minx else Array.reduce_exn all_y_points ~f:max in + let maxy = + if Array.is_empty all_y_points then maxx - minx else Array.reduce_exn all_y_points ~f:max + in let spanx = maxx - minx in let spanx = Float.(if spanx < epsilon_float then 1.0 else spanx) in let spany = maxy - miny in let spany = Float.(if spany < epsilon_float then 1.0 else spany) in let scale_1d y = - try Some (to_int @@ (of_int Int.(dimy - 1) * (y - miny) / spany)) with Invalid_argument _ -> None + try Some (to_int @@ (of_int Int.(dimy - 1) * (y - miny) / spany)) + with Invalid_argument _ -> None in let scale_2d (x, y) = try @@ -226,8 +236,12 @@ let table rows = let speedups = List.map times ~f:(fun x -> max_time /. x) in let mem_gains = List.map sizes ~f:Float.(fun x -> of_int max_size / of_int x) in let small_float = Fn.compose PrintBox.line (Printf.sprintf "%.3f") in - let results = List.map rows ~f:(fun (Benchmark { result; _ }) -> nolines @@ Sexp.to_string_hum result) in - let result_labels = List.map rows ~f:(fun (Benchmark { result_label; _ }) -> nolines result_label) in + let results = + List.map rows ~f:(fun (Benchmark { result; _ }) -> nolines @@ Sexp.to_string_hum result) + in + let result_labels = + List.map rows ~f:(fun (Benchmark { result_label; _ }) -> nolines result_label) + in (* TODO(#140): partition by unique result_label and output a vlist of records. *) PrintBox.( frame diff --git a/lib/operation.ml b/lib/operation.ml index d42488cc..ad0fb648 100644 --- a/lib/operation.ml +++ b/lib/operation.ml @@ -74,21 +74,22 @@ let pointmul ?(label = []) = let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 * v2 in mul Pointwise_bin ~op_asn ~label:("*." :: label) -(* N1: AxB, N2 BxC, v: AxC, A: output of N1, B: input/output of N1/N2, C: input of N2. Although the matrix - algebra would require that we insert additional transposes in gradient multiplies: AxB = AxC * CxB = AxC * - (BxC)^T -> N1g += Ng * N2v^T, BxC = BxA * AxC = (AxB)^T * AxC -> N2g += N1v^T * Ng, in our setup there is - no transposing to do, since the projections produce correct indices for their corresponding matrices. *) +(* N1: AxB, N2 BxC, v: AxC, A: output of N1, B: input/output of N1/N2, C: input of N2. Although the + matrix algebra would require that we insert additional transposes in gradient multiplies: AxB = + AxC * CxB = AxC * (BxC)^T -> N1g += Ng * N2v^T, BxC = BxA * AxC = (AxB)^T * AxC -> N2g += N1v^T * + Ng, in our setup there is no transposing to do, since the projections produce correct indices for + their corresponding matrices. *) let matmul ?(label = []) = let module NTDSL = Initial_NTDSL in let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in mul Compose ~op_asn ~label:("*" :: label) -(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of matrix - multiplication, inner and outer products, etc. +(** Similar to the explicit mode of [numpy.einsum], the binary variant. Can compute various forms of + matrix multiplication, inner and outer products, etc. - Note that ["a,b->c"] from [numpy] is ["a;b=>c"] in OCANNL, since ["->"] is used to separate the input and - the output axes. *) + Note that ["a,b->c"] from [numpy] is ["a;b=>c"] in OCANNL, since ["->"] is used to separate the + input and the output axes. *) let einsum ?(label = []) spec = let module NTDSL = Initial_NTDSL in let%cd op_asn ~v ~t1 ~t2 ~projections = v =:+ v1 * v2 in @@ -108,11 +109,11 @@ let outer_sum ?(label = []) spec = in Tensor.binop ~label:(";=>+" :: label) ~compose_op:(Einsum spec) ~op_asn ~grad_asn -(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract diagonals, - compute traces etc. +(** Similar to the explicit mode of [numpy.einsum], the unary variant. Can permute axes, extract + diagonals, compute traces etc. - Note that ["a->c"] from [numpy] is ["a=>c"] in OCANNL, since ["->"] is used to separate the input and the - output axes. *) + Note that ["a->c"] from [numpy] is ["a=>c"] in OCANNL, since ["->"] is used to separate the + input and the output axes. *) let einsum1 ?(label = []) spec = let module NTDSL = Initial_NTDSL in let%cd op_asn ~v ~t1 ~projections = v =:+ v1 in @@ -173,7 +174,8 @@ let rec pointdiv ?(label : string list = []) ~grad_spec t1 t2 = end end in let%cd op_asn ~v ~t1 ~t2 ~projections = v =: v1 / v2 in - (* We cannot use g in a tensor expression since it's an array, so we keep it to the left (RHS1). *) + (* We cannot use g in a tensor expression since it's an array, so we keep it to the left + (RHS1). *) let%cd grad_asn ~v:_ ~g ~t1 ~t2 ~projections = g1 =+ g / v2; g2 =+ g * (-1 *. t1 /. (t2 **. 2)) @@ -190,18 +192,21 @@ let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto = | None -> result ~output_dims:[ upto + 1 ] () | Some l -> result ~output_axes:[ (l, upto + 1) ] () -let range_of_shape ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?batch_dims ?input_dims ?output_dims - ?batch_axes ?input_axes ?output_axes () = +let range_of_shape ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?batch_dims ?input_dims + ?output_dims ?batch_axes ?input_axes ?output_axes () = let f (dims, axes) = Array.of_list @@ Option.value ~default:[] @@ Option.first_some dims @@ Option.map axes ~f:(List.map ~f:snd) in let dims = - Array.concat_map ~f [| (batch_dims, batch_axes); (output_dims, output_axes); (input_dims, input_axes) |] + Array.concat_map ~f + [| (batch_dims, batch_axes); (output_dims, output_axes); (input_dims, input_axes) |] in let batch_dims = Option.first_some batch_dims @@ Option.some_if (Option.is_none batch_axes) [] in let input_dims = Option.first_some input_dims @@ Option.some_if (Option.is_none input_axes) [] in - let output_dims = Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) [] in + let output_dims = + Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) [] + in Tensor.term ~label:(("r" ^ Idx.dims_to_string dims) :: label) ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes @@ -226,13 +231,18 @@ let slice ?(label = []) ~grad_spec (batch_idx : Idx.static_symbol) t1 : Tensor.t } in let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ g in - Tensor.unop ~label:("@|" :: label) ~transpose_op:(Batch_slice batch_idx) ~op_asn ~grad_asn ~grad_spec t1 + Tensor.unop ~label:("@|" :: label) ~transpose_op:(Batch_slice batch_idx) ~op_asn ~grad_asn + ~grad_spec t1 let embed_symbol ?(label = []) static_sym : Tensor.t = let module NTDSL = Initial_NTDSL in let op_asn ~v ~projections = Asgns.Fetch - { array = v; fetch_op = Embed_symbol static_sym; dims = lazy (Lazy.force projections).Idx.lhs_dims } + { + array = v; + fetch_op = Embed_symbol static_sym; + dims = lazy (Lazy.force projections).Idx.lhs_dims; + } in let grad_asn ~v:_ ~g:_ ~projections:_ = Asgns.Noop in Tensor.op ~label:("!@" :: label) ~op_asn ~grad_asn ~grad_spec:Prohibit_grad @@ -273,11 +283,14 @@ module TDSL = struct let range_of_shape = range_of_shape ~grad_spec:If_needed let stop_gradient = stop_gradient - (** The input [i] dimensions default to empty. The batch dimensions will be inferred if omitted. [strict] - controls whether [Constant_fill] will try to fit the given values in the tensor and contribute to shape - inference. If it is not provided explicitly, it will be [true] if [b] is omitted, and [false] otherwise. *) + (** The input [i] dimensions default to empty. The batch dimensions will be inferred if omitted. + [strict] controls whether [Constant_fill] will try to fit the given values in the tensor and + contribute to shape inference. If it is not provided explicitly, it will be [true] if [b] is + omitted, and [false] otherwise. *) let init_const ~l ?strict ?b ?(i = []) ~o values = - let strict = match (strict, b) with Some s, _ -> s | None, Some _ -> false | None, None -> true in + let strict = + match (strict, b) with Some s, _ -> s | None, Some _ -> false | None, None -> true + in Tensor.term ~label:[ l ] ~grad_spec:Prohibit_grad ?batch_dims:b ~input_dims:i ~output_dims:o ~init_op:(Constant_fill { values; strict }) () diff --git a/lib/ppx_cd.ml b/lib/ppx_cd.ml index 14a0e968..292aa546 100644 --- a/lib/ppx_cd.ml +++ b/lib/ppx_cd.ml @@ -14,12 +14,15 @@ let ndarray_op ~ident_label ?axis_labels ?label expr = | None, Some label -> [%expr NTDSL.ndarray ~label:[%e label]] | Some axis_labels, Some label -> [%expr - NTDSL.ndarray ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_labels:[%e axis_labels] - ~label:[%e label]] + NTDSL.ndarray + ~label:[%e opt_pat2string_list ~loc ident_label] + ~axis_labels:[%e axis_labels] ~label:[%e label]] in [%expr - [%e op] ~label:[%e opt_pat2string_list ~loc ident_label] ~batch_dims:[%e edims batch_dims] - ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] [%e values]] + [%e op] + ~label:[%e opt_pat2string_list ~loc ident_label] + ~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims] + ~output_dims:[%e edims output_dims] [%e values]] type expr_type = | Code @@ -55,10 +58,12 @@ let assignment_op expr = | _ -> ( false, Ast_builder.Default.pexp_extension ~loc - @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected an assignment operator, one of: %s %s" - "=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2), =:+, =:-," - " =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before the start \ - of the calculation)" ) + @@ Location.error_extensionf ~loc + "ppx_ocannl %%cd: expected an assignment operator, one of: %s %s" + "=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2), \ + =:+, =:-," + " =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before \ + the start of the calculation)" ) let binary_op expr = (* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *) @@ -75,7 +80,8 @@ let binary_op expr = | [%expr ( / )] -> ( Ast_builder.Default.pexp_extension ~loc @@ Location.error_extensionf ~loc - "For clarity, no default compose type for binary `/`, use ~logic:\".\" for pointwise division", + "For clarity, no default compose type for binary `/`, use ~logic:\".\" for pointwise \ + division", [%expr Arrayjit.Ops.Div] ) | [%expr ( ** )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.ToPowOf]) | [%expr ( -?/ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate]) @@ -87,7 +93,8 @@ let binary_op expr = @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a binary operator, one of: %s" "+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2)" ) -let is_binary_op ident = List.mem [ "+"; "-"; "*"; "/"; "**"; "-?/"; "-/>"; "-@>" ] ident ~equal:String.equal +let is_binary_op ident = + List.mem [ "+"; "-"; "*"; "/"; "**"; "-?/"; "-/>"; "-@>" ] ident ~equal:String.equal let unary_op expr = (* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *) @@ -147,7 +154,8 @@ let project_p_slot debug loc slot = "ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s" debug | Undet -> Ast_builder.Default.pexp_extension ~loc - @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug + @@ Location.error_extensionf ~loc + "ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug "(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)" let project_p_dims debug loc slot = @@ -161,7 +169,8 @@ let project_p_dims debug loc slot = "ppx_ocannl %%cd: not a valid accumulation/assignment slot filler at %s" debug | Undet -> Ast_builder.Default.pexp_extension ~loc - @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug + @@ Location.error_extensionf ~loc + "ppx_ocannl %%cd: insufficient slot filler information at %s %s" debug "(incorporate one of: v, v1, v2, g, g1, g2, lhs, rhs, rhs1, rhs2)" type array_setup = { @@ -178,7 +187,8 @@ let setup_array ~is_lhs filler_pat (filler_typ, slot, filler) = if is_lhs then [%expr Some [%e tn]] else [%expr Some (Arrayjit.Assignments.Node [%e tn])] in let buffer opt_tn = - if is_lhs then opt_tn else [%expr Option.map [%e opt_tn] ~f:(fun tn -> Arrayjit.Assignments.Node tn)] + if is_lhs then opt_tn + else [%expr Option.map [%e opt_tn] ~f:(fun tn -> Arrayjit.Assignments.Node tn)] in match filler_typ with | Tensor | Unknown -> @@ -191,7 +201,8 @@ let setup_array ~is_lhs filler_pat (filler_typ, slot, filler) = else Arrayjit.Assignments.Noop] in { - binding = Some { var = filler_pat; lazy_bind_to = [%expr lazy [%e filler]]; fwd_code_or_noop }; + binding = + Some { var = filler_pat; lazy_bind_to = [%expr lazy [%e filler]]; fwd_code_or_noop }; filler_typ; slot; array_opt = opt_buffer [%expr [%e t].value]; @@ -212,8 +223,10 @@ let setup_array ~is_lhs filler_pat (filler_typ, slot, filler) = tensor = None; } | Array -> { binding = None; filler_typ; slot; array_opt = opt_buffer filler; tensor = None } - | Value_of_tensor t -> { binding = None; filler_typ; slot; array_opt = opt_buffer filler; tensor = Some t } - | Grad_of_tensor t -> { binding = None; filler_typ; slot; array_opt = buffer filler; tensor = Some t } + | Value_of_tensor t -> + { binding = None; filler_typ; slot; array_opt = opt_buffer filler; tensor = Some t } + | Grad_of_tensor t -> + { binding = None; filler_typ; slot; array_opt = buffer filler; tensor = Some t } | (Merge_value _ | Merge_grad _) when is_lhs -> { binding = None; @@ -237,7 +250,8 @@ let setup_array ~is_lhs filler_pat (filler_typ, slot, filler) = binding = None; filler_typ; slot; - array_opt = [%expr Option.map [%e filler] ~f:(fun tn -> Arrayjit.Assignments.Merge_buffer tn)]; + array_opt = + [%expr Option.map [%e filler] ~f:(fun tn -> Arrayjit.Assignments.Merge_buffer tn)]; tensor = Some t; } @@ -249,40 +263,53 @@ let args_for ~loc = function | _ -> ( Ast_builder.Default.pexp_extension ~loc @@ Location.error_extensionf ~loc - "ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or `.value` or \ - `.grad` notation", + "ppx_ocannl %%cd: cannot use `~logic` (infer shapes) for arrays, use tensors or \ + `.value` or `.grad` notation", [%expr false], [%expr false] ) -let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * projections_slot * expression - = +let rec translate ?ident_label ~proj_in_scope (expr : expression) : + expr_type * projections_slot * expression = let loc = expr.pexp_loc in match expr with | { pexp_desc = Pexp_constant (Pconst_float _); _ } -> - (Tensor, Undet, [%expr NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]]) + ( Tensor, + Undet, + [%expr NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]] ) | { pexp_desc = Pexp_constant (Pconst_integer _); _ } -> ( Tensor, Undet, - [%expr NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] (Float.of_int [%e expr])] ) + [%expr + NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] (Float.of_int [%e expr])] ) | [%expr [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }] [%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] -> - let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in + let axis = + Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) + in ( Tensor, Undet, - [%expr NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]] ) + [%expr + NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]] + ) | [%expr [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }] [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] -> - let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in + let axis = + Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) + in ( Tensor, Undet, [%expr - NTDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] + NTDSL.number + ~label:[%e opt_pat2string_list ~loc ident_label] + ~axis_label:[%e axis] (Float.of_int [%e i])] ) - (* | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident "merge_buffer"; _ }; _ }] [%e? for_tn]] -> let - typ1, slot1, e1 = translate ~proj_in_scope for_tn in ( typ1, slot1, [%expr ] ) *) - | { pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> + (* | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident "merge_buffer"; _ }; _ }] [%e? + for_tn]] -> let typ1, slot1, e1 = translate ~proj_in_scope for_tn in ( typ1, slot1, [%expr + ] ) *) + | { pexp_desc = Pexp_array _; _ } + | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> (Tensor, Undet, ndarray_op ~ident_label expr) | { pexp_desc = Pexp_ident { txt = Lident ("v" | "lhs"); _ }; _ } -> (Array, LHS, expr) | { pexp_desc = Pexp_ident { txt = Lident "g"; _ }; _ } -> (Array, LHS, expr) @@ -290,12 +317,16 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * | { pexp_desc = Pexp_ident { txt = Lident "t1"; _ }; _ } -> (Tensor, RHS1, expr) | { pexp_desc = Pexp_ident { txt = Lident "v1"; _ }; _ } -> (Array, RHS1, [%expr t1.Tensor.value]) | { pexp_desc = Pexp_ident { txt = Lident "g1"; _ }; _ } -> - (Grad_of_tensor [%expr t1], RHS1, [%expr Option.map t1.Tensor.diff ~f:(fun d -> d.Tensor.grad)]) + ( Grad_of_tensor [%expr t1], + RHS1, + [%expr Option.map t1.Tensor.diff ~f:(fun d -> d.Tensor.grad)] ) | { pexp_desc = Pexp_ident { txt = Lident "rhs2"; _ }; _ } -> (Array, RHS2, expr) | { pexp_desc = Pexp_ident { txt = Lident "t2"; _ }; _ } -> (Tensor, RHS2, expr) | { pexp_desc = Pexp_ident { txt = Lident "v2"; _ }; _ } -> (Array, RHS2, [%expr t2.Tensor.value]) | { pexp_desc = Pexp_ident { txt = Lident "g2"; _ }; _ } -> - (Grad_of_tensor [%expr t2], RHS2, [%expr Option.map t2.Tensor.diff ~f:(fun d -> d.Tensor.grad)]) + ( Grad_of_tensor [%expr t2], + RHS2, + [%expr Option.map t2.Tensor.diff ~f:(fun d -> d.Tensor.grad)] ) | { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident -> (Tensor, Undet, expr) | [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] -> @@ -304,13 +335,17 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * let _typ1, slot1, e1 = translate ~proj_in_scope expr1 in ( Tensor, slot1, - [%expr NTDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] (Float.of_int [%e i])] - ) + [%expr + NTDSL.O.( **. ) + ~label:[%e opt_pat2string_list ~loc ident_label] + [%e e1] + (Float.of_int [%e i])] ) | [%expr [%e? expr1] **. [%e? expr2]] -> let _typ1, slot1, e1 = translate ~proj_in_scope expr1 in ( Tensor, slot1, - [%expr NTDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]] ) + [%expr NTDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]] + ) | [%expr [%e? expr1] *+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [%e? expr2]] @@ -318,28 +353,36 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * let _typ1, slot1, expr1 = translate ~proj_in_scope expr1 in let _typ2, slot2, expr2 = translate ~proj_in_scope expr2 in let slot = - Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) [ slot1; slot2 ] + Option.value ~default:Undet + @@ List.find ~f:(function Undet -> false | _ -> true) [ slot1; slot2 ] in ( Tensor, slot, - [%expr NTDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e expr1] [%e expr2]] - ) - | [%expr [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]] + [%expr + NTDSL.einsum + ~label:[%e opt_pat2string_list ~loc ident_label] + [%e spec] [%e expr1] [%e expr2]] ) + | [%expr + [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]] when String.contains spec_str '>' -> let _typ1, slot1, expr1 = translate ~proj_in_scope expr1 in ( Tensor, slot1, - [%expr NTDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e expr1]] ) + [%expr NTDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e expr1]] + ) | [%expr [%e? expr1].grad] -> ( let typ1, slot1, expr1 = translate ?ident_label ~proj_in_scope expr1 in match typ1 with | Unknown | Tensor -> - (Grad_of_tensor expr1, slot1, [%expr Option.map [%e expr1].Tensor.diff ~f:(fun d -> d.Tensor.grad)]) + ( Grad_of_tensor expr1, + slot1, + [%expr Option.map [%e expr1].Tensor.diff ~f:(fun d -> d.Tensor.grad)] ) | Merge_value _ -> ( Merge_grad expr1, slot1, Ast_builder.Default.pexp_extension ~loc - @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: write .grad.merge instead of .merge.grad" ) + @@ Location.error_extensionf ~loc + "ppx_ocannl %%cd: write .grad.merge instead of .merge.grad" ) | Code | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_grad _ -> ( Array, slot1, @@ -369,16 +412,24 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * slot1, Ast_builder.Default.pexp_extension ~loc @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: repeated .merge not allowed" )) - | [%expr [%e? accu_op] [%e? lhs] ([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] -> + | [%expr + [%e? accu_op] [%e? lhs] ([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] + -> let initialize_neutral, accu_op = assignment_op accu_op in let setup_l = - setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope:true lhs + setup_array ~is_lhs:true [%pat? nondiff___lhs] + @@ translate ?ident_label ~proj_in_scope:true lhs in let _, bin_op = binary_op bin_op in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope:true rhs1 in - let setup_r2 = setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope:true rhs2 in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope:true rhs1 + in + let setup_r2 = + setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope:true rhs2 + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in - (* TODO: might be better to treat missing [rhs1, rhs2] as zeros rather than eliding the code. *) + (* TODO: might be better to treat missing [rhs1, rhs2] as zeros rather than eliding the + code. *) let body = [%expr Option.value ~default:Arrayjit.Assignments.Noop @@ -395,19 +446,25 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * projections = [%e projections]; })] in - let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in + let setups = + List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] + in with_forward_args setups body | [%expr [%e? accu_op] [%e? lhs] (([%e? un_op] [%e? rhs]) ~projections:[%e? projections])] | [%expr [%e? accu_op] [%e? lhs] ([%e? un_op] ([%e? rhs] ~projections:[%e? projections]))] -> (* Handle both un_op priority levels -- where application binds tighter and less tight. *) let initialize_neutral, accu_op = assignment_op accu_op in - (* FIXME: I think this ignores the slot information here! Just assuming [projections] is as-should-be, - but that's not consistent with omitting the projections arg (assuming it comes from the context). *) + (* FIXME: I think this ignores the slot information here! Just assuming [projections] is + as-should-be, but that's not consistent with omitting the projections arg (assuming it + comes from the context). *) let setup_l = - setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope:true lhs + setup_array ~is_lhs:true [%pat? nondiff___lhs] + @@ translate ?ident_label ~proj_in_scope:true lhs in let _, un_op = unary_op un_op in - let setup_r = setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope:true rhs in + let setup_r = + setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope:true rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in (* TODO: might be better to treat missing [rhs] as zeros rather than eliding the code. *) let body = @@ -429,9 +486,12 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * | [%expr [%e? accu_op] [%e? lhs] ([%e? rhs] ~projections:[%e? projections])] -> let initialize_neutral, accu_op = assignment_op accu_op in let setup_l = - setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope:true lhs + setup_array ~is_lhs:true [%pat? nondiff___lhs] + @@ translate ?ident_label ~proj_in_scope:true lhs + in + let setup_r = + setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope:true rhs in - let setup_r = setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope:true rhs in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let body = [%expr @@ -455,7 +515,9 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * ([%e? bin_op] [%e? rhs1] ([%e? rhs2] - ~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))] -> + ~logic: + [%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))] + -> let logic = let loc = s_loc in if String.equal spec "." then [%expr Shape.Pointwise_bin] @@ -467,20 +529,26 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let _, bin_op = binary_op bin_op in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 in - let setup_r2 = setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 + in + let setup_r2 = + setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in let body = [%expr - Tensor.raw_binop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] ~t:[%e t_expr] - ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e bin_op] ~t1:[%e t1_expr] ~rhs1_is_grad:[%e rhs1_is_grad] - ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr] ~rhs2_is_grad:[%e rhs2_is_grad] - ~rhs2_is_merge:[%e rhs2_is_merge] ~logic:[%e logic]] + Tensor.raw_binop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] + ~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e bin_op] ~t1:[%e t1_expr] + ~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr] + ~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~logic:[%e logic]] + in + let setups = + List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in - let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in with_forward_args setups body | [%expr [%e? accu_op] @@ -491,7 +559,9 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * [%e? accu_op] [%e? lhs] ([%e? un_op] - ([%e? rhs] ~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))] + ([%e? rhs] + ~logic: + [%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ } as logic]))] -> (* Handle both un_op priority levels -- where application binds tighter and less tight. *) let logic = @@ -505,15 +575,17 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let _, un_op = unary_op un_op in - let setup_r = setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs in + let setup_r = + setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in let body = [%expr - Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] ~t:[%e t_expr] - ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e un_op] ~t1:[%e t1_expr] ~rhs_is_grad:[%e rhs_is_grad] - ~rhs_is_merge:[%e rhs_is_merge] ~logic:[%e logic]] + Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] + ~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e un_op] ~t1:[%e t1_expr] + ~rhs_is_grad:[%e rhs_is_grad] ~rhs_is_merge:[%e rhs_is_merge] ~logic:[%e logic]] in let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r ] in with_forward_args setups body @@ -529,8 +601,12 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let _, bin_op = binary_op bin_op in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 in - let setup_r2 = setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 + in + let setup_r2 = + setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let projections = let lhs_dims = project_p_dims "LHS" lhs.pexp_loc setup_l.slot in @@ -554,7 +630,9 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * { p.debug_info with trace = - ( "ppx_cd " ^ [%e string_expr ~loc:accu_loc accu_ident] ^ " " + ( "ppx_cd " + ^ [%e string_expr ~loc:accu_loc accu_ident] + ^ " " ^ [%e string_expr ~loc:op_loc binop_ident], Arrayjit.Indexing.unique_debug_id () ) :: p.debug_info.trace; @@ -577,19 +655,24 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * projections = [%e projections]; })] in - let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in + let setups = + List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] + in with_forward_args setups body | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; loc = accu_loc }; _ } as accu_op] [%e? lhs] - ([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; loc = op_loc }; _ } as un_op] [%e? rhs])] + ([%e? { pexp_desc = Pexp_ident { txt = Lident unop_ident; loc = op_loc }; _ } as un_op] + [%e? rhs])] when is_assignment accu_ident && is_unary_op unop_ident && proj_in_scope -> let initialize_neutral, accu_op = assignment_op accu_op in let setup_l = setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let _, un_op = unary_op un_op in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let projections = let lhs_dims = project_p_dims "LHS" lhs.pexp_loc setup_l.slot in @@ -611,7 +694,9 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * { p.debug_info with trace = - ( "ppx_cd " ^ [%e string_expr ~loc:accu_loc accu_ident] ^ " " + ( "ppx_cd " + ^ [%e string_expr ~loc:accu_loc accu_ident] + ^ " " ^ [%e string_expr ~loc:op_loc unop_ident], Arrayjit.Indexing.unique_debug_id () ) :: p.debug_info.trace; @@ -643,7 +728,9 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * let setup_l = setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let projections = let lhs_dims = project_p_dims "LHS" lhs.pexp_loc setup_l.slot in @@ -690,27 +777,35 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op] [%e? lhs] - ([%e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op] [%e? rhs1] [%e? rhs2])] + ([%e? { pexp_desc = Pexp_ident { txt = Lident binop_ident; _ }; _ } as bin_op] + [%e? rhs1] + [%e? rhs2])] when is_assignment accu_ident && is_binary_op binop_ident -> let initialize_neutral, accu_op = assignment_op accu_op in let setup_l = setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let logic, bin_op = binary_op bin_op in - let setup_r1 = setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 in - let setup_r2 = setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 in + let setup_r1 = + setup_array ~is_lhs:false [%pat? nondiff___rhs1] @@ translate ~proj_in_scope rhs1 + in + let setup_r2 = + setup_array ~is_lhs:false [%pat? nondiff___rhs2] @@ translate ~proj_in_scope rhs2 + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in let body = [%expr - Tensor.raw_binop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] ~t:[%e t_expr] - ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e bin_op] ~t1:[%e t1_expr] ~rhs1_is_grad:[%e rhs1_is_grad] - ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr] ~rhs2_is_grad:[%e rhs2_is_grad] - ~rhs2_is_merge:[%e rhs2_is_merge] ~logic:[%e logic]] + Tensor.raw_binop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] + ~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e bin_op] ~t1:[%e t1_expr] + ~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr] + ~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~logic:[%e logic]] + in + let setups = + List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in - let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r1; setup_r2 ] in with_forward_args setups body | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op] @@ -722,32 +817,39 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in let logic, un_op = unary_op un_op in - let setup_r = setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs in + let setup_r = + setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in let body = [%expr - Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] ~t:[%e t_expr] - ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e un_op] ~t1:[%e t1_expr] ~rhs_is_grad:[%e rhs_is_grad] - ~rhs_is_merge:[%e rhs_is_merge] ~logic:[%e logic]] + Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] + ~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e un_op] ~t1:[%e t1_expr] + ~rhs_is_grad:[%e rhs_is_grad] ~rhs_is_merge:[%e rhs_is_merge] ~logic:[%e logic]] in let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r ] in with_forward_args setups body - | [%expr [%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op] [%e? lhs] [%e? rhs]] + | [%expr + [%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } as accu_op] + [%e? lhs] + [%e? rhs]] when is_assignment op_ident -> let initialize_neutral, accu_op = assignment_op accu_op in let setup_l = setup_array ~is_lhs:true [%pat? nondiff___lhs] @@ translate ?ident_label ~proj_in_scope lhs in - let setup_r = setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs in + let setup_r = + setup_array ~is_lhs:false [%pat? nondiff___rhs] @@ translate ~proj_in_scope rhs + in let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in let body = [%expr - Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] ~t:[%e t_expr] - ~lhs_is_grad:[%e lhs_is_grad] ~op:Arrayjit.Ops.Identity ~t1:[%e t1_expr] + Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op] + ~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:Arrayjit.Ops.Identity ~t1:[%e t1_expr] ~rhs_is_grad:[%e rhs_is_grad] ~rhs_is_merge:[%e rhs_is_merge] ~logic:Shape.Pointwise_un] in let setups = List.filter_map ~f:(fun setup -> setup.binding) [ setup_l; setup_r ] in @@ -765,7 +867,8 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * let typ1, slot1, expr1 = translate ?ident_label ~proj_in_scope expr1 in let _typ2, slot2, expr2 = translate ~proj_in_scope expr2 in let slot = - Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) [ slot1; slot2 ] + Option.value ~default:Undet + @@ List.find ~f:(function Undet -> false | _ -> true) [ slot1; slot2 ] in (typ1, slot, [%expr [%e expr1] [%e expr2]]) | { pexp_desc = Pexp_fun ((arg_label : arg_label), arg, opt_val, body); _ } as expr -> @@ -821,7 +924,8 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * slot, [%expr Arrayjit.Assignments.Block_comment - (String.concat_array ~sep:" " [%e Ast_helper.Exp.array ~loc:pexp_loc elements], [%e body])] ) + ( String.concat_array ~sep:" " [%e Ast_helper.Exp.array ~loc:pexp_loc elements], + [%e body] )] ) | [%expr [%e? expr1]; [%e? expr2]] -> @@ -833,7 +937,8 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * let typ3, slot3, expr3 = translate ?ident_label ~proj_in_scope expr3 in let typ = if is_unknown typ2 then typ3 else typ2 in let slot = - Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) [ slot2; slot3 ] + Option.value ~default:Undet + @@ List.find ~f:(function Undet -> false | _ -> true) [ slot2; slot3 ] in (typ, slot, [%expr if [%e expr1] then [%e expr2] else [%e expr3]]) | [%expr if [%e? expr1] then [%e? expr2]] -> @@ -847,16 +952,19 @@ let rec translate ?ident_label ~proj_in_scope (expr : expression) : expr_type * (typ, slot, { c with pc_rhs })) in let typ = Option.value ~default:Unknown @@ List.find typs ~f:(Fn.non is_unknown) in - let slot = Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) slots in + let slot = + Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) slots + in (typ, slot, { expr with pexp_desc = Pexp_match (expr1, cases) }) | { pexp_desc = Pexp_let (_recflag, _bindings, _body); _ } -> (* TODO(80): to properly support local bindings, we need to collect the type environment. *) ( Unknown, Undet, Ast_builder.Default.pexp_extension ~loc - @@ Location.error_extensionf ~loc "ppx_ocannl %%cd: let-in: local let-bindings not implemented yet" ) - (* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=translate binding.pvb_expr}) - in {expr with pexp_desc=Pexp_let (recflag, bindings, translate body)} *) + @@ Location.error_extensionf ~loc + "ppx_ocannl %%cd: let-in: local let-bindings not implemented yet" ) + (* let bindings = List.map bindings ~f:(fun binding -> {binding with pvb_expr=translate + binding.pvb_expr}) in {expr with pexp_desc=Pexp_let (recflag, bindings, translate body)} *) | { pexp_desc = Pexp_open (decl, body); _ } -> let typ, slot, body = translate ?ident_label ~proj_in_scope body in (typ, slot, { expr with pexp_desc = Pexp_open (decl, body) }) diff --git a/lib/ppx_ocannl.ml b/lib/ppx_ocannl.ml index e57b41d9..01a4fd2d 100644 --- a/lib/ppx_ocannl.ml +++ b/lib/ppx_ocannl.ml @@ -7,13 +7,17 @@ let rules = @@ Extension.declare "cd" Extension.Context.expression Ast_pattern.(single_expr_payload __) @@ Ppx_cd.expr_expander; Ppxlib.Context_free.Rule.extension - @@ Extension.declare "cd" Extension.Context.structure_item Ast_pattern.(pstr __) Ppx_cd.str_expander; + @@ Extension.declare "cd" Extension.Context.structure_item + Ast_pattern.(pstr __) + Ppx_cd.str_expander; Ppxlib.Context_free.Rule.extension @@ Extension.declare "op" Extension.Context.expression Ast_pattern.(single_expr_payload __) Ppx_op.expr_expander; Ppxlib.Context_free.Rule.extension - @@ Extension.declare "op" Extension.Context.structure_item Ast_pattern.(pstr __) Ppx_op.str_expander; + @@ Extension.declare "op" Extension.Context.structure_item + Ast_pattern.(pstr __) + Ppx_op.str_expander; ] let () = Driver.register_transformation ~rules "ppx_ocannl" diff --git a/lib/ppx_op.ml b/lib/ppx_op.ml index 5f2b679e..09bb9c35 100644 --- a/lib/ppx_op.ml +++ b/lib/ppx_op.ml @@ -13,8 +13,10 @@ let ndarray_op ?ident_label ?axis_labels expr = | Some axis_labels -> [%expr TDSL.ndarray ~axis_labels:[%e axis_labels]] in [%expr - [%e op] ~label:[%e opt_pat2string_list ~loc ident_label] ~batch_dims:[%e edims batch_dims] - ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] [%e values]] + [%e op] + ~label:[%e opt_pat2string_list ~loc ident_label] + ~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims] + ~output_dims:[%e edims output_dims] [%e values]] let make_vb ?value ~loc ~str_loc ~ident string = let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in @@ -49,8 +51,8 @@ let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string = | Some axis_labels -> [%expr TDSL.param ~axis_labels:[%e axis_labels]] in [%expr - [%e op] ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] ~values:[%e values] - [%e string]] + [%e op] ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] + ~values:[%e values] [%e string]] in let vb = Ast_helper.Vb.mk ~loc pat v in (pat, vb) @@ -65,16 +67,24 @@ let rec translate ?ident_label expr = | [%expr [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }] [%e? { pexp_desc = Pexp_constant (Pconst_float _); _ } as f]] -> - let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in + let axis = + Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) + in ( no_vbs, - [%expr TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]] ) + [%expr + TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]] + ) | [%expr [%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }] [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] -> - let axis = Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) in + let axis = + Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None)) + in ( no_vbs, [%expr - TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] + TDSL.number + ~label:[%e opt_pat2string_list ~loc ident_label] + ~axis_label:[%e axis] (Float.of_int [%e i])] ) | [%expr [%e? expr1] @@ -83,8 +93,10 @@ let rec translate ?ident_label expr = let vbs1, e1 = translate expr1 in let vbs2, e2 = translate expr2 in ( reduce_vbss [ vbs1; vbs2 ], - [%expr TDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1] [%e e2]] ) - | [%expr [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]] + [%expr + TDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1] [%e e2]] ) + | [%expr + [%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]] when String.contains spec_str '>' -> let vbs1, e1 = translate expr1 in (vbs1, [%expr TDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1]]) @@ -98,8 +110,8 @@ let rec translate ?ident_label expr = | [%expr [%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s] [%e? - ({ pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ }) - as init_nd]] -> + ( { pexp_desc = Pexp_array _; _ } + | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ) as init_nd]] -> let pat, vb = make_vb_nd ~loc ~str_loc ~ident ~init_nd s in (Map.singleton (module String) ident vb, pat2expr pat) | [%expr @@ -110,17 +122,24 @@ let rec translate ?ident_label expr = | { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } -> let pat, vb = make_vb ~loc ~str_loc ~ident expr in (Map.singleton (module String) ident vb, pat2expr pat) - | { pexp_desc = Pexp_array _; _ } | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> + | { pexp_desc = Pexp_array _; _ } + | { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } -> (no_vbs, ndarray_op ?ident_label expr) | [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] -> - (* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *) + (* We need to hardcode these two patterns to prevent the numbers from being converted to + tensors. *) let vbs, e1 = translate expr1 in ( vbs, - [%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] (Float.of_int [%e i])] - ) + [%expr + TDSL.O.( **. ) + ~label:[%e opt_pat2string_list ~loc ident_label] + [%e e1] + (Float.of_int [%e i])] ) | [%expr [%e? expr1] **. [%e? expr2]] -> let vbs, e1 = translate expr1 in - (vbs, [%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]]) + ( vbs, + [%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]] + ) | [%expr [%e? expr1] [%e? expr2] [%e? expr3]] -> let vbs1, e1 = translate ?ident_label expr1 in let vbs2, e2 = translate expr2 in @@ -133,7 +152,8 @@ let rec translate ?ident_label expr = | [%expr fun ~config [%p? pat1] [%p? pat2] -> [%e? body]] -> (* TODO(#38): generalize config to any number of labeled arguments with any labels. *) let vbs, body = translate ?ident_label body in - (no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat1] [%p pat2] -> [%e body]]]]) + ( no_vbs, + [%expr fun ~config -> [%e let_opt ~loc vbs [%expr fun [%p pat1] [%p pat2] -> [%e body]]]] ) | [%expr fun ~config [%p? pat] -> [%e? body]] -> (* TODO(#38): generalize config to any number of labeled arguments with any labels. *) let vbs, body = translate ?ident_label body in @@ -266,7 +286,8 @@ let translate_str ({ pstr_desc; pstr_loc = loc; _ } as str) = pvb_expr = [%expr let open! TDSL.O in - [%e if is_unused then [%expr Tensor.with_unchanged_roots ~f:(fun () -> [%e v])] else v]]; + [%e + if is_unused then [%expr Tensor.with_unchanged_roots ~f:(fun () -> [%e v])] else v]]; } in { str with pstr_desc = Pstr_value (recf, List.map bindings ~f) } diff --git a/lib/ppx_shared.ml b/lib/ppx_shared.ml index 21de8c4d..7751c9b1 100644 --- a/lib/ppx_shared.ml +++ b/lib/ppx_shared.ml @@ -30,8 +30,14 @@ let pat2string pat = in string_expr ~loc:pat.ppat_loc @@ loop pat -let opt_pat2string ~loc = function None -> [%expr None] | Some pat -> [%expr Some [%e pat2string pat]] -let opt_pat2string_list ~loc = function None -> [%expr []] | Some pat -> [%expr [ [%e pat2string pat] ]] +let opt_pat2string ~loc = function + | None -> [%expr None] + | Some pat -> [%expr Some [%e pat2string pat]] + +let opt_pat2string_list ~loc = function + | None -> [%expr []] + | Some pat -> [%expr [ [%e pat2string pat] ]] + let opt_expr ~loc = function None -> [%expr None] | Some expr -> [%expr Some [%e expr]] let rec pat2expr pat = @@ -39,7 +45,8 @@ let rec pat2expr pat = let loc = pat.ppat_loc in match pat.ppat_desc with | Ppat_constraint (pat', typ) -> Ast.pexp_constraint ~loc (pat2expr pat') typ - | Ppat_alias (_, ident) | Ppat_var ident -> Ast.pexp_ident ~loc { ident with txt = Lident ident.txt } + | Ppat_alias (_, ident) | Ppat_var ident -> + Ast.pexp_ident ~loc { ident with txt = Lident ident.txt } | Ppat_variant (ident, e_opt) -> Ast.pexp_variant ~loc ident @@ Option.map e_opt ~f:pat2expr | Ppat_constant c -> Ast.pexp_constant ~loc c | Ppat_construct (c, None) -> Ast.pexp_construct ~loc c None diff --git a/lib/row.ml b/lib/row.ml index ebee0d4b..39778fd3 100644 --- a/lib/row.ml +++ b/lib/row.ml @@ -109,24 +109,35 @@ type row_constraint = (** The row or remainder of a row, inclusive of the further row spec, has this many elements. *) [@@deriving equal, hash, compare, sexp, variants] -(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and [subr] must - be sorted using the [@@deriving compare] comparison. *) +(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and + [subr] must be sorted using the [@@deriving compare] comparison. *) type dim_entry = | Solved_dim of dim - | Bounds_dim of { cur : dim_var list; subr : dim_var list; lub : dim option; constr : dim_constraint } + | Bounds_dim of { + cur : dim_var list; + subr : dim_var list; + lub : dim option; + constr : dim_constraint; + } [@@deriving sexp] type row_entry = | Solved_row of t - | Bounds_row of { cur : row_var list; subr : row_var list; lub : t option; constr : row_constraint } + | Bounds_row of { + cur : row_var list; + subr : row_var list; + lub : t option; + constr : row_constraint; + } [@@deriving sexp] type dim_env = dim_entry Map.M(Dim_var).t [@@deriving sexp] type row_env = row_entry Map.M(Row_var).t [@@deriving sexp] type environment = { dim_env : dim_env; row_env : row_env } [@@deriving sexp] -(** The environment is only in resolved wrt. variables that are solved: [v -> Solved ...] do not appear - elsewhere in the environment. In particular, per-dim and per-row constraints might not have been applied. *) +(** The environment is only in resolved wrt. variables that are solved: [v -> Solved ...] do not + appear elsewhere in the environment. In particular, per-dim and per-row constraints might not + have been applied. *) type constraint_ = | Dim_eq of { d1 : dim; d2 : dim } @@ -139,7 +150,8 @@ type constraint_ = | Terminal_row of t [@@deriving compare, equal, sexp, variants] -type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7 [@@deriving sexp, equal, compare] +type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7 +[@@deriving sexp, equal, compare] let is_stage2_up = function Stage1 -> false | _ -> true let is_stage3_up = function Stage1 | Stage2 -> false | _ -> true @@ -151,19 +163,27 @@ let is_stage7 = function Stage7 -> true | _ -> false module Idx = Arrayjit.Indexing type error_trace = .. -type error_trace += Row_mismatch of t list | Dim_mismatch of dim list | Index_mismatch of Idx.axis_index list + +type error_trace += + | Row_mismatch of t list + | Dim_mismatch of dim list + | Index_mismatch of Idx.axis_index list let sexp_of_error_trace = function | Row_mismatch rs -> Sexp.List (Sexp.Atom "Row_mismatch" :: List.map rs ~f:sexp_of_t) | Dim_mismatch ds -> Sexp.List (Sexp.Atom "Dim_mismatch" :: List.map ds ~f:sexp_of_dim) - | Index_mismatch idcs -> Sexp.List (Sexp.Atom "Index_mismatch" :: List.map idcs ~f:Idx.sexp_of_axis_index) + | Index_mismatch idcs -> + Sexp.List (Sexp.Atom "Index_mismatch" :: List.map idcs ~f:Idx.sexp_of_axis_index) | _ -> Sexp.Atom "" exception Shape_error of string * error_trace list [@@deriving sexp_of] type source = Direct | Equation | Cur | Subr [@@deriving equal, sexp] -let dim_to_int_exn = function Dim { d; _ } -> d | Var _ -> invalid_arg "dim_to_int: dim still unknown" +let dim_to_int_exn = function + | Dim { d; _ } -> d + | Var _ -> invalid_arg "dim_to_int: dim still unknown" + let s_dim_one v ~value ~in_ = match in_ with Var v2 when equal_dim_var v v2 -> value | _ -> in_ (* For future flexibility *) @@ -180,24 +200,31 @@ let row_conjunction ?(id = phantom_row_id) constr1 constr2 = match (constr1, constr2) with | Unconstrained, _ -> Some ([], constr2) | _, Unconstrained -> Some ([], constr1) - | Total_elems { nominator = n1; divided_by = vars1 }, Total_elems { nominator = n2; divided_by = vars2 } + | ( Total_elems { nominator = n1; divided_by = vars1 }, + Total_elems { nominator = n2; divided_by = vars2 } ) when [%equal: Set.M(Dim_var).t] vars1 vars2 -> if n1 <> n2 then elems_mismatch n1 n2 else Some ([], constr2) - | Total_elems { nominator = n1; divided_by = vars1 }, Total_elems { nominator = n2; divided_by = vars2 } -> + | ( Total_elems { nominator = n1; divided_by = vars1 }, + Total_elems { nominator = n2; divided_by = vars2 } ) -> let shared = Set.inter vars1 vars2 |> Set.to_list in let extras ~keep_constr1 = (* If we keep constr1, then it has fewer divided_by, i.e. n1 > n2. *) let nominator = if keep_constr1 then n1 / n2 else n2 / n1 in if nominator <= 0 then elems_mismatch n1 n2 - else if nominator = 1 then List.map shared ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) + else if nominator = 1 then + List.map shared ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) else if List.is_empty shared then [] else let r = { dims = List.map shared ~f:(fun v -> Var v); bcast = Broadcastable; id } in - [ Row_constr { r; constr = Total_elems { nominator; divided_by = Set.empty (module Dim_var) } } ] + [ + Row_constr + { r; constr = Total_elems { nominator; divided_by = Set.empty (module Dim_var) } }; + ] in let subsum = Set.symmetric_diff vars1 vars2 in if Sequence.for_all ~f:Either.is_first subsum then Some (extras ~keep_constr1:false, constr2) - else if Sequence.for_all ~f:Either.is_second subsum then Some (extras ~keep_constr1:true, constr1) + else if Sequence.for_all ~f:Either.is_second subsum then + Some (extras ~keep_constr1:true, constr1) else None let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr : dim_constraint) @@ -208,7 +235,8 @@ let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr if d < d_min then raise @@ Shape_error - ("At_least_dim constraint failed, expected " ^ Int.to_string d_min, [ Dim_mismatch [ dim ] ]) + ( "At_least_dim constraint failed, expected " ^ Int.to_string d_min, + [ Dim_mismatch [ dim ] ] ) else ([], constr) | Var v, _ -> ( match Map.find env.dim_env v with @@ -222,11 +250,12 @@ let apply_dim_constraint ~(source : source) ~(stage : stage) (dim : dim) (constr | _, Unconstrained_dim -> ([], constr) in match (dim, constr, stage) with - | Var _, At_least_dim d, Stage4 -> (Dim_eq { d1 = dim; d2 = get_dim ~d () } :: extras, Unconstrained_dim) + | Var _, At_least_dim d, Stage4 -> + (Dim_eq { d1 = dim; d2 = get_dim ~d () } :: extras, Unconstrained_dim) | _ -> (extras, constr) -let%debug_sexp reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims : dim list) : - row_constraint = +let%debug_sexp reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim list) + ~(dims : dim list) : row_constraint = match constr with | Total_elems { nominator; divided_by } -> let ds, (vars : dim_var list) = @@ -249,8 +278,8 @@ let%debug_sexp reduce_row_constraint (constr : row_constraint) ~(beg_dims : dim | Unconstrained -> Unconstrained (* Inverts what [reduce_row_constraint] would do. *) -let%debug_sexp _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) ~(dims : dim list) : - row_constraint = +let%debug_sexp _lift_row_constraint (constr : row_constraint) ~(beg_dims : dim list) + ~(dims : dim list) : row_constraint = match constr with | Total_elems { nominator; divided_by } -> let ds, vars = @@ -288,7 +317,12 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons Map.set env.row_env ~key:v ~data: (Bounds_row - { constr = reduce constr ~beg_dims ~dims; cur = []; subr = []; lub = None }); + { + constr = reduce constr ~beg_dims ~dims; + cur = []; + subr = []; + lub = None; + }); }, true, false ) @@ -313,7 +347,8 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons constr, { env with - row_env = Map.set env.row_env ~key:v ~data:(Bounds_row { bounds with constr }); + row_env = + Map.set env.row_env ~key:v ~data:(Bounds_row { bounds with constr }); }, true, true ))) @@ -324,14 +359,17 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons | { dims; bcast = Broadcastable; _ }, Total_elems { nominator; divided_by } when Set.length divided_by <= 1 -> ( let (ds : int list), (vars : dim_var list) = - List.partition_map dims ~f:(function Dim { d; _ } -> Either.First d | Var v -> Either.Second v) + List.partition_map dims ~f:(function + | Dim { d; _ } -> Either.First d + | Var v -> Either.Second v) in let d : int = List.fold ds ~init:1 ~f:( * ) in let nominator : int = nominator / d in if nominator = 0 then raise @@ Shape_error - ("apply_row_constraint: Total_elems constraint failed, shape is too big", [ Dim_mismatch dims ]); + ( "apply_row_constraint: Total_elems constraint failed, shape is too big", + [ Dim_mismatch dims ] ); match (vars, Set.elements divided_by) with | [], [] -> if nominator = 1 then (extras, env) @@ -340,9 +378,11 @@ let apply_row_constraint ~stage:_ (r : row) (constr : row_constraint) env : cons @@ Shape_error ( "apply_row_constraint: Total_elems constraint failed, shape is too small", [ Row_mismatch [ r ] ] ) - | [ v ], [] | [], [ v ] -> (Dim_eq { d1 = Var v; d2 = get_dim ~d:nominator () } :: extras, env) + | [ v ], [] | [], [ v ] -> + (Dim_eq { d1 = Var v; d2 = get_dim ~d:nominator () } :: extras, env) | vs1, vs2 when nominator = 1 -> - (List.map ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) (vs1 @ vs2) @ extras, env) + ( List.map ~f:(fun v -> Dim_eq { d1 = Var v; d2 = get_dim ~d:1 () }) (vs1 @ vs2) @ extras, + env ) (* TODO: we can work harder making assumptions here if necessary... *) (* | v :: _, [] | [], v :: _ when (is_stage4_up stage) -> (Dim_eq { d1 = Var v; d2 = get_dim ~d:nominator () } :: extras, env) *) @@ -360,7 +400,9 @@ let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry = let cur_v, cur = find_v cur in let subr_v, subr = find_v subr in let ineqs0 = - match (subr_v, lub) with _ :: _, Some lub -> [ Dim_ineq { cur = lub; subr = value } ] | _ -> [] + match (subr_v, lub) with + | _ :: _, Some lub -> [ Dim_ineq { cur = lub; subr = value } ] + | _ -> [] in let ineqs1 = if List.is_empty subr_v then [] @@ -371,7 +413,8 @@ let s_dim_one_in_entry v ~value (in_ : dim_entry) : _ * dim_entry = else List.map subr ~f:(fun subr -> Dim_ineq { subr = Var subr; cur = value }) in ( ineqs0 @ ineqs1 @ ineqs2, - Bounds_dim { cur; subr; lub = Option.map lub ~f:(fun in_ -> s_dim_one v ~value ~in_); constr } ) + Bounds_dim + { cur; subr; lub = Option.map lub ~f:(fun in_ -> s_dim_one v ~value ~in_); constr } ) let s_dim_one_in_row v ~value in_ = { in_ with dims = List.map in_.dims ~f:(fun in_ -> s_dim_one v ~value ~in_) } @@ -413,7 +456,11 @@ let s_row_one v ~value:{ dims = more_dims; bcast; id = _ } ~in_ = match bcast with | Broadcastable -> { dims = beg_dims @ more_dims @ dims; bcast; id } | Row_var { v = v3; beg_dims = more_beg_dims } -> - { dims = more_dims @ dims; bcast = Row_var { v = v3; beg_dims = beg_dims @ more_beg_dims }; id }) + { + dims = more_dims @ dims; + bcast = Row_var { v = v3; beg_dims = beg_dims @ more_beg_dims }; + id; + }) | _ -> in_ let s_row_one_in_row_constr _v ~value:_ ~in_ = match in_ with Unconstrained | Total_elems _ -> in_ @@ -423,12 +470,15 @@ let s_row_one_in_entry v ~value in_ = match in_ with | Solved_row in_ -> ([], Solved_row (s_row_one v ~value ~in_)) | Bounds_row { cur; subr; lub; constr } -> - (* TODO: audit code to ensure we don't lose the constraints associated with the bounds variables. *) + (* TODO: audit code to ensure we don't lose the constraints associated with the bounds + variables. *) let find_v side = List.partition_tf side ~f:(equal_row_var v) in let cur_v, cur = find_v cur in let subr_v, subr = find_v subr in let ineqs0 = - match (subr_v, lub) with _ :: _, Some lub -> [ Row_ineq { cur = lub; subr = value } ] | _ -> [] + match (subr_v, lub) with + | _ :: _, Some lub -> [ Row_ineq { cur = lub; subr = value } ] + | _ -> [] in let ineqs1 = if List.is_empty subr_v then [] @@ -436,11 +486,13 @@ let s_row_one_in_entry v ~value in_ = in let ineqs2 = if List.is_empty cur_v then [] - else List.map subr ~f:(fun subr -> Row_ineq { subr = row_of_var subr value.id; cur = value }) + else + List.map subr ~f:(fun subr -> Row_ineq { subr = row_of_var subr value.id; cur = value }) in let constr = s_row_one_in_row_constr v ~value ~in_:constr in ( ineqs0 @ ineqs1 @ ineqs2, - Bounds_row { cur; subr; lub = Option.map lub ~f:(fun in_ -> s_row_one v ~value ~in_); constr } ) + Bounds_row + { cur; subr; lub = Option.map lub ~f:(fun in_ -> s_row_one v ~value ~in_); constr } ) let subst_row (env : environment) ({ dims; bcast; id } : t) : t = let s_dims = List.map ~f:(subst_dim env) in @@ -456,15 +508,18 @@ let subst_row (env : environment) ({ dims; bcast; id } : t) : t = | Row_var { v; beg_dims } -> ( match Map.find env.row_env v with | None | Some (Bounds_row _) -> default - | Some (Solved_row { dims = []; bcast = Row_var { v = v2; beg_dims = [] }; _ }) when equal_row_var v v2 - -> + | Some (Solved_row { dims = []; bcast = Row_var { v = v2; beg_dims = [] }; _ }) + when equal_row_var v v2 -> default | Some (Solved_row ({ bcast = Row_var { v = v2; _ }; _ } as r2)) when equal_row_var v v2 -> - raise @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch [ default; r2 ] ]) + raise + @@ Shape_error + ("Infinite number of axes by self-reference", [ Row_mismatch [ default; r2 ] ]) | Some (Solved_row { dims = more_dims; bcast; id = _ }) -> ( (* Note: we assume env is idempotent (solved wrt. equalities). *) match bcast with - | Broadcastable -> { dims = beg_dims @ s_dims more_dims @ dims; bcast = Broadcastable; id } + | Broadcastable -> + { dims = beg_dims @ s_dims more_dims @ dims; bcast = Broadcastable; id } | Row_var { v = v2; beg_dims = more_beg_dims } -> { dims = s_dims more_dims @ dims; @@ -476,7 +531,9 @@ let rec unify_dim ~stage (eq : dim * dim) (env : environment) : constraint_ list let dim1 : dim = subst_dim env @@ fst eq and dim2 : dim = subst_dim env @@ snd eq in match (dim1, dim2) with | Dim { label = Some l1; _ }, Dim { label = Some l2; _ } when not (String.equal l1 l2) -> - raise @@ Shape_error ("solved dimensions for axis: different labels", [ Dim_mismatch [ dim1; dim2 ] ]) + raise + @@ Shape_error + ("solved dimensions for axis: different labels", [ Dim_mismatch [ dim1; dim2 ] ]) | Dim { d = d1; _ }, Dim { d = d2; _ } when d1 = d2 -> ([], env) | Var v1, Var v2 when equal_dim_var v1 v2 -> ([], env) | Var v, dim2 | dim2, Var v -> @@ -498,11 +555,13 @@ let rec unify_dim ~stage (eq : dim * dim) (env : environment) : constraint_ list | Some (Bounds_dim { cur; subr; lub; constr }) -> let dim_env = Map.map env.dim_env ~f in List.iter cur ~f:(fun cur -> ineqs := Dim_ineq { cur = Var cur; subr = dim2 } :: !ineqs); - List.iter subr ~f:(fun subr -> ineqs := Dim_ineq { subr = Var subr; cur = dim2 } :: !ineqs); + List.iter subr ~f:(fun subr -> + ineqs := Dim_ineq { subr = Var subr; cur = dim2 } :: !ineqs); Option.iter lub ~f:(fun lub -> ineqs := Dim_ineq { cur = lub; subr = dim2 } :: !ineqs); let extras, constr = apply_dim_constraint ~source:Equation ~stage dim2 constr env in let extras = - if is_unconstrained_dim constr then extras else Dim_constr { d = dim2; constr } :: extras + if is_unconstrained_dim constr then extras + else Dim_constr { d = dim2; constr } :: extras in ineqs := extras @ !ineqs; { @@ -536,12 +595,14 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e | Row_eq { r1; r2 } -> let more_ineqs, env = unify_row ~stage (r1, r2) env in (more_ineqs @ ineqs, env) - | (Dim_ineq _ | Row_ineq _ | Dim_constr _ | Row_constr _ | Terminal_dim _ | Terminal_row _) as ineq -> + | (Dim_ineq _ | Row_ineq _ | Dim_constr _ | Row_constr _ | Terminal_dim _ | Terminal_row _) as + ineq -> (ineq :: ineqs, env) in let unify_suffix init dims1 dims2 len = let dims1 = take_from_end dims1 len and dims2 = take_from_end dims2 len in - List.fold ~init ~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2 })) @@ List.zip_exn dims1 dims2 + List.fold ~init ~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2 })) + @@ List.zip_exn dims1 dims2 in let r1 : t = subst_row env @@ fst eq and r2 : t = subst_row env @@ snd eq in let l = List.length in @@ -555,12 +616,15 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e and beg_dims1_l = l beg_dims1 and beg_dims2_l = l beg_dims2 in if beg_dims1_l + dims1_l <> beg_dims2_l + dims2_l then - raise @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch [ r1; r2 ] ]); + raise + @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch [ r1; r2 ] ]); let result = unify_suffix ([], env) dims1 dims2 @@ min dims1_l dims2_l in unify_suffix result (List.rev beg_dims1) (List.rev beg_dims2) @@ min beg_dims1_l beg_dims2_l | ({ bcast = Row_var { v; beg_dims = beg_dims1 }; dims = dims1; id } as r1), r2 | r2, ({ bcast = Row_var { v; beg_dims = beg_dims1 }; dims = dims1; id } as r1) -> ( - let dims1_l : int = l dims1 and dims2_l : int = l r2.dims and beg_dims1_l : int = l beg_dims1 in + let dims1_l : int = l dims1 + and dims2_l : int = l r2.dims + and beg_dims1_l : int = l beg_dims1 in let beg_dims2_l : int = match r2.bcast with Row_var { beg_dims; _ } -> l beg_dims | Broadcastable -> 0 in @@ -575,19 +639,25 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e | Row_var { v = v2; beg_dims = beg_dims2 } -> let result = try unify_suffix ([], env) dims1 r2.dims dims1_l - with Shape_error (s, trace) -> raise @@ Shape_error (s, Row_mismatch orig_rows :: trace) + with Shape_error (s, trace) -> + raise @@ Shape_error (s, Row_mismatch orig_rows :: trace) in let dims = drop_from_end r2.dims dims1_l in if equal_row_var v v2 then if List.is_empty dims && l beg_dims2 = l beg_dims1 then let bcast = Row_var { v; beg_dims = [] } in let value : row = { bcast; dims; id } in - (true, unify_suffix result (List.rev beg_dims1) (List.rev beg_dims2) @@ l beg_dims2, value) + ( true, + unify_suffix result (List.rev beg_dims1) (List.rev beg_dims2) @@ l beg_dims2, + value ) else raise - @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch orig_rows ]) + @@ Shape_error + ("Infinite number of axes by self-reference", [ Row_mismatch orig_rows ]) else - let result = unify_suffix result (List.rev beg_dims1) (List.rev beg_dims2) beg_dims_l in + let result = + unify_suffix result (List.rev beg_dims1) (List.rev beg_dims2) beg_dims_l + in let bcast = Row_var { v = v2; beg_dims = List.drop beg_dims2 beg_dims_l } in let value : row = { bcast; dims; id } in (beg_dims_l = l beg_dims1, result, value) @@ -599,7 +669,8 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e let result = List.zip_exn beg_dims1 (List.take r2.dims beg_dims1_l) @ List.zip_exn dims1 (take_from_end r2.dims dims1_l) - |> List.fold ~init:([], env) ~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2 })) + |> List.fold ~init:([], env) ~f:(fun acc (d1, d2) -> + solve acc (Dim_eq { d1; d2 })) in let value : row = { bcast = Broadcastable; dims; id } in (true, result, value) @@ -615,12 +686,18 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e let result env = let row_env = Map.map env.row_env ~f in let unsolved, env = - if beg_handled then ([], { env with row_env = Map.set row_env ~key:v ~data:(Solved_row value) }) + if beg_handled then + ([], { env with row_env = Map.set row_env ~key:v ~data:(Solved_row value) }) else ( [ Row_eq { - r1 = { dims = []; bcast = Row_var { v; beg_dims = List.drop beg_dims1 beg_dims_l }; id }; + r1 = + { + dims = []; + bcast = Row_var { v; beg_dims = List.drop beg_dims1 beg_dims_l }; + id; + }; r2; }; ], @@ -648,13 +725,17 @@ let rec unify_row ~stage (eq : t * t) (env : environment) : constraint_ list * e | ( ({ bcast = Broadcastable; dims = dims1; id = _ } as r1), ({ bcast = Broadcastable; dims = dims2; id = _ } as r2) ) -> ( match List.zip dims1 dims2 with - | Unequal_lengths -> raise @@ Shape_error ("Mismatching number of axes", [ Row_mismatch [ r1; r2 ] ]) - | Ok eqs -> List.fold ~init:([], env) ~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2 })) eqs) + | Unequal_lengths -> + raise @@ Shape_error ("Mismatching number of axes", [ Row_mismatch [ r1; r2 ] ]) + | Ok eqs -> + List.fold ~init:([], env) ~f:(fun acc (d1, d2) -> solve acc (Dim_eq { d1; d2 })) eqs) let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environment) : constraint_ list * environment = let nonredundant ?(more = []) v vs = - Utils.sorted_diff ~compare:compare_dim_var (List.dedup_and_sort ~compare:compare_dim_var (v :: vs)) more + Utils.sorted_diff ~compare:compare_dim_var + (List.dedup_and_sort ~compare:compare_dim_var (v :: vs)) + more in let rec cyclic ~subr_v ~curs = (* TODO: it's somewhat inefficient *) @@ -669,7 +750,9 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen match (cur, subr) with | cur, subr when equal_dim cur subr -> ([], env) | Dim { label = Some l1; _ }, Dim { label = Some l2; _ } when not (String.equal l1 l2) -> - raise @@ Shape_error ("dimension comparison for axis: different labels", [ Dim_mismatch [ cur; subr ] ]) + raise + @@ Shape_error + ("dimension comparison for axis: different labels", [ Dim_mismatch [ cur; subr ] ]) | Dim { d = d1; _ }, Dim { d = d2; _ } when d1 = d2 -> ([], env) | _, Dim { d = 1; _ } -> ([], env) | (Dim { d = 1; _ } as cur), _ -> ([ Dim_eq { d1 = subr; d2 = cur } ], env) @@ -687,15 +770,20 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen env.dim_env |> Map.add_exn ~key:cur_v ~data: - (Bounds_dim { lub = None; cur = []; subr = [ subr_v ]; constr = Unconstrained_dim }) + (Bounds_dim + { lub = None; cur = []; subr = [ subr_v ]; constr = Unconstrained_dim }) |> Map.add_exn ~key:subr_v - ~data:(Bounds_dim { lub = None; cur = [ cur_v ]; subr = []; constr = Unconstrained_dim }); + ~data: + (Bounds_dim + { lub = None; cur = [ cur_v ]; subr = []; constr = Unconstrained_dim }); } ) | Some (Solved_dim _), _ | _, Some (Solved_dim _) -> assert false | Some (Bounds_dim { cur = cur1; subr = subr1; lub = lub1; constr = constr1 }), None -> let from_lub = Option.to_list lub1 |> List.map ~f:(fun cur -> Dim_ineq { cur; subr }) in let from_constr1, constr1 = apply_dim_constraint ~source:Subr ~stage subr constr1 env in - let from_constr2, constr2 = apply_dim_constraint ~source:Cur ~stage cur Unconstrained_dim env in + let from_constr2, constr2 = + apply_dim_constraint ~source:Cur ~stage cur Unconstrained_dim env + in ( from_constr1 @ from_constr2 @ from_lub, { env with @@ -704,7 +792,12 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen |> Map.set ~key:cur_v ~data: (Bounds_dim - { lub = lub1; cur = cur1; subr = nonredundant subr_v subr1; constr = constr1 }) + { + lub = lub1; + cur = cur1; + subr = nonredundant subr_v subr1; + constr = constr1; + }) |> Map.add_exn ~key:subr_v ~data:(Bounds_dim { lub = None; cur = [ cur_v ]; subr = []; constr = constr2 }); } ) @@ -717,7 +810,9 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen when cyclic ~subr_v ~curs -> ([ Dim_eq { d1 = subr; d2 = cur } ], env) | None, Some (Bounds_dim { cur = cur2; subr = subr2; lub = lub2; constr = constr2 }) -> - let from_constr1, constr1 = apply_dim_constraint ~source:Subr ~stage subr Unconstrained_dim env in + let from_constr1, constr1 = + apply_dim_constraint ~source:Subr ~stage subr Unconstrained_dim env + in let from_constr2, constr2 = apply_dim_constraint ~source:Cur ~stage cur constr2 env in ( from_constr2 @ from_constr1, { @@ -725,11 +820,17 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen dim_env = env.dim_env |> Map.add_exn ~key:cur_v - ~data:(Bounds_dim { lub = None; cur = []; subr = [ subr_v ]; constr = constr1 }) + ~data: + (Bounds_dim { lub = None; cur = []; subr = [ subr_v ]; constr = constr1 }) |> Map.set ~key:subr_v ~data: (Bounds_dim - { lub = lub2; cur = nonredundant cur_v cur2; subr = subr2; constr = constr2 }); + { + lub = lub2; + cur = nonredundant cur_v cur2; + subr = subr2; + constr = constr2; + }); } ) | ( Some (Bounds_dim { cur = cur1; subr = subr1; lub = lub1; constr = constr1 }), Some (Bounds_dim { cur = cur2; subr = subr2; lub = lub2; constr = constr2 }) ) -> @@ -768,7 +869,8 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen env with dim_env = Map.add_exn env.dim_env ~key:subr_v - ~data:(Bounds_dim { lub = Some cur; cur = []; subr = []; constr = Unconstrained_dim }); + ~data: + (Bounds_dim { lub = Some cur; cur = []; subr = []; constr = Unconstrained_dim }); } ) | Some (Solved_dim _) -> assert false | Some (Bounds_dim { cur = cur2; subr = subr2; lub = Some lub2; constr = constr2 }) -> @@ -801,19 +903,26 @@ let solve_dim_ineq ~(stage : stage) ~(cur : dim) ~(subr : dim) (env : environmen } )) | Var _, Dim _ (* when d2 > 1 *) -> ([ Dim_eq { d1 = cur; d2 = subr } ], env) | Dim _, Dim _ -> - raise @@ Shape_error ("dimension comparison for axis: mismatch", [ Dim_mismatch [ cur; subr ] ]) + raise + @@ Shape_error ("dimension comparison for axis: mismatch", [ Dim_mismatch [ cur; subr ] ]) let global_template_cache = Hashtbl.Poly.create () let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : constraint_ list * environment = let nonredundant ?(more = []) v vs = - Utils.sorted_diff ~compare:compare_row_var (List.dedup_and_sort ~compare:compare_row_var (v :: vs)) more + Utils.sorted_diff ~compare:compare_row_var + (List.dedup_and_sort ~compare:compare_row_var (v :: vs)) + more in let l = List.length in let cur_dims_l : int = l cur.dims and subr_dims_l : int = l subr.dims in - let cur_beg_dims = match cur.bcast with Row_var { beg_dims; _ } -> beg_dims | Broadcastable -> [] in - let subr_beg_dims = match subr.bcast with Row_var { beg_dims; _ } -> beg_dims | Broadcastable -> [] in + let cur_beg_dims = + match cur.bcast with Row_var { beg_dims; _ } -> beg_dims | Broadcastable -> [] + in + let subr_beg_dims = + match subr.bcast with Row_var { beg_dims; _ } -> beg_dims | Broadcastable -> [] + in let cur_beg_dims_l = l cur_beg_dims and subr_beg_dims_l = l subr_beg_dims in let beg_dims_l = min cur_beg_dims_l subr_beg_dims_l in let dims_l = min cur_dims_l subr_dims_l in @@ -837,7 +946,9 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : | { bcast = Row_var { v = cur_v; _ }; _ }, { bcast = Row_var { v = subr_v; _ }; _ } when equal_row_var cur_v subr_v -> if cur_dims_l + cur_beg_dims_l = subr_dims_l + subr_beg_dims_l then (ineqs, env) - else raise @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch [ cur; subr ] ]) + else + raise + @@ Shape_error ("Infinite number of axes by self-reference", [ Row_mismatch [ cur; subr ] ]) | { bcast = Row_var { v = cur_v; _ }; _ }, { bcast = Row_var { v = subr_v; _ }; _ } when cur_dims_l = subr_dims_l && cur_beg_dims_l = subr_beg_dims_l -> ( match (Map.find env.row_env cur_v, Map.find env.row_env subr_v) with @@ -850,7 +961,8 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : (Row_eq { r1 = row_of_var subr_v subr.id; r2 = row_of_var cur_v cur.id } :: ineqs, env) | Some (Bounds_row { subr = subr1; _ }), _ when List.mem ~equal:equal_row_var subr1 subr_v -> (ineqs, env) - | _, Some (Bounds_row { cur = cur2; _ }) when List.mem ~equal:equal_row_var cur2 cur_v -> (ineqs, env) + | _, Some (Bounds_row { cur = cur2; _ }) when List.mem ~equal:equal_row_var cur2 cur_v -> + (ineqs, env) | None, None -> ( ineqs, { @@ -858,9 +970,13 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : row_env = env.row_env |> Map.add_exn ~key:cur_v - ~data:(Bounds_row { cur = []; subr = [ subr_v ]; lub = None; constr = Unconstrained }) + ~data: + (Bounds_row + { cur = []; subr = [ subr_v ]; lub = None; constr = Unconstrained }) |> Map.add_exn ~key:subr_v - ~data:(Bounds_row { cur = [ cur_v ]; subr = []; lub = None; constr = Unconstrained }); + ~data: + (Bounds_row + { cur = [ cur_v ]; subr = []; lub = None; constr = Unconstrained }); } ) | Some (Bounds_row { cur = cur1; subr = subr1; lub = lub1; constr = constr1 }), None -> ( ineqs, @@ -871,9 +987,16 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : |> Map.set ~key:cur_v ~data: (Bounds_row - { cur = cur1; subr = nonredundant subr_v subr1; lub = lub1; constr = constr1 }) + { + cur = cur1; + subr = nonredundant subr_v subr1; + lub = lub1; + constr = constr1; + }) |> Map.add_exn ~key:subr_v - ~data:(Bounds_row { cur = [ cur_v ]; subr = []; lub = None; constr = Unconstrained }); + ~data: + (Bounds_row + { cur = [ cur_v ]; subr = []; lub = None; constr = Unconstrained }); } ) | None, Some (Bounds_row { cur = cur2; subr = subr2; lub = lub2; constr = constr2 }) -> ( ineqs, @@ -884,9 +1007,16 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : |> Map.set ~key:subr_v ~data: (Bounds_row - { cur = nonredundant cur_v cur2; subr = subr2; lub = lub2; constr = constr2 }) + { + cur = nonredundant cur_v cur2; + subr = subr2; + lub = lub2; + constr = constr2; + }) |> Map.add_exn ~key:cur_v - ~data:(Bounds_row { cur = []; subr = [ subr_v ]; lub = None; constr = Unconstrained }); + ~data: + (Bounds_row + { cur = []; subr = [ subr_v ]; lub = None; constr = Unconstrained }); } ) | ( Some (Bounds_row { cur = cur1; subr = subr1; lub = lub1; constr = constr1 }), Some (Bounds_row { cur = cur2; subr = subr2; lub = lub2; constr = constr2 }) ) -> @@ -898,23 +1028,37 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : |> Map.set ~key:cur_v ~data: (Bounds_row - { cur = cur1; subr = nonredundant subr_v subr1; lub = lub1; constr = constr1 }) + { + cur = cur1; + subr = nonredundant subr_v subr1; + lub = lub1; + constr = constr1; + }) |> Map.set ~key:subr_v ~data: (Bounds_row - { cur = nonredundant cur_v cur2; subr = subr2; lub = lub2; constr = constr2 }); + { + cur = nonredundant cur_v cur2; + subr = subr2; + lub = lub2; + constr = constr2; + }); } ) | Some (Solved_row _), _ | _, Some (Solved_row _) -> assert false) | { bcast = Row_var { v = cur_v; _ }; dims; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l -> let budget = subr_dims_l + subr_beg_dims_l - (cur_dims_l + cur_beg_dims_l) in let more_dims_l = min budget @@ max 0 (subr_dims_l - cur_dims_l) in - let more_dims : dim list = Array.(to_list @@ init more_dims_l ~f:(fun _ -> Var (get_var ()))) in + let more_dims : dim list = + Array.(to_list @@ init more_dims_l ~f:(fun _ -> Var (get_var ()))) + in let budget = budget - more_dims_l in let more_beg_dims_l = min budget @@ max 0 (subr_beg_dims_l - cur_beg_dims_l) in - let more_beg_dims : dim list = Array.(to_list @@ init more_beg_dims_l ~f:(fun _ -> Var (get_var ()))) in - (* The key of the template cache reflects that cur_v will end up substituted by {dims=more_dims; - bcast=Row_var templ_v}. TODO: should we cache more_dims also? *) + let more_beg_dims : dim list = + Array.(to_list @@ init more_beg_dims_l ~f:(fun _ -> Var (get_var ()))) + in + (* The key of the template cache reflects that cur_v will end up substituted by + {dims=more_dims; bcast=Row_var templ_v}. TODO: should we cache more_dims also? *) let templ_v : row_var = Hashtbl.find_or_add global_template_cache (cur_v, subr_dims_l - cur_dims_l, subr_beg_dims_l - cur_beg_dims_l) @@ -927,10 +1071,11 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : id = cur.id; } in - (* We don't need to add any dimension inequalities, because they'll be captured by the extra row - inequalities. *) + (* We don't need to add any dimension inequalities, because they'll be captured by the extra + row inequalities. *) ([ Row_eq { r1 = cur; r2 = template }; Row_ineq { cur = template; subr } ], env) - | { bcast = Broadcastable; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l -> + | { bcast = Broadcastable; _ }, _ when cur_dims_l + cur_beg_dims_l < subr_dims_l + subr_beg_dims_l + -> raise @@ Shape_error ("Too many axes in a subtensor", [ Row_mismatch [ cur; subr ] ]) | { bcast; dims; id }, { bcast = Row_var { v = subr_v; _ }; _ } when subr_dims_l <= cur_dims_l && subr_beg_dims_l <= cur_beg_dims_l -> ( @@ -947,7 +1092,8 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : env with row_env = Map.add_exn env.row_env ~key:subr_v - ~data:(Bounds_row { cur = []; subr = []; lub = Some r_cur; constr = Unconstrained }); + ~data: + (Bounds_row { cur = []; subr = []; lub = Some r_cur; constr = Unconstrained }); } ) | Some (Bounds_row { cur = cur2; subr = subr2; lub = None; constr = constr2 }) -> ( ineqs, @@ -956,7 +1102,8 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : row_env = env.row_env |> Map.set ~key:subr_v - ~data:(Bounds_row { cur = cur2; subr = subr2; lub = Some r_cur; constr = constr2 }); + ~data: + (Bounds_row { cur = cur2; subr = subr2; lub = Some r_cur; constr = constr2 }); } ) | Some (Bounds_row { cur = cur2; subr = subr2; lub = Some lub2; constr = constr2 }) -> let len1 = List.length r_cur.dims and len2 = List.length lub2.dims in @@ -983,11 +1130,14 @@ let solve_row_ineq ~(stage : stage) ~(cur : t) ~(subr : t) (env : environment) : row_env = env.row_env |> Map.set ~key:subr_v - ~data:(Bounds_row { cur = cur2; subr = subr2; lub = Some lub; constr = constr2 }); + ~data: + (Bounds_row { cur = cur2; subr = subr2; lub = Some lub; constr = constr2 }); } ) | Some (Solved_row _) -> assert false) - | _ when cur_beg_dims_l > beg_dims_l && not (is_stage7 stage) -> (Row_ineq { cur; subr } :: ineqs, env) - | _, { bcast = Broadcastable; _ } when subr_dims_l + subr_beg_dims_l <= cur_dims_l + cur_beg_dims_l -> + | _ when cur_beg_dims_l > beg_dims_l && not (is_stage7 stage) -> + (Row_ineq { cur; subr } :: ineqs, env) + | _, { bcast = Broadcastable; _ } + when subr_dims_l + subr_beg_dims_l <= cur_dims_l + cur_beg_dims_l -> (ineqs, env) | { bcast = Row_var _ | Broadcastable; _ }, { bcast = Row_var _ | Broadcastable; _ } -> (Row_ineq { cur; subr } :: ineqs, env) @@ -1000,7 +1150,8 @@ let close_dim_terminal ~(stage : stage) (env : environment) (dim : dim) : constr | Some (Solved_dim _) -> assert false | Some (Bounds_dim { lub = None; constr = Unconstrained_dim; _ }) when is_stage2_up stage -> [ Dim_eq { d1 = dim; d2 = get_dim ~d:1 () } ] - | Some (Bounds_dim { lub = Some lub; _ }) when is_stage3_up stage -> [ Dim_eq { d1 = dim; d2 = lub } ] + | Some (Bounds_dim { lub = Some lub; _ }) when is_stage3_up stage -> + [ Dim_eq { d1 = dim; d2 = lub } ] | _ when not (is_stage4_up stage) -> [ Terminal_dim dim ] | _ -> []) @@ -1037,7 +1188,8 @@ let rec eliminate_row_constraint ~lub (r : row) (constr : row_constraint) env : let ineqs, _env = apply_row_constraint ~stage:Stage5 lub constr env in List.concat_map ineqs ~f:(function | Row_constr { r = r'; constr } -> - if not (phys_equal r r') then eliminate_row_constraint ~lub:None r constr env else [] + if not (phys_equal r r') then eliminate_row_constraint ~lub:None r constr env + else [] | ineq -> [ ineq ]) | _, [ v ], _ -> no_further_axes :: [ Dim_eq { d1 = Var v; d2 = get_dim ~d () } ] | _ -> []) @@ -1169,7 +1321,8 @@ let solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (env : enviro | _ -> [] in let finalizing_entries : constraint_ list = - Map.fold env.dim_env ~init:[] ~f:(fun ~key ~data accu -> finalize_lower_bound key data @ accu) + Map.fold env.dim_env ~init:[] ~f:(fun ~key ~data accu -> + finalize_lower_bound key data @ accu) in solve (finalizing_entries @ ineqs) env | Stage5 -> @@ -1181,7 +1334,8 @@ let solve_inequalities ~(stage : stage) (ineqs : constraint_ list) (env : enviro | _ -> [] in let finalizing_entries : constraint_ list = - Map.fold env.row_env ~init:[] ~f:(fun ~key ~data accu -> finalize_total_elems key data @ accu) + Map.fold env.row_env ~init:[] ~f:(fun ~key ~data accu -> + finalize_total_elems key data @ accu) in solve (finalizing_entries @ ineqs) env @@ -1202,7 +1356,8 @@ let rec row_to_labels env = row_to_labels env { dims = beg_dims @ dims2 @ dims; bcast = Broadcastable; id } | Some (Solved_row { dims = dims2; bcast = Row_var { v = v2; beg_dims = beg_dims2 }; _ }) -> row_to_labels env - { dims = dims2 @ dims; bcast = Row_var { v = v2; beg_dims = beg_dims @ beg_dims2 }; id }) + { dims = dims2 @ dims; bcast = Row_var { v = v2; beg_dims = beg_dims @ beg_dims2 }; id } + ) | { dims; bcast = Broadcastable; id = _ } -> Array.of_list_map dims ~f (** *** Projection inference *** *) @@ -1220,7 +1375,8 @@ let fresh_row_proj r = in { r with dims = List.map r.dims ~f:fresh_dim } -(* let update_proj_classes pid1 pid2 proj_classes = Utils.union_add ~equal:Int.equal proj_classes pid1 pid2 *) +(* let update_proj_classes pid1 pid2 proj_classes = Utils.union_add ~equal:Int.equal proj_classes + pid1 pid2 *) type proj = Var of dim_var | Proj of { proj_id : int; d : int } | Solved of Idx.axis_index [@@deriving compare, equal, sexp] @@ -1228,7 +1384,8 @@ type proj = Var of dim_var | Proj of { proj_id : int; d : int } | Solved of Idx. type error_trace += Projection_mismatch of proj list let sexp_of_error_trace = function - | Projection_mismatch ps -> Sexp.List (Sexp.Atom "Projection_mismatch" :: List.map ps ~f:sexp_of_proj) + | Projection_mismatch ps -> + Sexp.List (Sexp.Atom "Projection_mismatch" :: List.map ps ~f:sexp_of_proj) | error_trace -> sexp_of_error_trace error_trace type proj_to_index = Idx.axis_index Map.M(Int).t [@@deriving sexp] @@ -1243,10 +1400,11 @@ type proj_env = { [@@deriving sexp] type proj_equation = - | Proj_eq of proj * proj (** Two projections are the same, e.g. two axes share the same iterator. *) + | Proj_eq of proj * proj + (** Two projections are the same, e.g. two axes share the same iterator. *) | Iterated of proj - (** The projection needs to be an iterator even if an axis is not matched with another axis, e.g. for - broadcasted-to axes of a tensor assigned a constant. *) + (** The projection needs to be an iterator even if an axis is not matched with another axis, + e.g. for broadcasted-to axes of a tensor assigned a constant. *) [@@deriving compare, equal, sexp] let get_proj_equations (inequalities : constraint_ list) proj_axis_env (env : environment) : @@ -1276,7 +1434,8 @@ let get_proj_equations (inequalities : constraint_ list) proj_axis_env (env : en let len1 = List.length dims1 in let len = min len1 (List.length dims2) in let extras = - if with_broadcasting then List.map ~f:(fun d -> Iterated (to_proj d)) @@ List.take dims1 (len1 - len) + if with_broadcasting then + List.map ~f:(fun d -> Iterated (to_proj d)) @@ List.take dims1 (len1 - len) else [] in extras @@ -1307,18 +1466,22 @@ let solve_proj_equations (eqs : proj_equation list) : proj_env = | Proj_eq (Proj { proj_id = p1; d }, Proj { proj_id = p2; _ }) when p1 = p2 -> p_dims := (p1, d) :: !p_dims | Proj_eq (Var v1, Var v2) when equal_dim_var v1 v2 -> () - | Proj_eq ((Proj { proj_id = p1; d = d1 } as proj1), (Proj { proj_id = p2; d = d2 } as proj2)) -> + | Proj_eq ((Proj { proj_id = p1; d = d1 } as proj1), (Proj { proj_id = p2; d = d2 } as proj2)) + -> if d1 <> d2 then raise @@ Shape_error - ("Conflicting dimensions for the same projection", [ Projection_mismatch [ proj1; proj2 ] ]); + ( "Conflicting dimensions for the same projection", + [ Projection_mismatch [ proj1; proj2 ] ] ); p_dims := (p1, d1) :: !p_dims; proj_classes := Utils.union_add ~equal:Int.equal !proj_classes p1 p2 - | Proj_eq (Proj p, Solved idx) | Proj_eq (Solved idx, Proj p) -> p_solved := (p.proj_id, idx) :: !p_solved + | Proj_eq (Proj p, Solved idx) | Proj_eq (Solved idx, Proj p) -> + p_solved := (p.proj_id, idx) :: !p_solved | Proj_eq (Solved idx1, Solved idx2) when Idx.equal_axis_index idx1 idx2 -> () | Proj_eq (Solved idx1, Solved idx2) -> raise - @@ Shape_error ("Conflicting indices for the same axis/projection", [ Index_mismatch [ idx1; idx2 ] ]) + @@ Shape_error + ("Conflicting indices for the same axis/projection", [ Index_mismatch [ idx1; idx2 ] ]) | Proj_eq (Var v, p) | Proj_eq (p, Var v) -> ( match Hashtbl.find v_env v with | None -> Hashtbl.add_exn v_env ~key:v ~data:p @@ -1342,7 +1505,8 @@ let solve_proj_equations (eqs : proj_equation list) : proj_env = Utils.mref_add projs ~key:repr ~data:idx ~or_:(fun idx2 -> if not @@ Idx.equal_axis_index idx idx2 then raise - @@ Shape_error ("Multiple constraints on the same projection", [ Index_mismatch [ idx; idx2 ] ]))); + @@ Shape_error + ("Multiple constraints on the same projection", [ Index_mismatch [ idx; idx2 ] ]))); let product_dim = ref @@ Map.empty (module Int) in List.iter !p_dims ~f:(fun (p, d) -> let repr, _ = Utils.union_find ~equal:Int.equal !proj_classes ~key:p ~rank:0 in @@ -1352,7 +1516,9 @@ let solve_proj_equations (eqs : proj_equation list) : proj_env = raise @@ Shape_error ( "Conflicting dimensions for the same projection", - [ Projection_mismatch [ Proj { proj_id = p; d }; Proj { proj_id = p; d = d2 } ] ] ))); + [ + Projection_mismatch [ Proj { proj_id = p; d }; Proj { proj_id = p; d = d2 } ]; + ] ))); Map.iteri !product_dim ~f:(fun ~key:p ~data:_ -> let repr, _ = Utils.union_find ~equal:Int.equal !proj_classes ~key:p ~rank:0 in Utils.mref_add_missing projs repr ~f:(fun () -> Idx.(Iterator (get_symbol ())))); @@ -1382,14 +1548,16 @@ let get_proj_index proj_env = function ( "projection_of_solved_dims: unknown projection", [ Projection_mismatch [ Proj { proj_id; d } ] ] )) -let proj_repr proj_env p = fst @@ Utils.union_find ~equal:Int.equal proj_env.proj_classes ~key:p ~rank:0 +let proj_repr proj_env p = + fst @@ Utils.union_find ~equal:Int.equal proj_env.proj_classes ~key:p ~rank:0 let get_product_proj proj_env dim = match dim with | Dim { d; _ } when not @@ Idx.iterated d -> None | Dim { proj_id = Some proj_id; d; _ } -> let repr = proj_repr proj_env proj_id in - if Map.mem proj_env.proj_to_index repr && (not @@ Set.mem proj_env.non_product repr) then Some (repr, d) + if Map.mem proj_env.proj_to_index repr && (not @@ Set.mem proj_env.non_product repr) then + Some (repr, d) else None | Dim { proj_id = None; _ } -> None | Var v -> @@ -1400,4 +1568,6 @@ let get_product_proj proj_env dim = [ Dim_mismatch [ dim ] ] ) let proj_to_iterator proj_env p = - match Map.find_exn proj_env.proj_to_index (proj_repr proj_env p) with Iterator s -> s | _ -> assert false + match Map.find_exn proj_env.proj_to_index (proj_repr proj_env p) with + | Iterator s -> s + | _ -> assert false diff --git a/lib/row.mli b/lib/row.mli index 07ac15dc..affb53bd 100644 --- a/lib/row.mli +++ b/lib/row.mli @@ -31,7 +31,8 @@ val get_row_var : unit -> row_var (** A bcast specifies how axes of a single kind in a shape (i.e. the row) can adapt to other shapes. *) type bcast = - | Row_var of { v : row_var; beg_dims : dim list } (** The row can be inferred to have more axes. *) + | Row_var of { v : row_var; beg_dims : dim list } + (** The row can be inferred to have more axes. *) | Broadcastable (** The shape does not have more axes of this kind, but is "polymorphic". *) [@@deriving equal, hash, compare, sexp, variants] @@ -60,16 +61,26 @@ type row_constraint = (** The row or remainder of a row, inclusive of the further row spec, has this many elements. *) [@@deriving equal, hash, compare, sexp, variants] -(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and [subr] must - be sorted using the [@@deriving compare] comparison. *) +(** An entry implements inequalities [cur >= v >= subr] and/or an equality [v = solved]. [cur] and + [subr] must be sorted using the [@@deriving compare] comparison. *) type dim_entry = | Solved_dim of dim - | Bounds_dim of { cur : dim_var list; subr : dim_var list; lub : dim option; constr : dim_constraint } + | Bounds_dim of { + cur : dim_var list; + subr : dim_var list; + lub : dim option; + constr : dim_constraint; + } [@@deriving sexp] type row_entry = | Solved_row of t - | Bounds_row of { cur : row_var list; subr : row_var list; lub : t option; constr : row_constraint } + | Bounds_row of { + cur : row_var list; + subr : row_var list; + lub : t option; + constr : row_constraint; + } [@@deriving sexp] type constraint_ = @@ -83,13 +94,17 @@ type constraint_ = | Terminal_row of t [@@deriving compare, equal, sexp, variants] -type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7 [@@deriving sexp, equal, compare] +type stage = Stage1 | Stage2 | Stage3 | Stage4 | Stage5 | Stage6 | Stage7 +[@@deriving sexp, equal, compare] val subst_row : environment -> t -> t val unify_row : stage:stage -> t * t -> environment -> constraint_ list * environment val empty_env : environment val eliminate_variables : environment -> t -> constraint_ list -val solve_inequalities : stage:stage -> constraint_ list -> environment -> constraint_ list * environment + +val solve_inequalities : + stage:stage -> constraint_ list -> environment -> constraint_ list * environment + val row_to_labels : environment -> t -> string array type proj [@@deriving compare, equal, sexp] @@ -98,10 +113,11 @@ type proj_env [@@deriving sexp] val fresh_row_proj : t -> t type proj_equation = - | Proj_eq of proj * proj (** Two projections are the same, e.g. two axes share the same iterator. *) + | Proj_eq of proj * proj + (** Two projections are the same, e.g. two axes share the same iterator. *) | Iterated of proj - (** The projection needs to be an iterator even if an axis is not matched with another axis, e.g. for - broadcasted-to axes of a tensor assigned a constant. *) + (** The projection needs to be an iterator even if an axis is not matched with another axis, + e.g. for broadcasted-to axes of a tensor assigned a constant. *) [@@deriving compare, equal, sexp] val get_proj_equations : diff --git a/lib/shape.ml b/lib/shape.ml index 346e301c..70ab1142 100644 --- a/lib/shape.ml +++ b/lib/shape.ml @@ -10,12 +10,13 @@ module Debug_runtime = Arrayjit.Utils.Debug_runtime (** {2 Shape types and inference.} *) -(** An index pointing to any of a shape's axes, including the kind of the axis ([Batch, Input, Output]) and - the position (which is counted from the end to facilitate broadcasting). +(** An index pointing to any of a shape's axes, including the kind of the axis + ([Batch, Input, Output]) and the position (which is counted from the end to facilitate + broadcasting). - Note the following inconsistency due to differing conventions in function notation and matrix notation: - for label specifications and einsum notation, we write "batch|inputs->outputs", but when we convert a - shape to an [Ndarray] index we do it in the order [[batch; outputs; inputs]]. *) + Note the following inconsistency due to differing conventions in function notation and matrix + notation: for label specifications and einsum notation, we write "batch|inputs->outputs", but + when we convert a shape to an [Ndarray] index we do it in the order [[batch; outputs; inputs]]. *) module AxisKey = struct module T = struct type kind = [ `Batch | `Input | `Output ] [@@deriving equal, compare, sexp, hash] @@ -24,8 +25,8 @@ module AxisKey = struct in_axes : kind; pos : int; (** Indices start at [1], counted from the end if [from_end] is true. *) from_end : bool; - (** Axes are indexed from the front (rarely) or from the end (typically), to avoid reindexing when - broadcasting. *) + (** Axes are indexed from the front (rarely) or from the end (typically), to avoid + reindexing when broadcasting. *) } [@@deriving equal, compare, sexp] end @@ -49,11 +50,12 @@ type parsed_axis_labels = { labels : (string, int) Either.t axis_map; } [@@deriving compare, sexp, fields] -(** The labels are strings assigned to [AxisKey] axes. Moreover the [bcast_] fields represent whether - additional leading/middle axes are allowed (corresponding to the dot-ellipsis syntax for broadcasting). - The string can be used to identify a row variable, and defaults to ["batch"], ["input"], ["output"] - respectively when parsing ["..."]. The [given_] fields count the number of specified axes of the - corresponding kind in [labels] where [from_end=true], [given_beg_] where [from_end=false]. *) +(** The labels are strings assigned to [AxisKey] axes. Moreover the [bcast_] fields represent + whether additional leading/middle axes are allowed (corresponding to the dot-ellipsis syntax for + broadcasting). The string can be used to identify a row variable, and defaults to ["batch"], + ["input"], ["output"] respectively when parsing ["..."]. The [given_] fields count the number of + specified axes of the corresponding kind in [labels] where [from_end=true], [given_beg_] where + [from_end=false]. *) let axis_labels parsed = parsed.labels @@ -68,7 +70,9 @@ type t = { let row_of_kind = function `Batch -> batch | `Input -> input | `Output -> output -type deduce_within_shape = Not_constrained | Input_equals_output [@@deriving compare, sexp, variants] +type deduce_within_shape = Not_constrained | Input_equals_output +[@@deriving compare, sexp, variants] + type compose_type = Pointwise_bin | Compose | Einsum of string [@@deriving sexp, equal] type transpose_type = @@ -79,7 +83,9 @@ type transpose_type = [@@deriving equal, sexp] let identifier_multichar = Angstrom.take_while1 Char.is_alphanum -let opt_separators : _ Angstrom.t = Angstrom.take_while (fun c -> Char.is_whitespace c || Char.equal c ',') + +let opt_separators : _ Angstrom.t = + Angstrom.take_while (fun c -> Char.is_whitespace c || Char.equal c ',') let separators_with_comma = let open Angstrom in @@ -127,7 +133,8 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t = let row = lift3 (for_row ~kind in_axes) in opt_separators *> (row (return []) (lift Option.some ellipsis_spec) (axes_spec ~from_end:true) - <|> row (axes_spec ~from_end:false) (lift Option.some ellipsis_spec) (axes_spec ~from_end:true) + <|> row (axes_spec ~from_end:false) (lift Option.some ellipsis_spec) + (axes_spec ~from_end:true) <|> row (return []) (return None) (axes_spec ~from_end:true) <|> row (return []) (lift Option.some ellipsis_spec) (return [])) <* opt_separators @@ -165,41 +172,49 @@ let axis_labels_of_spec_parser ~multichar : parsed_axis_labels Angstrom.t = let axis_labels_of_spec spec = let multichar = String.contains spec ',' in match - Angstrom.(parse_string ~consume:Consume.All (axis_labels_of_spec_parser ~multichar <* end_of_input) spec) + Angstrom.( + parse_string ~consume:Consume.All (axis_labels_of_spec_parser ~multichar <* end_of_input) spec) with | Ok result -> result | Error msg -> - raise @@ Utils.User_error ("Shape.axis_labels_of_spec: while parsing: " ^ spec ^ " error: " ^ msg) + raise + @@ Utils.User_error ("Shape.axis_labels_of_spec: while parsing: " ^ spec ^ " error: " ^ msg) let einsum_of_spec_parser ~multichar : _ Angstrom.t = let open Angstrom in let p = axis_labels_of_spec_parser ~multichar in - lift3 (fun a b c -> (a, Some b, c)) (p "RHS1" <* char ';') (p "RHS2") (string "=>" *> (p "LHS")) + lift3 + (fun a b c -> (a, Some b, c)) + (p "RHS1" <* char ';') + (p "RHS2") + (string "=>" *> (p "LHS")) <|> lift2 (fun a c -> (a, None, c)) (p "RHS") (string "=>" *> (p "LHS")) let einsum_of_spec spec = let multichar = String.contains spec ',' in match - Angstrom.(parse_string ~consume:Consume.All (einsum_of_spec_parser ~multichar <* end_of_input) spec) + Angstrom.( + parse_string ~consume:Consume.All (einsum_of_spec_parser ~multichar <* end_of_input) spec) with | Ok result -> result - | Error msg -> raise @@ Utils.User_error ("Shape.einsum_of_spec: while parsing: " ^ spec ^ " error: " ^ msg) + | Error msg -> + raise @@ Utils.User_error ("Shape.einsum_of_spec: while parsing: " ^ spec ^ " error: " ^ msg) -(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the tensor. Axes - are broadcast-expanded on a bottom-up update to fit the incoming shape. *) +(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the + tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *) type logic = | Broadcast of compose_type * t * t (** Matches the shapes for a binary operation. - For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match according - to the [ls1], [ls2] lineup, and the resulting shape inherits the labels according to the [ls3] - lineup. *) + For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match + according to the [ls1], [ls2] lineup, and the resulting shape inherits the labels + according to the [ls3] lineup. *) | Transpose of transpose_type * t - (** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1], hence - the name. *) + (** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of + [s1], hence the name. *) | Terminal of Arrayjit.Ops.init_op - (** Extracts any available shape information from the initialization. E.g. for [File_mapped fn], opens - the file [fn] to check its length. *) + (** Extracts any available shape information from the initialization. E.g. for + [File_mapped fn], opens the file [fn] to check its length. *) [@@deriving equal, sexp] let logic_to_spec = function @@ -228,22 +243,25 @@ let get_update_id = Update_id.Update_id !uid type update_step = { shape : t; logic : logic; id : update_id } [@@deriving sexp] -(** Data required for a shape inference update step. Ideally, an update should be performed at least twice, - the second time after all the other relevant updates have been performed for the first time. In OCANNL, - this is achieved by performing updates both as the tensors are constructed, and via lazy callbacks as the - corresponding [Arrayjit.Indexing] dimensions and projections are first accessed. *) +(** Data required for a shape inference update step. Ideally, an update should be performed at least + twice, the second time after all the other relevant updates have been performed for the first + time. In OCANNL, this is achieved by performing updates both as the tensors are constructed, and + via lazy callbacks as the corresponding [Arrayjit.Indexing] dimensions and projections are first + accessed. *) type Row.error_trace += Shape_mismatch of t list let with_error_trace = ref true -(** Converts an axes-keyed map into three arrays of values: batch axes, input axes, output axes. If the map is - incomplete, the result will likely be invalid: gaps in the array are filled with an arbitrary one of the - provided values. *) +(** Converts an axes-keyed map into three arrays of values: batch axes, input axes, output axes. If + the map is incomplete, the result will likely be invalid: gaps in the array are filled with an + arbitrary one of the provided values. *) let axis_map_to_dims_bio (type a) ?(default : a option) (idcs : a axis_map) = if Map.is_empty idcs then (([||], [||], [||]), ([||], [||], [||])) else - let witness = match default with Some witness -> witness | None -> snd @@ Map.min_elt_exn idcs in + let witness = + match default with Some witness -> witness | None -> snd @@ Map.min_elt_exn idcs + in let bch_axes, other = Map.partition_mapi idcs ~f:(fun ~key:{ in_axes; _ } ~data -> if Row.is_batch in_axes then Either.First data else Either.Second data) @@ -271,15 +289,17 @@ let axis_map_to_dims_bio (type a) ?(default : a option) (idcs : a axis_map) = let out, beg_out = make_row out_axes in ((bch, inp, out), (beg_bch, beg_inp, beg_out)) -(** Converts an axes-keyed map into an array of values using the [force_to_dims] semantics of axes. If the map - is incomplete and the [~default] is not given, the result might be invalid: gaps in the array are filled - with an arbitrary one of the provided values. *) +(** Converts an axes-keyed map into an array of values using the [force_to_dims] semantics of axes. + If the map is incomplete and the [~default] is not given, the result might be invalid: gaps in + the array are filled with an arbitrary one of the provided values. *) let axis_map_to_dims_index (type a) ?(default : a option) (idcs : a axis_map) : a array = let (bch, inp, out), (beg_bch, beg_inp, beg_out) = axis_map_to_dims_bio ?default idcs in Array.concat [ beg_bch; bch; beg_out; out; beg_inp; inp ] let axes_spec_to_dims_bio ~sh_id ~row_var_env ~dim_var_env:_ ~f labels = - let (b_dims, i_dims, o_dims), (beg_b_dims, beg_i_dims, beg_o_dims) = axis_map_to_dims_bio labels.labels in + let (b_dims, i_dims, o_dims), (beg_b_dims, beg_i_dims, beg_o_dims) = + axis_map_to_dims_bio labels.labels + in let to_dim kind = Array.(Fn.compose to_list @@ map ~f:(f kind)) in let to_bcast kind v beg_dims = let beg_dims = to_dim kind beg_dims in @@ -308,7 +328,8 @@ let einsum_slot_spec_to_dims_bio ~generative ~sh_id ~row_var_env ~dim_var_env la | Second i -> let var = Row.get_var () in let d = Row.Var var in - proj_env_update := Map.add_exn !proj_env_update ~key:var ~data:(Arrayjit.Indexing.Fixed_idx i); + proj_env_update := + Map.add_exn !proj_env_update ~key:var ~data:(Arrayjit.Indexing.Fixed_idx i); extras := Row.Dim_constr { d; constr = At_least_dim (i + 1) } :: !extras; d in @@ -412,8 +433,8 @@ let get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : update_step) : let proj_axis_env = Map.add_exn Row.dim_map_empty ~key:slice_v ~data:(Arrayjit.Indexing.Iterator static_symbol) in - (* Expand a batch row instead of reducing one because even if the dimensions are known, the equations - are also needed for projection inference. *) + (* Expand a batch row instead of reducing one because even if the dimensions are known, the + equations are also needed for projection inference. *) let expanded_batch = match cur_sh.batch.bcast with | Broadcastable -> @@ -493,7 +514,8 @@ let get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : update_step) : in let proj_env = let combine ~key:_ _ _ = assert false in - Map.merge_skewed ~combine proj_env_rhs1 @@ Map.merge_skewed ~combine proj_env_rhs2 proj_env_lhs + Map.merge_skewed ~combine proj_env_rhs1 + @@ Map.merge_skewed ~combine proj_env_rhs2 proj_env_lhs in (* Forget the old proj_env as it is not relevant after a propagate_shapes call completes. *) ( proj_env, @@ -539,10 +561,13 @@ let apply_env_t env sh = let apply_env_update ~eliminate_variables env update_step = iter_shapes update_step ~f:(apply_env_t env); - if eliminate_variables then List.concat_map ~f:(Row.eliminate_variables env) @@ all_rows update_step else [] + if eliminate_variables then + List.concat_map ~f:(Row.eliminate_variables env) @@ all_rows update_step + else [] let propagate_shapes (update_step : update_step) : unit = - (* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice does). *) + (* Allow the derivation of constraints to depend on the shapes (currently, only Batch_slice + does). *) ignore (apply_env_update ~eliminate_variables:false !state update_step); let _, ineqs = get_inequalities update_step in active_update_steps := update_step :: !active_update_steps; @@ -553,13 +578,16 @@ let propagate_shapes (update_step : update_step) : unit = state := env let finish_inference (() : unit) : unit = - (* TODO: optimize to keep all needed information in unsolved, rather than starting with all constraints. *) + (* TODO: optimize to keep all needed information in unsolved, rather than starting with all + constraints. *) let unsolved, env = Row.solve_inequalities ~stage:Stage2 !active_constraints !state in let unsolved, env = Row.solve_inequalities ~stage:Stage3 unsolved env in let unsolved, env = Row.solve_inequalities ~stage:Stage4 unsolved env in let unsolved, env = Row.solve_inequalities ~stage:Stage5 unsolved env in let unsolved, env = Row.solve_inequalities ~stage:Stage6 unsolved env in - let eliminated = List.concat_map ~f:(apply_env_update ~eliminate_variables:true env) !active_update_steps in + let eliminated = + List.concat_map ~f:(apply_env_update ~eliminate_variables:true env) !active_update_steps + in let unsolved, env = Row.solve_inequalities ~stage:Stage7 (eliminated @ unsolved) env in assert (List.is_empty unsolved); ignore @@ List.map ~f:(apply_env_update ~eliminate_variables:false env) !active_update_steps; @@ -625,17 +653,19 @@ let fresh_proj_ids update = fresh_shape sh1; fresh_shape sh2 -(** Computes the indexing into subtensors given the shape information of a tensor. [derive_projections] should - only be invoked when the shapes are fully inferred already! *) +(** Computes the indexing into subtensors given the shape information of a tensor. + [derive_projections] should only be invoked when the shapes are fully inferred already! *) let derive_projections (update_step : update_step) : Idx.projections = finish_inference (); fresh_proj_ids update_step; let _debug_update_step : update_step = update_step in - let (proj_axis_env, ineqs) : proj_axis_env * Row.constraint_ list = get_inequalities update_step in - (* We need to solve the equations/inequalities one last time because of fresh row variables potentially - generated by [get_inequalities]. Since the variables in the shapes must be substituted-out at this point, - the global state is already an empty env, but in principle we want to only find a local solution to not - contaminate projections across operations. *) + let (proj_axis_env, ineqs) : proj_axis_env * Row.constraint_ list = + get_inequalities update_step + in + (* We need to solve the equations/inequalities one last time because of fresh row variables + potentially generated by [get_inequalities]. Since the variables in the shapes must be + substituted-out at this point, the global state is already an empty env, but in principle we + want to only find a local solution to not contaminate projections across operations. *) let unsolved, local_env = Row.solve_inequalities ~stage:Stage1 ineqs Row.empty_env in let unsolved, local_env = Row.solve_inequalities ~stage:Stage2 unsolved local_env in let unsolved, local_env = Row.solve_inequalities ~stage:Stage3 unsolved local_env in @@ -686,7 +716,8 @@ let derive_projections (update_step : update_step) : Idx.projections = trace = [ ("derive_projections", Idx.unique_debug_id ()) ]; }; } - with Row.Shape_error (s, trace) -> raise @@ Row.Shape_error (s, Shape_mismatch (lhs :: rhs) :: trace) + with Row.Shape_error (s, trace) -> + raise @@ Row.Shape_error (s, Shape_mismatch (lhs :: rhs) :: trace) (** {2 Shape builders.} *) @@ -694,7 +725,11 @@ let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_ax ?(deduced = Not_constrained) ~debug_name ~id () = let open Row in let make_dims kind ds = - { dims = List.map ~f:(fun d -> get_dim ~d ()) ds; bcast = Broadcastable; id = row_id ~sh_id:id ~kind } + { + dims = List.map ~f:(fun d -> get_dim ~d ()) ds; + bcast = Broadcastable; + id = row_id ~sh_id:id ~kind; + } in let make_axes kind ds = { @@ -704,7 +739,11 @@ let make ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_ax } in let make_unknown kind = - { dims = []; bcast = Row_var { v = get_row_var (); beg_dims = [] }; id = row_id ~sh_id:id ~kind } + { + dims = []; + bcast = Row_var { v = get_row_var (); beg_dims = [] }; + id = row_id ~sh_id:id ~kind; + } in let batch = match (batch_dims, batch_axes) with @@ -750,7 +789,8 @@ let shape_spec_to_dims_bio labels = in try Row.get_dim ~d:(Int.of_string dim) ~label () with _ -> invalid_arg "shape_spec_to_dims_bio: int expected after '='") - | First label -> Var (Hashtbl.find_or_add dim_var_env label ~default:(fun () -> Row.get_var ~label ())) + | First label -> + Var (Hashtbl.find_or_add dim_var_env label ~default:(fun () -> Row.get_var ~label ())) | Second d -> Row.get_dim ~d () in let row_var_env = Hashtbl.create (module String) in @@ -778,7 +818,10 @@ let to_string_hum ?(style = `Axis_size) (sh : t) = String.concat ~sep:"," @@ List.mapi dims ~f:(fun i d -> let num = - match kind with `Input -> n_batch + n_outputs + i | `Output -> n_batch + i | `Batch -> i + match kind with + | `Input -> n_batch + n_outputs + i + | `Output -> n_batch + i + | `Batch -> i in match style with | `Only_labels | `Axis_size -> Row.dim_to_string style d diff --git a/lib/shape.mli b/lib/shape.mli index abbefd3f..d3a67048 100644 --- a/lib/shape.mli +++ b/lib/shape.mli @@ -7,35 +7,36 @@ - Comes in two variants: single-character and multicharacter; - if there is a comma [','] anywhere in the initial text, the multicharacter version is used, - otherwise the single character version is used. - - Currently, the only non-whitespace, non-alphanumeric characters that make sense / are allowed in a spec - are: ['>', '|', '-', ',', '=', ';']. - - identifier: single alphanum character or '_' in single-char mode, a sequence of alphanum characters or - '_' otherwise (whitespace not allowed). + - Currently, the only non-whitespace, non-alphanumeric characters that make sense / are allowed + in a spec are: ['>', '|', '-', ',', '=', ';']. + - identifier: single alphanum character or '_' in single-char mode, a sequence of alphanum + characters or '_' otherwise (whitespace not allowed). - separators: a sequence of commas and whitespaces. - separators_with_comma: commas and whitespaces containing at least one comma. - axes_spec_single_char: separators? identifier+ separators? - axes_spec_multichar: separators? (identifier separators_with_comma)* identifier separators? - ellipsis_spec: '...' <|> '..' identifier '..' - row_spec: axes_spec <|> ellipsis_spec axes_spec <|> axes_spec ellipsis_spec axes_spec - - labels_spec: row_spec <|> row_spec '|' row_spec <|> row_spec '->' row_spec <|> row_spec '|' row_spec - '->' row_spec. + - labels_spec: row_spec <|> row_spec '|' row_spec <|> row_spec '->' row_spec <|> row_spec '|' + row_spec '->' row_spec. - permute_spec: labels_spec '=>' labels_spec - einsum_spec: labels_spec ';' labels_spec '=>' labels_spec - If labels_spec does not contain ["|"] nor ["->"], each label is of the kind [Output]. If the spec doesn't - contain ["|"], labels to the left of ["->"] are [Input] and to the right [Output]. Labels to the left of - ["|"] are [Batch], and between ["|"] and ["->"] are [Input]. + If labels_spec does not contain ["|"] nor ["->"], each label is of the kind [Output]. If the + spec doesn't contain ["|"], labels to the left of ["->"] are [Input] and to the right [Output]. + Labels to the left of ["|"] are [Batch], and between ["|"] and ["->"] are [Input]. - The labels [".."ident".."], ["..."] (where [ident] does not contain any of the special characters) are - only allowed once for a kind. They are used to enable (in-the-middle) broadcasting for the axis kind in - the einsum-related shape inference (like the ellipsis ["..."] in [numpy.einsum]), and are translated to - row variables. The ellipsis ["..."] is context dependent: in the batch row it is the same as - ["..batch.."], in the input row the same as ["..input.."], in the output row the same as ["..output.."]. - When the same row variable is used in multiple rows, the corresponding broadcasted axes are matched - pointwise in the resulting operation. + The labels [".."ident".."], ["..."] (where [ident] does not contain any of the special + characters) are only allowed once for a kind. They are used to enable (in-the-middle) + broadcasting for the axis kind in the einsum-related shape inference (like the ellipsis ["..."] + in [numpy.einsum]), and are translated to row variables. The ellipsis ["..."] is context + dependent: in the batch row it is the same as ["..batch.."], in the input row the same as + ["..input.."], in the output row the same as ["..output.."]. When the same row variable is used + in multiple rows, the corresponding broadcasted axes are matched pointwise in the resulting + operation. - The label ["_"] is a place-holder: it is not output to the resulting map but aligns the axes of other - labels. *) + The label ["_"] is a place-holder: it is not output to the resulting map but aligns the axes of + other labels. *) (** {2 User-ish API.} *) @@ -53,18 +54,21 @@ type t = { type deduce_within_shape = Not_constrained | Input_equals_output [@@deriving compare, sexp] type compose_type = - | Pointwise_bin (** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *) + | Pointwise_bin + (** NumPy-style broadcast matching batch, input and output axes, e.g. as in [s1 + s2]. *) | Compose - (** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape of - [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix multiply). *) + (** Compose the outputs of the second shape with the inputs of the first shape, i.e. the shape + of [fun x -> s1(s2(x))], or [s1 * s2] where [*] is the inner product (e.g. matrix + multiply). *) | Einsum of string - (** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications. Since - OCANNL's extended einsum notation supports both axis variables and row variables, it makes other - compose types redundant. The [axis_labels] use pseudo-labels local to the notation, to line up the - axes and row variables. The symmetric difference / disjunctive union of RHS1 and RHS2's - pseudo-labels should be equal to LHS pseudo-labels. - - Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", "rhs1;rhs2=>lhs". *) + (** The binary "einsum" syntax: RHS1;RHS2=>LHS, where RHSi, LHS are labels specifications. + Since OCANNL's extended einsum notation supports both axis variables and row variables, it + makes other compose types redundant. The [axis_labels] use pseudo-labels local to the + notation, to line up the axes and row variables. The symmetric difference / disjunctive + union of RHS1 and RHS2's pseudo-labels should be equal to LHS pseudo-labels. + + Note: The "right-hand-side" is on the left! I.e. the syntax is "rhs=>lhs", + "rhs1;rhs2=>lhs". *) [@@deriving sexp, equal] type transpose_type = @@ -87,31 +91,33 @@ val make : unit -> t (** Creates a shape. [id] should be the id the associated tensor (if any). At most one of the pairs - [batch_dims], [batch_axes] etc. should be given: if none, the corresponding row will be inferred. - [batch_axes] etc. provide labels for the dimensions of the corresponding axes. Note that these are - dimensions labels and not axis labels: they need not be unique for a row, are inferred when provided, and - must match whenever the axis sizes must match. *) + [batch_dims], [batch_axes] etc. should be given: if none, the corresponding row will be + inferred. [batch_axes] etc. provide labels for the dimensions of the corresponding axes. Note + that these are dimensions labels and not axis labels: they need not be unique for a row, are + inferred when provided, and must match whenever the axis sizes must match. *) val to_string_hum : - ?style:[< `Axis_number_and_size | `Axis_size | `Only_labels > `Axis_size `Only_labels ] -> t -> string + ?style:[< `Axis_number_and_size | `Axis_size | `Only_labels > `Axis_size `Only_labels ] -> + t -> + string (** {2 Internal-ish API.} *) -(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the tensor. Axes - are broadcast-expanded on a bottom-up update to fit the incoming shape. *) +(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the + tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *) type logic = | Broadcast of compose_type * t * t (** Matches the shapes for a binary operation. - For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match according - to the [ls1], [ls2] lineup, and the resulting shape inherits the labels according to the [ls3] - lineup. *) + For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match + according to the [ls1], [ls2] lineup, and the resulting shape inherits the labels + according to the [ls3] lineup. *) | Transpose of transpose_type * t - (** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of [s1], hence - the name. *) + (** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of + [s1], hence the name. *) | Terminal of Arrayjit.Ops.init_op - (** Extracts any available shape information from the initialization. E.g. for [File_mapped fn], opens - the file [fn] to check its length. *) + (** Extracts any available shape information from the initialization. E.g. for + [File_mapped fn], opens the file [fn] to check its length. *) [@@deriving equal, sexp] type update_id [@@deriving equal, compare, hash, sexp] @@ -119,17 +125,18 @@ type update_id [@@deriving equal, compare, hash, sexp] val get_update_id : unit -> update_id type update_step = { shape : t; logic : logic; id : update_id } [@@deriving sexp] -(** Data required for a shape inference update step. Ideally, an update should be performed at least twice, - the second time after all the other relevant updates have been performed for the first time. In OCANNL, - this is achieved by performing updates both as the tensors are constructed, and via lazy callbacks as the - corresponding [Arrayjit.Indexing] dimensions and projections are first accessed. *) +(** Data required for a shape inference update step. Ideally, an update should be performed at least + twice, the second time after all the other relevant updates have been performed for the first + time. In OCANNL, this is achieved by performing updates both as the tensors are constructed, and + via lazy callbacks as the corresponding [Arrayjit.Indexing] dimensions and projections are first + accessed. *) val to_dims : t -> int array val propagate_shapes : update_step -> unit val derive_projections : update_step -> Arrayjit.Indexing.projections -(** Computes the indexing into subtensors given the shape information of a tensor. [derive_projections] should - only be invoked when the shapes are fully inferred already! *) +(** Computes the indexing into subtensors given the shape information of a tensor. + [derive_projections] should only be invoked when the shapes are fully inferred already! *) val of_spec : ?deduced:deduce_within_shape -> debug_name:string -> id:int -> string -> t val default_display_indices : t -> int array diff --git a/lib/tensor.ml b/lib/tensor.ml index c8673129..4a3d7f40 100644 --- a/lib/tensor.ml +++ b/lib/tensor.ml @@ -14,7 +14,8 @@ type projections = Arrayjit.Indexing.projections [%%global_debug_log_level Nothing] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] -type diff = { grad : (Tn.t[@sexp.opaque]); zero_grads : Asgns.t; backprop : Asgns.t } [@@deriving sexp_of] +type diff = { grad : (Tn.t[@sexp.opaque]); zero_grads : Asgns.t; backprop : Asgns.t } +[@@deriving sexp_of] type t = { forward : Asgns.t; @@ -39,7 +40,9 @@ let rec sexp_of_t t = and sexp_of_subtensor ch = Sexp.message "child" - [ (if ch.embedded then ("", sexp_of_t ch.subtensor) else ("ref-id", sexp_of_int ch.subtensor.id)) ] + [ + (if ch.embedded then ("", sexp_of_t ch.subtensor) else ("ref-id", sexp_of_int ch.subtensor.id)); + ] include Comparator.Make (struct type nonrec t = t @@ -60,7 +63,9 @@ let session_state = let is_fwd_root t = Map.mem session_state.forward_roots t.id let remove_fwd_root t = session_state.forward_roots <- Map.remove session_state.forward_roots t.id let is_bprop_root t = Map.mem session_state.backprop_roots t.id -let remove_bprop_root t = session_state.backprop_roots <- Map.remove session_state.backprop_roots t.id + +let remove_bprop_root t = + session_state.backprop_roots <- Map.remove session_state.backprop_roots t.id let with_unchanged_roots ~f = let fwd_roots = session_state.forward_roots in @@ -93,12 +98,16 @@ let session_error_printer = function let () = Stdlib.Printexc.register_printer session_error_printer let lazy_to_dims shape = lazy (Shape.to_dims shape) -let fetch_zeros array shape = Asgns.Fetch { array; fetch_op = Constant 0.; dims = lazy_to_dims shape } + +let fetch_zeros array shape = + Asgns.Fetch { array; fetch_op = Constant 0.; dims = lazy_to_dims shape } + let default_init_op = Arrayjit.Ops.Constant_fill { values = [| 0.0 |]; strict = false } let max_sublabel_length = ref 25 -let raw_binop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) ~(rhs1_is_grad : bool) - ~(rhs1_is_merge : bool) ~(t2 : t) ~rhs2_is_grad ~rhs2_is_merge ~logic : Asgns.t = +let raw_binop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) + ~(rhs1_is_grad : bool) ~(rhs1_is_merge : bool) ~(t2 : t) ~rhs2_is_grad ~rhs2_is_merge ~logic : + Asgns.t = let shape = t.shape in let shape_logic = Shape.Broadcast (logic, t1.shape, t2.shape) in let local_shape_update = Shape.{ shape; logic = shape_logic; id = get_update_id () } in @@ -111,8 +120,8 @@ let raw_binop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 let rhs2 = if rhs2_is_merge then Asgns.Merge_buffer rhs2 else Node rhs2 in Asgns.Accum_binop { initialize_neutral; accum; lhs; op; rhs1; rhs2; projections } -let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) ~(rhs_is_grad : bool) - ~(rhs_is_merge : bool) ~logic = +let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 : t) + ~(rhs_is_grad : bool) ~(rhs_is_merge : bool) ~logic = let shape = t.shape in let shape_logic = Shape.Transpose (logic, t1.shape) in let local_shape_update = Shape.{ shape; logic = shape_logic; id = get_update_id () } in @@ -125,15 +134,16 @@ let raw_unop ~initialize_neutral ~accum ~(t : t) ~(lhs_is_grad : bool) ~op ~(t1 type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equal, variants] -let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) - ?(init_op = default_init_op) ~op_asn ~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t - = +let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin) + ?(transpose_op = Shape.Pointwise_un) ?(init_op = default_init_op) ~op_asn ~grad_asn + ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t = let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in let children = List.folding_map orig_ts ~init:(Set.empty (module Int)) ~f:(fun used ti -> - (Set.add used ti.id, { subtensor = ti; embedded = is_fwd_root ti && not (Set.mem used ti.id) })) + ( Set.add used ti.id, + { subtensor = ti; embedded = is_fwd_root ti && not (Set.mem used ti.id) } )) in let id = session_state.next_id in session_state.next_id <- session_state.next_id + 1; @@ -162,7 +172,8 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin) ?(transpose_op List.iter ordered_ts ~f:(fun ti -> remove_fwd_root ti); if is_prohibit_grad grad_spec - || (Fn.non is_require_grad grad_spec && List.for_all orig_ts ~f:(fun ti -> Option.is_none ti.diff)) + || Fn.non is_require_grad grad_spec + && List.for_all orig_ts ~f:(fun ti -> Option.is_none ti.diff) then ( let tensor = { forward; diff = None; id; value = v; shape; children } in session_state.forward_roots <- Map.add_exn session_state.forward_roots ~key:id ~data:tensor; @@ -181,14 +192,19 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin) ?(transpose_op let is_bck_root ti = Map.mem session_state.backprop_roots ti.id in let zero_grads = let zero_g = dcode ~f:(fun diff -> diff.zero_grads) in - let zeros = List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then zero_g ti else Asgns.Noop) in + let zeros = + List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then zero_g ti else Asgns.Noop) + in Asgns.sequential @@ zeros @ [ fetch_zeros g shape ] in - (* The code needs to be included in the reverse order to which it was computed! This guarantees that all - ancestors of a node are backpropagated before the node is backpropagated, even for non-tree DAGs. *) + (* The code needs to be included in the reverse order to which it was computed! This guarantees + that all ancestors of a node are backpropagated before the node is backpropagated, even for + non-tree DAGs. *) let backprop = let bprop = dcode ~f:(fun diff -> diff.backprop) in - let bcks = List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then bprop ti else Asgns.Noop) in + let bcks = + List.map ordered_ts ~f:(fun ti -> if is_bck_root ti then bprop ti else Asgns.Noop) + in Asgns.sequential @@ (grad_asn ~v ~g ~projections :: List.rev bcks) in List.iter ordered_ts ~f:(fun ti -> @@ -210,8 +226,8 @@ let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 = let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~projections in op ~label ?compose_op:None ?transpose_op ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1 ] -let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ?deduced - ?init_op ?fetch_op () = +let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes + ?deduced ?init_op ?fetch_op () = let op_asn ~v ~projections = let open Asgns in let dims = lazy (Lazy.force projections).Idx.lhs_dims in @@ -231,8 +247,8 @@ let term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?inp (match fetch_op with | Constant _ | Slice _ | Embed_symbol _ -> () | Imported _ -> - (* Note: [Imported] can be used for merging across devices. But, some use cases of [Imported] will - require a hosted tensor node. *) + (* Note: [Imported] can be used for merging across devices. But, some use cases of + [Imported] will require a hosted tensor node. *) Tn.update_memory_mode v Materialized 22); Fetch { array = v; fetch_op; dims } in @@ -257,8 +273,8 @@ let number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c = Tn.update_memory_mode t.value Effectively_constant 24; t -let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims ?batch_axes - ?input_axes ?output_axes ?(strict = true) values = +let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?output_dims + ?batch_axes ?input_axes ?output_axes ?(strict = true) values = let to_dim_list dims axes = Option.value ~default:[] @@ Option.first_some dims @@ Option.map axes ~f:(List.map ~f:snd) in @@ -271,19 +287,22 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ? let dims = Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list in let ndarr = Nd.create_array Arrayjit.Ops.double ~dims (Constant_fill { values; strict }) in let ( ! ) = List.length in - Nd.pp_array_inline ~num_batch_axes:!batch_ds ~num_output_axes:!output_ds ~num_input_axes:!input_ds - Stdlib.Format.str_formatter ndarr; + Nd.pp_array_inline ~num_batch_axes:!batch_ds ~num_output_axes:!output_ds + ~num_input_axes:!input_ds Stdlib.Format.str_formatter ndarr; Stdlib.Format.flush_str_formatter () in let op_label = if String.contains op_label '\n' then - "c" ^ Idx.dims_to_string @@ Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list + "c" ^ Idx.dims_to_string + @@ Array.concat_map [| batch_ds; output_ds; input_ds |] ~f:Array.of_list else op_label in let label = op_label :: label in let batch_dims = Option.first_some batch_dims @@ Option.some_if (Option.is_none batch_axes) [] in let input_dims = Option.first_some input_dims @@ Option.some_if (Option.is_none input_axes) [] in - let output_dims = Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) [] in + let output_dims = + Option.first_some output_dims @@ Option.some_if (Option.is_none output_axes) [] + in let t = term ~label ~grad_spec ?batch_dims ?input_dims ?output_dims ?batch_axes ?input_axes ?output_axes ~deduced:Not_constrained @@ -293,21 +312,23 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ? Tn.update_memory_mode t.value Effectively_constant 24; t -let param ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?(strict = false) ?values label = +let param ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?(strict = false) ?values label + = let init_op = match values with | Some values -> Arrayjit.Ops.Constant_fill { values; strict } | None -> Standard_uniform in let t = - term ~label:[ label ] ~grad_spec:Require_grad ~batch_dims:[] ?input_dims ?output_dims ?input_axes - ?output_axes ?deduced ~init_op () + term ~label:[ label ] ~grad_spec:Require_grad ~batch_dims:[] ?input_dims ?output_dims + ?input_axes ?output_axes ?deduced ~init_op () in let v = t.value in (* It is convenient to use the param syntax for volatiles (mutable inputs). *) Tn.update_memory_mode v (Hosted Nonconstant) 24; - (* In principle, gradients can even be local, if a single jitted block does forward, backprop, and update - computations. Use-cases needing [Materialized] gradients need to request that before any jitting. *) + (* In principle, gradients can even be local, if a single jitted block does forward, backprop, and + update computations. Use-cases needing [Materialized] gradients need to request that before any + jitting. *) let g = (Option.value_exn ~here:[%here] t.diff).grad in Tn.update_memory_mode g Never_virtual 26; t @@ -321,7 +342,8 @@ let consume_forward_code t = if not @@ is_fwd_root t then raise @@ Session_error - ( "Tensor.consume_forward_code: tensor is not a root for tnode: " ^ debug_name t ~label:t.value.label, + ( "Tensor.consume_forward_code: tensor is not a root for tnode: " + ^ debug_name t ~label:t.value.label, Some t ); let unsafe_roots = Map.data session_state.forward_roots @@ -368,13 +390,16 @@ found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_n let header t = let v_dims_s = Tn.dims_to_string t.value in - let g_dims_s = match t.diff with None -> "" | Some diff -> Tn.dims_to_string diff.grad in + let g_dims_s = + match t.diff with None -> "" | Some diff -> Tn.dims_to_string diff.grad + in let dims_s = if String.equal v_dims_s g_dims_s then "dims " ^ v_dims_s else "dims val " ^ v_dims_s ^ " grad " ^ g_dims_s in "#" ^ Int.to_string t.id ^ " " ^ Tn.label t.value ^ " " ^ dims_s ^ " [" - ^ String.concat ~sep:"," (List.map t.children ~f:(fun { subtensor = { id; _ }; _ } -> Int.to_string id)) + ^ String.concat ~sep:"," + (List.map t.children ~f:(fun { subtensor = { id; _ }; _ } -> Int.to_string id)) ^ "]" (*^" "^PrintBox_text.to_string (PrintBox.Simple.to_box v.label)*) @@ -385,7 +410,8 @@ let lazy_optional_payload ~present ~missing v = | None -> `Vlist (false, [ `Text (missing ()); `Text "" ]) else `Vlist (false, [ `Text (missing ()); `Text " " ]) -type array_print_style = [ `Default | `Inline | `Label_layout of (string * int) list | `N5_layout of string ] +type array_print_style = + [ `Default | `Inline | `Label_layout of (string * int) list | `N5_layout of string ] let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_value ~with_grad t = let rec to_dag { subtensor = t; embedded } : PrintBox_utils.dag = @@ -396,7 +422,8 @@ let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_v let where_located a = match a.Tn.memory_mode with | None -> "" - | Some (m, prov) -> [%string "<%{Sexp.to_string_hum @@ Tn.sexp_of_memory_mode m} %{prov#Int}>"] + | Some (m, prov) -> + [%string "<%{Sexp.to_string_hum @@ Tn.sexp_of_memory_mode m} %{prov#Int}>"] in let txt = if with_id then "#" ^ id ^ " " ^ Tn.label t.value (* ^ " DEBUG: " ^ where_located t.value *) @@ -405,7 +432,8 @@ let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_v let grad_txt diff = let label = Tn.label diff.grad in let label = - if String.is_substring (String.lowercase label) ~substring:"grad" then label else label ^ " Gradient" + if String.is_substring (String.lowercase label) ~substring:"grad" then label + else label ^ " Gradient" in if with_id then "#" ^ Int.to_string diff.grad.id ^ " " ^ label (* ^ " DEBUG: " ^ where_located diff.grad *) @@ -425,7 +453,8 @@ let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_v let node = lazy_optional_payload t.value.array ~present:(fun v_array -> - `Box (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array)) + `Box + (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array)) ~missing:(fun () -> txt ^ " " ^ where_located t.value) in `Subtree_with_ID (id, `Tree (add_shape [ node ], children)) @@ -443,15 +472,17 @@ let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_v let value = lazy_optional_payload t.value.array ~present:(fun v_array -> - `Box (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices v_array)) + `Box + (Nd.render_array ~brief:true ~prefix:txt ?entries_per_axis ~labels ~indices + v_array)) ~missing:(fun () -> txt ^ " " ^ where_located t.value) in let grad = lazy_optional_payload diff.grad.array ~present:(fun g_array -> `Box - (Nd.render_array ~brief:true ~prefix:(grad_txt diff) ?entries_per_axis ~labels ~indices - g_array)) + (Nd.render_array ~brief:true ~prefix:(grad_txt diff) ?entries_per_axis ~labels + ~indices g_array)) ~missing:(fun () -> grad_txt diff ^ " " ^ where_located diff.grad) in `Vlist (false, [ value; grad ]) @@ -460,12 +491,13 @@ let to_dag ?(single_node = false) ?entries_per_axis ~with_shape ~with_id ~with_v in to_dag { subtensor = t; embedded = true } -let to_printbox ?single_node ?entries_per_axis ?(with_id = false) ?(with_shape = false) ?(with_value = true) - ~with_grad ~depth t = +let to_printbox ?single_node ?entries_per_axis ?(with_id = false) ?(with_shape = false) + ?(with_value = true) ~with_grad ~depth t = to_dag ?single_node ?entries_per_axis ~with_id ~with_shape ~with_value ~with_grad t |> PrintBox_utils.reformat_dag depth -let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (style : array_print_style) t = +let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) + (style : array_print_style) t = let sh = t.shape in let label = Tn.label t.value in let prefix = @@ -475,7 +507,8 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl in let grad_txt diff = let label = Tn.label diff.grad in - if String.is_substring (String.lowercase label) ~substring:"grad" then label else label ^ " Gradient" + if String.is_substring (String.lowercase label) ~substring:"grad" then label + else label ^ " Gradient" in let labels = Shape.to_labels t.shape in let indices = @@ -494,7 +527,8 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl in let inv_labels = match inv_labels with - | `Duplicate_key l -> raise @@ Session_error ("`Label_layout found a repeating label: " ^ l, Some t) + | `Duplicate_key l -> + raise @@ Session_error ("`Label_layout found a repeating label: " ^ l, Some t) | `Ok inv_labels -> inv_labels in let result = Array.create ~len:(Array.length labels) 0 in @@ -519,8 +553,9 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl match (style, t.value.array) with | `Inline, (lazy None) -> Stdlib.Format.printf "@ " | `Inline, (lazy (Some arr)) -> - Nd.pp_array_inline (Stdlib.Format.get_std_formatter ()) ~num_batch_axes ~num_input_axes - ~num_output_axes ?axes_spec arr + Nd.pp_array_inline + (Stdlib.Format.get_std_formatter ()) + ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr | _, (lazy None) -> Stdlib.Format.printf "@ " | _, (lazy (Some arr)) -> Nd.pp_array (Stdlib.Format.get_std_formatter ()) ~prefix ~labels ~indices arr; @@ -532,11 +567,13 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl else match (style, diff.grad.array) with | `Inline, (lazy (Some arr)) -> - Nd.pp_array_inline (Stdlib.Format.get_std_formatter ()) ~num_batch_axes ~num_input_axes - ~num_output_axes ?axes_spec arr; + Nd.pp_array_inline + (Stdlib.Format.get_std_formatter ()) + ~num_batch_axes ~num_input_axes ~num_output_axes ?axes_spec arr; Stdlib.Format.print_newline () | _, (lazy (Some arr)) -> - Nd.pp_array (Stdlib.Format.get_std_formatter ()) + Nd.pp_array + (Stdlib.Format.get_std_formatter ()) ~prefix:(prefix ^ " " ^ grad_txt diff) ~labels ~indices arr; Stdlib.Format.print_newline () @@ -544,7 +581,8 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl if with_code then ( (match t.forward with | Noop -> () - | fwd_code -> Stdlib.Format.printf "@[Current forward body:%a@]@," (Asgns.fprint_hum ()) fwd_code); + | fwd_code -> + Stdlib.Format.printf "@[Current forward body:%a@]@," (Asgns.fprint_hum ()) fwd_code); match t.diff with | Some { backprop = Noop; _ } -> () | Some { backprop = bwd_code; _ } -> @@ -554,7 +592,8 @@ let print ~with_grad ~with_code ?(force = false) ?(with_low_level = false) (styl (match t.forward with | Noop -> () | fwd_code -> - Stdlib.Format.printf "@[Current forward low-level body:%a@]@," (Arrayjit.Low_level.fprint_hum ()) + Stdlib.Format.printf "@[Current forward low-level body:%a@]@," + (Arrayjit.Low_level.fprint_hum ()) @@ Asgns.to_low_level fwd_code); match t.diff with | Some { backprop = Noop; _ } -> () @@ -570,8 +609,8 @@ let print_forward_roots ~with_grad ~with_code (style : array_print_style) = assert (id = root.id); print ~with_grad ~with_code style root) -let print_tree ?entries_per_axis ?(with_backend_info = false) ?(with_id = true) ?(with_shape = false) - ?(with_value = true) ~with_grad ~depth t = +let print_tree ?entries_per_axis ?(with_backend_info = false) ?(with_id = true) + ?(with_shape = false) ?(with_value = true) ~with_grad ~depth t = (* FIXME: print backend info *) ignore with_backend_info; PrintBox_text.output Stdio.stdout @@ PrintBox_utils.dag_to_box @@ PrintBox_utils.boxify depth @@ -596,7 +635,8 @@ let grad_2d_points ?from_axis ~xdim ~ydim t = match t.diff with | None -> [||] | Some diff -> - Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr) + Option.value_map ~default:[||] ~f:(fun arr -> + Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr) @@ Lazy.force diff.grad.array let set_value t = Nd.set_from_float @@ Option.value_exn ~here:[%here] @@ Lazy.force t.value.array @@ -618,4 +658,5 @@ let set_values t values = @@ Option.value_exn ~here:[%here] @@ Lazy.force t.value.array) -let get_values t = Nd.(retrieve_flat_values @@ Option.value_exn ~here:[%here] @@ Lazy.force t.value.array) +let get_values t = + Nd.(retrieve_flat_values @@ Option.value_exn ~here:[%here] @@ Lazy.force t.value.array) diff --git a/lib/tensor.mli b/lib/tensor.mli index 89fbae26..bc80b2ea 100644 --- a/lib/tensor.mli +++ b/lib/tensor.mli @@ -10,10 +10,11 @@ type projections = Arrayjit.Indexing.projections type diff = { grad : tn; - zero_grads : asgns; (** Prepares for backpropagation. Always compile as: [Seq (zero_grads, backprop)]. *) + zero_grads : asgns; + (** Prepares for backpropagation. Always compile as: [Seq (zero_grads, backprop)]. *) backprop : asgns; - (** Backpropagates for the tensor and its descendants; which typically means adding partial gradients to - the gradient tensor of the subtensors, then for sub-subtensors etc. *) + (** Backpropagates for the tensor and its descendants; which typically means adding partial + gradients to the gradient tensor of the subtensors, then for sub-subtensors etc. *) } type t = { @@ -22,8 +23,8 @@ type t = { id : int; (** Same as [value.id]. *) value : tn; shape : Shape.t; - (** The eventual shape of [t.value] and [t.diff.grad], incorporating the current state of shape - inference. *) + (** The eventual shape of [t.value] and [t.diff.grad], incorporating the current state of + shape inference. *) children : subtensor list; } [@@deriving sexp_of] @@ -122,12 +123,12 @@ val term : ?fetch_op:(v:tn -> fetch_op) -> unit -> t -(** A terminal: a constant, a parameter, an input of the model. The semantics of shape specification is the - same as in {!Shape.make}, and by default the shape will be inferred. *) +(** A terminal: a constant, a parameter, an input of the model. The semantics of shape specification + is the same as in {!Shape.make}, and by default the shape will be inferred. *) val number : ?label:string list -> ?axis_label:string -> ?grad_spec:grad_spec -> float -> t -(** A number: a tensor with a single axis of one dimension, initialized to the given value. [grad_spec] is by - default [Prohibit_grad]. *) +(** A number: a tensor with a single axis of one dimension, initialized to the given value. + [grad_spec] is by default [Prohibit_grad]. *) val ndarray : ?label:string list -> @@ -141,10 +142,10 @@ val ndarray : ?strict:bool -> float array -> t -(** A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to no axes. - [grad_spec] is by default [Prohibit_grad]. If [strict] is [true] (the default), the given values must fill - the tensor's [value] node precisely; otherwise, the values will be looped over to populate the [value] - node. *) +(** A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to + no axes. [grad_spec] is by default [Prohibit_grad]. If [strict] is [true] (the default), the + given values must fill the tensor's [value] node precisely; otherwise, the values will be looped + over to populate the [value] node. *) val param : ?input_dims:int list -> @@ -163,14 +164,14 @@ val iter_embedded_arrays : f:(tn -> unit) -> t -> unit val consume_forward_code : t -> asgns (** A forward root is a tensor that is not (currently) used to compute another tensor. - [consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and checks that - there are no other forward roots for tensors with children. *) + [consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and + checks that there are no other forward roots for tensors with children. *) val consume_backprop_code : t -> asgns * asgns -(** A backprop root is a tensor with a gradient that is not (currently) receiving gradients from another - tensor. I.e. it is not currently used to compute a tensor with a gradient. [consume_backprop_code t] - ensures [t] is a backprop root, removes it from backprop roots, and checks that there are no other - backprop roots for tensors with children. *) +(** A backprop root is a tensor with a gradient that is not (currently) receiving gradients from + another tensor. I.e. it is not currently used to compute a tensor with a gradient. + [consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and + checks that there are no other backprop roots for tensors with children. *) (** {2 Printing.} *) @@ -179,34 +180,36 @@ val header : t -> string type array_print_style = [ `Default - (** The inner rectangles comprise both an input and an output axis, if available. Similarly, the outer - rectangle comprises a second-from-end input axis and a second-from-end output axis, if available. At - least one batch axis is output, when available. The axes that couldn't be output are printed at - position/dimension [0]. *) + (** The inner rectangles comprise both an input and an output axis, if available. Similarly, the + outer rectangle comprises a second-from-end input axis and a second-from-end output axis, if + available. At least one batch axis is output, when available. The axes that couldn't be + output are printed at position/dimension [0]. *) | `N5_layout of string - (** The string should provide exclusively non-negative integer pseudo-labels. The numbers [0]-[4] - represent the priorities of the axes to be printed out, where the priorities correspond to, from - highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the - outer rectangle, repetition (see also [Node.pp_print]). The numbers [n >= 5] stand for the actual - positions [n - 5] within the corresponding axes. *) + (** The string should provide exclusively non-negative integer pseudo-labels. The numbers + [0]-[4] represent the priorities of the axes to be printed out, where the priorities + correspond to, from highest: horizontal, vertical direction of the inner rectangle, + horizontal, vertical direction of the outer rectangle, repetition (see also + [Node.pp_print]). The numbers [n >= 5] stand for the actual positions [n - 5] within the + corresponding axes. *) | `Label_layout of (string * int) list - (** The association from axis labels to integers. The negative numbers [-5] to [-1] represent the - priorities of the axes to be printed out, where the priorities correspond to, from highest: - horizontal, vertical direction of the inner rectangle, horizontal, vertical direction of the outer - rectangle, repetition (as above). The numbers [n >= 0] stand for the actual positions within the - corresponding axes. Unspecified axes are printed at position [0]. *) + (** The association from axis labels to integers. The negative numbers [-5] to [-1] represent + the priorities of the axes to be printed out, where the priorities correspond to, from + highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical + direction of the outer rectangle, repetition (as above). The numbers [n >= 0] stand for the + actual positions within the corresponding axes. Unspecified axes are printed at position + [0]. *) | `Inline (** The tensors are printed linearly, in a bracketed manner, optionally prefixed with the labels - specification. Note that the syntax causes ambiguity for 1-dimensional input axes (underscores are - used for axes without explicit labels); when there is a 1-dimensional input axis, we output the labels - specification even if there are no axis labels as a way to display the number of axes. The axis - nesting is right-to-left (rightmost is innermost). The input axes are innermost and the batch axes - outermost. The input axes use [,] as a separator and [()] as axis delimiters, but the delimiter for - the outermost (i.e. leftmost) axis is omitted. The output axes use [;] as a separator and [[]] as axis - delimiters (obligatory). The batch axes use [;] as a separator and [[||]] as axis delimiters - (obligatory). *) ] -(** We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner) rectangles, - possibly repeated (screens). *) + specification. Note that the syntax causes ambiguity for 1-dimensional input axes + (underscores are used for axes without explicit labels); when there is a 1-dimensional input + axis, we output the labels specification even if there are no axis labels as a way to + display the number of axes. The axis nesting is right-to-left (rightmost is innermost). The + input axes are innermost and the batch axes outermost. The input axes use [,] as a separator + and [()] as axis delimiters, but the delimiter for the outermost (i.e. leftmost) axis is + omitted. The output axes use [;] as a separator and [[]] as axis delimiters (obligatory). + The batch axes use [;] as a separator and [[||]] as axis delimiters (obligatory). *) ] +(** We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner) + rectangles, possibly repeated (screens). *) val to_printbox : ?single_node:bool -> @@ -220,7 +223,13 @@ val to_printbox : PrintBox.t val print : - with_grad:bool -> with_code:bool -> ?force:bool -> ?with_low_level:bool -> array_print_style -> t -> unit + with_grad:bool -> + with_code:bool -> + ?force:bool -> + ?with_low_level:bool -> + array_print_style -> + t -> + unit val print_forward_roots : with_grad:bool -> with_code:bool -> array_print_style -> unit diff --git a/lib/train.ml b/lib/train.ml index 7fa8055f..3129293c 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -60,7 +60,9 @@ let fresh_backend ?backend_name ?(config = Arrayjit.Backend_types.Physical_devic backend let is_param t = - match t with { Tensor.children = []; diff = Some _; _ } -> not @@ Tn.known_not_param t.value | _ -> false + match t with + | { Tensor.children = []; diff = Some _; _ } -> not @@ Tn.known_not_param t.value + | _ -> false let get_params t = let rec loop accu { Tensor.subtensor = t; _ } = @@ -70,14 +72,16 @@ let get_params t = let save_params t = let file_name = - Option.value_or_thunk ~default:(fun () -> invalid_arg "Train.save_params: root tensor is not named") + Option.value_or_thunk ~default:(fun () -> + invalid_arg "Train.save_params: root tensor is not named") @@ Tn.ident_label t.Tensor.value in let with_name p = let v = p.Tensor.value in ( v, Option.value_or_thunk ~default:(fun () -> - invalid_arg @@ "Train.save_params: parameter is not named: " ^ Tn.name v ^ " " ^ Tn.label v) + invalid_arg @@ "Train.save_params: parameter is not named: " ^ Tn.name v ^ " " + ^ Tn.label v) @@ Tn.ident_label v ) in let with_names = get_params t |> Set.elements |> List.map ~f:with_name in @@ -88,14 +92,16 @@ let save_params t = let restore_params t = let file_name = - Option.value_or_thunk ~default:(fun () -> invalid_arg "Train.restore_params: root tensor is not named") + Option.value_or_thunk ~default:(fun () -> + invalid_arg "Train.restore_params: root tensor is not named") @@ Tn.ident_label t.Tensor.value in let with_name p = let v = p.Tensor.value in ( v, Option.value_or_thunk ~default:(fun () -> - invalid_arg @@ "Train.restore_params: parameter is not named: " ^ Tn.name v ^ " " ^ Tn.label v) + invalid_arg @@ "Train.restore_params: parameter is not named: " ^ Tn.name v ^ " " + ^ Tn.label v) @@ Tn.ident_label v ) in let with_names = get_params t |> Set.elements |> List.map ~f:with_name in @@ -116,8 +122,8 @@ let label_suffix label = @@ List.find ~f:(String.for_all ~f:(fun c -> Char.is_alphanum c || equal_char '_' c)) @@ List.rev label -(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived - comment. *) +(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a + label-derived comment. *) let forward ?(disable_rootness_check = false) t = let fwd = if disable_rootness_check then t.Tensor.forward else Tensor.consume_forward_code t in set_hosted t.Tensor.value; @@ -131,16 +137,18 @@ type updaten = { fwd_bprop : Asgns.t; } -(** Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived comments. - Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by default), sets the - parameters and their gradients as "non-local" (on-device). *) +(** Returns the tensor's forward, zeroing gradients, and backprop code wrapped with label-derived + comments. Sets the tensor's value as "fully on host". If [setup_for_parallel] is true (false by + default), sets the parameters and their gradients as "non-local" (on-device). *) let grad_update ?(disable_rootness_check = false) ?(setup_for_parallel = false) loss = set_hosted loss.Tensor.value; let params = get_params loss in if setup_for_parallel then Set.iter params ~f:(fun p -> set_materialized (Option.value_exn ~here:[%here] p.diff).grad); let label = label_suffix loss.value.label in - let fwd = if disable_rootness_check then loss.Tensor.forward else Tensor.consume_forward_code loss in + let fwd = + if disable_rootness_check then loss.Tensor.forward else Tensor.consume_forward_code loss + in let fwd_bprop = match loss.Tensor.diff with | Some diff -> @@ -160,7 +168,8 @@ let grad_update ?(disable_rootness_check = false) ?(setup_for_parallel = false) init_grad; Block_comment (label ^ " bprop", bprop); ] )) - | None -> raise @@ Tensor.Session_error ("Train.grad_update: tensor is not differentiable", Some loss) + | None -> + raise @@ Tensor.Session_error ("Train.grad_update: tensor is not differentiable", Some loss) in { loss; label; params; fwd_bprop } @@ -186,8 +195,8 @@ let sgd_update ~learning_rate ?momentum ?weight_decay ?nesterov l = in Asgns.Block_comment (l.label ^ " sgd update", code) -(** All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings - without ranges remain at their initial values. *) +(** All and only bindings with associated ranges are iterated, with the binding's initial value + lost. Bindings without ranges remain at their initial values. *) let%track_sexp sequential_loop ~f lowered_bindings = let rec loop = function | [] -> f () @@ -202,10 +211,10 @@ let%track_sexp sequential_loop ~f lowered_bindings = in loop lowered_bindings -(** Distributes iterated indices to workers in a round-robin fashion. All and only bindings with associated - ranges are iterated, with the binding's initial value lost. Bindings without ranges remain at their - initial values. [sync] is called after each round of calling all workers, and at the end if needed, with - the number of workers called during the round. *) +(** Distributes iterated indices to workers in a round-robin fashion. All and only bindings with + associated ranges are iterated, with the binding's initial value lost. Bindings without ranges + remain at their initial values. [sync] is called after each round of calling all workers, and at + the end if needed, with the number of workers called during the round. *) let%track_sexp round_robin fs parallel_jitbs jitbs ~sync : unit = let num_devices : int = Array.length fs in assert (Array.length parallel_jitbs = num_devices); @@ -255,22 +264,23 @@ let every_non_literal_on_host = Tensor.iter_embedded_arrays ~f:(fun a -> if Tn.mode_is_unspecified a && not (Tn.known_constant a) then set_hosted a) -let%debug_sexp all_host_to_device (type context) (module Backend : Backend_type with type context = context) - context = +let%debug_sexp all_host_to_device (type context) + (module Backend : Backend_type with type context = context) context = let f tn = ignore (Backend.from_host context tn : bool) in Tensor.iter_embedded_arrays ~f -let%debug_sexp all_device_to_host (type context) (module Backend : Backend_type with type context = context) - context = +let%debug_sexp all_device_to_host (type context) + (module Backend : Backend_type with type context = context) context = let f tn = ignore (Backend.to_host context tn : bool) in Tensor.iter_embedded_arrays ~f -(** Executes the jitted code and copies arrays embedded in the given tenosor from and to host, synchronizes - before copying to host. If [looping] is provided, loops over bindings and executes the given function - inside the loop after a run. All and only bindings with associated ranges are iterated, with the binding's - initial value lost. Bindings without ranges remain at their initial values. *) -let%track_sexp sync_run ?looping (type context) (module Backend : Backend_type with type context = context) - (routine : Backend.routine) t = +(** Executes the jitted code and copies arrays embedded in the given tenosor from and to host, + synchronizes before copying to host. If [looping] is provided, loops over bindings and executes + the given function inside the loop after a run. All and only bindings with associated ranges are + iterated, with the binding's initial value lost. Bindings without ranges remain at their initial + values. *) +let%track_sexp sync_run ?looping (type context) + (module Backend : Backend_type with type context = context) (routine : Backend.routine) t = all_host_to_device (module Backend) routine.context t; (match looping with | None -> Tn.run debug_rt routine.schedule @@ -285,25 +295,27 @@ let%track_sexp sync_run ?looping (type context) (module Backend : Backend_type w module Lazy = Utils.Lazy -(** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for different - devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of - [grad_updates] in a round-robin fashion, and performs the following synchronization each time all - [grad_updates] have been called: - - 1. merges all gradients into the device of [grad_updates.(0)], 2. calls [sgd_update], 3. copies all - parameters from the [grad_updates.(0)] device to the other devices, if needed, 4. calls [post_sync] with - the number of devices synced since the previous sync. - - All and only bindings with associated ranges are iterated, with the binding's initial value lost. Bindings - without ranges remain at their initial values. *) -let%track_sexp parallel_update (type context) (module Backend : Backend_type with type context = context) - ~(grad_updates : Backend.routine array) ~(sgd_update : Backend.routine) ~post_sync updaten : unit -> unit - = +(** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for + different devices). All jitted code must have the same bindings. Iterates over bindings with + ranges, calling one of [grad_updates] in a round-robin fashion, and performs the following + synchronization each time all [grad_updates] have been called: + + 1. merges all gradients into the device of [grad_updates.(0)], 2. calls [sgd_update], 3. copies + all parameters from the [grad_updates.(0)] device to the other devices, if needed, 4. calls + [post_sync] with the number of devices synced since the previous sync. + + All and only bindings with associated ranges are iterated, with the binding's initial value + lost. Bindings without ranges remain at their initial values. *) +let%track_sexp parallel_update (type context) + (module Backend : Backend_type with type context = context) + ~(grad_updates : Backend.routine array) ~(sgd_update : Backend.routine) ~post_sync updaten : + unit -> unit = assert (not @@ Array.is_empty grad_updates); let num_devices : int = Array.length grad_updates in let bindings : Idx.static_symbol list = List.map ~f:fst sgd_update.bindings in let occupancies = Array.init num_devices ~f:(fun _ -> Array.create ~len:num_devices false) in - (* to_, from positions correspond to the contexts (and devices) of grad_updates at the position. *) + (* to_, from positions correspond to the contexts (and devices) of grad_updates at the + position. *) let dry_merge ~from ~to_ = occupancies.(from).(to_) <- true in let dry_sync devices_to_sync = Arrayjit.Utils.parallel_merge dry_merge devices_to_sync in round_robin_dry_run ~num_devices sgd_update.bindings ~dry_sync; @@ -322,7 +334,8 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit in let grad_merges_to = Array.map ctxs ~f:(fun ctx -> - snd @@ Backend.link_batch ctx @@ Backend.compile_batch ~shared:true ~occupancy Idx.Empty grad_merges) + snd @@ Backend.link_batch ctx + @@ Backend.compile_batch ~shared:true ~occupancy Idx.Empty grad_merges) in (* We can cache scheduling, because merging and copying does not depend on static indexing. *) let loss_merge = @@ -340,12 +353,14 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit Array.iteri all_params ~f:(fun i p -> let grad_merge = Option.value_exn ~here:[%here] grad_merges_to.(to_).(i) in assert ( - Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer:Copy - ~dst:grad_merge.context ~src:ctxs.(from)); + Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad + ~into_merge_buffer:Copy ~dst:grad_merge.context ~src:ctxs.(from)); (Tn.run debug_rt grad_merge.schedule : unit)) in let merge_loss ~src = - assert (Backend.device_to_device updaten.loss.value ~into_merge_buffer:Copy ~dst:loss_merge.context ~src); + assert ( + Backend.device_to_device updaten.loss.value ~into_merge_buffer:Copy ~dst:loss_merge.context + ~src); Tn.run debug_rt loss_merge.schedule in (* FIXME: missing backcopy. *) @@ -361,12 +376,15 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit for to_ = 1 to num_devices - 1 do Array.iter all_params ~f:(fun p -> assert ( - Backend.device_to_device p.value ~into_merge_buffer:No ~dst:ctxs.(to_) ~src:sgd_update.context)) + Backend.device_to_device p.value ~into_merge_buffer:No ~dst:ctxs.(to_) + ~src:sgd_update.context)) done; post_sync ~num_synced_devices:devices_to_sync in let lowered_bindings = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.bindings)] in - let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Tn.run debug_rt upd.schedule)] in + let fs = + [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Tn.run debug_rt upd.schedule)] + in fun () -> round_robin fs lowered_bindings sgd_update.bindings ~sync let debug_name t = Tn.(debug_name ~id:t.Tensor.value.id ~label:t.value.label) @@ -386,8 +404,8 @@ let get_all_suggested_devices (type device) ?max_num_devices |> Array.concat_map ~f:Fn.id let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init_lr ?lr_schedule - ?max_num_devices ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn ~weight_decay ?per_batch_callback - ?per_epoch_callback (backend : (module Backend_type)) () = + ?max_num_devices ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn ~weight_decay + ?per_batch_callback ?per_epoch_callback (backend : (module Backend_type)) () = let module TDSL = Operation.TDSL in let module NTDSL = Operation.NTDSL in Rand.init seed; @@ -451,21 +469,26 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init learning_rates := learning_rate.@[0] :: !learning_rates; epoch_losses := !epoch_loss :: !epoch_losses; Option.iter per_epoch_callback ~f:(fun f -> - f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0] ~epoch_loss:!epoch_loss) + f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0] + ~epoch_loss:!epoch_loss) done; let%op model_result = model "infer" in let infer_fwd = - if disable_rootness_check then model_result.Tensor.forward else Tensor.consume_forward_code model_result + if disable_rootness_check then model_result.Tensor.forward + else Tensor.consume_forward_code model_result in set_on_host Volatile model_result.Tensor.value; - (* By using sgd_update.context here, maybe we don't need to copy the parameters back to the host. *) + (* By using sgd_update.context here, maybe we don't need to copy the parameters back to the + host. *) let routine = Backend.( - link sgd_update.context @@ compile IDX.empty @@ Block_comment (debug_name model_result, infer_fwd)) + link sgd_update.context @@ compile IDX.empty + @@ Block_comment (debug_name model_result, infer_fwd)) in let infer_callback values = Tensor.set_values infer values; - (* For the gccjit backend, infer is only on host, not on device. For cuda, this will be needed. *) + (* For the gccjit backend, infer is only on host, not on device. For cuda, this will be + needed. *) assert (Backend.from_host routine.context infer.value); run routine; assert (Backend.to_host routine.context model_result.value); diff --git a/npy/src/npy.ml b/npy/src/npy.ml index d49dd98c..b7305b49 100644 --- a/npy/src/npy.ml +++ b/npy/src/npy.ml @@ -32,7 +32,9 @@ let dtype ~packed_kind = let map_file file_descr ~pos kind layout shared shape = let is_scalar = Array.length shape = 0 in - let array = Unix.map_file file_descr ~pos kind layout shared (if is_scalar then [| 1 |] else shape) in + let array = + Unix.map_file file_descr ~pos kind layout shared (if is_scalar then [| 1 |] else shape) + in if is_scalar then Bigarray.reshape array [||] else array let fortran_order (type a) ~(layout : a Bigarray.layout) = @@ -76,7 +78,8 @@ let with_file filename flags mask ~f = let write ?header_len bigarray filename = with_file filename [ O_CREAT; O_TRUNC; O_RDWR ] 0o640 ~f:(fun file_descr -> let full_header = - full_header () ?header_len ~layout:(Bigarray.Genarray.layout bigarray) + full_header () ?header_len + ~layout:(Bigarray.Genarray.layout bigarray) ~packed_kind:(P (Bigarray.Genarray.kind bigarray)) ~dims:(Bigarray.Genarray.dims bigarray) in @@ -85,7 +88,8 @@ let write ?header_len bigarray filename = raise Cannot_write; let file_array = map_file ~pos:(Int64.of_int full_header_len) file_descr (Bigarray.Genarray.kind bigarray) - (Bigarray.Genarray.layout bigarray) true (Bigarray.Genarray.dims bigarray) + (Bigarray.Genarray.layout bigarray) + true (Bigarray.Genarray.dims bigarray) in Bigarray.Genarray.blit bigarray file_array) @@ -106,8 +110,9 @@ module Batch_writer = struct let file_array = map_file ~pos:(Int64.of_int t.bytes_written_so_far) - t.file_descr (Bigarray.Genarray.kind bigarray) (Bigarray.Genarray.layout bigarray) true - (Bigarray.Genarray.dims bigarray) + t.file_descr (Bigarray.Genarray.kind bigarray) + (Bigarray.Genarray.layout bigarray) + true (Bigarray.Genarray.dims bigarray) in Bigarray.Genarray.blit bigarray file_array; let size_in_bytes = Bigarray.Genarray.size_in_bytes bigarray in @@ -125,7 +130,8 @@ module Batch_writer = struct | _ :: d, _ :: d' -> d <> d' in if incorrect_dimensions then - Printf.sprintf "Incorrect dimensions %s vs %s." (shape ~dims) (shape ~dims:dims') |> failwith; + Printf.sprintf "Incorrect dimensions %s vs %s." (shape ~dims) (shape ~dims:dims') + |> failwith; dims.(0) <- dims.(0) + dims'.(0) let create filename = @@ -147,7 +153,8 @@ let really_read fd len = let buffer = Bytes.create len in let rec loop offset = let read = Unix.read fd buffer offset (len - offset) in - if read + offset < len then loop (read + offset) else if read = 0 then read_error "unexpected eof" + if read + offset < len then loop (read + offset) + else if read = 0 then read_error "unexpected eof" in loop 0; Bytes.to_string buffer @@ -193,7 +200,8 @@ module Header = struct |> List.filter (fun s -> String.length s > 0) |> List.map (fun header_field -> match split header_field ~on:':' with - | [ name; value ] -> (trim name ~on:[ '\''; ' ' ], trim value ~on:[ '\''; ' '; '('; ')' ]) + | [ name; value ] -> + (trim name ~on:[ '\''; ' ' ], trim value ~on:[ '\''; ' '; '('; ')' ]) | _ -> read_error "unable to parse field %s" header_field) in let find_field field = diff --git a/npy/src/npy.mli b/npy/src/npy.mli index cf6540d8..091d8fe1 100644 --- a/npy/src/npy.mli +++ b/npy/src/npy.mli @@ -4,12 +4,10 @@ val write1 : ('a, 'b, 'c) Bigarray.Array1.t -> string -> unit val write2 : ('a, 'b, 'c) Bigarray.Array2.t -> string -> unit val write3 : ('a, 'b, 'c) Bigarray.Array3.t -> string -> unit -(** [write ?header_len bigarray filename] writes a npy file [filename] - with the content of [bigarray]. - [header_len] can be used to override the npy header length. This is - only useful for testing. -*) val write : ?header_len:int -> ('a, 'b, 'c) Bigarray.Genarray.t -> string -> unit +(** [write ?header_len bigarray filename] writes a npy file [filename] with the content of + [bigarray]. [header_len] can be used to override the npy header length. This is only useful for + testing. *) module Batch_writer : sig type t @@ -24,10 +22,9 @@ type packed_array1 = P1 : (_, _, _) Bigarray.Array1.t -> packed_array1 type packed_array2 = P2 : (_, _, _) Bigarray.Array2.t -> packed_array2 type packed_array3 = P3 : (_, _, _) Bigarray.Array3.t -> packed_array3 -(** [read_mmap filename ~shared] returns a packed bigarray mmaped to the content - of [filename]. If [shared] is [true] modifications made to the array are reflected - to the file. *) val read_mmap : string -> shared:bool -> packed_array +(** [read_mmap filename ~shared] returns a packed bigarray mmaped to the content of [filename]. If + [shared] is [true] modifications made to the array are reflected to the file. *) val read_mmap1 : string -> shared:bool -> packed_array1 val read_mmap2 : string -> shared:bool -> packed_array2 @@ -43,9 +40,8 @@ module Npz : sig val open_in : string -> in_file val read : ?suffix:string -> in_file -> string -> packed_array + val restore : ?suffix:string -> in_file -> string -> ('a, 'b, 'c) Bigarray.Genarray.t -> unit (** Like {!read}, but stores the data directly in the provided bigarray. *) - val restore : - ?suffix:string -> in_file -> string -> ('a, 'b, 'c) Bigarray.Genarray.t -> unit val entries : in_file -> string list val close_in : in_file -> unit @@ -53,43 +49,38 @@ module Npz : sig type out_file val open_out : string -> out_file - - val write : - ?suffix:string -> out_file -> string -> ('a, 'b, 'c) Bigarray.Genarray.t -> unit - + val write : ?suffix:string -> out_file -> string -> ('a, 'b, 'c) Bigarray.Genarray.t -> unit val close_out : out_file -> unit end (** Conversion functions from packed arrays to bigarrays *) -(** [to_bigarray layout kind packed_array] returns [Some a] with - [a] a [Bigarray.Genarray.t] if the layout and the kind of [packed_array] - were equal to the [layout] and [kind] arguments. Otherwise, [to_bigarray] - returns [None] -*) val to_bigarray : - 'c Bigarray.layout - -> ('a, 'b) Bigarray.kind - -> packed_array - -> ('a, 'b, 'c) Bigarray.Genarray.t option + 'c Bigarray.layout -> + ('a, 'b) Bigarray.kind -> + packed_array -> + ('a, 'b, 'c) Bigarray.Genarray.t option +(** [to_bigarray layout kind packed_array] returns [Some a] with [a] a [Bigarray.Genarray.t] if the + layout and the kind of [packed_array] were equal to the [layout] and [kind] arguments. + Otherwise, [to_bigarray] returns [None] *) -(** Same as {!to_bigarray} for [Bigarray.Array1.t] *) val to_bigarray1 : - 'c Bigarray.layout - -> ('a, 'b) Bigarray.kind - -> packed_array1 - -> ('a, 'b, 'c) Bigarray.Array1.t option + 'c Bigarray.layout -> + ('a, 'b) Bigarray.kind -> + packed_array1 -> + ('a, 'b, 'c) Bigarray.Array1.t option +(** Same as {!to_bigarray} for [Bigarray.Array1.t] *) -(** Same as {!to_bigarray} for [Bigarray.Array2.t] *) val to_bigarray2 : - 'c Bigarray.layout - -> ('a, 'b) Bigarray.kind - -> packed_array2 - -> ('a, 'b, 'c) Bigarray.Array2.t option + 'c Bigarray.layout -> + ('a, 'b) Bigarray.kind -> + packed_array2 -> + ('a, 'b, 'c) Bigarray.Array2.t option +(** Same as {!to_bigarray} for [Bigarray.Array2.t] *) -(** Same as {!to_bigarray} for [Bigarray.Array3.t] *) val to_bigarray3 : - 'c Bigarray.layout - -> ('a, 'b) Bigarray.kind - -> packed_array3 - -> ('a, 'b, 'c) Bigarray.Array3.t option + 'c Bigarray.layout -> + ('a, 'b) Bigarray.kind -> + packed_array3 -> + ('a, 'b, 'c) Bigarray.Array3.t option +(** Same as {!to_bigarray} for [Bigarray.Array3.t] *) diff --git a/test/dune b/test/dune index 87a71ec7..31bf34b9 100644 --- a/test/dune +++ b/test/dune @@ -4,7 +4,12 @@ (inline_tests (deps ocannl_config)) (libraries base dynlink ocannl) - (modules einsum_trivia hello_world_op micrograd_demo zero2hero_1of7 moons_demo_parallel) + (modules + einsum_trivia + hello_world_op + micrograd_demo + zero2hero_1of7 + moons_demo_parallel) (preprocess (pps ppx_jane ppx_expect ppx_inline_test ppx_ocannl)) (modes native)) diff --git a/test/einsum_trivia.ml b/test/einsum_trivia.ml index 6abe4d8c..fb7b4ff1 100644 --- a/test/einsum_trivia.ml +++ b/test/einsum_trivia.ml @@ -45,7 +45,9 @@ let%expect_test "einsum1 permute axes" = ││ │ 2.00e+0 1.40e+1 │ 5.00e+0 1.70e+1 │ 8.00e+0 2.00e+1 │ 1.10e+1 2.30e+1 ││ │└──────┴──────────────────┴──────────────────┴──────────────────┴──────────────────┘│ └────────────────────────────────────────────────────────────────────────────────────┘ |}]; - let hey2 = TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () in + let hey2 = + TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () + in let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in Train.forward_and_forget backend ctx ho2; Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hey2; @@ -293,11 +295,13 @@ let%expect_test "einsum1 sum out axes" = ││ │ 6.60e+1 7.00e+1 7.40e+1 ││ │└──────┴───────────────────────────┘│ └────────────────────────────────────┘ |}]; - let hey2 = TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () in + let hey2 = + TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () + in let%op ho2 = hey2 ++ "ab|cd->ef => c|a->d" in Train.forward_and_forget backend ctx ho2; - (* Axis 5 of hey2, i.e. d in the einsum spec, has the lowest variation (progresses by 1), that's why axis 1 - of ho2 appears nearly constant. *) + (* Axis 5 of hey2, i.e. d in the einsum spec, has the lowest variation (progresses by 1), that's + why axis 1 of ho2 appears nearly constant. *) Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho2; [%expect {| @@ -702,7 +706,9 @@ let%expect_test "einsum1 broadcast or sum out prefix axes" = │└──────┴───────────────────────────┴───────────────────────────┴───────────────────────────┴───────────────────────────┘│ └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ |}]; - let hey2 = TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () in + let hey2 = + TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] () + in let%op ho3 = hey2 ++ "...b|...i->...o => ...i|...o->...b" in Train.forward_and_forget backend ctx ho3; Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hey2; @@ -1497,8 +1503,12 @@ let%expect_test "einsum with a leftmost input axis preserved as output axis" = let backend = (module Backend : Train.Backend_type with type context = Backend.context) in let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in let ctx = Backend.init device in - let a = TDSL.range_of_shape ~label:[ "a" ] ~batch_dims:[ 3 ] ~input_dims:[ 4 ] ~output_dims:[ 2 ] () in - let b = TDSL.range_of_shape ~label:[ "b" ] ~batch_dims:[ 3 ] ~input_dims:[ 2; 3 ] ~output_dims:[ 4 ] () in + let a = + TDSL.range_of_shape ~label:[ "a" ] ~batch_dims:[ 3 ] ~input_dims:[ 4 ] ~output_dims:[ 2 ] () + in + let b = + TDSL.range_of_shape ~label:[ "b" ] ~batch_dims:[ 3 ] ~input_dims:[ 2; 3 ] ~output_dims:[ 4 ] () + in let%op c = a *+ "...|i->1; ...|j...->i => ...|ij" b in Train.forward_and_forget backend ctx c; Tensor.print ~with_code:false ~with_grad:false `Default @@ a; @@ -1575,7 +1585,8 @@ let%expect_test "einsum permuting two leftmost input axes as output axes" = let%op c = a *+ "i->1; ij...->0 => ...->ji" b in Train.forward_and_forget backend ctx c; Tensor.print ~with_code:false ~with_grad:false `Default @@ a; - [%expect {| + [%expect + {| ┌────────────────────────────┐ │[62]: r2x2_a shape 1:2->0:2 │ │┌──────┬──────────────────┐ │ @@ -1586,7 +1597,8 @@ let%expect_test "einsum permuting two leftmost input axes as output axes" = │└──────┴──────────────────┘ │ └────────────────────────────┘ |}]; Tensor.print ~with_code:false ~with_grad:false `Default @@ b; - [%expect {| + [%expect + {| ┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │[63]: r2x2x3x4_b shape 1:2,2:3,3:4->0:2 │ │┌──────┬────────────────────────────────────┬────────────────────────────────────┬────────────────────────────────────┐│ @@ -1601,7 +1613,8 @@ let%expect_test "einsum permuting two leftmost input axes as output axes" = │└──────┴────────────────────────────────────┴────────────────────────────────────┴────────────────────────────────────┘│ └───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘ |}]; Tensor.print ~with_code:false ~with_grad:false `Default @@ c; - [%expect {| + [%expect + {| ┌───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐ │[64]: ;=>_c shape 2:4->0:3,1:2 │ │┌──────┬────────────────────────────────────┬────────────────────────────────────┬────────────────────────────────────┐│ diff --git a/test/hello_world_op.ml b/test/hello_world_op.ml index 976b083e..af062c49 100644 --- a/test/hello_world_op.ml +++ b/test/hello_world_op.ml @@ -451,7 +451,9 @@ let%expect_test "Very big tensor" = let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in let ctx = Backend.init device in Rand.init 0; - let hey = TDSL.range_of_shape ~batch_dims:[ 6 ] ~input_dims:[ 7; 8; 9 ] ~output_dims:[ 10; 11 ] () in + let hey = + TDSL.range_of_shape ~batch_dims:[ 6 ] ~input_dims:[ 7; 8; 9 ] ~output_dims:[ 10; 11 ] () + in let%op hoo = (hey * (1 + 1)) - 10 in Train.forward_and_forget backend ctx hoo; Tensor.print ~with_code:false ~with_grad:false `Default hey; diff --git a/test/micrograd_demo.ml b/test/micrograd_demo.ml index aa170824..c5147bd7 100644 --- a/test/micrograd_demo.ml +++ b/test/micrograd_demo.ml @@ -102,7 +102,9 @@ let%expect_test "Micrograd half-moons example" = let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in let step_n, bindings = IDX.get_static_symbol bindings in (* FIXME: should also work with explicit batch shape. *) - let moons_flat = TDSL.init_const ~l:"moons_flat" (* ~b:[ n_batches; batch_size ] *) ~o:[ 2 ] moons_flat in + let moons_flat = + TDSL.init_const ~l:"moons_flat" (* ~b:[ n_batches; batch_size ] *) ~o:[ 2 ] moons_flat + in let moons_classes = Array.init (len * 2) ~f:(fun i -> if i % 2 = 0 then 1. else -1.) in (* FIXME: should also work with explicit batch shape. *) let moons_classes = @@ -119,7 +121,8 @@ let%expect_test "Micrograd half-moons example" = let log_losses = ref [] in let learning_rates = ref [] in let%op margin_loss = ?/(1 - (moons_class *. mlp moons_input)) in - (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update computation. *) + (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update + computation. *) let weight_decay = 0.0001 in let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in let update = Train.grad_update scalar_loss in @@ -137,8 +140,9 @@ let%expect_test "Micrograd half-moons example" = assert (Backend.to_host sgd_routine.context learning_rate.value); assert (Backend.to_host sgd_routine.context scalar_loss.value); Backend.await device; - (* let batch_ref = IDX.find_exn sgd_jitted.bindings batch_n in Stdio.printf "Epoch=%d, step=%d, - batch=%d, lr=%f, loss=%f\n%!" epoch !step_ref !batch_ref learning_rate.@[0] scalar_loss.@[0]; *) + (* let batch_ref = IDX.find_exn sgd_jitted.bindings batch_n in Stdio.printf "Epoch=%d, + step=%d, batch=%d, lr=%f, loss=%f\n%!" epoch !step_ref !batch_ref learning_rate.@[0] + scalar_loss.@[0]; *) learning_rates := ~-.(learning_rate.@[0]) :: !learning_rates; losses := scalar_loss.@[0] :: !losses; log_losses := Float.max (-10.) (Float.log scalar_loss.@[0]) :: !log_losses; @@ -151,11 +155,13 @@ let%expect_test "Micrograd half-moons example" = Train.set_on_host Volatile mlp_result.value; let result_routine = Backend.( - link sgd_routine.context @@ compile IDX.empty @@ Block_comment ("moons infer", mlp_result.forward)) + link sgd_routine.context @@ compile IDX.empty + @@ Block_comment ("moons infer", mlp_result.forward)) in let callback (x, y) = Tensor.set_values point [| x; y |]; - (* For the gccjit backend, point is only on host, not on device. For cuda, this will be needed. *) + (* For the gccjit backend, point is only on host, not on device. For cuda, this will be + needed. *) assert (Backend.from_host result_routine.context point.value); Train.run result_routine; assert (Backend.to_host result_routine.context mlp_result.value); diff --git a/test/moons_demo_parallel.ml b/test/moons_demo_parallel.ml index 08323ade..cfc2a487 100644 --- a/test/moons_demo_parallel.ml +++ b/test/moons_demo_parallel.ml @@ -38,23 +38,25 @@ let%expect_test "Half-moons data parallel" = let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in (* let%op mlp x = "b" + ("w" * x) in *) let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in - (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update computation. *) + (* We don't need a regression loss formula thanks to weight_decay built into the sgd_update + computation. *) let weight_decay = 0.0002 in (* So that we can inspect them. *) let backend = Train.fresh_backend () in let per_batch_callback ~at_batch ~at_step ~learning_rate ~batch_loss ~epoch_loss = if (at_batch + 1) % 20 = 0 then - Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step learning_rate - batch_loss epoch_loss + Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step + learning_rate batch_loss epoch_loss in (* Tn.print_accessible_headers (); *) let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss = - Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate epoch_loss + Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate + epoch_loss in let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates = - Train.example_train_loop ~seed ~batch_size ~max_num_devices:(batch_size / 2) ~init_lr ~data_len:len - ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay ~per_batch_callback - ~per_epoch_callback backend () + Train.example_train_loop ~seed ~batch_size ~max_num_devices:(batch_size / 2) ~init_lr + ~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn + ~weight_decay ~per_batch_callback ~per_epoch_callback backend () in [%expect {| |}]; let points = Tensor.value_2d_points ~xdim:0 ~ydim:1 inputs in @@ -96,7 +98,8 @@ let%expect_test "Half-moons data parallel" = [ Line_plot { - points = Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); + points = + Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x)); pixel = "-"; }; ] diff --git a/test/zero2hero_1of7.ml b/test/zero2hero_1of7.ml index 09e829ab..e98f2de0 100644 --- a/test/zero2hero_1of7.ml +++ b/test/zero2hero_1of7.ml @@ -70,7 +70,9 @@ let%expect_test "Graph drawing recompile" = let ys = Array.map xs ~f:(fun v -> (* This is inefficient because it compiles the argument update inside the loop. *) - let assign_x = Backend.(link f_bprop.context @@ compile IDX.empty ~name:"assign_x" [%cd x =: !.v]) in + let assign_x = + Backend.(link f_bprop.context @@ compile IDX.empty ~name:"assign_x" [%cd x =: !.v]) + in Train.sync_run (module Backend) assign_x x; Train.sync_run (module Backend) f_bprop f; Backend.await device; @@ -232,8 +234,8 @@ let%expect_test "Simple gradients hosted" = let%op e = "a" [ 2 ] *. "b" [ -3 ] in let%op d = e + "c" [ 10 ] in let%op l = d *. "f" [ -2 ] in - (* We need to either call `grad_update` before introducing `learning_rate`, or disable the rootness - check. *) + (* We need to either call `grad_update` before introducing `learning_rate`, or disable the + rootness check. *) let grad = Train.grad_update l in let%op learning_rate = 0.1 in Train.every_non_literal_on_host l; @@ -261,8 +263,8 @@ let%expect_test "Simple gradients hosted" = 2.00e+0 │ -3.00e+0 │ │ #76 grad_a│#78 grad_b│ │ 0.00e+0 │ 0.00e+0 │ │ |}]; - (* Do not update the params: all values and gradients will be at initial points, which are specified in the - tensor in the brackets. *) + (* Do not update the params: all values and gradients will be at initial points, which are + specified in the tensor in the brackets. *) Train.sync_run backend grad_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect @@ -283,8 +285,9 @@ let%expect_test "Simple gradients hosted" = 2.00e+0 │ -3.00e+0 │ │ #76 grad_a│#78 grad_b│ │ 6.00e+0 │ -4.00e+0 │ │ |}]; - (* Now we update the params, but we are not doing the forward and backward passes: only params values will - change, compared to the above. The update is in the opposite direction of the gradient. *) + (* Now we update the params, but we are not doing the forward and backward passes: only params + values will change, compared to the above. The update is in the opposite direction of the + gradient. *) Train.sync_run backend sgd_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect @@ -306,8 +309,8 @@ let%expect_test "Simple gradients hosted" = #76 grad_a│#78 grad_b│ │ 6.00e+0 │ -4.00e+0 │ │ |}]; - (* Now the params will remain as above, but both param gradients and the values and gradients of other nodes - will change thanks to the forward and backward passes. *) + (* Now the params will remain as above, but both param gradients and the values and gradients of + other nodes will change thanks to the forward and backward passes. *) Train.sync_run backend grad_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect @@ -338,10 +341,10 @@ let%expect_test "Simple gradients virtual" = let%op e = "a" [ 2 ] *. "b" [ -3 ] in let%op d = e + "c" [ 10 ] in let%op l = d *. "f" [ -2 ] in - (* We pretend this is for parallel updates, to force materializing gradients, because our SGD update is - compiled separately from our gradient update. Alternatively we could mark all - [Assignments.recurrent_nodes sgd] as materialized. Or, the best non-parallel option is to compile - grad_update and sgd_update together.*) + (* We pretend this is for parallel updates, to force materializing gradients, because our SGD + update is compiled separately from our gradient update. Alternatively we could mark all + [Assignments.recurrent_nodes sgd] as materialized. Or, the best non-parallel option is to + compile grad_update and sgd_update together.*) let grad = Train.grad_update ~setup_for_parallel:true l in let%op learning_rate = 0.1 in let sgd = Train.sgd_update ~learning_rate grad in @@ -386,8 +389,8 @@ let%expect_test "Simple gradients virtual" = 2.00e+0 │ -3.00e+0 │ │ #112 grad_a │#114 grad_b │ │ │ │ |}]; - (* Do not update the params: all values and gradients will be at initial points, which are specified in the - tensor in the brackets. *) + (* Do not update the params: all values and gradients will be at initial points, which are + specified in the tensor in the brackets. *) Train.sync_run backend grad_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect @@ -410,9 +413,9 @@ let%expect_test "Simple gradients virtual" = │ │ |}]; (* Only now compile the SGD update. *) let sgd_routine = Backend.(link grad_routine.context @@ compile IDX.empty sgd) in - (* Now we update the params, but are not doing the forward and backward passes: only params values will - change, compared to the above. Since virtual tensors are computed by-need, they will always be recomputed - using the latest parameter state. *) + (* Now we update the params, but are not doing the forward and backward passes: only params values + will change, compared to the above. Since virtual tensors are computed by-need, they will + always be recomputed using the latest parameter state. *) Train.sync_run backend sgd_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect @@ -433,8 +436,8 @@ let%expect_test "Simple gradients virtual" = 1.40e+0 │ -2.60e+0 │ │ #112 grad_a │#114 grad_b │ │ │ │ |}]; - (* Now the params will remain as above, but both param gradients and the values and gradients of other nodes - will change thanks to the forward and backward passes. *) + (* Now the params will remain as above, but both param gradients and the values and gradients of + other nodes will change thanks to the forward and backward passes. *) Train.sync_run backend grad_routine l; Tensor.print_tree ~with_grad:true ~depth:9 l; [%expect diff --git a/test_ppx/test_ppx_op_expected.ml b/test_ppx/test_ppx_op_expected.ml index f8ef0b7d..dc00aad5 100644 --- a/test_ppx/test_ppx_op_expected.ml +++ b/test_ppx/test_ppx_op_expected.ml @@ -1,44 +1,48 @@ open Base open Ocannl module TDSL = Operation.TDSL + let y0 = let open! TDSL.O in - let hey1 = TDSL.param ?values:None "hey1" in - ((+) ~label:["y0"]) - ((( *. ) ~label:[]) (TDSL.number (Float.of_int 2)) hey1) - (TDSL.number (Float.of_int 3)) + let hey1 = TDSL.param ?values:None "hey1" in + (( + ) ~label:[ "y0" ]) + ((( *. ) ~label:[]) (TDSL.number (Float.of_int 2)) hey1) + (TDSL.number (Float.of_int 3)) + let y1 = let open! TDSL.O in - let hey2 = TDSL.param ?values:None "hey2" in - fun x -> - ((+) ~label:["y1"]) - ((( * ) ~label:[]) hey2 (TDSL.number (Float.of_int 2))) x + let hey2 = TDSL.param ?values:None "hey2" in + fun x -> (( + ) ~label:[ "y1" ]) ((( * ) ~label:[]) hey2 (TDSL.number (Float.of_int 2))) x + let y2 = let open! TDSL.O in - let hey3 = TDSL.param ?values:None "hey3" in - fun x1 -> fun x2 -> ((+) ~label:["y2"]) ((( *. ) ~label:[]) x1 hey3) x2 + let hey3 = TDSL.param ?values:None "hey3" in + fun x1 x2 -> (( + ) ~label:[ "y2" ]) ((( *. ) ~label:[]) x1 hey3) x2 + let a = let open! TDSL.O in - TDSL.ndarray ~label:["a"] ~batch_dims:[] ~input_dims:[3] ~output_dims: - [2] - [|(Float.of_int 1);(Float.of_int 2);(Float.of_int 3);(Float.of_int 4);( - Float.of_int 5);(Float.of_int 6)|] + TDSL.ndarray ~label:[ "a" ] ~batch_dims:[] ~input_dims:[ 3 ] ~output_dims:[ 2 ] + [| + Float.of_int 1; Float.of_int 2; Float.of_int 3; Float.of_int 4; Float.of_int 5; Float.of_int 6; + |] + let b = let open! TDSL.O in - TDSL.ndarray ~label:["b"] ~batch_dims:[2] ~input_dims:[] ~output_dims: - [2] - [|(Float.of_int 7);(Float.of_int 8);(Float.of_int 9);(Float.of_int 10)|] + TDSL.ndarray ~label:[ "b" ] ~batch_dims:[ 2 ] ~input_dims:[] ~output_dims:[ 2 ] + [| Float.of_int 7; Float.of_int 8; Float.of_int 9; Float.of_int 10 |] + let y = let open! TDSL.O in - let hey4 = TDSL.param ?values:None "hey4" in - ((+) ~label:["y"]) - ((( * ) ~label:[]) hey4 (TDSL.number ~label:[] ~axis_label:"q" 2.0)) - (TDSL.number ~label:[] ~axis_label:"p" 1.0) + let hey4 = TDSL.param ?values:None "hey4" in + (( + ) ~label:[ "y" ]) + ((( * ) ~label:[]) hey4 (TDSL.number ~label:[] ~axis_label:"q" 2.0)) + (TDSL.number ~label:[] ~axis_label:"p" 1.0) + let z = let open! TDSL.O in - let hey5 = TDSL.param ?values:None "hey5" - and hey6 = TDSL.param ?values:None "hey6" in - ((+) ~label:["z"]) - ((( * ) ~label:[]) (TDSL.number ~label:[] ~axis_label:"q" 2.0) hey5) - ((( * ) ~label:[]) hey6 (TDSL.number ~label:[] ~axis_label:"p" 1.0)) + let hey5 = TDSL.param ?values:None "hey5" and hey6 = TDSL.param ?values:None "hey6" in + (( + ) ~label:[ "z" ]) + ((( * ) ~label:[]) (TDSL.number ~label:[] ~axis_label:"q" 2.0) hey5) + ((( * ) ~label:[]) hey6 (TDSL.number ~label:[] ~axis_label:"p" 1.0)) + let () = ignore (y0, y1, y2, a, b, y, z)