Skip to content

Commit

Permalink
In progress: the synchronization graph
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 15, 2024
1 parent 2459d43 commit fb04bc0
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 31 deletions.
6 changes: 4 additions & 2 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ end
module Device_types (Device_config : Device_config) = struct
include Device_config

type nonrec device = (Device_config.buffer_ptr, Device_config.dev, Device_config.event) device
type nonrec device =
(Device_config.buffer_ptr, Device_config.dev, Device_config.runner, Device_config.event) device
[@@deriving sexp_of]

type nonrec stream =
Expand Down Expand Up @@ -106,7 +107,8 @@ struct
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_streams = Hashtbl.create (module Tnode);
stream_working_on = Hashtbl.create (module Tnode);
writer_stream = Hashtbl.create (module Tnode);
reader_streams = Hashtbl.create (module Tnode);
}

let make_stream device runner ~stream_id =
Expand Down
37 changes: 23 additions & 14 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ module type Device_config = sig
val name : string
end

type ('buffer_ptr, 'dev, 'event) device = {
type ('buffer_ptr, 'dev, 'runner, 'event) device = {
dev : 'dev;
ordinal : int;
mutable shared_merge_buffer : 'buffer_ptr buffer option;
Expand All @@ -91,32 +91,41 @@ type ('buffer_ptr, 'dev, 'event) device = {
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. *)
owner_streams : int Hashtbl.M(Tnode).t;
(** The streams owning the given nodes. This map can only grow. *)
stream_working_on : (int * 'event) option Hashtbl.M(Tnode).t;
(** The stream that most recently has been updating the node, and the associated update
completion event. An entry for a tensor node is only populated when
{!field-queried_work_for} is also populated. *)
owner_streams : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t;
(** The stream owning a given node. This map can only grow.
Currently, if the memory mode of a node is inferred, only this stream will modify a
cross-stream shared array. But memory modes can also be set manually. *)
writer_stream : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t;
(** The stream that most recently has been updating (writing to) the node, and the associated
update completion event. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from the node, and the associated use
completion events. An entry is only populated for cross-stream shared nodes, and an event
only populated when multiple streams worked with the node. *)
}
[@@deriving sexp_of]

type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
device : ('buffer_ptr, 'dev, 'event) device;
and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
device : ('buffer_ptr, 'dev, 'runner, 'event) device;
runner : 'runner;
merge_buffer : 'buffer_ptr buffer option ref;
(** Depending on backend implementations, either the currently used merge buffer, or the one
most recently scheduled. *)
mutable scheduled_merge_node : Tnode.t option;
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *)
stream_id : int;
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
queried_work_for : 'event option Hashtbl.M(Tnode).t;
(* The completion event for updating the node via this stream, if any. Only existing entries
are updated, and an entry is populated when {!work_for} is called for the first time on the
tensor node. *)
(* The completion event for updating (writing to) a node via this stream, if any. An entry
is populated when {!work_for} is called for the first time on the tensor node, or
{!field-writer_stream} needs to be updated with this stream. Otherwise, only existing
entries are udpated. *)
}
[@@deriving sexp_of]

let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal

type ('buffer_ptr, 'stream) context = {
stream : 'stream;
parent : ('buffer_ptr, 'stream) context option;
Expand All @@ -130,7 +139,7 @@ type ('buffer_ptr, 'stream) context = {
module type Device_types = sig
include Device_config

type nonrec device = (buffer_ptr, dev, event) device [@@deriving sexp_of]
type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of]
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of]
end
Expand Down
51 changes: 36 additions & 15 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,49 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
| None | Some None -> default ()
| Some (Some _ as event) -> event)

let wait_for_users ctx tn =
let s = ctx.stream in
let worked_multi_streams =
Hashtbl.find s.device.writer_stream tn
|> Option.map ~f:snd |> Option.join |> Option.is_some |> ref
in
Hashtbl.find s.device.writer_stream tn
|> Option.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then (
worked_multi_streams := true;
Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream)));
!worked_multi_streams

let wait_for_writers ctx tn =
let s = ctx.stream in
Hashtbl.find s.device.writer_stream tn
|> Option.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then
Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream))

let update_writer_event ~worked_multi_streams s tn =
if Hashtbl.mem s.queried_work_for tn || worked_multi_streams then (
let e = Backend.all_work s in
Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, Some e));
Hashtbl.update s.queried_work_for tn ~f:(Option.map ~f:(fun _ -> e)))
else Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, None))

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
let s = ctx.stream in
(* Wait for readers of the array before copying, if any are recorded. FIXME: this is
invalid. *)
Hashtbl.find s.device.stream_working_on tn
|> Option.join
|> Option.iter ~f:(fun (_work_stream_id, e) -> Backend.will_wait_for ctx e);
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
(* Update the latest work event for the node. *)
if Hashtbl.mem s.queried_work_for tn then (
let e = Backend.all_work s in
Hashtbl.update s.device.stream_working_on tn ~f:(fun _ -> Some (s.stream_id, e));
Hashtbl.update s.queried_work_for tn ~f:(fun _ -> Some e));
((* Wait for all users of the array before copying. *)
let worked_multi_streams = wait_for_users ctx tn in
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~worked_multi_streams ctx.stream tn);
true
| _ -> false

let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
(* Only wait for writers of the array before copying. *)
wait_for_writers ctx tn;
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
Backend.to_host ~src_ptr:src ~src:ctx hosted;
true
Expand Down Expand Up @@ -342,12 +363,12 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
Map.add_exn ctx_arrays ~key ~data)
else if Tn.known_shared_cross_stream key then (
if Hashtbl.mem device.owner_streams key then (
if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then
if not (equal_stream stream (Hashtbl.find_exn device.owner_streams key)) then
raise
@@ Utils.User_error
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
^ " assumed to be cross-stream-shared but then written to on multiple devices"))
else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id;
else Hashtbl.add_exn device.owner_streams ~key ~data:stream;
let data = Hashtbl.find_exn device.cross_stream_candidates key in
Map.add_exn ctx_arrays ~key ~data)
else (
Expand Down

0 comments on commit fb04bc0

Please sign in to comment.