Skip to content

Commit

Permalink
Bump version number, bump cudajit to 0.4, fix cuda compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jul 5, 2024
1 parent cc790ae commit 3a4c24d
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 21 deletions.
11 changes: 8 additions & 3 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@

- TODO: API improvements for mixed precision computations.
- TODO(#262): "term punning" for `%cd`.
- TODO: CUDA streaming multiprocessor parallelism via streams <-> virtual devices.

## [0.4.0] -- 2024-06-30
### Fixed

- TODO: Proper implementation of half precision. Requires OCaml 5.2.

## [0.4.0] -- 2024-07-07

### Added

- A new backend "cc": C based on a configurable C compiler command, defaulting to `cc`.
- TODO: Merge buffers representational abstraction (one per virtual device):
- Merge buffers representational abstraction (one per virtual device):
- backends just need to support device-to-device transfers,
- merging gets implemented in "user space".
- TODO: CUDA streaming multiprocessor parallelism via streams <-> virtual devices.

### Changed

Expand All @@ -26,6 +30,7 @@
- Further refactoring of the `Backends` API:
- split the `device` type into virtual `device` and `physical_device`,
- removed the direct support for `merge`, instead relying on merge buffers.
- Updated to cudajit 0.4.

## [0.3.3] -- 2024-04-24

Expand Down
4 changes: 2 additions & 2 deletions arrayjit.opam
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "0.3.3"
version: "0.4.0"
synopsis:
"An array language compiler with multiple backends (CPU, Cuda), staged compilation"
description:
Expand Down Expand Up @@ -31,7 +31,7 @@ depends: [
"odoc" {with-doc}
]
depopts: [
"cudajit"
"cudajit" {>= "0.4.0"}
"gccjit" {>= "0.3.2"}
]
build: [
Expand Down
21 changes: 9 additions & 12 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,10 @@ let%debug_sexp prepare_node traced_store info tn =
(* let tn = Low_level.get_node traced_store v in *)
(* TODO: We will need tn to perform more refined optimizations. *)
let dims = Lazy.force tn.dims in
let size_in_elems = Array.fold ~init:1 ~f:( * ) dims in
let prec = tn.prec in
let size_in_bytes = size_in_elems * Ops.prec_in_bytes prec in
let is_on_host = Tn.is_hosted_force tn 31 in
let is_materialized = Tn.is_hosted_force tn 32 in
assert (Bool.(Option.is_some (Lazy.force tn.array) = is_on_host));
let num_typ = Ops.cuda_typ_of_prec prec in
let num_typ = Ops.cuda_typ_of_prec tn.prec in
let mem = if not is_materialized then Local_only else Global in
let global = if is_local_only mem then None else Some (Tn.name tn) in
let local = Option.some_if (is_local_only mem) @@ Tn.name tn ^ "_local" in
Expand All @@ -320,15 +317,15 @@ let%debug_sexp prepare_node traced_store info tn =
"mem",
(backend_info : Sexp.t),
"prec",
(prec : Ops.prec),
(tn.prec : Ops.prec),
"on-host",
(is_on_host : bool),
"is-global",
(Option.is_some global : bool)];
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
let zero_initialized = (Hashtbl.find_exn traced_store tn).Low_level.zero_initialized in
{ tn; local; mem; dims; size_in_bytes; size_in_elems; num_typ; global; zero_initialized })
{ tn; local; mem; dims; num_typ; global; zero_initialized })

let compile_main traced_store info ppf llc : unit =
let open Stdlib.Format in
Expand Down Expand Up @@ -584,11 +581,11 @@ let%track_sexp compile_proc ~name ~get_ident ppf idx_params
match node.mem with
| Local_only ->
Option.iter node.local ~f:(fun t_name ->
fprintf ppf "%s %s[%d]%s;@," node.num_typ t_name node.size_in_elems
fprintf ppf "%s %s[%d]%s;@," node.num_typ t_name (Tn.num_elems tn)
(if (Hashtbl.find_exn traced_store tn).zero_initialized then " = {0}" else ""))
| Global when node.zero_initialized ->
Option.iter node.global ~f:(fun t_name ->
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " t_name node.size_in_bytes)
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " t_name (Tn.size_in_bytes tn))
| _ -> ());
fprintf ppf "/* Main logic. */@,";
compile_main traced_store info ppf llc;
Expand Down Expand Up @@ -630,8 +627,8 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
| Some (old_ptr, old_size) when size_in_bytes <= old_size -> old_ptr
| Some (old_ptr, _old_size) ->
Cudajit.mem_free old_ptr;
Cudajit.mem_alloc ~byte_size:size_in_bytes
| None -> Cudajit.mem_alloc ~byte_size:size_in_bytes
Cudajit.mem_alloc ~size_in_bytes
| None -> Cudajit.mem_alloc ~size_in_bytes

let%diagn_sexp link_proc (old_context : context) ~name info ptx =
let module Cu = Cudajit in
Expand All @@ -642,7 +639,7 @@ let%diagn_sexp link_proc (old_context : context) ~name info ptx =
Option.value_map ~default:globals node.global ~f:(fun _name ->
if Utils.settings.with_debug_level > 0 then [%log "mem_alloc", _name];
set_ctx ctx;
let ptr () = Cu.mem_alloc ~byte_size:node.size_in_bytes in
let ptr () = Cu.mem_alloc ~size_in_bytes:(Tn.size_in_bytes node.tn) in
Map.update globals key ~f:(fun old -> Option.value_or_thunk old ~default:ptr)))
in
let run_module = Cu.module_load_data_ex ptx [] in
Expand Down Expand Up @@ -736,7 +733,7 @@ let link old_context (code : code) =
if Hash_set.mem code.info.used_tensors key then
let node = Map.find_exn all_arrays key in
if node.zero_initialized then
Cu.memset_d8 ptr Unsigned.UChar.zero ~length:node.size_in_bytes);
Cu.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key));
[%log "launching the kernel"];
(* if Utils.settings.debug_log_from_routines then Cu.ctx_set_limit CU_LIMIT_PRINTF_FIFO_SIZE
4096; *)
Expand Down
5 changes: 3 additions & 2 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

(name ocannl)

(version 0.3.3)
(version 0.4.0)

(generate_opam_files true)

Expand Down Expand Up @@ -67,7 +67,8 @@
(ppx_minidebug
(>= 1.5)))
(depopts
cudajit
(cudajit
(>= 0.4.0))
(gccjit
(>= 0.3.2)))
(tags
Expand Down
2 changes: 1 addition & 1 deletion neural_nets_lib.opam
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "0.3.3"
version: "0.4.0"
synopsis:
"A from-scratch Deep Learning framework with an optimizing compiler, shape inference, concise syntax"
description:
Expand Down
2 changes: 1 addition & 1 deletion ocannl_npy.opam
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is generated by dune, edit dune-project instead
opam-version: "2.0"
version: "0.3.3"
version: "0.4.0"
synopsis: "Numpy file format support for ocaml"
maintainer: ["Lukasz Stafiniak <[email protected]>"]
authors: ["Laurent Mazare"]
Expand Down

0 comments on commit 3a4c24d

Please sign in to comment.