Skip to content

Commit

Permalink
In progress toward #286: type Tnode.sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 13, 2024
1 parent cdc7196 commit be9a299
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 40 deletions.
19 changes: 0 additions & 19 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
@@ -1,22 +1,3 @@
(** In the current design of the CUDA backend, unlike in the CPU backends, context arrays for
incomparable contexts do not need be disjoint, as long as they share a device. If a tensor node
is read-only for all contexts, its array will be shared even by incomparable contexts. The
particular design is as follows, within a single device:
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
cross-stream sharing candidate.
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream
shared, and the same array is reused.
- If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as
non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it
is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if
the node was already owned by a different stream.
If a tensor node is cross-stream shared, within-device copying is a NOOP as source and
destination pointers are in that case identical.
FIXME(#286): this should be controllable via {!Tnode.memory_mode}. *)

open Base
module Tn = Tnode
module Lazy = Utils.Lazy
Expand Down
3 changes: 2 additions & 1 deletion arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
else Tn.update_memory_mode tn Materialized 35);
if Hashtbl.exists traced.accesses ~f:is_recurrent then (
traced.read_before_write <- true;
if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Changed_on_devices) 38
if Tn.mode_is_unspecified tn then
Tn.update_memory_mode tn (Hosted (Changed_on_devices Unset)) 38
else Tn.update_memory_mode tn Materialized 36))

let%diagn_sexp check_and_store_virtual traced static_indices top_llc =
Expand Down
78 changes: 64 additions & 14 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,30 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

(** A possible algorithm for deciding sharing within a single device:
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
cross-stream sharing candidate.
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream
shared, and the same array is reused.
- If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as
non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it
is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if
the node was already owned by a different stream.
If a tensor node is shared cross-stream, within-device copying is a NOOP as source and
destination pointers are in that case identical. *)
type sharing =
| Unset
| Per_stream (** The tensor node has separate arrays for each stream. *)
| Shared_cross_stream (** The tensor node has a single array per device. *)
[@@deriving sexp, compare, equal]

type memory_type =
| Constant (** The tensor node does not change after initialization. *)
| Nonconstant (** One of: [Changed_on_devices], [Volatile]. *)
| Changed_on_devices (** The tensor node will only change on host via a [to_host] call. *)
| Changed_on_devices of sharing
(** The tensor node will only change on host via a [to_host] call. *)
| Volatile
(** The tensor node will only change on any device via a [from_host] call possibly followed by
[device_to_device]. *)
Expand All @@ -25,7 +45,7 @@ type memory_mode =
(** The full tensor node is cached for the duration of a computation but not persisted across
calls to compiled functions. It is not available for merging across devices. *)
| Device_only (** One of: [Local], [On_device]. *)
| On_device
| On_device of sharing
(** The tensor node is stored on the devices that compute with it and persisted across
function calls. It is available for merging across devices (for devices that support
merging / P2P), but not (directly) for visualization or storing to disk. *)
Expand Down Expand Up @@ -112,13 +132,17 @@ let log_debug_info ~from_log_level tn =
| (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd
else [%log "<not-in-yet>"]]]

(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *)
let default_to_most_local tn provenance =
match tn.memory_mode with
| None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance)
| Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance)
| Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance)
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device, provenance)
| Some ((Virtual | Local | On_device | Hosted _), _) -> ()
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
| Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
| Some (Hosted (Changed_on_devices Unset), _) ->
tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance)
| Some ((Virtual | Local | On_device _ | Hosted _), _) -> ()

let is_virtual_force tn provenance =
default_to_most_local tn provenance;
Expand All @@ -128,7 +152,7 @@ let is_hosted_force ?specifically tn provenance =
default_to_most_local tn provenance;
match (tn.memory_mode, specifically) with
| None, _ -> assert false
| Some ((Virtual | Local | Device_only | On_device), _), _ -> false
| Some ((Virtual | Local | Device_only | On_device _), _), _ -> false
| Some (Hosted _, _), None -> true
| Some (Hosted memtyp, _), Some query -> equal_memory_type memtyp query
| Some ((Never_virtual | Materialized | Effectively_constant), _), _ -> assert false
Expand All @@ -138,7 +162,7 @@ let is_materialized_force tn provenance =
match tn.memory_mode with
| None -> assert false
| Some ((Virtual | Local), _) -> false
| Some ((On_device | Hosted _ | Materialized), _) -> true
| Some ((On_device _ | Hosted _ | Materialized), _) -> true
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false

