diff --git a/arrayjit/lib/backend_types.ml b/arrayjit/lib/backend_types.ml index 8734d949..19e87bbf 100644 --- a/arrayjit/lib/backend_types.ml +++ b/arrayjit/lib/backend_types.ml @@ -7,6 +7,16 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime [%%global_debug_log_level 9] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] +module No_device_types = struct + type ctx_array = Ndarray.t [@@deriving sexp_of] + + type ctx_arrays = { used_memory : Utils.atomic_int; ctx_arrays : ctx_array Map.M(Tnode).t } + [@@deriving sexp_of] + + let empty_ctx_arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tnode) } + let get_array arrays = Map.find arrays.ctx_arrays +end + module Types = struct type 'context routine = { context : 'context; @@ -168,6 +178,9 @@ module type Backend = sig val init : stream -> context val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr + val get_used_memory : device -> int + (** Returns (an upper bound of) the memory used for arrays, in bytes. *) + val await : stream -> unit (** Blocks till the stream becomes idle, i.e. synchronizes the stream. *) @@ -198,11 +211,12 @@ module type Lowered_no_device_backend = sig type procedure [@@deriving sexp_of] type ctx_array [@@deriving sexp_of] type buffer_ptr [@@deriving sexp_of] - type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of] + type ctx_arrays [@@deriving sexp_of] val buffer_ptr : ctx_array -> buffer_ptr - val ctx_arrays : context -> ctx_arrays val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr + val ctx_arrays : context -> ctx_arrays + val get_array : ctx_arrays -> Tnode.t -> ctx_array option val is_in_context : Low_level.traced_array -> bool (** If true, the node is required to be in the contexts linked with code that uses it. @@ -246,6 +260,7 @@ module type Lowered_backend = sig type code [@@deriving sexp_of] type code_batch [@@deriving sexp_of] type ctx_array [@@deriving sexp_of] + type ctx_arrays [@@deriving sexp_of] type event val sync : event -> unit @@ -268,7 +283,8 @@ module type Lowered_backend = sig code_batch val is_in_context : Low_level.traced_array -> bool - val ctx_arrays : context -> ctx_array Map.M(Tnode).t + val ctx_arrays : context -> ctx_arrays + val get_array : ctx_arrays -> Tnode.t -> ctx_array option val link : context -> code -> context * Indexing.lowered_bindings * Task.t val link_batch : @@ -298,7 +314,7 @@ module type Lowered_backend = sig val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr - val get_used_memory : unit -> int + val get_used_memory : device -> int (** Returns (an upper bound of) the memory used for arrays, in bytes. *) val init : stream -> context diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 32ff98dd..d9a52734 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -76,7 +76,7 @@ struct let alloc_buffer ?old_buffer ~size_in_bytes _stream = Backend.alloc_buffer ?old_buffer ~size_in_bytes () - let get_used_memory = Backend.get_used_memory + let get_used_memory _device = Backend.get_used_memory () type device = stream [@@deriving sexp_of] type code = Backend.code [@@deriving sexp_of] @@ -370,9 +370,10 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types. let alloc_buffer ?old_buffer ~size_in_bytes _stream = Backend.alloc_buffer ?old_buffer ~size_in_bytes () - let get_used_memory = Backend.get_used_memory - type device = CPU [@@deriving sexp_of] + + let get_used_memory CPU = Backend.get_used_memory () + type code = Backend.code [@@deriving sexp_of] type code_batch = Backend.code_batch [@@deriving sexp_of] @@ -534,14 +535,14 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l = ) else (None, None)) -let verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context traced_stores - = +let verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context + traced_stores = let olds = ctx_arrays prior_context in Set.iter from_prior_context ~f:(fun tn -> let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in if Option.value_map node ~default:false ~f:(fun node -> - is_in_context node && not (Map.mem olds tn)) + is_in_context node && not (Option.is_some @@ get_array olds tn)) then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn)) let from_prior_context_batch comps = @@ -646,7 +647,8 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back let link ~merge_buffer (prior_context : context) (code : code) = let verify from_prior_context = - verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context + verify_prior_context ~get_array:Backend.get_array ~ctx_arrays ~is_in_context ~prior_context + ~from_prior_context [| get_traced_store code |] in let context, bindings, schedule, name = @@ -673,7 +675,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back let link_batch ~merge_buffer (prior_context : context) (code_batch : code_batch) = let verify from_prior_context = - verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context + verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context @@ get_traced_stores code_batch in let _opt_ctx_arrays, procs = @@ -703,7 +705,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back | None -> (context, None)) let get_buffer tn context = - Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr + Backend.(ctx_arrays context |> Fn.flip get_array tn |> Option.map ~f:buffer_ptr) let get_used_memory = Ndarray.get_used_memory end @@ -776,7 +778,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types. } let link context (code : code) = - verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context + verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context ~from_prior_context:code.from_prior_context [| code.traced_store |]; let context, bindings, schedule = link context code.code in let schedule = @@ -788,7 +790,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types. { context; schedule; bindings; name } let link_batch context code_batch = - verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context + verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context ~from_prior_context:code_batch.from_prior_context code_batch.traced_stores; let context, bindings, schedules = link_batch context code_batch.code_batch in ( context, diff --git a/arrayjit/lib/c_syntax.ml b/arrayjit/lib/c_syntax.ml index 91957010..fb059a2f 100644 --- a/arrayjit/lib/c_syntax.ml +++ b/arrayjit/lib/c_syntax.ml @@ -13,8 +13,10 @@ module C_syntax (B : sig val for_lowereds : Low_level.optimized array type ctx_array + type ctx_arrays - val opt_ctx_arrays : ctx_array Map.M(Tnode).t option + val opt_ctx_arrays : ctx_arrays option + val get_array : ctx_arrays -> Tn.t -> ctx_array option val hardcoded_context_ptr : (ctx_array -> string) option val is_in_context : Low_level.traced_array -> bool val host_ptrs_for_readonly : bool @@ -86,7 +88,7 @@ struct | true, Some get_ptr, Some ctx_arrays, _, _, _ -> let ident = get_ident node.tn in let ctx_array = - Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays node.tn + Option.value_exn ~here:[%here] ~message:ident @@ B.get_array ctx_arrays node.tn in fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array; Hash_set.add is_global node.tn diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index 442611c4..4d80abe2 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -7,6 +7,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime [%%global_debug_log_level 9] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] +include Backend_types.No_device_types open Backend_types.Types let name = "cc" @@ -18,8 +19,6 @@ let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backe module Tn = Tnode -type ctx_array = Ndarray.t [@@deriving sexp_of] -type ctx_arrays = ctx_array Map.M(Tn).t [@@deriving sexp_of] type context = { label : string; arrays : ctx_arrays } [@@deriving sexp_of] let ctx_arrays context = context.arrays @@ -36,7 +35,7 @@ let alloc_buffer ?old_buffer ~size_in_bytes () = | None -> assert false let to_buffer tn ~dst ~src = - let src = Map.find_exn src.arrays tn in + let src = Map.find_exn src.arrays.ctx_arrays tn in Ndarray.map2 { f2 = Ndarray.A.blit } src dst let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst @@ -50,7 +49,9 @@ let is_initialized, initialize = let finalize _ctx = () let init ~label = - let result = { label; arrays = Map.empty (module Tn) } in + let result = + { label; arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tn) } } + in Stdlib.Gc.finalise finalize result; result @@ -61,7 +62,7 @@ type procedure = { name : string; result : library; params : (string * param_source) list; - opt_ctx_arrays : Ndarray.t Map.M(Tn).t option; + opt_ctx_arrays : ctx_arrays option; } [@@deriving sexp_of] @@ -105,13 +106,14 @@ let c_compile_and_load ~f_name = module C_syntax_config (Input : sig val for_lowereds : Low_level.optimized array - val opt_ctx_arrays : (Tn.t, buffer_ptr, Tn.comparator_witness) Base.Map.t option + val opt_ctx_arrays : ctx_arrays option end) = struct - let for_lowereds = Input.for_lowereds - type nonrec ctx_array = ctx_array + type nonrec ctx_arrays = ctx_arrays + let get_array = get_array + let for_lowereds = Input.for_lowereds let opt_ctx_arrays = Input.opt_ctx_arrays let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string let is_in_context = is_in_context @@ -133,7 +135,7 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ let opt_ctx_arrays = Option.map opt_ctx_arrays ~f:(fun ctx_arrays -> Hashtbl.fold lowered.traced_store ~init:ctx_arrays ~f:(fun ~key:tn ~data:node ctx_arrays -> - match Map.find ctx_arrays tn with + match Map.find ctx_arrays.ctx_arrays tn with | None -> if is_in_context node then let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in @@ -141,7 +143,7 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims) @@ Constant_fill { values = [| 0. |]; strict = false } in - Map.add_exn ctx_arrays ~key:tn ~data + { ctx_arrays with ctx_arrays = Map.add_exn ctx_arrays.ctx_arrays ~key:tn ~data } else ctx_arrays | Some _ -> ctx_arrays)) in @@ -162,22 +164,25 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings (lowereds : Low_level.optimized option array) = let for_lowereds = Array.filter_map ~f:Fn.id lowereds in let opt_ctx_arrays = - Option.map opt_ctx_arrays ~f:(fun ctx_arrays -> - Array.fold for_lowereds ~init:ctx_arrays ~f:(fun ctx_arrays lowered -> - Hashtbl.fold lowered.traced_store ~init:ctx_arrays - ~f:(fun ~key:tn ~data:node ctx_arrays -> - match Map.find ctx_arrays tn with - | None -> - if is_in_context node then - let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in - let data = - Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) - ~dims:(Lazy.force tn.dims) - @@ Constant_fill { values = [| 0. |]; strict = false } - in - Map.add_exn ctx_arrays ~key:tn ~data - else ctx_arrays - | Some _ -> ctx_arrays))) + Option.map opt_ctx_arrays ~f:(fun arrays -> + let ctx_arrays = + Array.fold for_lowereds ~init:arrays.ctx_arrays ~f:(fun ctx_arrays lowered -> + Hashtbl.fold lowered.traced_store ~init:ctx_arrays + ~f:(fun ~key:tn ~data:node ctx_arrays -> + match Map.find ctx_arrays tn with + | None -> + if is_in_context node then + let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in + let data = + Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) + ~dims:(Lazy.force tn.dims) + @@ Constant_fill { values = [| 0. |]; strict = false } + in + Map.add_exn ctx_arrays ~key:tn ~data + else ctx_arrays + | Some _ -> ctx_arrays)) + in + { arrays with ctx_arrays }) in let module Syntax = C_syntax.C_syntax (C_syntax_config (struct let for_lowereds = for_lowereds @@ -186,7 +191,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings (* FIXME: do we really want all of them, or only the used ones? *) let idx_params = Indexing.bound_symbols bindings in let global_ctx_arrays = - ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> Map.empty (module Tn)) + ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> empty_ctx_arrays) in let base_name = String.( @@ -206,7 +211,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings let opt_ctx_arrays = Option.map opt_ctx_arrays ~f:(fun _ -> !global_ctx_arrays) in ( opt_ctx_arrays, Array.mapi params ~f:(fun i params -> - Option.map names.(i) ~f:(fun name -> + Option.map names.(i) ~f:(fun name -> { result; params = Option.value_exn ~here:[%here] params; @@ -219,7 +224,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro context * _ * _ * string = let label : string = prior_context.label in let name : string = code.name in - let arrays : Ndarray.t Base.Map.M(Tn).t = + let arrays = match code with | { opt_ctx_arrays = Some arrays; _ } -> arrays | { params; _ } -> @@ -232,7 +237,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims) @@ Constant_fill { values = [| 0. |]; strict = false } in - Map.update ctx_arrays tn ~f + { ctx_arrays with ctx_arrays = Map.update ctx_arrays.ctx_arrays tn ~f } | _ -> ctx_arrays) in let context = { label; arrays } in @@ -258,7 +263,9 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro let get_ptr (buffer, _) = Ndarray.get_voidptr_not_managed buffer in Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs)) | bs, Param_ptr tn :: ps -> - let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in + let nd = + match get_array (ctx_arrays context) tn with Some nd -> nd | None -> assert false + in let c_ptr = Ndarray.get_voidptr_not_managed nd in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) in diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 748442ba..ef7d2285 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -19,6 +19,10 @@ type event = Cu.Delimited_event.t type ctx_array = { ptr : buffer_ptr; mutable tracking : (event[@sexp.opaque]) option } [@@deriving sexp_of] +type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of] + +let get_array = Map.find + type device = { dev : (Cu.Device.t[@sexp.opaque]); ordinal : int; @@ -51,7 +55,7 @@ and context = { run_module : (Cu.Module.t[@sexp.opaque]) option; (** Code jitted for this context, typically independent of the parent and child contexts, but shared by batch linked contexts. *) - ctx_arrays : ctx_array Map.M(Tn).t; + ctx_arrays : ctx_arrays; (** This map contains arrays used in this context or an ancestor context (they might be unique but might also be cross-stream shared. *) finalized : Utils.atomic_bool; @@ -103,16 +107,17 @@ let alloc_buffer ?old_buffer ~size_in_bytes stream = set_ctx stream.device.primary_context; Cu.Deviceptr.mem_alloc ~size_in_bytes -let get_used_memory () = +let get_used_memory dev = + set_ctx dev.primary_context; let free, total = Cudajit.Device.get_free_and_total_mem () in total - free -let opt_alloc_merge_buffer ~size_in_bytes phys_dev = - if phys_dev.copy_merge_buffer_capacity < size_in_bytes then ( - set_ctx phys_dev.primary_context; - Cu.Deviceptr.mem_free phys_dev.copy_merge_buffer; - phys_dev.copy_merge_buffer <- Cu.Deviceptr.mem_alloc ~size_in_bytes; - phys_dev.copy_merge_buffer_capacity <- size_in_bytes) +let opt_alloc_merge_buffer ~size_in_bytes dev = + if dev.copy_merge_buffer_capacity < size_in_bytes then ( + set_ctx dev.primary_context; + Cu.Deviceptr.mem_free dev.copy_merge_buffer; + dev.copy_merge_buffer <- Cu.Deviceptr.mem_alloc ~size_in_bytes; + dev.copy_merge_buffer_capacity <- size_in_bytes) let cleanup_device device = Cu.Context.set_current device.primary_context; @@ -367,8 +372,10 @@ end) = struct let for_lowereds = Input.for_lowereds - type nonrec ctx_array = buffer_ptr + type nonrec ctx_array = buffer_ptr [@@deriving sexp_of] + type nonrec ctx_arrays = ctx_array Map.M(Tn).t [@@deriving sexp_of] + let get_array = Map.find let opt_ctx_arrays = None let hardcoded_context_ptr = None let is_in_context = is_in_context diff --git a/arrayjit/lib/cuda_backend.missing.ml b/arrayjit/lib/cuda_backend.missing.ml index 243c912f..a74f2731 100644 --- a/arrayjit/lib/cuda_backend.missing.ml +++ b/arrayjit/lib/cuda_backend.missing.ml @@ -5,13 +5,14 @@ type context = Unimplemented_ctx [@@deriving sexp_of] type code = Indexing.unit_bindings [@@deriving sexp_of] type code_batch = Indexing.unit_bindings array [@@deriving sexp_of] type ctx_array = | [@@deriving sexp_of] +type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of] type event = unit let sync () = () let is_done () = true let work_for _ctx _tn = Some () let will_wait_for _ctx () = () -let initialize (_config : Backend_utils.Types.config) = () +let initialize (_config : Backend_types.Types.config) = () let is_initialized () = true let finalize _context = () let compile ~name:_ bindings _optimized = bindings diff --git a/arrayjit/lib/gcc_backend.gccjit.ml b/arrayjit/lib/gcc_backend.gccjit.ml index 6331646b..7bbd968f 100644 --- a/arrayjit/lib/gcc_backend.gccjit.ml +++ b/arrayjit/lib/gcc_backend.gccjit.ml @@ -7,6 +7,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime [%%global_debug_log_level 9] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] +include Backend_types.No_device_types open Backend_types.Types let name = "gccjit" @@ -24,9 +25,7 @@ type mem_properties = let root_ctx = ref None module Tn = Tnode - -type ctx_array = Ndarray.t [@@deriving sexp_of] -type ctx_arrays = ctx_array Map.M(Tn).t [@@deriving sexp_of] +include Backend_types.No_device_types type buffer_ptr = ctx_array [@@deriving sexp_of] (** Alternative approach: @@ -50,7 +49,7 @@ type context = { let ctx_arrays context = context.arrays let to_buffer tn ~dst ~src = - let src = Map.find_exn src.arrays tn in + let src = Option.value_exn ~here:[%here] @@ get_array (ctx_arrays src) tn in Ndarray.map2 { f2 = Ndarray.A.blit } src dst let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst @@ -76,7 +75,7 @@ let finalize ctx = Option.iter ctx.result ~f:Result.release let init ~label = - let result = { label; result = None; arrays = Map.empty (module Tn) } in + let result = { label; result = None; arrays = empty_ctx_arrays } in Stdlib.Gc.finalise finalize result; result @@ -114,7 +113,7 @@ type procedure = { bindings : Indexing.unit_bindings; name : string; result : (Gccjit.result[@sexp.opaque]); - opt_ctx_arrays : Ndarray.t Map.M(Tn).t option; + opt_ctx_arrays : ctx_arrays option; params : param_source list; } [@@deriving sexp_of] @@ -649,7 +648,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident let ctx_nodes : ctx_nodes = match opt_ctx_arrays with | None -> Param_ptrs params - | Some ctx_arrays -> Ctx_arrays (ref ctx_arrays) + | Some ctx_arrays -> Ctx_arrays (ref ctx_arrays.ctx_arrays) in let initializations = ref [] in let nodes = Hashtbl.create (module Tn) in @@ -708,7 +707,9 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident Block.jump init_block main_block; Block.return_void after_proc; let opt_ctx_arrays = - match ctx_nodes with Param_ptrs _ -> None | Ctx_arrays { contents } -> Some contents + match (opt_ctx_arrays, ctx_nodes) with + | None, _ | _, Param_ptrs _ -> None + | Some arrays, Ctx_arrays { contents } -> Some { arrays with ctx_arrays = contents } in (ctx_info, opt_ctx_arrays, params) @@ -729,14 +730,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim Context.dump_to_file ctx ~update_locs:true f_name); let result = Context.compile ctx in Context.release ctx; - { - info; - result; - bindings; - name; - opt_ctx_arrays; - params = List.map ~f:snd params; - } + { info; result; bindings; name; opt_ctx_arrays; params = List.map ~f:snd params } let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bindings (lowereds : Low_level.optimized option array) = @@ -774,14 +768,7 @@ let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind ( opt_ctx_arrays, Array.mapi funcs ~f:(fun i -> Option.map2 names.(i) ~f:(fun name (info, opt_ctx_arrays, params) -> - { - info; - result; - bindings; - name; - opt_ctx_arrays; - params = List.map ~f:snd params; - })) ) + { info; result; bindings; name; opt_ctx_arrays; params = List.map ~f:snd params })) ) let alloc_buffer ?old_buffer ~size_in_bytes () = (* FIXME: NOT IMPLEMENTED YET but should not be needed for the streaming case. *) @@ -794,11 +781,11 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro context * _ * _ * string = let label : string = prior_context.label in let name : string = code.name in - let arrays : Ndarray.t Base.Map.M(Tn).t = + let ctx_arrays : Ndarray.t Base.Map.M(Tn).t = match code with - | { opt_ctx_arrays = Some arrays; _ } -> arrays + | { opt_ctx_arrays = Some arrays; _ } -> arrays.ctx_arrays | { params; _ } -> - List.fold params ~init:prior_context.arrays ~f:(fun ctx_arrays -> function + List.fold params ~init:(ctx_arrays prior_context).ctx_arrays ~f:(fun arrays -> function | Param_ptr tn -> let f = function | Some arr -> arr @@ -807,10 +794,12 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims) @@ Constant_fill { values = [| 0. |]; strict = false } in - Map.update ctx_arrays tn ~f - | _ -> ctx_arrays) + Map.update arrays tn ~f + | _ -> arrays) + in + let context = + { label; arrays = { prior_context.arrays with ctx_arrays }; result = Some code.result } in - let context = { label; arrays; result = Some code.result } in let log_file_name = Utils.diagn_log_file [%string "debug-%{label}-%{code.name}.log"] in let run_variadic = [%log_level @@ -831,7 +820,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro | bs, Log_file_name :: ps -> Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs)) | bs, Param_ptr tn :: ps -> - let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in + let nd = match Map.find ctx_arrays tn with Some nd -> nd | None -> assert false in let c_ptr = Ndarray.get_voidptr_not_managed nd in Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs)) | bs, Merge_buffer :: ps -> diff --git a/arrayjit/lib/utils.ml b/arrayjit/lib/utils.ml index 45980757..393a8a45 100644 --- a/arrayjit/lib/utils.ml +++ b/arrayjit/lib/utils.ml @@ -391,10 +391,15 @@ let parallel_merge merge (num_devices : int) = in loop (num_devices - 1) +let ( !@ ) = Atomic.get + type atomic_bool = bool Atomic.t let sexp_of_atomic_bool flag = sexp_of_bool @@ Atomic.get flag -let ( !@ ) = Atomic.get + +type atomic_int = int Atomic.t + +let sexp_of_atomic_int flag = sexp_of_int @@ Atomic.get flag let sexp_append ~elem = function | Sexp.List l -> Sexp.List (elem :: l)