Skip to content

Commit

Permalink
Go back to using ints to identify streams
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 25, 2024
1 parent 1775098 commit 2806622
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 56 deletions.
11 changes: 5 additions & 6 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ type ('buffer_ptr, 'device, 'stream_state, 'runner, 'event) stream = {
device : 'device;
state : 'stream_state;
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
unique_name : string;
stream_id : int;
mutable allocated_buffer : 'buffer_ptr buffer option;
runner : 'runner;
requested_work_for : 'event option Hashtbl.M(Tnode).t;
Expand Down Expand Up @@ -155,8 +155,7 @@ module type Device = sig
include Device_types
include Alloc_buffer with type buffer_ptr := buffer_ptr and type stream := stream

val make_stream :
device:device -> state:stream_state -> unique_name:string -> runner:runner -> stream
val make_stream : device:device -> state:stream_state -> stream_id:int -> runner:runner -> stream
end

module Device_types (Device_config : Device_config) = struct
Expand All @@ -173,14 +172,14 @@ struct
include Device_types
include Alloc_buffer

let make_stream ~device ~state ~unique_name ~runner =
let make_stream ~device ~state ~stream_id ~runner =
{
device;
state;
merge_buffer = ref None;
unique_name : string;
stream_id;
allocated_buffer = None;
runner : 'runner;
runner;
requested_work_for = Hashtbl.create (module Tnode);
}
end
Expand Down
53 changes: 22 additions & 31 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ module Multicore_backend (Backend : No_device_backend) = struct

let is_dev_queue_empty state = Queue.size state.Device_config.queue = 0
let is_idle stream = is_dev_queue_empty stream.state && stream.state.is_ready
let name = "multicore " ^ name
let name = "multicore_" ^ name
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]

let%track3_l_sexp await stream =
assert (Domain.is_main_domain ());
Expand All @@ -162,18 +163,17 @@ module Multicore_backend (Backend : No_device_backend) = struct
Stdlib.Condition.wait d.host_wait_for_idle d.mut
done;
Mut.unlock d.mut;
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ name ^ " " ^ stream.unique_name))
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream))

(** TODO: Returns the event indicating if any currently running or scheduled computations on the
stream have completed. *)
let all_work _stream = Device_config.Not_implemented_yet

let%track3_l_sexp schedule_task stream task =
assert (Domain.is_main_domain ());
[%log_result "schedule_task", Task.describe task, stream.unique_name];
[%log_result "schedule_task", Task.describe task, get_name stream];
let d = stream.state in
Option.iter d.Device_config.stream_error ~f:(fun e ->
Exn.reraise e @@ name ^ " " ^ stream.unique_name);
Option.iter d.Device_config.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream);
if not d.keep_spinning then invalid_arg "Multicore_backend: stream not available";
if not @@ Queue.try_push d.queue task then (
await stream;
Expand All @@ -185,7 +185,7 @@ module Multicore_backend (Backend : No_device_backend) = struct

let global_run_no = ref 0

let%track3_l_sexp spinup_stream ~unique_name : stream =
let%track3_l_sexp spinup_stream ~stream_id : stream =
Int.incr global_run_no;
let state =
{
Expand Down Expand Up @@ -217,17 +217,17 @@ module Multicore_backend (Backend : No_device_backend) = struct
with e ->
state.stream_error <- Some e;
state.keep_spinning <- false;
[%log1 unique_name, "exception", Exn.to_string e];
[%log1 "stream", (stream_id : int), "exception", Exn.to_string e];
(* TODO: we risk raising this error multiple times because await and schedule_task raise
stream_error. But this is fine if we assume all exceptions are fatal. *)
raise e
in
make_stream ~device:Device_config.CPU ~state ~unique_name ~runner:(Domain.spawn worker)
make_stream ~device:Device_config.CPU ~state ~stream_id ~runner:(Domain.spawn worker)

type nonrec context = { stream : stream; ctx : context } [@@deriving sexp_of]

let ctx_arrays context = ctx_arrays context.ctx
let init stream = { stream; ctx = init (name ^ " " ^ stream.unique_name) }
let init stream = { stream; ctx = init (get_name stream) }
let initialize = initialize
let is_initialized = is_initialized

Expand All @@ -237,7 +237,6 @@ module Multicore_backend (Backend : No_device_backend) = struct

let compile = compile
let compile_batch = compile_batch
let get_name stream = stream.unique_name

let link { ctx; stream } code =
let task = link ~merge_buffer:stream.merge_buffer ctx code in
Expand All @@ -264,30 +263,25 @@ module Multicore_backend (Backend : No_device_backend) = struct

let num_devices () = 1
let suggested_num_streams Device_config.CPU = Domain.recommended_domain_count () - 1
let used_names = Hash_set.create (module String)

let cleanup_stream stream =
assert (Domain.is_main_domain ());
await stream;
stream.state.keep_spinning <- false;
Stdlib.Condition.broadcast stream.state.dev_wait_for_work;
Hash_set.remove used_names stream.unique_name;
Domain.join stream.runner

let get_device ~ordinal =
if ordinal <> 0 then
invalid_arg [%string "Multicore_backend.get_device %{ordinal#Int}: only device 0 exists"];
Device_config.CPU

let latest_stream_id = ref (-1)

let new_stream Device_config.CPU =
assert (Domain.is_main_domain ());
let rec unique_name suffix =
let name = "stream " ^ Int.to_string suffix in
if Hash_set.mem used_names name then unique_name (suffix + 1) else name
in
let unique_name = unique_name 0 in
Hash_set.add used_names unique_name;
let stream = spinup_stream ~unique_name in
Int.incr latest_stream_id;
let stream = spinup_stream ~stream_id:!latest_stream_id in
Stdlib.Gc.finalise cleanup_stream stream;
stream

Expand All @@ -300,14 +294,13 @@ module Multicore_backend (Backend : No_device_backend) = struct
(* TODO: pass description to from_host. *)
schedule_task dst.stream
(Task.Task
{ context_lifetime = dst; description = "from_host on " ^ dst.stream.unique_name; work })
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })

