From 7db1fac089c4e46d638ca11515f045176e865904 Mon Sep 17 00:00:00 2001
From: Avik Pal <avikpal@mit.edu>
Date: Mon, 15 Jul 2024 18:37:39 -0700
Subject: [PATCH] feat: catch scalar indexing failures early

---
 Project.toml              | 2 ++
 ext/NNlibCUDAExt/utils.jl | 7 +++++++
 src/NNlib.jl              | 1 +
 src/conv.jl               | 6 +++++-
 src/utils.jl              | 8 ++++++++
 5 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 60e2efada..ec0af5b48 100644
--- a/Project.toml
+++ b/Project.toml
@@ -4,6 +4,7 @@ version = "0.9.20"
 
 [deps]
 Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
+ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
 Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
 ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
 GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
@@ -31,6 +32,7 @@ NNlibFFTWExt = "FFTW"
 [compat]
 AMDGPU = "0.9.4"
 Adapt = "3.2, 4"
+ArrayInterface = "7.10"
 Atomix = "0.1"
 CUDA = "4, 5"
 ChainRulesCore = "1.13"
diff --git a/ext/NNlibCUDAExt/utils.jl b/ext/NNlibCUDAExt/utils.jl
index a9eaa8dbe..49739fda2 100644
--- a/ext/NNlibCUDAExt/utils.jl
+++ b/ext/NNlibCUDAExt/utils.jl
@@ -34,3 +34,10 @@ function NNlib.reverse_indices(idx::AnyCuArray{<:Any,N}) where N
     NNlib.reverse_indices!(rev, idx)
     return map(cu, rev)
 end
+
+for op in (:conv!, :∇conv_data!, :∇conv_filter!, :depthwiseconv!, :∇depthwiseconv_data!, :∇depthwiseconv_filter!)
+    error_msg = "`$(op)` requires all arguments to support fast scalar indexing. You might be missing an `using cuDNN` or `import cuDNN` statement."
+    @eval function NNlib.special_scalar_indexing_error(::Val{$(Meta.quot(op))}, ::CUDA.AnyCuArray)
+        throw(AssertionError($(error_msg)))
+    end
+end
diff --git a/src/NNlib.jl b/src/NNlib.jl
index 8cf66370f..3244e7a8c 100644
--- a/src/NNlib.jl
+++ b/src/NNlib.jl
@@ -3,6 +3,7 @@ module NNlib
 import Atomix
 import ChainRulesCore: rrule
 
+using ArrayInterface: ArrayInterface
 using Base.Broadcast: broadcasted
 using Base.Threads
 using ChainRulesCore
diff --git a/src/conv.jl b/src/conv.jl
index 3fecb9151..28a6a869e 100644
--- a/src/conv.jl
+++ b/src/conv.jl
@@ -192,6 +192,7 @@ for (front_name, backend, signature) in (
                 @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "!  ",
                         "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
             end
+            assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)
 
             x_cs = Iterators.partition(1:size(in1, 4),
                                     channels_in(cdims) ÷ groupcount(cdims))
@@ -233,7 +234,7 @@ for (front_name, backend, signature) in (
                 @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "!  ",
                         "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
             end
-
+            assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)
 
             dx_cs = Iterators.partition(1:size(out, 4),
                                         channels_in(cdims) ÷ groupcount(cdims))
@@ -276,6 +277,7 @@ for (front_name, backend, signature) in (
                 @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "!  ",
                         "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
             end
+            assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)
 
             dw_cs = Iterators.partition(1:size(out, 5),
                                         channels_out(cdims) ÷ groupcount(cdims))
@@ -327,6 +329,8 @@ for (front_name, backend, signature) in (
                 @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "!  ",
                         "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1
             end
+            assert_all_fast_scalar_indexing(Val($(Meta.quot(Symbol(front_name, "!")))), out, in1, in2)
+
             $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...)
         end
     end
diff --git a/src/utils.jl b/src/utils.jl
index 3d23e7383..2be2d0f8e 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -162,3 +162,11 @@ if VERSION < v"1.7.0-DEV.793"
     end
 end
 
+function assert_all_fast_scalar_indexing(call::Val{C}, args::AbstractArray...) where {C}
+    if !all(ArrayInterface.fast_scalar_indexing, args)
+        foreach(Base.Fix1(special_scalar_indexing_error, call), args)
+        throw(AssertionError("`$(C)` requires all arguments to support fast scalar indexing"))
+    end
+end
+
+special_scalar_indexing_error(::Val, ::AbstractArray) = nothing