diff --git a/CHANGES.md b/CHANGES.md index 83391539..867b6fdd 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,7 @@ - Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments. - Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`. - New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`. +- Removed half-static verification of merge buffer nodes inside `device_to_device`. - TODO: Moved the multicore backend from a `device = stream` model to a single device model. - TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`. - TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations. diff --git a/arrayjit/lib/backend_types.ml b/arrayjit/lib/backend_types.ml index d09a5528..c40c6721 100644 --- a/arrayjit/lib/backend_types.ml +++ b/arrayjit/lib/backend_types.ml @@ -156,11 +156,8 @@ module type Backend = sig given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s stream. - - If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge - buffer code, the [device_to_device] call should fail immediately if there's a mismatch with - [into_merge_buffer]. - NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge + NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge buffer but before scheduling work on [src] that modifies [tn], execute [will_wait_for src (all_work (get_ctx_stream dst))]. *) @@ -287,7 +284,7 @@ module type Lowered_backend = sig val device_to_device : Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool - (** If the tensor node is in both contexts, copies from [dst] to [src]. *) + (** See {!Backend.device_to_device}. *) type buffer_ptr [@@deriving sexp_of] diff --git a/arrayjit/lib/backends.ml b/arrayjit/lib/backends.ml index 44704f35..437b0963 100644 --- a/arrayjit/lib/backends.ml +++ b/arrayjit/lib/backends.ml @@ -280,15 +280,6 @@ struct let device_to_device tn ~into_merge_buffer ~dst ~src = let dev = dst.stream in - if - (not (equal_merge_buffer_use into_merge_buffer No)) - && not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node) - then - raise - @@ Utils.User_error - ("Multicore_backend.device_to_device: merge node mismatch, expected " - ^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node) - ^ ", actual " ^ Tnode.debug_name tn); let schedule dst = let work = (* TODO: log the operation if [Utils.settings.with_log_level > 0]. *) @@ -479,15 +470,6 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types. let device_to_device tn ~into_merge_buffer ~dst ~src = let dev = dst.stream in - if - (not (equal_merge_buffer_use into_merge_buffer No)) - && not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node) - then - raise - @@ Utils.User_error - ("Multicore_backend.device_to_device: merge node mismatch, expected " - ^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node) - ^ ", actual " ^ Tnode.debug_name tn); (* TODO: log the operation if [Utils.settings.with_log_level > 0]. *) match (Backend.get_buffer tn dst.ctx, Backend.get_buffer tn src.ctx) with | None, _ | _, None -> false @@ -855,15 +837,6 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types. let to_host context tn = to_host context.ctx tn let device_to_device tn ~into_merge_buffer ~dst ~src = - if - (not (equal_merge_buffer_use into_merge_buffer No)) - && not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node) - then - raise - @@ Utils.User_error - ("Multicore_backend.device_to_device: merge node mismatch, expected " - ^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node) - ^ ", actual " ^ Tnode.debug_name tn); device_to_device tn ~into_merge_buffer ~dst:dst.ctx ~src:src.ctx end diff --git a/bin/moons_demo_parallel.ml b/bin/moons_demo_parallel.ml index 3658b3d0..c1bd1cc4 100644 --- a/bin/moons_demo_parallel.ml +++ b/bin/moons_demo_parallel.ml @@ -13,16 +13,16 @@ let experiment ~seed ~backend_name ~config () = (* Utils.set_log_level 3; *) (* Utils.settings.output_debug_files_in_build_directory <- true; *) (* Utils.settings.debug_log_from_routines <- true; *) - (* let hid_dim = 16 in *) - let hid_dim = 4 in + let hid_dim = 16 in + (* let hid_dim = 4 in *) (* let batch_size = 120 in *) - (* let batch_size = 60 in *) - let batch_size = 20 in + let batch_size = 60 in + (* let batch_size = 20 in *) let len = batch_size * 20 in let init_lr = 0.1 in (* let epochs = 10 in *) - (* let epochs = 40 in *) - let epochs = 1 in + let epochs = 40 in + (* let epochs = 1 in *) let noise () = Rand.float_range (-0.1) 0.1 in let moons_flat = Array.concat_map (Array.create ~len ()) diff --git a/lib/train.ml b/lib/train.ml index 592fab9c..38cb5d7a 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -364,14 +364,16 @@ let%track3_sexp parallel_update (type context) let grad_merge = Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i) in + (* NOTE: we no longer have to to pass [grad_merge.context] as [dst]. *) assert ( Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer - ~dst:grad_merge.context ~src:ctxs.(from)); + ~dst:ctxs.(to_) ~src:ctxs.(from)); (Task.run grad_merge.schedule : unit)) in let merge_loss ~src = + (* NOTE: we no longer have to to pass [loss_merge.context] as [dst]. *) assert ( - Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:loss_merge.context ~src); + Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src); Task.run loss_merge.schedule in (* FIXME: missing backcopy. *)