-
Notifications
You must be signed in to change notification settings - Fork 23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lots of new functions on NDArrays #78
Conversation
Thanks for looking into this. I like the idea overall-- I have a few comments on the implementation. |
templates/NDArray.kt
Outdated
* @param rtol the maximum relative (i.e. fractional) difference to allow between elements | ||
* @param atol the maximum absolute difference to allow between elements | ||
*/ | ||
fun <R> allClose(other: NDArray<R>, rtol: Double=1e-05, atol: Double=1e-08): Boolean { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding allClose
directly to the NDArray
interface means it will be available to all NDArray
s, including e.g. NDArray<String>
. One of the goals of implementing it as an extension function on Matrix was that it would only appear when a user had a supported numeric matrix type. For example, Matrix<Double>.allClose
would compile but Matrix<Int>.allClose
wouldn't. With the implementation in this PR, ndArrayOf("a","b").allClose(ndArrayOf("a","b"))
compiles and throws an exception at runtime and we therefore lose type safety.
I agree with removing the Matrix
version of allClose
(since NDArray is a super type), but I'd recommend keeping the implementation as a set of extension functions, optimally only implemented for floating point (since as you pointed out in the issue, these functions don't make sense for integral types). If the difficulty in doing so is the currently complicated codegen implementation, I can take over from here and work that in if you'd like.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about defining it as an extension function for type <out Number>
? That way you could still use it to compare between different numerical types, like I do in the test case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me, as long as we can be careful to implement it so we don't box every element.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unless we can guarantee that getDouble()
will always work, I think boxing is unavoidable. What about modifying the generic version to cast between numeric types, just like the specialized ones do? For example, in DefaultGenericNDArray.getDouble()
change
if (ele is Double)
return ele
to
if (ele is Number)
return ele.toDouble()
Likewise, require that any future backends also support casting. Then we can just call getDouble()
and know it will be implemented in whatever way is most efficient given the storage of the particular array.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking we could dispatch in a when
block, similar to the NDArray.invoke
block, that checked the reified type and called a specific one for each primitive, i.e. for a NDArray<Int>
we could call .getInt(..).toDouble()
. I haven't tried to see if it would break down though.
if (ele is Number)
return ele.toDouble()
I think this route would work as well and might be cleaner. As long as we only expose the extension functions for NDArrays containing primitive types and we don't box in the specialized implementations, no objections here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I'll change it to work like that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another option would be to not even have the extension function, and just put the whole implementation into the allclose()
function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inside the interface, you only get an opportunity to reify <R>
, so you're still doing needlessly-expensive number conversion logic on the contents of the NDArray.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why? If getDouble()
is declared to return a Double
, on the JVM it will be implemented with a method whose return type is a primitive double
. It's up to each implementation of the interface to provide a value of that type. In the case of an array of primitive doubles, no conversion or boxing should be needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I misread much of the above. You guys were talking about boxing and I misread it as being about type conversion. My comment was trying to avoid doing O(n) primitive casts from whatever the true underlying data type is to double, which seems like it might be important given
>>> Long.MAX_VALUE.toDouble() == Long.MAX_VALUE.toDouble() + 1
true
templates/NDArray.kt
Outdated
if (!(shape().toIntArray() contentEquals other.shape().toIntArray())) | ||
return false | ||
for (i in 0 until size) { | ||
val a = getDouble(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getDouble is only guaranteed to be implemented by NDArray<Double>
backends, so this function is only guaranteed to work with Double
currently. The current Default
implementations for numerical types do cast between primitives, but the DefaultGenericNDArray
implementation doesn't and it's unclear what future NDArray implementations like #70 will do. If one ends up needing to call e.g. getDouble
in a context where the type is only known to be generic NDArray<T>
, a try/catch is probably necessary to make sure the user doesn't get a undiagnosable error (for a DefaultGenericNDArray
this method currently gives them the error Exception in thread "main" java.lang.IllegalStateException: Double methods not implemented for generic NDArray
, which is difficult to decipher for a user)
Note that the goal of the get$PRIMITIVE
methods is to have a non-boxing accessor that should only be called by code that knows the primitive type is correct. In particular, they were designed to be called by the extension functions with a known primitive type, such as fun NDArray<Double>.foo(...) = this.getDouble()
, thus retaining type safety for the user. Ideally these get$PRIMITIVE
methods would be marked internal
so users wouldn't be able to see them at all, but there's some issues with module access that haven't been resolved yet (it looks like they are resolved in the new mpp though, so perhaps #77 will allow us to).
templates/NDArray.kt
Outdated
* Find the linear index of the minimum element in this array. | ||
* If the array contains non-comparable values, this throws an exception. | ||
*/ | ||
fun argMin(): Int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These could also be implemented as extension functions with a T: Comparable
upper bound on their applicability. That is, something like fun <T: Comparable<T>> NDArray<T>.argMin()
. Thus if the user has an NDArray
holding something that is comparable argMin
is a valid method, otherwise it's a compile time error.
For Matrix
these were defined on the interface itself because matrices are guaranteed to be numerical, thus they applied to all Matrix
implementations. For NDArray, there's a wide class of comparable types that arent numerical these would apply to, and an even wider class of objects that aren't ordered in which case this function doesn't apply
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed you made them methods rather than extensions so different backends could provide their own implementations? I can change them to extensions if you prefer, but that will rule out the possibility of a different backend having an optimized implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about a hybrid approach, similar to how we implement indexed access operators? The user would see an extension function NDArray<Double>.argMin
, which would in turn delegate the action to a method on the instance with a different name like argMinDouble
(for index accessors, the analog is a NDArray<Double>.get(i)
extension function that forwards to getDouble(i)
).
The disadvantage is that right now the should-not-be-used-by-the-end-user-directly type-specific operations are still exposed to the user. There's a few potential ways this could be hidden from the user though:
-
Marking the delegated methods as
protected
. This is unavailable since we are using an interface to implementNDArray
instead of an abstract base (which would cause other issues). -
Marking the delegated methods as
internal
. This causes problems because the platform specific implementations are in a different project, which don't see internal declarations. However, this might be changing in the new multiplatform gradle plugin. -
Hiding them all on some internal class that the user likely won't dive into and could be overridden, e.g. declaring a
NumericalOperations
class that has all the platform specific stuff on it. I've considered doing this a few times to solve other problems in the past.
I think for now we could just do the straightforward hybrid approach with publically delegated functions like the array accessors do now, and then convert all of them together in the future if we ever decide to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another perspective on this: BLAS level 1 operations like argmin are almost certainly memory bound and don't benefit from any sort of computational optimization a backend could bring. So I don't think there will be any immediate advantage to delegating these to the backends to do the work (my own benchmarks on the cblas k/native backend showed that implementing things like matrix addition directly in kotlin made no difference over making the call into openblas). So maybe the right answer here is O(N) operations like finding the minimum we just go ahead and implement it once in the extension function.
The only concern there is that, for future GPU backends, such an implementation of argmin
will require a recall of all the memory from the gpu to the cpu to do the work in kotlin. However, thats a bridge we can cross when we get there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BLAS level 1 operations like argmin are almost certainly memory bound and don't benefit from any sort of computational optimization a backend could bring.
Actually there's a lot of room for optimization. If the reduction is over the whole array, or the last axis, then memory access is fast. You're accessing everything in order so the cache makes it efficient. Otherwise you break the array up into tiles and compute several output values at once. That way you still only have to load each cache line once. Then you vectorize it and optionally multithread it.
Also, if accessing an individual element from Kotlin involves a native method call, putting the loop over elements in Kotlin rather than native code will be much slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I'm following, but that may be due to my being too brief in reply again and therefore being unclear. I'm good with having a delegating virtual method on the NDArray instance that the extension function calls to leave our options open, so let's go with that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I just misunderstood. So I'll do what you described in #78 (comment). And for the moment I won't worry about trying to hide the internal method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, the base NDArray class will define all possible reductions for all types? argMinDouble()
, argMinFloat()
, argMinInt()
, etc., and the same for argMax, min, and max? Then there will be an extension function argMin()
on NDArray<Double>
that calls argMinDouble()
. Is that right?
It's not obvious to me what advantage that has over just a single method for each one. Perhaps I'm still misunderstanding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I'm bad at explaining myself. Let me see if I can shed some light on the idea...
The goal of using extension functions instead of just having members on NDArray is type safety. Suppose we define:
fun <T: Comparable<T>> NDArray<T>.argMin(): Int {
...
}
then we would have
ndArrayOf(1, 2, 3) // Okay
ndArrayOf(1.0,2.0) // Okay
ndArrayOf(Image("a.jpg"), Image("b.jpg")).argMin() // Compile-time error, not Comparable
Thus trying to use argMin
on something non-comparable becomes a syntax error at compile time. If argMin
were defined on the instance itself (as it currently is in this PR) then (foo as NDArray<Image>).argMin()
doesn't fail until runtime. This is noted in the current docstring: If the array contains non-comparable values, this throws an exception.
It would be better to throw a compile time error than an exception, thus the recommendation to use extension functions.
So we could just end it there, and have the extension function contain a full implementation of argMin
. However, as you noted there might be use cases where the backend wants to take control over that implementation. In this case we'll need to delegate the extension function implementation to virtual methods on NDArray
that can be overridden by the implementation. Then we'd have:
fun <T: Comparable<T>> NDArray<T>.argMin(): Int {
return this.argMinHiddenFromUser()
}
The virtual method should be "hidden" from the user one way or another, with the extension functions being what they see and call (the hiding part is TBD). Doing it this way with a user-facing extension function backed by a "hidden" virtual method on NDArray
allows us to present the user with type safety and still delegate the implementation to the backends.
I don't think we need virtual methods argMin$PRIMITIVE
for each primitive. We probably just need a single argMinHiddenFromUser()
virtual method on NDArray
which is overridden by the backends to do the operation in an efficient way for its particular data configuration, similar to what you currently have for fun argMin(): Int
(except renamed to something hidden, so it doesn't override the extension function). The duplicate-methods-for-each-primitive thing is only needed when we are boxing on each individual element. For example, with the array accessors we did need separate methods for each primitive, because otherwise array access would be extremely slow (it was actually re-implemented this way after a user report that it was too slow to use). The decision as to whether or not we need one-per-primitive overloads of a method should hinge on whether or not we're incurring a box per element. For aggregate operations like finding the argmin over the entire array, there is no need. If there are methods that pass in/out elements one by one then we may need to look at whats getting boxed to decide.
Does that help clarify at all?
koma-core-api/common/src/koma/internal/default/generated/ndarray/DefaultLongNDArray.kt
Show resolved
Hide resolved
Ok, I think I've made all the changes. |
Looks good now, thanks! |
This implements #74.
For the moment,
mean()
andsum()
are only implemented as functions, not methods. So you can writesum(x)
, but notx.sum()
. Once the codegen issues are figured out, that will be easy to change.For all the reduction funtions (
min()
,argMin()
,sum()
, etc.) I've included two versions. There's one that works over the whole array, and a second one that does the reduction over a single axis. I modelled the API on Numpy.