Skip to content

Commit

Permalink
Fixes sharing update: Hosted Nonconstant -> `Hosted (Changed_on_dev…
Browse files Browse the repository at this point in the history
…ices ...)` if sharing specified
  • Loading branch information
lukstafi committed Oct 15, 2024
1 parent e2780a6 commit 6e11ff9
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
31 changes: 25 additions & 6 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,16 @@ let known_not_param tn =

let known_shared_cross_stream tn =
match tn.memory_mode with
| Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) ->
| Some
( ( On_device Shared_cross_stream
| Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream) ),
_ ) ->
true
| _ -> false

let known_non_cross_stream tn =
match tn.memory_mode with
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) ->
true
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> true
| _ -> false

let mode_is_unspecified tn =
Expand Down Expand Up @@ -246,6 +248,9 @@ let update_memory_mode tn mode provenance =
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
%{debug_name tn}"]

(** [update_memory_sharing tn sharing provenance] preserves the memory mode of [tn] while updating
the cross-stream sharing property, except that [Hosted Nonconstant] is further specialized to
[Hosted (Changed_on_devices sharing)]. *)
let update_memory_sharing tn sharing provenance =
match (tn.memory_mode, sharing) with
| None, _ -> tn.memory_mode <- Some (On_device sharing, provenance)
Expand All @@ -264,13 +269,27 @@ let update_memory_sharing tn sharing provenance =
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
%{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
permitted"]
| Some (Hosted (Changed_on_devices _), _), _ ->
| Some (Hosted (Constant | Volatile), prov2), Per_stream ->
raise
@@ Utils.User_error
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
%{debug_name tn} (hosted) -- currently hosted nodes not changed on devices must be \
shared cross-stream"]
| Some (Hosted (Constant | Volatile), _), Shared_cross_stream -> ()
| Some (Hosted (Nonconstant | Changed_on_devices _), _), _ ->
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance)
| Some (_, prov2), _ ->
| Some (_, prov2), Unset ->
invalid_arg
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
tn} -- currently unsetting of sharing not allowed"]
| Some (mem_mode, prov2), _ ->
invalid_arg
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
%{debug_name tn} -- not materialized on the devices"]
%{debug_name tn} -- not materialized on the devices: %{Sexp.to_string_hum @@ \
sexp_of_memory_mode mem_mode}"]

let update_prec ?only_if tn prec =
let do_update =
Expand Down
4 changes: 2 additions & 2 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ let _mem_benchmarks =
~f:(fun batch_size ->
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name ->
List.concat_map [ CDSL.double; CDSL.single (* ; CDSL.half *) ]
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
Expand Down

0 comments on commit 6e11ff9

Please sign in to comment.