Skip to content

Commit

Permalink
get_used_memory depends on the device
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 15, 2024
1 parent 1953872 commit a09e2d7
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 90 deletions.
24 changes: 20 additions & 4 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

module No_device_types = struct
type ctx_array = Ndarray.t [@@deriving sexp_of]

type ctx_arrays = { used_memory : Utils.atomic_int; ctx_arrays : ctx_array Map.M(Tnode).t }
[@@deriving sexp_of]

let empty_ctx_arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tnode) }
let get_array arrays = Map.find arrays.ctx_arrays
end

module Types = struct
type 'context routine = {
context : 'context;
Expand Down Expand Up @@ -168,6 +178,9 @@ module type Backend = sig
val init : stream -> context
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val await : stream -> unit
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)

Expand Down Expand Up @@ -198,11 +211,12 @@ module type Lowered_no_device_backend = sig
type procedure [@@deriving sexp_of]
type ctx_array [@@deriving sexp_of]
type buffer_ptr [@@deriving sexp_of]
type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of]
type ctx_arrays [@@deriving sexp_of]

val buffer_ptr : ctx_array -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
val ctx_arrays : context -> ctx_arrays
val get_array : ctx_arrays -> Tnode.t -> ctx_array option

val is_in_context : Low_level.traced_array -> bool
(** If true, the node is required to be in the contexts linked with code that uses it.
Expand Down Expand Up @@ -246,6 +260,7 @@ module type Lowered_backend = sig
type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]
type ctx_array [@@deriving sexp_of]
type ctx_arrays [@@deriving sexp_of]
type event

val sync : event -> unit
Expand All @@ -268,7 +283,8 @@ module type Lowered_backend = sig
code_batch

val is_in_context : Low_level.traced_array -> bool
val ctx_arrays : context -> ctx_array Map.M(Tnode).t
val ctx_arrays : context -> ctx_arrays
val get_array : ctx_arrays -> Tnode.t -> ctx_array option
val link : context -> code -> context * Indexing.lowered_bindings * Task.t

val link_batch :
Expand Down Expand Up @@ -298,7 +314,7 @@ module type Lowered_backend = sig

val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr

val get_used_memory : unit -> int
val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)

val init : stream -> context
Expand Down
24 changes: 13 additions & 11 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()

let get_used_memory = Backend.get_used_memory
let get_used_memory _device = Backend.get_used_memory ()

type device = stream [@@deriving sexp_of]
type code = Backend.code [@@deriving sexp_of]
Expand Down Expand Up @@ -370,9 +370,10 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()

let get_used_memory = Backend.get_used_memory

type device = CPU [@@deriving sexp_of]

let get_used_memory CPU = Backend.get_used_memory ()

type code = Backend.code [@@deriving sexp_of]
type code_batch = Backend.code_batch [@@deriving sexp_of]

Expand Down Expand Up @@ -534,14 +535,14 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
)
else (None, None))

let verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context traced_stores
=
let verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
traced_stores =
let olds = ctx_arrays prior_context in
Set.iter from_prior_context ~f:(fun tn ->
let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in
if
Option.value_map node ~default:false ~f:(fun node ->
is_in_context node && not (Map.mem olds tn))
is_in_context node && not (Option.is_some @@ get_array olds tn))
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))

let from_prior_context_batch comps =
Expand Down Expand Up @@ -646,7 +647,8 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back

let link ~merge_buffer (prior_context : context) (code : code) =
let verify from_prior_context =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
verify_prior_context ~get_array:Backend.get_array ~ctx_arrays ~is_in_context ~prior_context
~from_prior_context
[| get_traced_store code |]
in
let context, bindings, schedule, name =
Expand All @@ -673,7 +675,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back

let link_batch ~merge_buffer (prior_context : context) (code_batch : code_batch) =
let verify from_prior_context =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
@@ get_traced_stores code_batch
in
let _opt_ctx_arrays, procs =
Expand Down Expand Up @@ -703,7 +705,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
| None -> (context, None))

let get_buffer tn context =
Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr
Backend.(ctx_arrays context |> Fn.flip get_array tn |> Option.map ~f:buffer_ptr)

let get_used_memory = Ndarray.get_used_memory
end
Expand Down Expand Up @@ -776,7 +778,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
}

let link context (code : code) =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context
~from_prior_context:code.from_prior_context [| code.traced_store |];
let context, bindings, schedule = link context code.code in
let schedule =
Expand All @@ -788,7 +790,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
{ context; schedule; bindings; name }

let link_batch context code_batch =
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context
~from_prior_context:code_batch.from_prior_context code_batch.traced_stores;
let context, bindings, schedules = link_batch context code_batch.code_batch in
( context,
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ module C_syntax (B : sig
val for_lowereds : Low_level.optimized array

type ctx_array
type ctx_arrays

val opt_ctx_arrays : ctx_array Map.M(Tnode).t option
val opt_ctx_arrays : ctx_arrays option
val get_array : ctx_arrays -> Tn.t -> ctx_array option
val hardcoded_context_ptr : (ctx_array -> string) option
val is_in_context : Low_level.traced_array -> bool
val host_ptrs_for_readonly : bool
Expand Down Expand Up @@ -86,7 +88,7 @@ struct
| true, Some get_ptr, Some ctx_arrays, _, _, _ ->
let ident = get_ident node.tn in
let ctx_array =
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays node.tn
Option.value_exn ~here:[%here] ~message:ident @@ B.get_array ctx_arrays node.tn
in
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array;
Hash_set.add is_global node.tn
Expand Down
69 changes: 38 additions & 31 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

include Backend_types.No_device_types
open Backend_types.Types

let name = "cc"
Expand All @@ -18,8 +19,6 @@ let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backe

module Tn = Tnode

type ctx_array = Ndarray.t [@@deriving sexp_of]
type ctx_arrays = ctx_array Map.M(Tn).t [@@deriving sexp_of]
type context = { label : string; arrays : ctx_arrays } [@@deriving sexp_of]

let ctx_arrays context = context.arrays
Expand All @@ -36,7 +35,7 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
| None -> assert false

let to_buffer tn ~dst ~src =
let src = Map.find_exn src.arrays tn in
let src = Map.find_exn src.arrays.ctx_arrays tn in
Ndarray.map2 { f2 = Ndarray.A.blit } src dst

let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
Expand All @@ -50,7 +49,9 @@ let is_initialized, initialize =
let finalize _ctx = ()

let init ~label =
let result = { label; arrays = Map.empty (module Tn) } in
let result =
{ label; arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tn) } }
in
Stdlib.Gc.finalise finalize result;
result

