Skip to content

Commit

Permalink
Remove verification of merge buffer nodes inside device_to_device
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 11, 2024
1 parent e866289 commit 2858d24
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
- Removed half-static verification of merge buffer nodes inside `device_to_device`.
- TODO: Moved the multicore backend from a `device = stream` model to a single device model.
- TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
Expand Down
7 changes: 2 additions & 5 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ module type Backend = sig
given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
from [src] to the merge buffer of [dst]'s stream.
- If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
[into_merge_buffer].
NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
[will_wait_for src (all_work (get_ctx_stream dst))]. *)

Expand Down Expand Up @@ -287,7 +284,7 @@ module type Lowered_backend = sig

val device_to_device :
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
(** If the tensor node is in both contexts, copies from [dst] to [src]. *)
(** See {!Backend.device_to_device}. *)

type buffer_ptr [@@deriving sexp_of]

Expand Down
27 changes: 0 additions & 27 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,6 @@ struct

let device_to_device tn ~into_merge_buffer ~dst ~src =
let dev = dst.stream in
if
(not (equal_merge_buffer_use into_merge_buffer No))
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
then
raise
@@ Utils.User_error
("Multicore_backend.device_to_device: merge node mismatch, expected "
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
^ ", actual " ^ Tnode.debug_name tn);
let schedule dst =
let work =
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
Expand Down Expand Up @@ -479,15 +470,6 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.

let device_to_device tn ~into_merge_buffer ~dst ~src =
let dev = dst.stream in
if
(not (equal_merge_buffer_use into_merge_buffer No))
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
then
raise
@@ Utils.User_error
("Multicore_backend.device_to_device: merge node mismatch, expected "
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
^ ", actual " ^ Tnode.debug_name tn);
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
match (Backend.get_buffer tn dst.ctx, Backend.get_buffer tn src.ctx) with
| None, _ | _, None -> false
Expand Down Expand Up @@ -855,15 +837,6 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
let to_host context tn = to_host context.ctx tn

let device_to_device tn ~into_merge_buffer ~dst ~src =
if
(not (equal_merge_buffer_use into_merge_buffer No))
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
then
raise
@@ Utils.User_error
("Multicore_backend.device_to_device: merge node mismatch, expected "
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
^ ", actual " ^ Tnode.debug_name tn);
device_to_device tn ~into_merge_buffer ~dst:dst.ctx ~src:src.ctx
end

Expand Down
12 changes: 6 additions & 6 deletions bin/moons_demo_parallel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ let experiment ~seed ~backend_name ~config () =
(* Utils.set_log_level 3; *)
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
(* Utils.settings.debug_log_from_routines <- true; *)
(* let hid_dim = 16 in *)
let hid_dim = 4 in
let hid_dim = 16 in
(* let hid_dim = 4 in *)
(* let batch_size = 120 in *)
(* let batch_size = 60 in *)
let batch_size = 20 in
let batch_size = 60 in
(* let batch_size = 20 in *)
let len = batch_size * 20 in
let init_lr = 0.1 in
(* let epochs = 10 in *)
(* let epochs = 40 in *)
let epochs = 1 in
let epochs = 40 in
(* let epochs = 1 in *)
let noise () = Rand.float_range (-0.1) 0.1 in
let moons_flat =
Array.concat_map (Array.create ~len ())
Expand Down
6 changes: 4 additions & 2 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -364,14 +364,16 @@ let%track3_sexp parallel_update (type context)
let grad_merge =
Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i)
in
(* NOTE: we no longer have to to pass [grad_merge.context] as [dst]. *)
assert (
Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer
~dst:grad_merge.context ~src:ctxs.(from));
~dst:ctxs.(to_) ~src:ctxs.(from));
(Task.run grad_merge.schedule : unit))
in
let merge_loss ~src =
(* NOTE: we no longer have to to pass [loss_merge.context] as [dst]. *)
assert (
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:loss_merge.context ~src);
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src);
Task.run loss_merge.schedule
in
(* FIXME: missing backcopy. *)
Expand Down

0 comments on commit 2858d24

Please sign in to comment.