let is_in_context_force tn provenance =
Expand All @@ -164,7 +188,7 @@ let known_non_virtual tn =
let known_not_param tn =
match tn.memory_mode with
| Some
( ( Virtual | Local | Effectively_constant | Device_only | On_device
( ( Virtual | Local | Effectively_constant | Device_only | On_device _
| Hosted (Constant | Volatile) ),
_ ) ->
true
Expand All @@ -190,9 +214,9 @@ let update_memory_mode tn mode provenance =
| Some (Effectively_constant, _), (Never_virtual | Materialized | Hosted Constant) ->
tn.memory_mode <- Some (Hosted Constant, provenance)
| Some (Effectively_constant, _), Virtual -> tn.memory_mode <- Some (mode, provenance)
| Some (Hosted Nonconstant, _), Hosted (Changed_on_devices | Volatile) ->
| Some (Hosted Nonconstant, _), Hosted (Changed_on_devices _ | Volatile) ->
tn.memory_mode <- Some (mode, provenance)
| Some (Hosted (Changed_on_devices | Volatile), _), Hosted Nonconstant -> ()
| Some (Hosted (Changed_on_devices _ | Volatile), _), Hosted Nonconstant -> ()
| Some (Never_virtual, _), mode -> tn.memory_mode <- Some (mode, provenance)
| Some (Virtual, prov2), Never_virtual ->
raise
Expand All @@ -201,18 +225,44 @@ let update_memory_mode tn mode provenance =
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
tn} is already virtual"]
| Some (_, _), Never_virtual -> ()
| Some (Device_only, _), (Local | On_device) -> tn.memory_mode <- Some (mode, provenance)
| Some (Materialized, _), (On_device | Hosted _) -> tn.memory_mode <- Some (mode, provenance)
| Some ((Local | On_device), _), Device_only -> ()
| Some ((On_device | Hosted _), _), Materialized -> ()
| Some (Device_only, _), (Local | On_device _) -> tn.memory_mode <- Some (mode, provenance)
| Some (Materialized, _), (On_device _ | Hosted _) -> tn.memory_mode <- Some (mode, provenance)
| Some ((Local | On_device _), _), Device_only -> ()
| Some ((On_device _ | Hosted _), _), Materialized -> ()
| Some (Device_only, _), Materialized | Some (Materialized, _), Device_only ->
tn.memory_mode <- Some (On_device, provenance)
tn.memory_mode <- Some (On_device Unset, provenance)
| Some (_, prov2), _ ->
invalid_arg
[%string
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
%{debug_name tn}"]

let update_memory_sharing tn sharing provenance =
match (tn.memory_mode, sharing) with
| None, _ -> tn.memory_mode <- Some (On_device sharing, provenance)
| Some (On_device Per_stream, prov2), Shared_cross_stream ->
raise
@@ Utils.User_error
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
%{debug_name tn} -- change from non-shared to shared is currently not permitted"]
| Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream ->
tn.memory_mode <- Some (On_device sharing, provenance)
| Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream ->
raise
@@ Utils.User_error
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
%{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
permitted"]
| Some (Hosted (Changed_on_devices _), _), _ ->
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance)
| Some (_, prov2), _ ->
invalid_arg
[%string
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
%{debug_name tn} -- not materialized on the devices"]

let update_prec ?only_if tn prec =
let do_update =
match only_if with
Expand Down
2 changes: 1 addition & 1 deletion bin/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
let classes = Tensor.value_1d_points ~xdim:0 moons_classes in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
let%op mlp_result = mlp "point" in
Train.set_on_host Changed_on_devices mlp_result.value;
Train.set_on_host mlp_result.value;
(* By using jitted.context here, we don't need to copy the parameters back to the host. *)
let result_routine =
Train.to_routine
Expand Down
2 changes: 1 addition & 1 deletion bin/moons_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ let demo () =
in

let%op mlp_result = mlp "point" in
Train.set_on_host Changed_on_devices mlp_result.value;
Train.set_on_host mlp_result.value;
let result_routine =
Train.to_routine
(module Backend)
Expand Down
9 changes: 6 additions & 3 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ let restore_params t =
let f arr = Npy.Npz.restore in_file name arr in
Nd.map { f } @@ Option.value_exn ~here:[%here] @@ Lazy.force v.array)

let set_on_host memtype (a : Tn.t) = Tn.update_memory_mode a (Hosted memtype) 27
let set_on_host ?(from_device = true) (a : Tn.t) =
let memtype = if from_device then Tn.(Changed_on_devices Unset) else Volatile in
Tn.update_memory_mode a (Hosted memtype) 27

let set_materialized (a : Tn.t) = Tn.update_memory_mode a Materialized 28

let set_hosted (a : Tn.t) =
if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 41
else Tn.update_memory_mode a (Hosted Changed_on_devices) 41
else Tn.update_memory_mode a (Hosted (Changed_on_devices Unset)) 41

(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a
label-derived comment. *)
Expand Down Expand Up @@ -510,7 +513,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
else Tensor.consume_forward_code model_result
in
if not disable_rootness_check then Tensor.remove_bprop_root model_result;
set_on_host Changed_on_devices model_result.Tensor.value;
set_on_host model_result.Tensor.value;
(* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *)
let routine =
Backend.(
Expand Down
2 changes: 1 addition & 1 deletion test/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ let%expect_test "Micrograd half-moons example" =
let classes = Tensor.value_1d_points ~xdim:0 moons_classes in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
let%op mlp_result = mlp "point" in
Train.set_on_host Changed_on_devices mlp_result.value;
Train.set_on_host mlp_result.value;
let result_routine =
Train.to_routine
(module Backend)
Expand Down

0 comments on commit be9a299

Please sign in to comment.