diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index 97bc4f86..09d0c85b 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -193,14 +193,16 @@ let known_not_param tn = let known_shared_cross_stream tn = match tn.memory_mode with - | Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) -> + | Some + ( ( On_device Shared_cross_stream + | Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream) ), + _ ) -> true | _ -> false let known_non_cross_stream tn = match tn.memory_mode with - | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> - true + | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> true | _ -> false let mode_is_unspecified tn = @@ -246,6 +248,9 @@ let update_memory_mode tn mode provenance = "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \ %{debug_name tn}"] +(** [update_memory_sharing tn sharing provenance] preserves the memory mode of [tn] while updating + the cross-stream sharing property, except that [Hosted Nonconstant] is further specialized to + [Hosted (Changed_on_devices sharing)]. *) let update_memory_sharing tn sharing provenance = match (tn.memory_mode, sharing) with | None, _ -> tn.memory_mode <- Some (On_device sharing, provenance) @@ -264,13 +269,27 @@ let update_memory_sharing tn sharing provenance = "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ %{debug_name tn} (hosted) -- change from non-shared to shared is currently not \ permitted"] - | Some (Hosted (Changed_on_devices _), _), _ -> + | Some (Hosted (Constant | Volatile), prov2), Per_stream -> + raise + @@ Utils.User_error + [%string + "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ + %{debug_name tn} (hosted) -- currently hosted nodes not changed on devices must be \ + shared cross-stream"] + | Some (Hosted (Constant | Volatile), _), Shared_cross_stream -> () + | Some (Hosted (Nonconstant | Changed_on_devices _), _), _ -> tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance) - | Some (_, prov2), _ -> + | Some (_, prov2), Unset -> + invalid_arg + [%string + "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \ + tn} -- currently unsetting of sharing not allowed"] + | Some (mem_mode, prov2), _ -> invalid_arg [%string "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \ - %{debug_name tn} -- not materialized on the devices"] + %{debug_name tn} -- not materialized on the devices: %{Sexp.to_string_hum @@ \ + sexp_of_memory_mode mem_mode}"] let update_prec ?only_if tn prec = let do_update = diff --git a/bin/moons_benchmark.ml b/bin/moons_benchmark.ml index 4b32342b..6a43ea19 100644 --- a/bin/moons_benchmark.ml +++ b/bin/moons_benchmark.ml @@ -217,8 +217,8 @@ let _mem_benchmarks = ~f:(fun batch_size -> List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff -> List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed -> - List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name -> - List.concat_map [ CDSL.double; CDSL.single (* ; CDSL.half *) ] + List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name -> + List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ] ~f:(fun value_prec -> [ classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams