Skip to content

Commit

Permalink
Formatting update
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 25, 2024
1 parent 233a7b2 commit 864cc24
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 19 deletions.
5 changes: 1 addition & 4 deletions arrayjit/lib/backends.mli
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ val reinitialize : (module Backend_types.Backend) -> Backend_types.config -> uni
(** Initializes the backend, and if it was already initialized, performs garbage collection. *)

val fresh_backend :
?backend_name:string ->
?config:Backend_types.config ->
unit ->
(module Backend_types.Backend)
?backend_name:string -> ?config:Backend_types.config -> unit -> (module Backend_types.Backend)
(** Reinitializes and returns a backend corresponding to [backend_name], or if omitted, selected via
the global [backend] setting. See {!reinitialize}. *)
9 changes: 8 additions & 1 deletion arrayjit/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
(-> cuda_backend.missing.ml))
ppx_minidebug.runtime)
(preprocess
(pps ppx_compare ppx_hash ppx_here ppx_sexp_conv ppx_string ppx_variants_conv ppx_minidebug))
(pps
ppx_compare
ppx_hash
ppx_here
ppx_sexp_conv
ppx_string
ppx_variants_conv
ppx_minidebug))
(modules
utils
rand
Expand Down
10 changes: 3 additions & 7 deletions arrayjit/lib/task.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

type t =
| Task : {
context_lifetime : ('a[@sexp.opaque]);
description : string;
work : unit -> unit;
}
-> t
| Task : { context_lifetime : ('a[@sexp.opaque]); description : string; work : unit -> unit } -> t
[@@deriving sexp_of]

let describe (Task task) = task.description
Expand All @@ -32,7 +27,8 @@ let prepend ~work (Task task) =
task.work ());
}

let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream (Task { description; _ } as task) =
let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream
(Task { description; _ } as task) =
[%log_result "enschedule", description, "on", get_stream_name stream];
let work () = schedule_task stream task in
Task
Expand Down
2 changes: 1 addition & 1 deletion bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ let _mem_benchmarks =
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; *) "cc"; "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single ; CDSL.half ]
List.concat_map [ (* CDSL.double; *) CDSL.single; CDSL.half ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
Expand Down
5 changes: 1 addition & 4 deletions bin/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@ let _suspended () =
let%op f5 = f 5 in
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
Train.every_non_literal_on_host f5;
Train.forward_and_forget
(module Backend)
Backend.(init @@ new_stream @@ get_device ~ordinal:0)
f5;
Train.forward_and_forget (module Backend) Backend.(init @@ new_stream @@ get_device ~ordinal:0) f5;
Stdio.printf "\n%!";
Tensor.print_tree ~with_grad:false ~depth:9 f5;
Stdio.printf "\n%!"
Expand Down
4 changes: 2 additions & 2 deletions lib/tensor.mli
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ val consume_backprop_code : t -> asgns * comp

val iter_embedded : f:(tn -> unit) -> t -> unit
(** [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are not members of
[t.forward.embedded_nodes] or '[t.diff.backprop.embedded_nodes]' (if any). Note: [iter_embedded] should only be
called after shape inference finishes. *)
[t.forward.embedded_nodes] or '[t.diff.backprop.embedded_nodes]' (if any). Note: [iter_embedded]
should only be called after shape inference finishes. *)

val unsafe_reinitialize : unit -> unit
(** Bring global state to its initialization values. This invalidates any previously defined tensors
Expand Down

0 comments on commit 864cc24

Please sign in to comment.