Expand All @@ -61,7 +62,7 @@ type procedure = {
name : string;
result : library;
params : (string * param_source) list;
opt_ctx_arrays : Ndarray.t Map.M(Tn).t option;
opt_ctx_arrays : ctx_arrays option;
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -105,13 +106,14 @@ let c_compile_and_load ~f_name =

module C_syntax_config (Input : sig
val for_lowereds : Low_level.optimized array
val opt_ctx_arrays : (Tn.t, buffer_ptr, Tn.comparator_witness) Base.Map.t option
val opt_ctx_arrays : ctx_arrays option
end) =
struct
let for_lowereds = Input.for_lowereds

type nonrec ctx_array = ctx_array
type nonrec ctx_arrays = ctx_arrays

let get_array = get_array
let for_lowereds = Input.for_lowereds
let opt_ctx_arrays = Input.opt_ctx_arrays
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
let is_in_context = is_in_context
Expand All @@ -133,15 +135,15 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
let opt_ctx_arrays =
Option.map opt_ctx_arrays ~f:(fun ctx_arrays ->
Hashtbl.fold lowered.traced_store ~init:ctx_arrays ~f:(fun ~key:tn ~data:node ctx_arrays ->
match Map.find ctx_arrays tn with
match Map.find ctx_arrays.ctx_arrays tn with
| None ->
if is_in_context node then
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.add_exn ctx_arrays ~key:tn ~data
{ ctx_arrays with ctx_arrays = Map.add_exn ctx_arrays.ctx_arrays ~key:tn ~data }
else ctx_arrays
| Some _ -> ctx_arrays))
in
Expand All @@ -162,22 +164,25 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
(lowereds : Low_level.optimized option array) =
let for_lowereds = Array.filter_map ~f:Fn.id lowereds in
let opt_ctx_arrays =
Option.map opt_ctx_arrays ~f:(fun ctx_arrays ->
Array.fold for_lowereds ~init:ctx_arrays ~f:(fun ctx_arrays lowered ->
Hashtbl.fold lowered.traced_store ~init:ctx_arrays
~f:(fun ~key:tn ~data:node ctx_arrays ->
match Map.find ctx_arrays tn with
| None ->
if is_in_context node then
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec)
~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.add_exn ctx_arrays ~key:tn ~data
else ctx_arrays
| Some _ -> ctx_arrays)))
Option.map opt_ctx_arrays ~f:(fun arrays ->
let ctx_arrays =
Array.fold for_lowereds ~init:arrays.ctx_arrays ~f:(fun ctx_arrays lowered ->
Hashtbl.fold lowered.traced_store ~init:ctx_arrays
~f:(fun ~key:tn ~data:node ctx_arrays ->
match Map.find ctx_arrays tn with
| None ->
if is_in_context node then
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
let data =
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec)
~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.add_exn ctx_arrays ~key:tn ~data
else ctx_arrays
| Some _ -> ctx_arrays))
in
{ arrays with ctx_arrays })
in
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let for_lowereds = for_lowereds
Expand All @@ -186,7 +191,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
(* FIXME: do we really want all of them, or only the used ones? *)
let idx_params = Indexing.bound_symbols bindings in
let global_ctx_arrays =
ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> Map.empty (module Tn))
ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> empty_ctx_arrays)
in
let base_name =
String.(
Expand All @@ -206,7 +211,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
let opt_ctx_arrays = Option.map opt_ctx_arrays ~f:(fun _ -> !global_ctx_arrays) in
( opt_ctx_arrays,
Array.mapi params ~f:(fun i params ->
Option.map names.(i) ~f:(fun name ->
Option.map names.(i) ~f:(fun name ->
{
result;
params = Option.value_exn ~here:[%here] params;
Expand All @@ -219,7 +224,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
context * _ * _ * string =
let label : string = prior_context.label in
let name : string = code.name in
let arrays : Ndarray.t Base.Map.M(Tn).t =
let arrays =
match code with
| { opt_ctx_arrays = Some arrays; _ } -> arrays
| { params; _ } ->
Expand All @@ -232,7 +237,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims)
@@ Constant_fill { values = [| 0. |]; strict = false }
in
Map.update ctx_arrays tn ~f
{ ctx_arrays with ctx_arrays = Map.update ctx_arrays.ctx_arrays tn ~f }
| _ -> ctx_arrays)
in
let context = { label; arrays } in
Expand All @@ -258,7 +263,9 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
let get_ptr (buffer, _) = Ndarray.get_voidptr_not_managed buffer in
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
| bs, Param_ptr tn :: ps ->
let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in
let nd =
match get_array (ctx_arrays context) tn with Some nd -> nd | None -> assert false
in
let c_ptr = Ndarray.get_voidptr_not_managed nd in
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
in
Expand Down
Loading

0 comments on commit a09e2d7

Please sign in to comment.