diff --git a/lib/train.ml b/lib/train.ml index 21b6f790..38d05e8e 100644 --- a/lib/train.ml +++ b/lib/train.ml @@ -284,105 +284,6 @@ let%track_sexp sync_run ?looping (type context) (module Backend : Backend_type w module Lazy = Utils.Lazy -(** {[ - let merge ?name_prefix tn ~accum ~(src : context) : code option = - let bindings = Indexing.Empty in - merge ?name_prefix tn ~accum ~src bindings |> Option.map ~f:(fun routine : code -> Compiled routine) - - let merge_batch ?name_prefixes ~occupancy tns ~accum ~(srcs : context array) = - let bindings = Indexing.Empty in - merge_batch ?name_prefixes ~occupancy tns ~accum ~srcs bindings - |> Hashtbl.map ~f:(fun procs -> - Array.map procs ~f:(Option.map ~f:(fun routine : code -> Compiled routine))) - - let%track_sexp merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst ~accum ~src = - let global = - match src with - | None -> Arrayjit.Ops.Merge_buffer src.id - | Some src -> External_unsafe { ptr = src; prec = dst.Arrayjit.Tnode.prec; dims = dst.dims } - in - let open Arrayjit.Low_level in - let body idcs = - Set - { - tn = dst; - idcs; - llv = Binop (accum, Get (dst, idcs), Get_global (global, Some idcs)); - debug = ""; - } - in - let llc = loop_over_dims (Lazy.force dst.dims) ~body in - optimize_proc ~unoptim_ll_source ~ll_source ~name [] llc - - let compile_merges ?name_prefixes:_ ~occupancy:_ - (* Tnode.t -> dst_n:int -> dst:context -> src_n:int -> src:context -> Utils.requirement *) _tns - ~accum:_ ~dsts:_ ~srcs:_ = - failwith "NOT IMPLEMENTED YET" - - let%track_sexp merge (type context) ?name_prefix la ~accum ~(src : context) bindings = - (* FIXME: reconstruct name if missing *) - let name = Option.value name_prefix ~default:"" in - let unoptim_ll_source = Utils.get_debug_formatter ~fname:(name ^ "-unoptimized.ll") in - let ll_source = Utils.get_debug_formatter ~fname:(name ^ ".ll") in - let name_prefix : string = Option.value ~default:"" @@ Option.map name_prefix ~f:(fun s -> s ^ "_") in - Option.map (Map.find src.arrays la) ~f:(fun (src : Ndarray.t) -> - let name = [%string "%{name_prefix}into_%{Tn.name la}"] in - compile ~opt_ctx_arrays:None ~name bindings - @@ merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst:la ~accum - ~src:(Ndarray.get_voidptr src)) - - let%track_sexp merge_batch ?(name_prefixes : string array option) ~occupancy tns ~accum - ~(srcs : context array) bindings = - (* FIXME: reconstruct name if missing *) - let name = - String.( - strip ~drop:(equal_char '_') - @@ common_prefix @@ Array.to_list - @@ Option.value name_prefixes ~default:[||]) - in - let unoptim_ll_source = Utils.get_debug_formatter ~fname:(name ^ "-unoptimized.ll") in - let ll_source = Utils.get_debug_formatter ~fname:(name ^ ".ll") in - let complete = - Array.concat_map (Array.of_list tns) ~f:(fun tn -> - Array.mapi srcs ~f:(fun i src -> - match (occupancy tn ~src_n:i ~src, Map.find src.arrays tn) with - | Utils.Skip, _ -> ((tn, i), (None, None)) - | Optional { callback_if_missing }, None -> - callback_if_missing (); - ((tn, i), (None, None)) - | Required, None -> - failwith @@ "Gccjit_backend.merge_batch: missing tnode " ^ Tn.name tn ^ " in context " - ^ src.label - | _, Some src -> - let prefix = match name_prefixes with Some ns -> ns.(i) ^ "_" | None -> "" in - let name = [%string "%{prefix}into_%{Tn.name tn}"] in - ( (tn, i), - ( Some name, - Some - (merge_from_global ~unoptim_ll_source ~ll_source ~name ~dst:tn ~accum - ~src:(Ndarray.get_voidptr src)) ) ))) - in - let ids, compileds = Array.unzip complete in - let len = Array.length srcs in - let together = - Hashtbl.of_alist_exn (module Tn) @@ List.map tns ~f:(fun tn -> (tn, Array.create ~len None)) - in - if Array.is_empty compileds then together - else - let names, compileds = Array.unzip compileds in - let result = compile_batch ~opt_ctx_arrays:None ~names bindings compileds in - Array.iter2_exn ids result ~f:(fun (tn, i) res -> - let r = Hashtbl.find_exn together tn in - assert (Option.is_none r.(i)); - r.(i) <- res); - together - ]} *) - -let collapse_merges merges = - Hashtbl.data merges - |> List.map ~f:(Array.map ~f:Option.to_list) - |> List.reduce_exn ~f:(Array.map2_exn ~f:( @ )) - (** Performs one optimization step, potentially in parallel (if [grad_updates] are compiled for different devices). All jitted code must have the same bindings. Iterates over bindings with ranges, calling one of [grad_updates] in a round-robin fashion, and performs the following synchronization each time all @@ -409,69 +310,47 @@ let%track_sexp parallel_update (type context) (module Backend : Backend_type wit assert ( Array.for_all grad_updates ~f:(fun upd -> [%equal: Idx.static_symbol list] bindings @@ List.map ~f:fst upd.bindings))]; - let all_params : Tensor.t list = Set.to_list updaten.params in - let param_vals = [%debug_notrace List.map all_params ~f:(fun t -> t.value)] in - let param_grads = [%debug_notrace List.map all_params ~f:(fun t -> (Option.value_exn t.diff).grad)] in + let all_params = Set.to_array updaten.params in let ctxs = [%debug_notrace Array.map grad_updates ~f:(fun upd -> upd.context)] in - let occupancy _tn ~src_n ~src:_ = - if Array.exists ~f:Fn.id occupancies.(src_n) then Utils.Required else Utils.Skip - in - let name_prefixes = Array.create ~len:num_devices "grad_merge" in - let merge_batch ~name_prefixes:_ ~occupancy:_ _tns ~accum:_ ~srcs:_ = failwith "NOT IMPLEMENTED YET" in - let grad_merges = - collapse_merges @@ merge_batch ~name_prefixes ~occupancy param_grads ~accum:Arrayjit.Ops.Add ~srcs:ctxs - in - let grad_merges = - Array.init num_devices ~f:(fun (to_ : int) -> - Array.init num_devices ~f:(fun (from : int) -> - List.map grad_merges.(from) ~f:(fun c -> (Backend.link ctxs.(to_) c).schedule))) + let occupancy ~name:_ ~src_n = Array.exists ~f:Fn.id occupancies.(src_n) in + (* let names = Array.create ~len:num_devices "grad_merge" in *) + let grad_merges = Array.map all_params ~f:(fun p -> [%cd p.grad =+ p.grad.merge]) in + let grad_merges_to = + Array.map ctxs ~f:(fun ctx -> + snd @@ Backend.link_batch ctx @@ Backend.compile_batch ~shared:true ~occupancy Idx.Empty grad_merges) in (* We can cache scheduling, because merging and copying does not depend on static indexing. *) - let name_prefixes = Array.create ~len:num_devices "loss_merge" in - let loss_merges = - collapse_merges - @@ merge_batch ~name_prefixes ~occupancy [ updaten.loss.value ] ~accum:Arrayjit.Ops.Add ~srcs:ctxs - in - let loss_merges = - Array.init num_devices ~f:(fun (to_ : int) -> - Array.init num_devices ~f:(fun (from : int) -> - match loss_merges.(from) with - | [] -> None - | [ c ] -> Some (Backend.link ctxs.(to_) c).schedule - | _ -> assert false)) + (* let names = Array.create ~len:num_devices "loss_merge" in *) + let loss_merge = + Backend.( + link sgd_update.context @@ compile Idx.Empty [%cd updaten.loss.value =+ updaten.loss.value.merge]) in - let merge ~(from : int) ~(to_ : int) : unit = + (* FIXME: need to iterate over params in the outer loop. *) + let merge_grads ~(from : int) ~(to_ : int) : unit = + (* FIXME: do we need to sync already? *) Backend.(await @@ get_ctx_device ctxs.(from)); - Option.iter ~f:(Tn.run debug_rt) loss_merges.(to_).(from); - List.iter ~f:(Tn.run debug_rt) grad_merges.(to_).(from) - in - let needed_on_host = ref @@ Set.empty (module Tn) in - (* Backends may choose to not store parameters on devices other than the 0th. *) - let occupancy p ~src_n:_ ~src:_ = - Utils.Optional { callback_if_missing = (fun () -> needed_on_host := Set.add !needed_on_host p) } + Array.iteri all_params ~f:(fun i p -> + assert (Backend.device_to_device p.value ~into_merge_buffer:Copy ~dst:ctxs.(to_) ~src:ctxs.(from)); + (Tn.run debug_rt (Option.value_exn grad_merges_to.(to_).(i)).schedule : unit)) in - let copies = - collapse_merges - @@ merge_batch ~name_prefixes:[| "param_copy" |] ~occupancy param_vals ~accum:Arrayjit.Ops.Arg2 - ~srcs:[| sgd_update.context |] - in - let copies = - assert (Array.length copies = 1); - copies.(0) - in - let copies = - Array.init (num_devices - 1) ~f:(fun (to_m_1 : int) -> - List.map copies ~f:(fun c -> (Backend.link ctxs.(to_m_1 + 1) c).schedule)) + let merge_loss ~src = + assert (Backend.device_to_device updaten.loss.value ~into_merge_buffer:Copy ~dst:sgd_update.context ~src); + Tn.run debug_rt loss_merge.schedule in + (* FIXME: missing backcopy. *) + let needed_on_host = ref @@ Set.empty (module Tn) in let%track_sexp sync (devices_to_sync : int) : unit = - Arrayjit.Utils.parallel_merge merge devices_to_sync; + Arrayjit.Utils.parallel_merge merge_grads devices_to_sync; Tn.run debug_rt sgd_update.schedule; (* We need to wait, because copying happens on other devices. *) + Array.iteri ctxs ~f:(fun i src -> if i <> 0 then merge_loss ~src); Set.iter !needed_on_host ~f:(fun p -> assert (Backend.to_host sgd_update.context p)); Backend.(await @@ get_ctx_device sgd_update.context); (* We will need to update params on all devices! Not only the ones that computed gradients. *) for to_ = 1 to num_devices - 1 do - List.iter copies.(to_ - 1) ~f:(Tn.run debug_rt) + Array.iter all_params ~f:(fun p -> + assert ( + Backend.device_to_device p.value ~into_merge_buffer:No ~dst:ctxs.(to_) ~src:sgd_update.context)) done; post_sync ~num_synced_devices:devices_to_sync in