Skip to content

Commit

Permalink
Fixes #286: use Tnode.sharing in the cuda backend
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 13, 2024
1 parent be9a299 commit bd0dc98
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 24 deletions.
26 changes: 9 additions & 17 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ type device = {
cross_stream_candidates : ctx_array Hashtbl.M(Tn).t;
(** Freshly created arrays that might be shared across streams. The map can both grow and
shrink. See the explanation on top of this file. *)
cross_stream_shared : Hash_set.M(Tn).t;
(** Tensor nodes known to be cross-stream shared. This set can only grow. *)
non_cross_stream : Hash_set.M(Tn).t;
(** Tensor nodes known to not be cross-stream shared. This set can only grow. *)
owner_stream_subordinal : int Hashtbl.M(Tn).t;
(** The streams owning the given nodes. This map can only grow. *)
}
Expand Down Expand Up @@ -77,7 +73,6 @@ let is_done event = Cu.Delimited_event.query event
let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stream event
let sync event = Cu.Delimited_event.synchronize event
let all_work stream = Cu.Delimited_event.record stream.cu_stream

let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer

let is_initialized, initialize =
Expand Down Expand Up @@ -152,8 +147,6 @@ let get_device ~(ordinal : int) : device =
copy_merge_buffer_capacity;
released = Atomic.make false;
cross_stream_candidates = (Hashtbl.create (module Tn) : ctx_array Hashtbl.M(Tn).t);
cross_stream_shared = Hash_set.create (module Tn);
non_cross_stream = Hash_set.create (module Tn);
owner_stream_subordinal = Hashtbl.create (module Tn);
}
in
Expand Down Expand Up @@ -275,9 +268,7 @@ let%diagn2_l_sexp rec device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : co
~src:s_arr.ptr ~src_ctx:src.ctx dst.stream.cu_stream
in
if
same_device
&& (src.stream.subordinal = dst.stream.subordinal
|| Hash_set.mem dst.stream.device.cross_stream_shared tn)
same_device && (src.stream.subordinal = dst.stream.subordinal || Tn.known_shared_cross_stream tn)
then false
else
match Map.find src.ctx_arrays tn with
Expand Down Expand Up @@ -559,24 +550,25 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
work;
} )

let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
let%track3_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
if is_in_context node && not (Map.mem ctx_arrays key) then (
[%log Tn.debug_name key, "read_only", (node.read_only : bool)];
let default () =
[%log2 Tn.debug_name key, "read_only", (node.read_only : bool)];
[%log3 (key : Tn.t)];
let default () : ctx_array =
set_ctx ctx;
let ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes:(Tn.size_in_bytes key) in
{ ptr; tracking = None }
in
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
let device = stream.device in
if node.read_only then
if Hash_set.mem device.non_cross_stream key then add_new ()
if Tn.known_non_cross_stream key then add_new ()
else (
if Hashtbl.mem device.cross_stream_candidates key then
Hash_set.add device.cross_stream_shared key;
Tn.update_memory_sharing key Tn.Shared_cross_stream 40;
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
Map.add_exn ctx_arrays ~key ~data)
else if Hash_set.mem device.cross_stream_shared key then (
else if Tn.known_shared_cross_stream key then (
if Hashtbl.mem device.owner_stream_subordinal key then
if Hashtbl.find_exn device.owner_stream_subordinal key <> stream.subordinal then
raise
Expand All @@ -587,7 +579,7 @@ let%diagn2_sexp alloc_if_needed ctx stream ~key ~data:node ctx_arrays =
let data = Hashtbl.find_exn device.cross_stream_candidates key in
Map.add_exn ctx_arrays ~key ~data)
else (
Hash_set.add device.non_cross_stream key;
Tn.update_memory_sharing key Tn.Per_stream 41;
Hashtbl.remove device.cross_stream_candidates key;
add_new ()))
else ctx_arrays
Expand Down
21 changes: 15 additions & 6 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,13 @@ let log_debug_info ~from_log_level tn =
| (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd
else [%log "<not-in-yet>"]]]

(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *)
(** The one exception to "most local" is that the sharing property is kept at [Unset]. *)
let default_to_most_local tn provenance =
match tn.memory_mode with
| None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance)
| Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance)
| Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance)
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
| Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
| Some (Hosted (Changed_on_devices Unset), _) ->
tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance)
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Unset, provenance)
| Some ((Virtual | Local | On_device _ | Hosted _), _) -> ()

let is_virtual_force tn provenance =
Expand Down Expand Up @@ -194,6 +191,18 @@ let known_not_param tn =
true
| _ -> false

let known_shared_cross_stream tn =
match tn.memory_mode with
| Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) ->
true
| _ -> false

let known_non_cross_stream tn =
match tn.memory_mode with
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) ->
true
| _ -> false

let mode_is_unspecified tn =
match tn.memory_mode with
| None | Some ((Never_virtual | Effectively_constant), _) -> true
Expand Down Expand Up @@ -246,7 +255,7 @@ let update_memory_sharing tn sharing provenance =
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
%{debug_name tn} -- change from non-shared to shared is currently not permitted"]
| Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream ->
| Some ((On_device _ | Device_only | Materialized), _), _ ->
tn.memory_mode <- Some (On_device sharing, provenance)
| Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream ->
raise
Expand Down
2 changes: 1 addition & 1 deletion bin/hello_world.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module Rand = Arrayjit.Rand.Lib
let hello1 () =
Rand.init 0;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
Utils.set_log_level 2;
(* Utils.set_log_level 2; *)
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.init stream in
Expand Down

0 comments on commit bd0dc98

Please sign in to comment.