let to_host ~src_ptr ~src hosted =
let work () = buffer_to_host hosted ~src:src_ptr in
(* TODO: pass description to to_host. *)
schedule_task src.stream
(Task.Task
{ context_lifetime = src; description = "to_host on " ^ src.stream.unique_name; work })
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let dev = dst.stream in
Expand All @@ -332,8 +325,8 @@ module Multicore_backend (Backend : No_device_backend) = struct
buffer_to_buffer ~dst:merge_ptr ~src:src_ptr ~size_in_bytes
in
let description =
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.unique_name ^ " src "
^ src.stream.unique_name
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
^ get_name src.stream
in
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
end
Expand Down Expand Up @@ -383,20 +376,18 @@ module Sync_backend (Backend : No_device_backend) = struct
let num_devices () = 1
let suggested_num_streams Device_config.CPU = !sync_suggested_num_streams
let get_used_memory Device_config.CPU = Backend.get_used_memory ()
let next_stream = ref 0
let latest_stram_id = ref (-1)

let new_stream Device_config.CPU : stream =
Int.incr next_stream;
make_stream ~device:Device_config.CPU ~state:()
~unique_name:("stream " ^ Int.to_string (!next_stream - 1))
~runner:()
Int.incr latest_stram_id;
make_stream ~device:Device_config.CPU ~state:() ~stream_id:!latest_stram_id ~runner:()

type code = Backend.code [@@deriving sexp_of]
type code_batch = Backend.code_batch [@@deriving sexp_of]

let all_work _stream = ()
let is_idle _stream = true
let name = "sync " ^ Backend.name
let name = "sync_" ^ Backend.name
let await _stream = ()
(* let global_run_no = ref 0 *)

Expand All @@ -422,7 +413,7 @@ module Sync_backend (Backend : No_device_backend) = struct
Array.map routines
~f:(Option.map ~f:(fun task -> { task with context = { ctx = task.context; stream } })) )

let get_name stream = stream.unique_name
let get_name stream = [%string "%{name}:0:%{stream.stream_id#Int}"]
let from_host ~dst_ptr ~dst:_ hosted = host_to_buffer hosted ~dst:dst_ptr
let to_host ~src_ptr ~src:_ hosted = buffer_to_host hosted ~src:src_ptr

