Skip to content

Commit

Permalink
fix: de-specialize 3 arg mul!
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 29, 2024
1 parent 89fe3d4 commit 51c0d6e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function create_result(tocopy::D, path, result_stores) where {K,V,D<:AbstractDic
end

function create_result(
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol},
tocopy::Union{Integer,AbstractFloat,AbstractString,Nothing,Type,Symbol,Char},
path,
result_stores,
)
Expand Down
36 changes: 19 additions & 17 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,26 @@ end
# LinearAlgebra.jl overloads
## `_mul!` goes through too many layers of abstractions and we aren't able to overload
## without specializing on every possible combination of types
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::AbstractVector, A::AbstractMatrix, B::AbstractVector, α::Number, β::Number
for (cT, aT, bT) in (
(:AbstractVector, :AbstractMatrix, :AbstractVector),
(:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat),
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end
@eval begin
@reactant_overlay @noinline function LinearAlgebra.mul!(
C::$cT, A::$aT, B::$bT, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
end
return C
end

@reactant_overlay @noinline function LinearAlgebra.mul!(
C::AbstractMatrix, A::AbstractMatrix, B::AbstractVecOrMat, α::Number, β::Number
)
if any(Base.Fix2(isa, TracedRArray) ancestor, (C, A, B))
TracedLinearAlgebra.overloaded_mul!(C, A, B, α, β)
else
LinearAlgebra._mul!(C, A, B, α, β)
# Needed mostly for 1.10 where 3-arg mul is often specialized
@reactant_overlay @noinline function LinearAlgebra.mul!(C::$cT, A::$aT, B::$bT)
call_with_reactant(LinearAlgebra.mul!, C, A, B, true, false)
return C
end
end
return C
end
30 changes: 30 additions & 0 deletions test/wrapped_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,33 @@ end
@test all(iszero, y_res)
end
end

function lower_triangular_write(x)
y = LowerTriangular(copy(x))
@. y *= 2
return y
end

function upper_triangular_write(x)
y = UpperTriangular(copy(x))
@. y *= 2
return y
end

function tridiagonal_write(x)
y = Tridiagonal(copy(x))
@. y *= 2
return y
end

@testset "Broadcasted Multiply and Alloate" begin
@testset "$(aType)" for (aType, fn) in [
("LowerTriangular", lower_triangular_write),
("UpperTriangular", upper_triangular_write),
("Tridiagonal", tridiagonal_write),
]
x = rand(4, 4)
x_ra = Reactant.to_rarray(x)
@test @jit(fn(x_ra)) fn(x)
end
end

0 comments on commit 51c0d6e

Please sign in to comment.