From 8f3a27f2a3278b349f05bfae10a2903bfec742fc Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Tue, 2 Jul 2024 14:59:57 +0200 Subject: [PATCH] Cc_backend fixes: setup globals for `compile_batch`; merge buffer accesses --- arrayjit/lib/cc_backend.ml | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/arrayjit/lib/cc_backend.ml b/arrayjit/lib/cc_backend.ml index fdccae3e..0fd56538 100644 --- a/arrayjit/lib/cc_backend.ml +++ b/arrayjit/lib/cc_backend.ml @@ -348,6 +348,11 @@ let compile_main ~traced_store info ppf llc : unit = ^ "v" ^ Int.to_string id.scope_id in (v ^ "{=%f}", [ `Value v ]) + | Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) -> + let tn = Option.value_exn @@ Tn.find ~id:source_node_id in + let node = get_node tn in + let v = sprintf "@[<2>merge_buffer[%s@;<0 -2>]@]" (array_offset_to_string (idcs, node.dims)) in + ("merge_buffer[%u]{=%f}", [ `Accessor (idcs, node.dims); `Value v ]) | Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET" | Get (tn, idcs) -> let ident = info.get_ident tn in @@ -371,12 +376,18 @@ let compile_main ~traced_store info ppf llc : unit = in pp_ll ppf llc -let%track_sexp compile_globals ~get_ident ppf info = +let%track_sexp compile_globals ~get_ident ppf infos = let open Stdlib.Format in - fprintf ppf {|@[#include "stdio.h"@,#include "stdlib.h"@,/* Global declarations. */@,|}; - Hash_set.to_list info.used_tensors + fprintf ppf {|@[#include @,#include @,/* Global declarations. */@,|}; + let infos = Array.filter_opt infos in + let used = + Array.map infos ~f:(fun i -> i.used_tensors) + |> Array.fold ~init:(Hash_set.create (module Tn)) ~f:Hash_set.union + in + let get_node tn = Array.find_map infos ~f:(fun info -> Hashtbl.find info.nodes tn) in + Hash_set.to_list used |> List.iter ~f:(fun tn -> - let node = Hashtbl.find_exn info.nodes tn in + let node = Option.value_exn @@ get_node tn in match node.mem with | Constant_from_host -> let nd = Option.value_exn @@ Lazy.force tn.Tn.array in @@ -404,8 +415,10 @@ let%track_sexp compile_proc ~name info ppf idx_params Low_level.{ traced_store; if Utils.settings.debug_log_from_routines then [ ("const char* log_file_name", Log_file_name) ] else [] in let merge_param = - Option.(to_list @@ map merge_node ~f:(fun tn -> - ("const " ^ Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer))) + Option.( + to_list + @@ map merge_node ~f:(fun tn -> + ("const " ^ Ops.cuda_typ_of_prec tn.prec ^ " *merge_buffer", Merge_buffer))) in let params = log_file @ merge_param @ idx_params @ params in fprintf ppf "@[@[void %s(@,@[%a@]@;<0 -4>)@] {@ " name @@ -475,7 +488,7 @@ let%track_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_ prepare_nodes info ctx_nodes lowered; let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in let base_name = Filename_base.chop_extension pp_file.f_name in - compile_globals ~get_ident:info.get_ident pp_file.ppf info; + compile_globals ~get_ident:info.get_ident pp_file.ppf [| Some info |]; let params = compile_proc ~name info pp_file.ppf idx_params lowered in pp_file.finalize (); let log_fname = base_name ^ ".log" in @@ -528,6 +541,7 @@ let%track_sexp compile_batch ~names ~opt_ctx_arrays bindings (lowereds : Low_lev @@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names)) in let pp_file = Utils.pp_file ~base_name ~extension:".c" in + compile_globals ~get_ident pp_file.ppf infos; let params = Array.mapi lowereds ~f:(fun i lowered -> Option.map2 names.(i) infos.(i) ~f:(fun name info ->