diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index 35fcc6ee..7d8fc68a 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -1,22 +1,3 @@ -(** In the current design of the CUDA backend, unlike in the CPU backends, context arrays for - incomparable contexts do not need be disjoint, as long as they share a device. If a tensor node - is read-only for all contexts, its array will be shared even by incomparable contexts. The - particular design is as follows, within a single device: - - If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a - cross-stream sharing candidate. - - If a cross-stream sharing candidate is read-only for another context, whose parent does not - have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream - shared, and the same array is reused. - - If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as - non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it - is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if - the node was already owned by a different stream. - - If a tensor node is cross-stream shared, within-device copying is a NOOP as source and - destination pointers are in that case identical. - - FIXME(#286): this should be controllable via {!Tnode.memory_mode}. *) - open Base module Tn = Tnode module Lazy = Utils.Lazy diff --git a/arrayjit/lib/low_level.ml b/arrayjit/lib/low_level.ml index 821197af..9348c017 100644 --- a/arrayjit/lib/low_level.ml +++ b/arrayjit/lib/low_level.ml @@ -271,7 +271,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc = else Tn.update_memory_mode tn Materialized 35); if Hashtbl.exists traced.accesses ~f:is_recurrent then ( traced.read_before_write <- true; - if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Changed_on_devices) 38 + if Tn.mode_is_unspecified tn then + Tn.update_memory_mode tn (Hosted (Changed_on_devices Unset)) 38 else Tn.update_memory_mode tn Materialized 36)) let%diagn_sexp check_and_store_virtual traced static_indices top_llc = diff --git a/arrayjit/lib/tnode.ml b/arrayjit/lib/tnode.ml index 7d7418d8..12a70e99 100644 --- a/arrayjit/lib/tnode.ml +++ b/arrayjit/lib/tnode.ml @@ -8,10 +8,30 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime [%%global_debug_log_level 9] [%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"] +(** A possible algorithm for deciding sharing within a single device: + - If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a + cross-stream sharing candidate. + - If a cross-stream sharing candidate is read-only for another context, whose parent does not + have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream + shared, and the same array is reused. + - If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as + non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it + is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if + the node was already owned by a different stream. + + If a tensor node is shared cross-stream, within-device copying is a NOOP as source and + destination pointers are in that case identical. *) +type sharing = + | Unset + | Per_stream (** The tensor node has separate arrays for each stream. *) + | Shared_cross_stream (** The tensor node has a single array per device. *) +[@@deriving sexp, compare, equal] + type memory_type = | Constant (** The tensor node does not change after initialization. *) | Nonconstant (** One of: [Changed_on_devices], [Volatile]. *) - | Changed_on_devices (** The tensor node will only change on host via a [to_host] call. *) + | Changed_on_devices of sharing + (** The tensor node will only change on host via a [to_host] call. *) | Volatile (** The tensor node will only change on any device via a [from_host] call possibly followed by [device_to_device]. *) @@ -25,7 +45,7 @@ type memory_mode = (** The full tensor node is cached for the duration of a computation but not persisted across calls to compiled functions. It is not available for merging across devices. *) | Device_only (** One of: [Local], [On_device]. *) - | On_device + | On_device of sharing (** The tensor node is stored on the devices that compute with it and persisted across function calls. It is available for merging across devices (for devices that support merging / P2P), but not (directly) for visualization or storing to disk. *) @@ -112,13 +132,17 @@ let log_debug_info ~from_log_level tn = | (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd else [%log ""]]] +(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *) let default_to_most_local tn provenance = match tn.memory_mode with | None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance) | Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance) | Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance) - | Some (Materialized, _) -> tn.memory_mode <- Some (On_device, provenance) - | Some ((Virtual | Local | On_device | Hosted _), _) -> () + | Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance) + | Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance) + | Some (Hosted (Changed_on_devices Unset), _) -> + tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance) + | Some ((Virtual | Local | On_device _ | Hosted _), _) -> () let is_virtual_force tn provenance = default_to_most_local tn provenance; @@ -128,7 +152,7 @@ let is_hosted_force ?specifically tn provenance = default_to_most_local tn provenance; match (tn.memory_mode, specifically) with | None, _ -> assert false - | Some ((Virtual | Local | Device_only | On_device), _), _ -> false + | Some ((Virtual | Local | Device_only | On_device _), _), _ -> false | Some (Hosted _, _), None -> true | Some (Hosted memtyp, _), Some query -> equal_memory_type memtyp query | Some ((Never_virtual | Materialized | Effectively_constant), _), _ -> assert false @@ -138,7 +162,7 @@ let is_materialized_force tn provenance = match tn.memory_mode with | None -> assert false | Some ((Virtual | Local), _) -> false - | Some ((On_device | Hosted _ | Materialized), _) -> true + | Some ((On_device _ | Hosted _ | Materialized), _) -> true | Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false let is_in_context_force tn provenance = @@ -164,7 +188,7 @@ let known_non_virtual tn = let known_not_param tn = match tn.memory_mode with | Some - ( ( Virtual | Local | Effectively_constant | Device_only | On_device + ( ( Virtual | Local | Effectively_constant | Device_only | On_device _ | Hosted (Constant | Volatile) ), _ ) -> true @@ -190,9 +214,9 @@ let update_memory_mode tn mode provenance = | Some (Effectively_constant, _), (Never_virtual | Materialized | Hosted Constant) -> tn.memory_mode <- Some (Hosted Constant, provenance) | Some (Effectively_constant, _), Virtual -> tn.memory_mode <- Some (mode, provenance) - | Some (Hosted Nonconstant, _), Hosted (Changed_on_devices | Volatile) -> + | Some (Hosted Nonconstant, _), Hosted (Changed_on_devices _ | Volatile) -> tn.memory_mode <- Some (mode, provenance) - | Some (Hosted (Changed_on_devices | Volatile), _), Hosted Nonconstant -> () + | Some (Hosted (Changed_on_devices _ | Volatile), _), Hosted Nonconstant -> () | Some (Never_virtual, _), mode -> tn.memory_mode <- Some (mode, provenance) | Some (Virtual, prov2), Never_virtual -> raise @@ -201,18 +225,44 @@ let update_memory_mode tn mode provenance = "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \ tn} is already virtual"] | Some (_, _), Never_virtual -> () - | Some (Device_only, _), (Local | On_device) -> tn.memory_mode <- Some (mode, provenance) - | Some (Materialized, _), (On_device | Hosted _) -> tn.memory_mode <- Some (mode, provenance) - | Some ((Local | On_device), _), Device_only -> () - | Some ((On_device | Hosted _), _), Materialized -> () + | Some (Device_only, _), (Local | On_device _) -> tn.memory_mode <- Some (mode, provenance) + | Some (Materialized, _), (On_device _ | Hosted _) -> tn.memory_mode <- Some (mode, provenance) + | Some ((Local | On_device _), _), Device_only -> () + | Some ((On_device _ | Hosted _), _), Materialized -> () | Some (Device_only, _), Materialized | Some (Materialized, _), Device_only -> - tn.memory_mode <- Some (On_device, provenance) + tn.memory_mode <- Some (On_device Unset, provenance) | Some (_, prov2), _ -> invalid_arg [%string "Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \ %{debug_name tn}"] +let update_memory_sharing tn sharing provenance = + match (tn.memory_mode, sharing) with + | None, _ -> tn.memory_mode <- Some (On_device sharing, provenance) + | Some (On_device Per_stream, prov2), Shared_cross_stream -> + raise + @@ Utils.User_error + [%string + "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ + %{debug_name tn} -- change from non-shared to shared is currently not permitted"] + | Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream -> + tn.memory_mode <- Some (On_device sharing, provenance) + | Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream -> + raise + @@ Utils.User_error + [%string + "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \ + %{debug_name tn} (hosted) -- change from non-shared to shared is currently not \ + permitted"] + | Some (Hosted (Changed_on_devices _), _), _ -> + tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance) + | Some (_, prov2), _ -> + invalid_arg + [%string + "Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \ + %{debug_name tn} -- not materialized on the devices"] + let update_prec ?only_if tn prec = let do_update = match only_if with diff --git a/bin/micrograd_demo.ml b/bin/micrograd_demo.ml index 00b4b032..b2984465 100644 --- a/bin/micrograd_demo.ml +++ b/bin/micrograd_demo.ml @@ -122,7 +122,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () = let classes = Tensor.value_1d_points ~xdim:0 moons_classes in let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in let%op mlp_result = mlp "point" in - Train.set_on_host Changed_on_devices mlp_result.value; + Train.set_on_host mlp_result.value; (* By using jitted.context here, we don't need to copy the parameters back to the host. *) let result_routine = Train.to_routine diff --git a/bin/moons_demo.ml b/bin/moons_demo.ml index 37c7bcf9..9e6509a0 100644 --- a/bin/moons_demo.ml +++ b/bin/moons_demo.ml @@ -103,7 +103,7 @@ let demo () = in let%op mlp_result = mlp "point" in - Train.set_on_host Changed_on_devices mlp_result.value; + Train.set_on_host mlp_result.value; let result_routine = Train.to_routine (module Backend) diff --git a/lib/train.ml b/lib/train.ml index 38cb5d7a..c2e947e7 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -105,12 +105,15 @@ let restore_params t = let f arr = Npy.Npz.restore in_file name arr in Nd.map { f } @@ Option.value_exn ~here:[%here] @@ Lazy.force v.array) -let set_on_host memtype (a : Tn.t) = Tn.update_memory_mode a (Hosted memtype) 27 +let set_on_host ?(from_device = true) (a : Tn.t) = + let memtype = if from_device then Tn.(Changed_on_devices Unset) else Volatile in + Tn.update_memory_mode a (Hosted memtype) 27 + let set_materialized (a : Tn.t) = Tn.update_memory_mode a Materialized 28 let set_hosted (a : Tn.t) = if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 41 - else Tn.update_memory_mode a (Hosted Changed_on_devices) 41 + else Tn.update_memory_mode a (Hosted (Changed_on_devices Unset)) 41 (** Sets the tensor's value as "fully on host", returns the tensor's forward code with a label-derived comment. *) @@ -510,7 +513,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init else Tensor.consume_forward_code model_result in if not disable_rootness_check then Tensor.remove_bprop_root model_result; - set_on_host Changed_on_devices model_result.Tensor.value; + set_on_host model_result.Tensor.value; (* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *) let routine = Backend.( diff --git a/test/micrograd_demo.ml b/test/micrograd_demo.ml index e27e74bb..24a16fad 100644 --- a/test/micrograd_demo.ml +++ b/test/micrograd_demo.ml @@ -161,7 +161,7 @@ let%expect_test "Micrograd half-moons example" = let classes = Tensor.value_1d_points ~xdim:0 moons_classes in let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in let%op mlp_result = mlp "point" in - Train.set_on_host Changed_on_devices mlp_result.value; + Train.set_on_host mlp_result.value; let result_routine = Train.to_routine (module Backend)