Skip to content

Commit

Permalink
Fixes #245: report used memory
Browse files Browse the repository at this point in the history
Note: cuda backend migration to `Tnode.sharing` still broken.
  • Loading branch information
lukstafi committed Oct 15, 2024
1 parent a09e2d7 commit e2780a6
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 50 deletions.
4 changes: 1 addition & 3 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ let is_initialized, initialize =
let finalize _ctx = ()

let init ~label =
let result =
{ label; arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tn) } }
in
let result = { label; arrays = empty_ctx_arrays } in
Stdlib.Gc.finalise finalize result;
result

Expand Down
5 changes: 3 additions & 2 deletions bin/compilation_speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ let benchmark_overhead backend () =
(* Train.every_non_literal_on_host f; *)
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.init stream in
let init_mem = Backend.(get_used_memory @@ get_stream_device stream) in
let update_f = Train.grad_update f in
(* Initialize the context with a mock update of x to ensure that it is not optimized as a
constant. *)
Expand Down Expand Up @@ -60,13 +61,13 @@ let benchmark_overhead backend () =
in
let final_time = Time_now.nanoseconds_since_unix_epoch () in
let time_in_sec = Int63.(to_float @@ (final_time - init_time)) /. 1000_000_000. in
let mem_in_bytes = Backend.(get_used_memory @@ get_stream_device stream) - init_mem in
let result =
PrintBox_utils.Benchmark
{
bench_title = Backend.name ^ " overhead";
time_in_sec;
(* FIXME: global mem consumption *)
mem_in_bytes = 0;
mem_in_bytes;
result_label = "x, f(x)";
result =
[%sexp_of: (float * float) list]
Expand Down
63 changes: 45 additions & 18 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~backend_name
let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~backend_name
~value_prec ~grad_prec () =
[%track_sexp
let _debug : string = "started" in
(fun (started : unit) -> started) ()];
(* ignore seed; *)
let bench_title =
[%string
"seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_devices#Int}, batch \
"seed %{seed#Int}, inline %{inlining_cutoff#Int}, parallel %{num_streams#Int}, batch \
%{batch_size#Int}, backend %{backend_name}, val prec %{Ops.prec_string value_prec}, grad \
prec %{Ops.prec_string grad_prec}"]
in
Expand All @@ -45,11 +45,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
(* TINY for debugging: *)
(* let data_len = 3 * 4 in *)
let flat_len = data_len / 2 in
(* Note: [minibatch_size = batch_size / num_devices] is the actual per-device batch used. *)
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
(* let epochs = 200 in *)
let epochs = 100 in
(* let epochs = 100 in *)
(* TINY for debugging: *)
(* let epochs = 2 in *)
let epochs = 2 in
(* let epochs = 1 in *)
(* let init_lr = 0.1 in *)
let init_lr = 0.01 in
Expand Down Expand Up @@ -78,7 +78,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
let%op loss_fn ~output ~expectation = ?/(!..1 - (expectation *. output)) in
let start_time = ref None in
let weight_decay = 0.0002 in
Arrayjit.Backends.sync_suggested_num_streams := num_devices;
Arrayjit.Backends.sync_suggested_num_streams := num_streams;
let backend = Arrayjit.Backends.fresh_backend ~backend_name () in
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
Expand All @@ -90,8 +90,17 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
in
let module Backend = (val backend) in
Backend.initialize Train.BT.Most_parallel_streams;
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_devices ~data_len
let {
Train.inputs;
outputs;
model_result;
infer_callback;
batch_losses;
epoch_losses;
learning_rates;
used_memory;
} =
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_streams ~data_len
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
~per_batch_callback ~per_epoch_callback
(module Backend)
Expand Down Expand Up @@ -161,8 +170,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
{
bench_title;
time_in_sec;
(* FIXME: implement total mem assessment. *)
mem_in_bytes = 0;
mem_in_bytes = used_memory;
result_label = "init time in sec, min loss, last loss";
result =
[%sexp_of: float * float * float]
Expand All @@ -176,11 +184,11 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b

let _suspend () =
ignore
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_devices:8 ~batch_size:16
@@ classify_moons ~seed:0 ~on_device:true ~inlining_cutoff:3 ~num_streams:8 ~batch_size:16
~backend_name:"gccjit" ~value_prec:CDSL.single ~grad_prec:CDSL.double ()

let cuda_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16; 20 (* 32; 64 *) ] ~f:(fun num_devices ->
let _cuda_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16; 20 (* 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[
(* TINY for debugging: *)
Expand All @@ -194,7 +202,26 @@ let cuda_benchmarks =
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_devices
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))

let _mem_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16 (* ; 20; 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[
(* TINY for debugging: *)
(* 3 * 2 *)
3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *);
]
~f:(fun batch_size ->
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name ->
List.concat_map [ CDSL.double; CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))

Expand All @@ -204,7 +231,7 @@ let cuda_benchmarks =
(nth - 1) *)

let fixed_seed_search seed =
classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_devices:1 ~batch_size:20
classify_moons ~seed ~on_device:true ~inlining_cutoff:3 ~num_streams:1 ~batch_size:20
~backend_name:"cuda" ~value_prec:CDSL.single ~grad_prec:CDSL.single ()

let _suspended () =
Expand All @@ -213,8 +240,8 @@ let _suspended () =
(* let () = List.map benchmarks ~f:(nth_best 2) |> PrintBox_utils.table |> PrintBox_text.output
Stdio.stdout *)

let benchmark () =
List.map cuda_benchmarks ~f:(fun bench -> bench ())
let benchmark benchmarks =
List.map benchmarks ~f:(fun bench -> bench ())
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout

let () = benchmark ()
let () = benchmark _mem_benchmarks
13 changes: 11 additions & 2 deletions bin/moons_demo_parallel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@ let experiment ~seed ~backend_name ~config () =
epoch_loss
in
let module Backend = (val backend) in
let inputs, outputs, model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
let {
Train.inputs;
outputs;
model_result;
infer_callback;
batch_losses;
epoch_losses;
learning_rates;
used_memory;
} =
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
~weight_decay ~per_batch_callback ~per_epoch_callback
Expand All @@ -68,7 +77,7 @@ let experiment ~seed ~backend_name ~config () =
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
Stdio.print_endline "\n******** mlp_result **********";
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result;
Stdio.printf "\n********\n%!";
Stdio.printf "\n********\nUsed memory: %d\n%!" used_memory;
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
let plot_moons =
let open PrintBox_utils in
Expand Down
63 changes: 42 additions & 21 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -405,43 +405,53 @@ let%track3_sexp parallel_update (type context)

let get_all_suggested_streams ?(max_num_streams : int option) (type device stream)
(backend : (module Backend_type with type device = device and type stream = stream)) :
stream array =
device array * stream array =
let max_num_streams = Option.value max_num_streams ~default:Int.max_value_30_bits in
let module Backend =
(val backend : Backend_type with type device = device and type stream = stream)
in
let num_devices = min max_num_streams @@ Backend.num_devices () in
let devices = Array.init num_devices ~f:(fun ordinal -> Backend.get_device ~ordinal) in
Array.folding_mapi devices ~init:0 ~f:(fun ordinal num_collected device ->
let remaining_devices = num_devices - ordinal - 1 in
let max_current = Backend.suggested_num_streams device in
let take_current = min max_current @@ (max_num_streams - remaining_devices) in
( num_collected + take_current,
Array.init take_current ~f:(fun _subordinal -> Backend.new_stream device) ))
|> Array.concat_map ~f:Fn.id
let result =
Array.folding_mapi devices ~init:0 ~f:(fun ordinal num_collected device ->
let remaining_devices = num_devices - ordinal - 1 in
let max_current = Backend.suggested_num_streams device in
let take_current = min max_current @@ (max_num_streams - remaining_devices) in
( num_collected + take_current,
Array.init take_current ~f:(fun _subordinal -> Backend.new_stream device) ))
|> Array.concat_map ~f:Fn.id
in
(devices, result)

let to_routine (type context) (module Backend : Backend_type with type context = context)
(context : context) ?shared ?name bindings comp =
Backend.link context @@ Backend.compile ?shared ?name bindings comp

type example_train_result = {
inputs : Tensor.t;
outputs : Tensor.t;
model_result : Tensor.t;
infer_callback : float array -> float array;
(** Note: infer_callback is significantly less efficient than using the model via arrayjit. *)
batch_losses : float list;
epoch_losses : float list;
learning_rates : float list;
used_memory : int;
}

let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init_lr ?lr_schedule
?(copy_to_merge = false) ?max_num_streams ~data_len ~epochs ~inputs ~outputs ~model ~loss_fn
~weight_decay ?per_batch_callback ?per_epoch_callback (type context)
?(prior_contexts : context array option)
(backend : (module Backend_type with type context = context)) () =
let module TDSL = Operation.TDSL in
let module NTDSL = Operation.NTDSL in
Rand.init seed;
let module Backend = (val backend : Backend_type with type context = context) in
let prior_contexts =
match prior_contexts with
| Some contexts -> contexts
| None ->
let devices = get_all_suggested_streams ?max_num_streams (module Backend) in
Array.map devices ~f:Backend.init
in
let num_devices = Array.length prior_contexts in
let minibatch_size = batch_size / num_devices in
let devices, streams = get_all_suggested_streams ?max_num_streams (module Backend) in
let num_streams = Array.length streams in
let contexts = Array.map streams ~f:Backend.init in
let init_mem = Array.fold devices ~init:0 ~f:(fun acc dev -> acc + Backend.get_used_memory dev) in
let minibatch_size = batch_size / num_streams in
let n_minibatches = data_len / minibatch_size in
let inputs = inputs ~b:[ n_minibatches; minibatch_size ] in
let outputs = outputs ~b:[ n_minibatches; minibatch_size ] in
Expand All @@ -468,7 +478,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
set_hosted learning_rate.value;
let sgd = sgd_update ~learning_rate ~weight_decay update in
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
let grad_updates = Array.map prior_contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
let grad_updates = Array.map contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
Tensor.log_debug_info ~from_log_level:2 inputs;
Tensor.log_debug_info ~from_log_level:2 outputs;
Expand Down Expand Up @@ -534,8 +544,19 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
Backend.(await @@ get_ctx_stream routine.context);
Tensor.get_values model_result
in
(* Note: infer_callback is significantly less efficient than using the model via arrayjit. *)
(inputs, outputs, model_result, infer_callback, !batch_losses, !epoch_losses, !learning_rates)
let used_memory =
Array.fold devices ~init:0 ~f:(fun acc dev -> acc + Backend.get_used_memory dev) - init_mem
in
{
inputs;
outputs;
model_result;
infer_callback;
batch_losses = !batch_losses;
epoch_losses = !epoch_losses;
learning_rates = !learning_rates;
used_memory;
}

let%track3_sexp forward_and_ctx ?(disable_rootness_check = false) (type context)
(module Backend : Backend_type with type context = context) ctx ?(bindings = IDX.empty) t =
Expand Down
12 changes: 10 additions & 2 deletions test/moons_demo_parallel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,16 @@ let main () =
epoch_loss
in
let module Backend = (val backend) in
let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
=
let {
Train.inputs;
outputs;
model_result = _;
infer_callback;
batch_losses = _;
epoch_losses = _;
learning_rates = _;
used_memory = _;
} =
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
~weight_decay ~per_batch_callback ~per_epoch_callback
Expand Down
12 changes: 10 additions & 2 deletions test/moons_demo_parallel_run.ml
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,16 @@ let main () =
epoch_loss
in
let module Backend = (val backend) in
let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
=
let {
Train.inputs;
outputs;
model_result = _;
infer_callback;
batch_losses = _;
epoch_losses = _;
learning_rates = _;
used_memory = _;
} =
Train.example_train_loop ~seed ~batch_size ~max_num_streams:(batch_size / 2) ~init_lr
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
~weight_decay ~per_batch_callback ~per_epoch_callback
Expand Down

0 comments on commit e2780a6

Please sign in to comment.