Skip to content

Commit

Permalink
Cc_backend fixes: setup globals for compile_batch; merge buffer acc…
Browse files Browse the repository at this point in the history
…esses
  • Loading branch information
lukstafi committed Jul 2, 2024
1 parent 9947972 commit 8f3a27f
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions arrayjit/lib/cc_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,11 @@ let compile_main ~traced_store info ppf llc : unit =
^ "v" ^ Int.to_string id.scope_id
in
(v ^ "{=%f}", [ `Value v ])
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
let tn = Option.value_exn @@ Tn.find ~id:source_node_id in
let node = get_node tn in
let v = sprintf "@[<2>merge_buffer[%s@;<0 -2>]@]" (array_offset_to_string (idcs, node.dims)) in
("merge_buffer[%u]{=%f}", [ `Accessor (idcs, node.dims); `Value v ])
| Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
| Get (tn, idcs) ->
let ident = info.get_ident tn in
Expand All @@ -371,12 +376,18 @@ let compile_main ~traced_store info ppf llc : unit =
in
pp_ll ppf llc
let%track_sexp compile_globals ~get_ident ppf info =
let%track_sexp compile_globals ~get_ident ppf infos =
let open Stdlib.Format in
fprintf ppf {|@[<v 0>#include "stdio.h"@,#include "stdlib.h"@,/* Global declarations. */@,|};
Hash_set.to_list info.used_tensors
fprintf ppf {|@[<v 0>#include <stdio.h>@,#include <stdlib.h>@,/* Global declarations. */@,|};
let infos = Array.filter_opt infos in
let used =
Array.map infos ~f:(fun i -> i.used_tensors)
|> Array.fold ~init:(Hash_set.create (module Tn)) ~f:Hash_set.union
in
let get_node tn = Array.find_map infos ~f:(fun info -> Hashtbl.find info.nodes tn) in
Hash_set.to_list used
|> List.iter ~f:(fun tn ->
let node = Hashtbl.find_exn info.nodes tn in
let node = Option.value_exn @@ get_node tn in
match node.mem with
| Constant_from_host ->
let nd = Option.value_exn @@ Lazy.force tn.Tn.array in
Expand Down Expand Up @@ -404,8 +415,10 @@ let%track_sexp compile_proc ~name info ppf idx_params Low_level.{ traced_store;
if Utils.settings.debug_log_from_routines then [ ("const char* log_file_name", Log_file_name) ] else []
in
let merge_param =
Option.(to_list @@ map merge_node ~f:(fun tn ->
("const " ^ Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer)))
Option.(
to_list
@@ map merge_node ~f:(fun tn ->
("const " ^ Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer)))
in
let params = log_file @ merge_param @ idx_params @ params in
fprintf ppf "@[<v 2>@[<hv 4>void %s(@,@[<hov 0>%a@]@;<0 -4>)@] {@ " name
Expand Down Expand Up @@ -475,7 +488,7 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
prepare_nodes info ctx_nodes lowered;
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
let base_name = Filename_base.chop_extension pp_file.f_name in
compile_globals ~get_ident:info.get_ident pp_file.ppf info;
compile_globals ~get_ident:info.get_ident pp_file.ppf [| Some info |];
let params = compile_proc ~name info pp_file.ppf idx_params lowered in
pp_file.finalize ();
let log_fname = base_name ^ ".log" in
Expand Down Expand Up @@ -528,6 +541,7 @@ let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings (lowereds : Low_lev
@@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names))
in
let pp_file = Utils.pp_file ~base_name ~extension:".c" in
compile_globals ~get_ident pp_file.ppf infos;
let params =
Array.mapi lowereds ~f:(fun i lowered ->
Option.map2 names.(i) infos.(i) ~f:(fun name info ->
Expand Down

0 comments on commit 8f3a27f

Please sign in to comment.