Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance of Wrapped & Structured Arrays #373

Open
3 tasks
avik-pal opened this issue Dec 13, 2024 · 1 comment
Open
3 tasks

Performance of Wrapped & Structured Arrays #373

avik-pal opened this issue Dec 13, 2024 · 1 comment

Comments

@avik-pal
Copy link
Collaborator

Problem Description

Currently, our approach to dealing with Wrapped Arrays is (considering the case of mul!:

function mul!(C::TracedRArray, B::AnyTracedRArray, A::AnyTracedRArray)
    B = materialize_traced_array(B)
    A = materialize_traced_array(A)
    Ops.dot_general(....)
    return C
end

This ensures that the code works as long as a wrapper type implements materialize_traced_array. But this is not the most efficient solution, trivial to see with the simple case of a Diagonal wrapper (see @mofeing's comment #369 (comment) for implementation of Diagonal using dot_general)

Current list of slow fallbacks

  • mul!
  • diag
  • diagm
@mofeing
Copy link
Collaborator

mofeing commented Dec 13, 2024

This is more of a $N \times M$ problem: it's the combination of array types with methods. For example, a PermutedDimsArray on a matrix multiplication or on a more genera einsum can also have a more efficient implementation without materializing it. But there might be some array type or method where using the default is fine.

Also, check it out that thanks to the high-level opt passes we added in Enzyme-JAX, the implementation of some of the default implementations could have the same performance as writing it by hand. But I prefer to directly emit optimal code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants