diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index fdbc449..36f3783 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -77,7 +77,8 @@ end module Device_types (Device_config : Device_config) = struct include Device_config - type nonrec device = (Device_config.buffer_ptr, Device_config.dev, Device_config.event) device + type nonrec device = + (Device_config.buffer_ptr, Device_config.dev, Device_config.runner, Device_config.event) device [@@deriving sexp_of] type nonrec stream = @@ -106,7 +107,8 @@ struct released = Atomic.make false; cross_stream_candidates = Hashtbl.create (module Tnode); owner_streams = Hashtbl.create (module Tnode); - stream_working_on = Hashtbl.create (module Tnode); + writer_stream = Hashtbl.create (module Tnode); + reader_streams = Hashtbl.create (module Tnode); } let make_stream device runner ~stream_id = diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index 32c9cc5..c59b13e 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -78,7 +78,7 @@ module type Device_config = sig val name : string end -type ('buffer_ptr, 'dev, 'event) device = { +type ('buffer_ptr, 'dev, 'runner, 'event) device = { dev : 'dev; ordinal : int; mutable shared_merge_buffer : 'buffer_ptr buffer option; @@ -91,32 +91,41 @@ type ('buffer_ptr, 'dev, 'event) device = { cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; (** Freshly created arrays that might be shared across streams. The map can both grow and shrink. *) - owner_streams : int Hashtbl.M(Tnode).t; - (** The streams owning the given nodes. This map can only grow. *) - stream_working_on : (int * 'event) option Hashtbl.M(Tnode).t; - (** The stream that most recently has been updating the node, and the associated update - completion event. An entry for a tensor node is only populated when - {!field-queried_work_for} is also populated. *) + owner_streams : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t; + (** The stream owning a given node. This map can only grow. + + Currently, if the memory mode of a node is inferred, only this stream will modify a + cross-stream shared array. But memory modes can also be set manually. *) + writer_stream : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t; + (** The stream that most recently has been updating (writing to) the node, and the associated + update completion event. *) + reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event option) Hashtbl.M(Tnode).t; + (** The streams that most recently have been reading from the node, and the associated use + completion events. An entry is only populated for cross-stream shared nodes, and an event + only populated when multiple streams worked with the node. *) } [@@deriving sexp_of] -type ('buffer_ptr, 'dev, 'runner, 'event) stream = { - device : ('buffer_ptr, 'dev, 'event) device; +and ('buffer_ptr, 'dev, 'runner, 'event) stream = { + device : ('buffer_ptr, 'dev, 'runner, 'event) device; runner : 'runner; merge_buffer : 'buffer_ptr buffer option ref; (** Depending on backend implementations, either the currently used merge buffer, or the one most recently scheduled. *) mutable scheduled_merge_node : Tnode.t option; (** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *) - stream_id : int; + stream_id : int; (** An ID unique within the device. *) mutable allocated_buffer : 'buffer_ptr buffer option; queried_work_for : 'event option Hashtbl.M(Tnode).t; - (* The completion event for updating the node via this stream, if any. Only existing entries - are updated, and an entry is populated when {!work_for} is called for the first time on the - tensor node. *) + (* The completion event for updating (writing to) a node via this stream, if any. An entry + is populated when {!work_for} is called for the first time on the tensor node, or + {!field-writer_stream} needs to be updated with this stream. Otherwise, only existing + entries are udpated. *) } [@@deriving sexp_of] +let equal_stream s1 s2 = s1.stream_id = s2.stream_id && s1.device.ordinal = s2.device.ordinal + type ('buffer_ptr, 'stream) context = { stream : 'stream; parent : ('buffer_ptr, 'stream) context option; @@ -130,7 +139,7 @@ type ('buffer_ptr, 'stream) context = { module type Device_types = sig include Device_config - type nonrec device = (buffer_ptr, dev, event) device [@@deriving sexp_of] + type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of] type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of] type nonrec context = (buffer_ptr, stream) context [@@deriving sexp_of] end diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 1a90feb..b5a26c8 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -30,28 +30,49 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin | None | Some None -> default () | Some (Some _ as event) -> event) + let wait_for_users ctx tn = + let s = ctx.stream in + let worked_multi_streams = + Hashtbl.find s.device.writer_stream tn + |> Option.map ~f:snd |> Option.join |> Option.is_some |> ref + in + Hashtbl.find s.device.writer_stream tn + |> Option.iter ~f:(fun (work_stream, e) -> + if not (equal_stream work_stream s) then ( + worked_multi_streams := true; + Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream))); + !worked_multi_streams + + let wait_for_writers ctx tn = + let s = ctx.stream in + Hashtbl.find s.device.writer_stream tn + |> Option.iter ~f:(fun (work_stream, e) -> + if not (equal_stream work_stream s) then + Backend.will_wait_for ctx @@ Option.value e ~default:(Backend.all_work work_stream)) + + let update_writer_event ~worked_multi_streams s tn = + if Hashtbl.mem s.queried_work_for tn || worked_multi_streams then ( + let e = Backend.all_work s in + Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, Some e)); + Hashtbl.update s.queried_work_for tn ~f:(Option.map ~f:(fun _ -> e))) + else Hashtbl.update s.device.writer_stream tn ~f:(fun _ -> (s, None)) + let%diagn2_l_sexp from_host (ctx : Backend.context) tn = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some dst -> - [%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"]; - let s = ctx.stream in - (* Wait for readers of the array before copying, if any are recorded. FIXME: this is - invalid. *) - Hashtbl.find s.device.stream_working_on tn - |> Option.join - |> Option.iter ~f:(fun (_work_stream_id, e) -> Backend.will_wait_for ctx e); - Backend.from_host ~dst_ptr:dst ~dst:ctx hosted; - (* Update the latest work event for the node. *) - if Hashtbl.mem s.queried_work_for tn then ( - let e = Backend.all_work s in - Hashtbl.update s.device.stream_working_on tn ~f:(fun _ -> Some (s.stream_id, e)); - Hashtbl.update s.queried_work_for tn ~f:(fun _ -> Some e)); + ((* Wait for all users of the array before copying. *) + let worked_multi_streams = wait_for_users ctx tn in + [%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"]; + Backend.from_host ~dst_ptr:dst ~dst:ctx hosted; + update_writer_event ~worked_multi_streams ctx.stream tn); true | _ -> false let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some src -> + (* Only wait for writers of the array before copying. *) + wait_for_writers ctx tn; [%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"]; Backend.to_host ~src_ptr:src ~src:ctx hosted; true @@ -342,12 +363,12 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct Map.add_exn ctx_arrays ~key ~data) else if Tn.known_shared_cross_stream key then ( if Hashtbl.mem device.owner_streams key then ( - if not (stream.stream_id = Hashtbl.find_exn device.owner_streams key) then + if not (equal_stream stream (Hashtbl.find_exn device.owner_streams key)) then raise @@ Utils.User_error ("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key ^ " assumed to be cross-stream-shared but then written to on multiple devices")) - else Hashtbl.add_exn device.owner_streams ~key ~data:stream.stream_id; + else Hashtbl.add_exn device.owner_streams ~key ~data:stream; let data = Hashtbl.find_exn device.cross_stream_candidates key in Map.add_exn ctx_arrays ~key ~data) else (