diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 7d8fc68a..16db6903 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -30,10 +30,6 @@ type device = { cross_stream_candidates : ctx_array Hashtbl.M(Tn).t; (** Freshly created arrays that might be shared across streams. The map can both grow and shrink. See the explanation on top of this file. *) - cross_stream_shared : Hash_set.M(Tn).t; - (** Tensor nodes known to be cross-stream shared. This set can only grow. *) - non_cross_stream : Hash_set.M(Tn).t; - (** Tensor nodes known to not be cross-stream shared. This set can only grow. *) owner_stream_subordinal : int Hashtbl.M(Tn).t; (** The streams owning the given nodes. This map can only grow. *) } @@ -77,7 +73,6 @@ let is_done event = Cu.Delimited_event.query event let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stream event let sync event = Cu.Delimited_event.synchronize event let all_work stream = Cu.Delimited_event.record stream.cu_stream - let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer let is_initialized, initialize = @@ -152,8 +147,6 @@ let get_device ~(ordinal : int) : device = copy_merge_buffer_capacity; released = Atomic.make false; cross_stream_candidates = (Hashtbl.create (module Tn) : ctx_array Hashtbl.M(Tn).t); - cross_stream_shared = Hash_set.create (module Tn); - non_cross_stream = Hash_set.create (module Tn); owner_stream_subordinal = Hashtbl.create (module Tn); } in @@ -275,9 +268,7 @@ let%diagn2_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : co ~src:s_arr.ptr ~src_ctx:src.ctx dst.stream.cu_stream in if - same_device - && (src.stream.subordinal = dst.stream.subordinal - || Hash_set.mem dst.stream.device.cross_stream_shared tn) + same_device && (src.stream.subordinal = dst.stream.subordinal || Tn.known_shared_cross_stream tn) then false else match Map.find src.ctx_arrays tn with @@ -559,10 +550,11 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx work; } ) -let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays = +let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays = if is_in_context node && not (Map.mem ctx_arrays key) then ( - [%log Tn.debug_name key, "read_only", (node.read_only : bool)]; - let default () = + [%log2 Tn.debug_name key, "read_only", (node.read_only : bool)]; + [%log3 (key : Tn.t)]; + let default () : ctx_array = set_ctx ctx; let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key) in { ptr; tracking = None } @@ -570,13 +562,13 @@ let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays = let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in let device = stream.device in if node.read_only then - if Hash_set.mem device.non_cross_stream key then add_new () + if Tn.known_non_cross_stream key then add_new () else ( if Hashtbl.mem device.cross_stream_candidates key then - Hash_set.add device.cross_stream_shared key; + Tn.update_memory_sharing key Tn.Shared_cross_stream 40; let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in Map.add_exn ctx_arrays ~key ~data) - else if Hash_set.mem device.cross_stream_shared key then ( + else if Tn.known_shared_cross_stream key then ( if Hashtbl.mem device.owner_stream_subordinal key then if Hashtbl.find_exn device.owner_stream_subordinal key <> stream.subordinal then raise @@ -587,7 +579,7 @@ let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays = let data = Hashtbl.find_exn device.cross_stream_candidates key in Map.add_exn ctx_arrays ~key ~data) else ( - Hash_set.add device.non_cross_stream key; + Tn.update_memory_sharing key Tn.Per_stream 41; Hashtbl.remove device.cross_stream_candidates key; add_new ())) else ctx_arrays diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index 12a70e99..97bc4f86 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -132,16 +132,13 @@ let log_debug_info ~from_log_level tn = | (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd else [%log ""]]] -(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *) +(** The one exception to "most local" is that the sharing property is kept at [Unset]. *) let default_to_most_local tn provenance = match tn.memory_mode with | None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance) | Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance) | Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance) - | Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance) - | Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance) - | Some (Hosted (Changed_on_devices Unset), _) -> - tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance) + | Some (Materialized, _) -> tn.memory_mode <- Some (On_device Unset, provenance) | Some ((Virtual | Local | On_device _ | Hosted _), _) -> () let is_virtual_force tn provenance = @@ -194,6 +191,18 @@ let known_not_param tn = true | _ -> false +let known_shared_cross_stream tn = + match tn.memory_mode with + | Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) -> + true + | _ -> false + +let known_non_cross_stream tn = + match tn.memory_mode with + | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> + true + | _ -> false + let mode_is_unspecified tn = match tn.memory_mode with | None | Some ((Never_virtual | Effectively_constant), _) -> true @@ -246,7 +255,7 @@ let update_memory_sharing tn sharing provenance = [%string "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ %{debug_name tn} -- change from non-shared to shared is currently not permitted"] - | Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream -> + | Some ((On_device _ | Device_only | Materialized), _), _ -> tn.memory_mode <- Some (On_device sharing, provenance) | Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream -> raise diff --git a/bin/hello_world.ml b/bin/hello_world.ml index d1c4ec47..32556541 100644 --- a/bin/hello_world.ml +++ b/bin/hello_world.ml @@ -11,7 +11,7 @@ module Rand = Arrayjit.Rand.Lib let hello1 () = Rand.init 0; let module Backend = (val Arrayjit.Backends.fresh_backend ()) in - Utils.set_log_level 2; + (* Utils.set_log_level 2; *) (* Utils.settings.output_debug_files_in_build_directory <- true; *) let stream = Backend.(new_stream @@ get_device ~ordinal:0) in let ctx = Backend.init stream in