You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, our approach to dealing with Wrapped Arrays is (considering the case of mul!:
functionmul!(C::TracedRArray, B::AnyTracedRArray, A::AnyTracedRArray)
B =materialize_traced_array(B)
A =materialize_traced_array(A)
Ops.dot_general(....)
return C
end
This ensures that the code works as long as a wrapper type implements materialize_traced_array. But this is not the most efficient solution, trivial to see with the simple case of a Diagonal wrapper (see @mofeing's comment #369 (comment) for implementation of Diagonal using dot_general)
Current list of slow fallbacks
mul!
diag
diagm
The text was updated successfully, but these errors were encountered:
This is more of a $N \times M$ problem: it's the combination of array types with methods. For example, a PermutedDimsArray on a matrix multiplication or on a more genera einsum can also have a more efficient implementation without materializing it. But there might be some array type or method where using the default is fine.
Also, check it out that thanks to the high-level opt passes we added in Enzyme-JAX, the implementation of some of the default implementations could have the same performance as writing it by hand. But I prefer to directly emit optimal code.
Problem Description
Currently, our approach to dealing with Wrapped Arrays is (considering the case of
mul!
:This ensures that the code works as long as a wrapper type implements
materialize_traced_array
. But this is not the most efficient solution, trivial to see with the simple case of aDiagonal
wrapper (see @mofeing's comment #369 (comment) for implementation ofDiagonal
usingdot_general
)Current list of slow fallbacks
mul!
diag
diagm
The text was updated successfully, but these errors were encountered: