Skip to content

Commit

Permalink
Rename backend_utils -> c_syntax, uniformly validate merge nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Oct 11, 2024
1 parent d54b5e0 commit 0f0336b
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 31 deletions.
6 changes: 5 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
- Migrated to cudajit 0.5.
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; TODO: renamed `backend_utils.ml` to `c_syntax.ml`.
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
- TODO: Moved the multicore backend from a `device = stream` model to a single device model.
- TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.

### Fixed

- Validating merge nodes for the CUDA backend.

## [0.4.1] -- 2024-09-17

### Added
Expand Down
11 changes: 7 additions & 4 deletions arrayjit/lib/backend_types.ml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ module type Backend = sig
[will_wait_for src (all_work (get_ctx_stream dst))]. *)

type device
type stream
type stream [@@deriving sexp_of]

val init : stream -> context
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
Expand All @@ -180,7 +180,6 @@ module type Backend = sig
val is_idle : stream -> bool
(** Whether the stream is currently waiting for work. *)

val sexp_of_stream : stream -> Sexp.t
val get_device : ordinal:int -> device
val num_devices : unit -> int

Expand Down Expand Up @@ -298,14 +297,18 @@ module type Lowered_backend = sig
val get_buffer : Tnode.t -> context -> buffer_ptr option

type device
type stream
type stream [@@deriving sexp_of]

val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
val init : stream -> context
val await : stream -> unit
val is_idle : stream -> bool
val all_work : stream -> event
val sexp_of_stream : stream -> Sexplib.Sexp.t

val scheduled_merge_node : stream -> Tnode.t option
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge
buffer right after [await stream]. *)

val num_devices : unit -> int
val suggested_num_streams : device -> int
val get_device : ordinal:int -> device
Expand Down
44 changes: 37 additions & 7 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]

let check_merge_buffer ~scheduled_node ~code_node =
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
match (scheduled_node, code_node) with
| _, None -> ()
| Some actual, Some expected when Tnode.equal actual expected -> ()
| _ ->
raise
@@ Utils.User_error
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
^ name code_node)

module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
struct
module Domain = Domain [@warning "-3"]
Expand Down Expand Up @@ -690,6 +701,11 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
verify from_prior_context;
link_compiled ~merge_buffer prior_context proc
in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer ~scheduled_node:(Option.map !merge_buffer ~f:snd)
~code_node:(expected_merge_node code))
in
{ context; schedule; bindings; name }

let link_batch ~merge_buffer (prior_context : context) (code_batch : code_batch) =
Expand All @@ -711,9 +727,15 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
verify from_prior_context;
procs
in
Array.fold_map procs ~init:prior_context ~f:(fun context -> function
let code_nodes = expected_merge_nodes code_batch in
Array.fold_mapi procs ~init:prior_context ~f:(fun i context -> function
| Some proc ->
let context, bindings, schedule, name = link_compiled ~merge_buffer context proc in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer ~scheduled_node:(Option.map !merge_buffer ~f:snd)
~code_node:code_nodes.(i))
in
(context, Some { context; schedule; bindings; name })
| None -> (context, None))

Expand Down Expand Up @@ -800,6 +822,12 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx
~from_prior_context:code.from_prior_context [| code.traced_store |];
let ctx, bindings, schedule = link context.ctx code.code in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
~code_node:(expected_merge_node code))
in
{ context = { ctx; expected_merge_node = code.expected_merge_node }; schedule; bindings; name }

let link_batch context code_batch =
Expand All @@ -809,12 +837,14 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
( { ctx; expected_merge_node = context.expected_merge_node },
Array.mapi schedules ~f:(fun i ->
Option.map ~f:(fun schedule ->
{
context = { ctx; expected_merge_node = code_batch.expected_merge_nodes.(i) };
schedule;
bindings;
name;
})) )
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
~code_node:expected_merge_node)
in
{ context = { ctx; expected_merge_node }; schedule; bindings; name })) )

let init stream = { ctx = init stream; expected_merge_node = None }
let get_ctx_stream context = get_ctx_stream context.ctx
Expand Down
12 changes: 0 additions & 12 deletions arrayjit/lib/backend_utils.ml → arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -390,15 +390,3 @@ struct
fprintf ppf "@;<0 -2>}@]@.";
params
end
let check_merge_buffer ~merge_buffer ~code_node =
let stream_node = Option.map !merge_buffer ~f:snd in
let name = function Some tn -> Tn.debug_name tn | None -> "none" in
match (stream_node, code_node) with
| _, None -> ()
| Some actual, Some expected when Tn.equal actual expected -> ()
| _ ->
raise
@@ Utils.User_error
("Merge buffer mismatch, on stream: " ^ name stream_node ^ ", expected by code: "
^ name code_node)
4 changes: 2 additions & 2 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
else ctx_arrays
| Some _ -> ctx_arrays))
in
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let for_lowereds = [| lowered |]
let opt_ctx_arrays = opt_ctx_arrays
end)) in
Expand Down Expand Up @@ -185,7 +185,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
else ctx_arrays
| Some _ -> ctx_arrays)))
in
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let for_lowereds = for_lowereds
let opt_ctx_arrays = opt_ctx_arrays
end)) in
Expand Down
6 changes: 4 additions & 2 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stre
let sync event = Cu.Delimited_event.synchronize event
let all_work stream = Cu.Delimited_event.record stream.cu_stream

let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer

let is_initialized, initialize =
let initialized = ref false in
let init (config : config) : unit =
Expand Down Expand Up @@ -462,7 +464,7 @@ end
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let for_lowereds = [| lowered |]
end)) in
let idx_params = Indexing.bound_symbols bindings in
Expand All @@ -477,7 +479,7 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =

let compile_batch ~names bindings lowereds =
let for_lowereds = Array.filter_map ~f:Fn.id lowereds in
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let for_lowereds = for_lowereds
end)) in
let idx_params = Indexing.bound_symbols bindings in
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
assignments
task
backend_types
backend_utils
c_syntax
cc_backend
gcc_backend
cuda_backend
Expand Down
1 change: 0 additions & 1 deletion arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,6 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
in
let%diagn_l_sexp work () : unit =
[%log_result name];
Backend_utils.check_merge_buffer ~merge_buffer ~code_node:code.expected_merge_node;
Indexing.apply run_variadic ();
if Utils.debug_log_from_routines () then (
Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name);
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/writing_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ type mem_properties =
| Constant_from_host (** The array is read directly from the host. *)
```

while the CC and CUDA backends do it implicitly via the input to the `Backend_utils.C_syntax` functor:
while the CC and CUDA backends do it implicitly via the input to the `C_syntax.C_syntax` functor:

```ocaml
module C_syntax (B : sig
Expand Down

0 comments on commit 0f0336b

Please sign in to comment.