Skip to content

Commit

Permalink
More refined events for tensor node syncing
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 15, 2024
1 parent 9041153 commit b116838
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 33 deletions.
8 changes: 5 additions & 3 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ struct
latest_stream_id = -1;
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
owner_streams = Hashtbl.create (module Tnode);
writer_streams = Hashtbl.create (module Tnode);
reader_streams = Hashtbl.create (module Tnode);
owner_stream = Hashtbl.create (module Tnode);
shared_writer_streams = Hashtbl.create (module Tnode);
host_reading_streams = Hashtbl.create (module Tnode);
host_writing_streams = Hashtbl.create (module Tnode);
}

let make_stream device runner ~stream_id =
Expand All @@ -120,6 +121,7 @@ struct
stream_id;
allocated_buffer = None;
updating_for = Hashtbl.create (module Tnode);
reader_streams = Hashtbl.create (module Tnode);
}

let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
Expand Down
50 changes: 31 additions & 19 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,23 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = {
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. *)
owner_streams : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t;
(** The stream owning a given node. This map can only grow.
Currently, if the memory mode of a node is inferred, only this stream will modify a
cross-stream shared array. But memory modes can also be set manually. *)
writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) the node, and the
associated update completion event. The completed events are removed opportunistically. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from the node, and the associated use
completion events. The completed events are removed opportunistically. *)
owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t;
(** The stream owning a given node. This map can only grow. Currently, if the memory mode of a
node is inferred, only this stream will modify a cross-stream shared array. But memory
modes can also be set manually. *)
shared_writer_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) a
cross-stream-shared node, and the associated update completion event. The completed events
are removed opportunistically. *)
host_reading_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from a node's on-host array. The
completed events are removed opportunistically. *)
host_writing_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been writing to a node's on-host array. The completed
events are removed opportunistically. *)
}
[@@deriving sexp_of]

Expand All @@ -116,7 +122,11 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
(* The completion event for updating (writing to) a node via this stream, if any. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
this stream's context, and the associated use completion events. The completed events are
removed opportunistically. *)
}
[@@deriving sexp_of]

Expand Down Expand Up @@ -214,8 +224,7 @@ module type Backend_device_common = sig
(** Schedules waiting for the given event on the context's stream.
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
called internally when necessary. But there is one exception, see {!device_to_device} when
[into_merge_buffer=Streaming]. *)
called internally when necessary. *)

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
Expand Down Expand Up @@ -257,12 +266,15 @@ module type With_buffer_retrieval_and_syncing = sig
- If the node is absent from the [src] context and either it is present in the [dst] context
or [into_merge_buffer] is different from [No]: raises an error.
- If the node is absent from [dst] and [into_merge_buffer=No]: returns false.
- Executes [will_wait_for dst (work_for src tn)].
- If [into_merge_buffer=No]: schedules a copy of the tensor node from [src] to [dst].
- Schedules waiting for writing into the tensor node on [src] to finish, if any.
- If [into_merge_buffer=No]: schedules a copy of the tensor node from [src] to [dst] and
updates the writer event for the node.
- If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
from [src] to the merge buffer of [dst]'s stream.
given node.
- If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for
streaming.
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
stream, and registers [dst.stream] with a reader event for the node.
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
Expand Down
32 changes: 21 additions & 11 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,42 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
|> List.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)

let update_writer_event s tn =
let update_writer_event ?(from_host = false) s tn =
let e = Backend.all_work s in
Hashtbl.update s.device.writer_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l);
if from_host then
Hashtbl.update s.device.host_writing_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
(* To be on the safe side, record events for potentially cross-stream nodes. *)
if Tn.potentially_cross_stream tn then
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)

let add_reader s tn =
let add_reader s tn from =
let e = Backend.all_work s in
Hashtbl.update s.device.reader_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l)
let f l = (s, e) :: Option.value ~default:[] l in
match from with
| `Host -> Hashtbl.update s.device.host_reading_streams tn ~f
| `Src src -> Hashtbl.update src.reader_streams tn ~f

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
wait_for_all ctx ctx.stream.device.reader_streams tn;
wait_for_all ctx ctx.stream.reader_streams tn;
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ctx.stream tn;
update_writer_event ~from_host:true ctx.stream tn;
add_reader ctx.stream tn @@ `Host;
true
| _ -> false

let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
wait_for_all ctx ctx.stream.device.writer_streams tn;
if Tn.potentially_cross_stream tn then
wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
Backend.to_host ~src_ptr:src ~src:ctx hosted;
add_reader ctx.stream tn;
true
| _ -> false

Expand Down Expand Up @@ -343,13 +353,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
Map.add_exn ctx_arrays ~key ~data)
else if Tn.known_shared_cross_stream key then (
if Hashtbl.mem device.owner_streams key then (
if not (equal_stream stream (Hashtbl.find_exn device.owner_streams key)) then
if Hashtbl.mem device.owner_stream key then (
if not (equal_stream stream (Hashtbl.find_exn device.owner_stream key)) then
raise
@@ Utils.User_error
("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key
^ " assumed to be cross-stream-shared but then written to on multiple devices"))
else Hashtbl.add_exn device.owner_streams ~key ~data:stream;
else Hashtbl.add_exn device.owner_stream ~key ~data:stream;
let data = Hashtbl.find_exn device.cross_stream_candidates key in
Map.add_exn ctx_arrays ~key ~data)
else (
Expand Down
2 changes: 2 additions & 0 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ let known_non_cross_stream tn =
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> true
| _ -> false

let potentially_cross_stream tn = not (known_not_materialized tn || known_non_cross_stream tn)

let mode_is_unspecified tn =
match tn.memory_mode with
| None | Some ((Never_virtual | Effectively_constant), _) -> true
Expand Down

0 comments on commit b116838

Please sign in to comment.