From 387acd3d5968760f16eae100a42533ac6ad422c0 Mon Sep 17 00:00:00 2001 From: Lukasz Stafiniak Date: Mon, 30 Sep 2024 22:38:02 +0200 Subject: [PATCH] cuda backend: Fix: unsafe_cleanup was working with a destroyed context / finalized device --- arrayjit/lib/cuda_backend.cudajit.ml | 31 ++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/arrayjit/lib/cuda_backend.cudajit.ml b/arrayjit/lib/cuda_backend.cudajit.ml index c7d92042..8193e5ff 100644 --- a/arrayjit/lib/cuda_backend.cudajit.ml +++ b/arrayjit/lib/cuda_backend.cudajit.ml @@ -111,10 +111,7 @@ let devices = ref @@ Core.Weak.create 0 (* Unlike [devices] above, [initialized_devices] never forgets its entries. *) let initialized_devices = Hash_set.create (module Int) - -let set_ctx ctx = - let cur_ctx = Cu.Context.get_current () in - if not @@ phys_equal ctx cur_ctx then Cu.Context.set_current ctx +let set_ctx ctx = Cu.Context.set_current ctx (* It's not actually used, but it's required by the [Backend] interface. *) let alloc_buffer ?old_buffer ~size_in_bytes device = @@ -135,6 +132,17 @@ let opt_alloc_merge_buffer ~size_in_bytes phys_dev = phys_dev.copy_merge_buffer <- Cu.Deviceptr.mem_alloc ~size_in_bytes; phys_dev.copy_merge_buffer_capacity <- size_in_bytes) +let cleanup_physical device = + Cu.Context.set_current device.primary_context; + Cu.Context.synchronize (); + Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()); + Cu.Deviceptr.mem_free device.copy_merge_buffer; + Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array -> + Cu.Deviceptr.mem_free ctx_array.ptr) + +let finalize_physical device = + if Atomic.compare_and_set device.released false true then cleanup_physical device + let get_device ~(ordinal : int) : physical_device = if num_physical_devices () <= ordinal then invalid_arg [%string "Exec_as_cuda.get_device %{ordinal#Int}: not enough devices"]; @@ -166,6 +174,7 @@ let get_device ~(ordinal : int) : physical_device = owner_device_subordinal = Hashtbl.create (module Tn); } in + Stdlib.Gc.finalise finalize_physical result; Core.Weak.set !devices ordinal (Some result); result) @@ -243,18 +252,14 @@ let init device = let unsafe_cleanup () = let len = Core.Weak.length !devices in (* NOTE: releasing the context should free its resources, there's no need to finalize the - remaining contexts, and [finalize] will not do anything for a [released] physical device. *) + remaining contexts, and [finalize], [finalize_physical] will not do anything for a [released] + physical device. *) for i = 0 to len - 1 do Option.iter (Core.Weak.get !devices i) ~f:(fun device -> - if Atomic.compare_and_set device.released false true then ( - Cu.Context.set_current device.primary_context; - Cu.Context.synchronize (); - Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ()); - Hashtbl.iter device.cross_device_candidates ~f:(fun ctx_array -> - Cu.Deviceptr.mem_free ctx_array.ptr); - Cu.Device.primary_ctx_release device.dev)) + if Atomic.compare_and_set device.released false true then cleanup_physical device) done; - Core.Weak.fill !devices 0 len None + Core.Weak.fill !devices 0 len None; + Stdlib.Gc.compact () let%diagn_l_sexp from_host (ctx : context) tn = match (tn, Map.find ctx.ctx_arrays tn) with