Skip to content

Commit

Permalink
cuda backend: Fix: unsafe_cleanup was working with a destroyed contex…
Browse files Browse the repository at this point in the history
…t / finalized device
  • Loading branch information
lukstafi committed Sep 30, 2024
1 parent 58b4a60 commit 387acd3
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,7 @@ let devices = ref @@ Core.Weak.create 0

(* Unlike [devices] above, [initialized_devices] never forgets its entries. *)
let initialized_devices = Hash_set.create (module Int)

let set_ctx ctx =
let cur_ctx = Cu.Context.get_current () in
if not @@ phys_equal ctx cur_ctx then Cu.Context.set_current ctx
let set_ctx ctx = Cu.Context.set_current ctx

(* It's not actually used, but it's required by the [Backend] interface. *)
let alloc_buffer ?old_buffer ~size_in_bytes device =
Expand All @@ -135,6 +132,17 @@ let opt_alloc_merge_buffer ~size_in_bytes phys_dev =
phys_dev.copy_merge_buffer <- Cu.Deviceptr.mem_alloc ~size_in_bytes;
phys_dev.copy_merge_buffer_capacity <- size_in_bytes)

let cleanup_physical device =
Cu.Context.set_current device.primary_context;
Cu.Context.synchronize ();
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
Cu.Deviceptr.mem_free device.copy_merge_buffer;
Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array ->
Cu.Deviceptr.mem_free ctx_array.ptr)

let finalize_physical device =
if Atomic.compare_and_set device.released false true then cleanup_physical device

let get_device ~(ordinal : int) : physical_device =
if num_physical_devices () <= ordinal then
invalid_arg [%string "Exec_as_cuda.get_device %{ordinal#Int}: not enough devices"];
Expand Down Expand Up @@ -166,6 +174,7 @@ let get_device ~(ordinal : int) : physical_device =
owner_device_subordinal = Hashtbl.create (module Tn);
}
in
Stdlib.Gc.finalise finalize_physical result;
Core.Weak.set !devices ordinal (Some result);
result)

Expand Down Expand Up @@ -243,18 +252,14 @@ let init device =
let unsafe_cleanup () =
let len = Core.Weak.length !devices in
(* NOTE: releasing the context should free its resources, there's no need to finalize the
remaining contexts, and [finalize] will not do anything for a [released] physical device. *)
remaining contexts, and [finalize], [finalize_physical] will not do anything for a [released]
physical device. *)
for i = 0 to len - 1 do
Option.iter (Core.Weak.get !devices i) ~f:(fun device ->
if Atomic.compare_and_set device.released false true then (
Cu.Context.set_current device.primary_context;
Cu.Context.synchronize ();
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array ->
Cu.Deviceptr.mem_free ctx_array.ptr);
Cu.Device.primary_ctx_release device.dev))
if Atomic.compare_and_set device.released false true then cleanup_physical device)
done;
Core.Weak.fill !devices 0 len None
Core.Weak.fill !devices 0 len None;
Stdlib.Gc.compact ()

let%diagn_l_sexp from_host (ctx : context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
Expand Down

0 comments on commit 387acd3

Please sign in to comment.