diff --git a/arrayjit/lib/backend_impl.ml b/arrayjit/lib/backend_impl.ml index bafeda3..61de5d6 100644 --- a/arrayjit/lib/backend_impl.ml +++ b/arrayjit/lib/backend_impl.ml @@ -106,9 +106,10 @@ struct latest_stream_id = -1; released = Atomic.make false; cross_stream_candidates = Hashtbl.create (module Tnode); - owner_streams = Hashtbl.create (module Tnode); - writer_streams = Hashtbl.create (module Tnode); - reader_streams = Hashtbl.create (module Tnode); + owner_stream = Hashtbl.create (module Tnode); + shared_writer_streams = Hashtbl.create (module Tnode); + host_reading_streams = Hashtbl.create (module Tnode); + host_writing_streams = Hashtbl.create (module Tnode); } let make_stream device runner ~stream_id = @@ -120,6 +121,7 @@ struct stream_id; allocated_buffer = None; updating_for = Hashtbl.create (module Tnode); + reader_streams = Hashtbl.create (module Tnode); } let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"] diff --git a/arrayjit/lib/backend_intf.ml b/arrayjit/lib/backend_intf.ml index 5353d32..9a37a90 100644 --- a/arrayjit/lib/backend_intf.ml +++ b/arrayjit/lib/backend_intf.ml @@ -91,17 +91,23 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = { cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t; (** Freshly created arrays that might be shared across streams. The map can both grow and shrink. *) - owner_streams : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t; - (** The stream owning a given node. This map can only grow. - - Currently, if the memory mode of a node is inferred, only this stream will modify a - cross-stream shared array. But memory modes can also be set manually. *) - writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; - (** The streams that most recently have been scheduled to update (write to) the node, and the - associated update completion event. The completed events are removed opportunistically. *) - reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; - (** The streams that most recently have been reading from the node, and the associated use - completion events. The completed events are removed opportunistically. *) + owner_stream : ('buffer_ptr, 'dev, 'runner, 'event) stream Hashtbl.M(Tnode).t; + (** The stream owning a given node. This map can only grow. Currently, if the memory mode of a + node is inferred, only this stream will modify a cross-stream shared array. But memory + modes can also be set manually. *) + shared_writer_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been scheduled to update (write to) a + cross-stream-shared node, and the associated update completion event. The completed events + are removed opportunistically. *) + host_reading_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been reading from a node's on-host array. The + completed events are removed opportunistically. *) + host_writing_streams : + (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams that most recently have been writing to a node's on-host array. The completed + events are removed opportunistically. *) } [@@deriving sexp_of] @@ -116,7 +122,11 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = { stream_id : int; (** An ID unique within the device. *) mutable allocated_buffer : 'buffer_ptr buffer option; updating_for : 'event Hashtbl.M(Tnode).t; - (* The completion event for updating (writing to) a node via this stream, if any. *) + (* The completion event for updating (writing to) a node via this stream, if any. *) + reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t; + (** The streams, other than this stream, that most recently have been reading from a node in + this stream's context, and the associated use completion events. The completed events are + removed opportunistically. *) } [@@deriving sexp_of] @@ -214,8 +224,7 @@ module type Backend_device_common = sig (** Schedules waiting for the given event on the context's stream. NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically - called internally when necessary. But there is one exception, see {!device_to_device} when - [into_merge_buffer=Streaming]. *) + called internally when necessary. *) val get_used_memory : device -> int (** Returns (an upper bound of) the memory used for arrays, in bytes. *) @@ -257,12 +266,15 @@ module type With_buffer_retrieval_and_syncing = sig - If the node is absent from the [src] context and either it is present in the [dst] context or [into_merge_buffer] is different from [No]: raises an error. - If the node is absent from [dst] and [into_merge_buffer=No]: returns false. - - Executes [will_wait_for dst (work_for src tn)]. - - If [into_merge_buffer=No]: schedules a copy of the tensor node from [src] to [dst]. + - Schedules waiting for writing into the tensor node on [src] to finish, if any. + - If [into_merge_buffer=No]: schedules a copy of the tensor node from [src] to [dst] and + updates the writer event for the node. - If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the - given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source - node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying - from [src] to the merge buffer of [dst]'s stream. + given node. + - If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for + streaming. + - If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s + stream, and registers [dst.stream] with a reader event for the node. NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge buffer but before scheduling work on [src] that modifies [tn], execute diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index c699c85..27d5846 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -30,32 +30,42 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin |> List.iter ~f:(fun (work_stream, e) -> if not (equal_stream work_stream s) then Backend.will_wait_for ctx e) - let update_writer_event s tn = + let update_writer_event ?(from_host = false) s tn = let e = Backend.all_work s in - Hashtbl.update s.device.writer_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l); + if from_host then + Hashtbl.update s.device.host_writing_streams tn ~f:(fun l -> + (s, e) :: Option.value ~default:[] l); + (* To be on the safe side, record events for potentially cross-stream nodes. *) + if Tn.potentially_cross_stream tn then + Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l -> + (s, e) :: Option.value ~default:[] l); Hashtbl.update s.updating_for tn ~f:(fun _ -> e) - let add_reader s tn = + let add_reader s tn from = let e = Backend.all_work s in - Hashtbl.update s.device.reader_streams tn ~f:(fun l -> (s, e) :: Option.value ~default:[] l) + let f l = (s, e) :: Option.value ~default:[] l in + match from with + | `Host -> Hashtbl.update s.device.host_reading_streams tn ~f + | `Src src -> Hashtbl.update src.reader_streams tn ~f let%diagn2_l_sexp from_host (ctx : Backend.context) tn = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some dst -> - wait_for_all ctx ctx.stream.device.reader_streams tn; + wait_for_all ctx ctx.stream.reader_streams tn; [%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"]; Backend.from_host ~dst_ptr:dst ~dst:ctx hosted; - update_writer_event ctx.stream tn; + update_writer_event ~from_host:true ctx.stream tn; + add_reader ctx.stream tn @@ `Host; true | _ -> false let%diagn2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) = match (tn, Map.find ctx.ctx_arrays tn) with | { Tn.array = (lazy (Some hosted)); _ }, Some src -> - wait_for_all ctx ctx.stream.device.writer_streams tn; + if Tn.potentially_cross_stream tn then + wait_for_all ctx ctx.stream.device.shared_writer_streams tn; [%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"]; Backend.to_host ~src_ptr:src ~src:ctx hosted; - add_reader ctx.stream tn; true | _ -> false @@ -343,13 +353,13 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in Map.add_exn ctx_arrays ~key ~data) else if Tn.known_shared_cross_stream key then ( - if Hashtbl.mem device.owner_streams key then ( - if not (equal_stream stream (Hashtbl.find_exn device.owner_streams key)) then + if Hashtbl.mem device.owner_stream key then ( + if not (equal_stream stream (Hashtbl.find_exn device.owner_stream key)) then raise @@ Utils.User_error ("Cuda_backend.alloc_if_needed: node " ^ Tn.debug_name key ^ " assumed to be cross-stream-shared but then written to on multiple devices")) - else Hashtbl.add_exn device.owner_streams ~key ~data:stream; + else Hashtbl.add_exn device.owner_stream ~key ~data:stream; let data = Hashtbl.find_exn device.cross_stream_candidates key in Map.add_exn ctx_arrays ~key ~data) else ( diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index 91afe01..5fb7e7c 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -247,6 +247,8 @@ let known_non_cross_stream tn = | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> true | _ -> false +let potentially_cross_stream tn = not (known_not_materialized tn || known_non_cross_stream tn) + let mode_is_unspecified tn = match tn.memory_mode with | None | Some ((Never_virtual | Effectively_constant), _) -> true