From ccc61d4b278d2aea7fde3588c3835b747bde541c Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Dec 2024 22:00:46 -0500 Subject: [PATCH] fix: dispatches --- src/stdlibs/LinearAlgebra.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index 457c73a43..4b9ca80aa 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -14,7 +14,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_m using LinearAlgebra # Various Wrapper Arrays defined in LinearAlgebra -function materialize_traced_array( +function TracedUtils.materialize_traced_array( x::Transpose{TracedRNumber{T},TracedRArray{T,N}} ) where {T,N} px = parent(x) @@ -22,16 +22,16 @@ function materialize_traced_array( return permutedims(A, (2, 1)) end -function materialize_traced_array( +function TracedUtils.materialize_traced_array( x::Adjoint{TracedRNumber{T},TracedRArray{T,N}} ) where {T,N} return conj(materialize_traced_array(transpose(parent(x)))) end -function materialize_traced_array( - x::LinearAlgebra.Diagonal{TracedRNumber{T},TracedRArray{T,1}} +function TracedUtils.materialize_traced_array( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}} ) where {T} - return LinearAlgebra.diagm(parent(x)) + return diagm(parent(x)) end function TracedUtils.materialize_traced_array(x::Tridiagonal{T,TracedRArray{T,1}}) where {T} @@ -42,7 +42,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) uAT = Symbol(:Unit, AT) @eval begin function TracedUtils.materialize_traced_array( - x::$(AT){T,TracedRArray{T,2}} + x::$(AT){TracedRNumber{T},TracedRArray{T,2}} ) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) @@ -52,7 +52,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) end function TracedUtils.materialize_traced_array( - x::$(uAT){T,TracedRArray{T,2}} + x::$(uAT){TracedRNumber{T},TracedRArray{T,2}} ) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) @@ -64,7 +64,9 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE")) end end -function TracedUtils.materialize_traced_array(x::Symmetric{T,TracedRArray{T,2}}) where {T} +function TracedUtils.materialize_traced_array( + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}} +) where {T} m, n = size(x) row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1) col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2) @@ -107,7 +109,9 @@ function TracedUtils.set_mlir_data!( return x end -function TracedUtils.set_mlir_data!(x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} parent(x).mlir_data = diag(TracedRArray{T}(data)).mlir_data return x end @@ -119,7 +123,7 @@ for (AT, dcomp, ocomp) in ( (:UnitUpperTriangular, "LT", "GE"), ) @eval function TracedUtils.set_mlir_data!( - x::LinearAlgebra.$(AT){T,TracedRArray{T,2}}, data + x::$(AT){TracedRNumber{T},TracedRArray{T,2}}, data ) where {T} tdata = TracedRArray{T}(data) z = zero(tdata) @@ -137,17 +141,19 @@ for (AT, dcomp, ocomp) in ( end function TracedUtils.set_mlir_data!( - x::LinearAlgebra.Symmetric{T,TracedRArray{T,2}}, data + x::Symmetric{TracedRNumber{T},TracedRArray{T,2}}, data ) where {T} if x.uplo == 'L' - set_mlir_data!(LinearAlgebra.LowerTriangular(parent(x)), data) + set_mlir_data!(LowerTriangular(parent(x)), data) else - set_mlir_data!(LinearAlgebra.UpperTriangular(parent(x)), data) + set_mlir_data!(UpperTriangular(parent(x)), data) end return x end -function TracedUtils.set_mlir_data!(x::Tridiagonal{T,TracedRArray{T,1}}, data) where {T} +function TracedUtils.set_mlir_data!( + x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}, data +) where {T} tdata = TracedRArray{T}(data) set_mlir_data!(x.dl, diag(tdata, -1).mlir_data) set_mlir_data!(x.d, diag(tdata, 0).mlir_data)