Expand Down
26 changes: 11 additions & 15 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ module Device_config = struct
primary_context : Cu.Context.t;
mutable copy_merge_buffer : buffer_ptr;
mutable copy_merge_buffer_capacity : int;
used_names : Hash_set.M(String).t; (** Unique names of streams. *)
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : buffer_ptr Hashtbl.M(Tn).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. See the explanation on top of this file. *)
owner_streams : string Hashtbl.M(Tn).t;
owner_streams : int Hashtbl.M(Tn).t;
(** The streams owning the given nodes. This map can only grow. *)
}
[@@deriving sexp_of]
Expand Down Expand Up @@ -166,7 +166,7 @@ let%track3_sexp get_device ~(ordinal : int) : device =
{
dev;
ordinal;
used_names = Hash_set.create (module String);
latest_stream_id = -1;
primary_context;
copy_merge_buffer;
copy_merge_buffer_capacity;
Expand All @@ -184,16 +184,11 @@ let%track3_sexp get_device ~(ordinal : int) : device =
if Atomic.get result.released then default () else result

let%track3_sexp new_stream (device : device) : stream =
let rec unique_name suffix =
let name = "stream " ^ Int.to_string suffix in
if Hash_set.mem device.used_names name then unique_name (suffix + 1) else name
in
let unique_name = unique_name 0 in
Hash_set.add device.used_names unique_name;
device.latest_stream_id <- device.latest_stream_id + 1;
(* Strange that we need ctx_set_current even with a single device! *)
set_ctx device.primary_context;
let cu_stream = Cu.Stream.create ~non_blocking:true () in
make_stream ~device ~state:() ~unique_name ~runner:cu_stream
make_stream ~device ~state:() ~stream_id:device.latest_stream_id ~runner:cu_stream

let cuda_properties =
let cache =
Expand All @@ -216,7 +211,10 @@ let suggested_num_streams device =
let get_ctx_stream { stream; _ } = stream
let get_stream_device { device; _ } = device
let to_ordinal Device_config.{ ordinal; _ } = ordinal
let get_name stream = stream.unique_name
let name = "cuda"

let get_name stream =
[%string "%{name}:%{stream.device.Device_config.ordinal#Int}:%{stream.stream_id#Int}"]

let await stream : unit =
set_ctx stream.device.Device_config.primary_context;
Expand Down Expand Up @@ -543,12 +541,12 @@ let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
Map.add_exn ctx_arrays ~key ~data)
else if Tn.known_shared_cross_stream key then (
if Hashtbl.mem device.owner_streams key then
if not @@ String.equal stream.unique_name @@ Hashtbl.find_exn device.owner_streams key then
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
raise
@@ Utils.User_error
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
^ " assumed to be cross-stream-shared but then written to on multiple devices")
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.unique_name;
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
let data = Hashtbl.find_exn device.cross_stream_candidates key in
Map.add_exn ctx_arrays ~key ~data)
else (
Expand Down Expand Up @@ -601,5 +599,3 @@ let%track3_sexp link_batch prior_context (code_batch : code_batch) : context * _
((context, ctx_arrays), Some task)))
in
(context, lowered_bindings, procs)

let name = "cuda"
8 changes: 4 additions & 4 deletions lib/attic.mld
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ Old copying mechanism in backends.ml Multicore_backend:
{
context_lifetime = context;
description =
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.stream_id;
work;
});
true
Expand Down Expand Up @@ -415,7 +415,7 @@ Old copying mechanism in backends.ml Multicore_backend:
{
context_lifetime = context;
description =
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.unique_name;
"from_host " ^ Tnode.debug_name tn ^ " dst " ^ context.stream.stream_id;
work;
});
true
Expand Down Expand Up @@ -453,8 +453,8 @@ Old copying mechanism in backends.ml Multicore_backend:
Backend.to_buffer tn ~dst:merge_ptr ~src:src.ctx
in
let description =
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.unique_name ^ " src "
^ src.stream.unique_name
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ dev.stream_id ^ " src "
^ src.stream.stream_id
in
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
in
Expand Down
1 change: 1 addition & 0 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ let%track3_sexp parallel_update (type context)
let fs = [%debug_notrace Array.map grad_updates ~f:(fun upd () -> Task.run upd.schedule)] in
fun () -> round_robin fs lowered_bindings sgd_update.bindings ~sync

(* Note: this type signature looks ugly, but it will get simple again with modular explicits. *)
let get_all_suggested_streams ?(max_num_streams : int option)
(type buffer_ptr device stream_state runner event)
(module Backend : Backend_type
Expand Down

0 comments on commit 2806622

Please sign in to comment.