Skip to content

Commit

Permalink
In progress: first pass of adapting the data-parallel update to the n…
Browse files Browse the repository at this point in the history
…ew merge buffers
  • Loading branch information
lukstafi committed Jul 1, 2024
1 parent 9b20243 commit 90d2027
Showing 1 changed file with 27 additions and 148 deletions.
175 changes: 27 additions & 148 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,105 +284,6 @@ let%track_sexp sync_run ?looping (type context) (module Backend : Backend_type w

module Lazy = Utils.Lazy

(** {[
let merge ?name_prefix tn ~accum ~(src : context) : code option =
let bindings = Indexing.Empty in
merge ?name_prefix tn ~accum ~src bindings |> Option.map ~f:(fun routine : code -> Compiled routine)
let merge_batch ?name_prefixes ~occupancy tns ~accum ~(srcs : context array) =
let bindings = Indexing.Empty in
merge_batch ?name_prefixes ~occupancy tns ~accum ~srcs bindings
|> Hashtbl.map ~f:(fun procs ->
Array.map procs ~f:(Option.map ~f:(fun routine : code -> Compiled routine)))
let%track_sexp merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst ~accum ~src =
let global =
match src with
| None -> Arrayjit.Ops.Merge_buffer src.id
| Some src -> External_unsafe { ptr = src; prec = dst.Arrayjit.Tnode.prec; dims = dst.dims }
in
let open Arrayjit.Low_level in
let body idcs =
Set
{
tn = dst;
idcs;
llv = Binop (accum, Get (dst, idcs), Get_global (global, Some idcs));
debug = "";
}
in
let llc = loop_over_dims (Lazy.force dst.dims) ~body in
optimize_proc ~unoptim_ll_source ~ll_source ~name [] llc
let compile_merges ?name_prefixes:_ ~occupancy:_
(* Tnode.t -> dst_n:int -> dst:context -> src_n:int -> src:context -> Utils.requirement *) _tns
~accum:_ ~dsts:_ ~srcs:_ =
failwith "NOT IMPLEMENTED YET"
let%track_sexp merge (type context) ?name_prefix la ~accum ~(src : context) bindings =
(* FIXME: reconstruct name if missing *)
let name = Option.value name_prefix ~default:"" in
let unoptim_ll_source = Utils.get_debug_formatter ~fname:(name ^ "-unoptimized.ll") in
let ll_source = Utils.get_debug_formatter ~fname:(name ^ ".ll") in
let name_prefix : string = Option.value ~default:"" @@ Option.map name_prefix ~f:(fun s -> s ^ "_") in
Option.map (Map.find src.arrays la) ~f:(fun (src : Ndarray.t) ->
let name = [%string "%{name_prefix}into_%{Tn.name la}"] in
compile ~opt_ctx_arrays:None ~name bindings
@@ merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst:la ~accum
~src:(Ndarray.get_voidptr src))
let%track_sexp merge_batch ?(name_prefixes : string array option) ~occupancy tns ~accum
~(srcs : context array) bindings =
(* FIXME: reconstruct name if missing *)
let name =
String.(
strip ~drop:(equal_char '_')
@@ common_prefix @@ Array.to_list
@@ Option.value name_prefixes ~default:[||])
in
let unoptim_ll_source = Utils.get_debug_formatter ~fname:(name ^ "-unoptimized.ll") in
let ll_source = Utils.get_debug_formatter ~fname:(name ^ ".ll") in
let complete =
Array.concat_map (Array.of_list tns) ~f:(fun tn ->
Array.mapi srcs ~f:(fun i src ->
match (occupancy tn ~src_n:i ~src, Map.find src.arrays tn) with
| Utils.Skip, _ -> ((tn, i), (None, None))
| Optional { callback_if_missing }, None ->
callback_if_missing ();
((tn, i), (None, None))
| Required, None ->
failwith @@ "Gccjit_backend.merge_batch: missing tnode " ^ Tn.name tn ^ " in context "
^ src.label
| _, Some src ->
let prefix = match name_prefixes with Some ns -> ns.(i) ^ "_" | None -> "" in
let name = [%string "%{prefix}into_%{Tn.name tn}"] in
( (tn, i),
( Some name,
Some
(merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst:tn ~accum
~src:(Ndarray.get_voidptr src)) ) )))
in
let ids, compileds = Array.unzip complete in
let len = Array.length srcs in
let together =
Hashtbl.of_alist_exn (module Tn) @@ List.map tns ~f:(fun tn -> (tn, Array.create ~len None))
in
if Array.is_empty compileds then together
else
let names, compileds = Array.unzip compileds in
let result = compile_batch ~opt_ctx_arrays:None ~names bindings compileds in
Array.iter2_exn ids result ~f:(fun (tn, i) res ->
let r = Hashtbl.find_exn together tn in
assert (Option.is_none r.(i));
r.(i) <- res);
together
]} *)

let collapse_merges merges =
Hashtbl.data merges
|> List.map ~f:(Array.map ~f:Option.to_list)
|> List.reduce_exn ~f:(Array.map2_exn ~f:( @ ))

(** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for different
devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of
[grad_updates] in a round-robin fashion, and performs the following synchronization each time all
Expand All @@ -409,69 +310,47 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit
assert (
Array.for_all grad_updates ~f:(fun upd ->
[%equal: Idx.static_symbol list] bindings @@ List.map ~f:fst upd.bindings))];
let all_params : Tensor.t list = Set.to_list updaten.params in
let param_vals = [%debug_notrace List.map all_params ~f:(fun t -> t.value)] in
let param_grads = [%debug_notrace List.map all_params ~f:(fun t -> (Option.value_exn t.diff).grad)] in
let all_params = Set.to_array updaten.params in
let ctxs = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.context)] in
let occupancy _tn ~src_n ~src:_ =
if Array.exists ~f:Fn.id occupancies.(src_n) then Utils.Required else Utils.Skip
in
let name_prefixes = Array.create ~len:num_devices "grad_merge" in
let merge_batch ~name_prefixes:_ ~occupancy:_ _tns ~accum:_ ~srcs:_ = failwith "NOT IMPLEMENTED YET" in
let grad_merges =
collapse_merges @@ merge_batch ~name_prefixes ~occupancy param_grads ~accum:Arrayjit.Ops.Add ~srcs:ctxs
in
let grad_merges =
Array.init num_devices ~f:(fun (to_ : int) ->
Array.init num_devices ~f:(fun (from : int) ->
List.map grad_merges.(from) ~f:(fun c -> (Backend.link ctxs.(to_) c).schedule)))
let occupancy ~name:_ ~src_n = Array.exists ~f:Fn.id occupancies.(src_n) in
(* let names = Array.create ~len:num_devices "grad_merge" in *)
let grad_merges = Array.map all_params ~f:(fun p -> [%cd p.grad =+ p.grad.merge]) in
let grad_merges_to =
Array.map ctxs ~f:(fun ctx ->
snd @@ Backend.link_batch ctx @@ Backend.compile_batch ~shared:true ~occupancy Idx.Empty grad_merges)
in
(* We can cache scheduling, because merging and copying does not depend on static indexing. *)
let name_prefixes = Array.create ~len:num_devices "loss_merge" in
let loss_merges =
collapse_merges
@@ merge_batch ~name_prefixes ~occupancy [ updaten.loss.value ] ~accum:Arrayjit.Ops.Add ~srcs:ctxs
in
let loss_merges =
Array.init num_devices ~f:(fun (to_ : int) ->
Array.init num_devices ~f:(fun (from : int) ->
match loss_merges.(from) with
| [] -> None
| [ c ] -> Some (Backend.link ctxs.(to_) c).schedule
| _ -> assert false))
(* let names = Array.create ~len:num_devices "loss_merge" in *)
let loss_merge =
Backend.(
link sgd_update.context @@ compile Idx.Empty [%cd updaten.loss.value =+ updaten.loss.value.merge])
in
let merge ~(from : int) ~(to_ : int) : unit =
(* FIXME: need to iterate over params in the outer loop. *)
let merge_grads ~(from : int) ~(to_ : int) : unit =
(* FIXME: do we need to sync already? *)
Backend.(await @@ get_ctx_device ctxs.(from));
Option.iter ~f:(Tn.run debug_rt) loss_merges.(to_).(from);
List.iter ~f:(Tn.run debug_rt) grad_merges.(to_).(from)
in
let needed_on_host = ref @@ Set.empty (module Tn) in
(* Backends may choose to not store parameters on devices other than the 0th. *)
let occupancy p ~src_n:_ ~src:_ =
Utils.Optional { callback_if_missing = (fun () -> needed_on_host := Set.add !needed_on_host p) }
Array.iteri all_params ~f:(fun i p ->
assert (Backend.device_to_device p.value ~into_merge_buffer:Copy ~dst:ctxs.(to_) ~src:ctxs.(from));
(Tn.run debug_rt (Option.value_exn grad_merges_to.(to_).(i)).schedule : unit))
in
let copies =
collapse_merges
@@ merge_batch ~name_prefixes:[| "param_copy" |] ~occupancy param_vals ~accum:Arrayjit.Ops.Arg2
~srcs:[| sgd_update.context |]
in
let copies =
assert (Array.length copies = 1);
copies.(0)
in
let copies =
Array.init (num_devices - 1) ~f:(fun (to_m_1 : int) ->
List.map copies ~f:(fun c -> (Backend.link ctxs.(to_m_1 + 1) c).schedule))
let merge_loss ~src =
assert (Backend.device_to_device updaten.loss.value ~into_merge_buffer:Copy ~dst:sgd_update.context ~src);
Tn.run debug_rt loss_merge.schedule
in
(* FIXME: missing backcopy. *)
let needed_on_host = ref @@ Set.empty (module Tn) in
let%track_sexp sync (devices_to_sync : int) : unit =
Arrayjit.Utils.parallel_merge merge devices_to_sync;
Arrayjit.Utils.parallel_merge merge_grads devices_to_sync;
Tn.run debug_rt sgd_update.schedule;
(* We need to wait, because copying happens on other devices. *)
Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src);
Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p));
Backend.(await @@ get_ctx_device sgd_update.context);
(* We will need to update params on all devices! Not only the ones that computed gradients. *)
for to_ = 1 to num_devices - 1 do
List.iter copies.(to_ - 1) ~f:(Tn.run debug_rt)
Array.iter all_params ~f:(fun p ->
assert (
Backend.device_to_device p.value ~into_merge_buffer:No ~dst:ctxs.(to_) ~src:sgd_update.context))
done;
post_sync ~num_synced_devices:devices_to_sync
in
Expand Down

0 comments on commit 90d2027

Please sign in to comment.