diff --git a/src/MatrixFields/MatrixFields.jl b/src/MatrixFields/MatrixFields.jl index b437af042c..fe66741f70 100644 --- a/src/MatrixFields/MatrixFields.jl +++ b/src/MatrixFields/MatrixFields.jl @@ -43,7 +43,7 @@ multiples of `LinearAlgebra.I`. This comes with the following functionality: """ module MatrixFields -import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv, norm +import LinearAlgebra: I, UniformScaling, Adjoint, AdjointAbsVec, mul!, inv, norm, dot import StaticArrays: SMatrix, SVector import BandedMatrices: BandedMatrix, band, _BandedMatrix import RecursiveArrayTools: recursive_bottom_eltype diff --git a/src/MatrixFields/matrix_multiplication.jl b/src/MatrixFields/matrix_multiplication.jl index 80863a7555..75280f3a52 100644 --- a/src/MatrixFields/matrix_multiplication.jl +++ b/src/MatrixFields/matrix_multiplication.jl @@ -478,3 +478,17 @@ if hasfield(Method, :recursion_relation) m.recursion_relation = dont_limit end end + +# TODO: Explain this +function Base.Broadcast.broadcasted( + ::typeof(dot), # This is LinearAlgebra.dot + arg1, + arg2, +) + args_bced = (Base.Broadcast.broadcastable(arg1), Base.Broadcast.broadcastable(arg2)) + if eltype(args_bced[1]) <: BandMatrixRow + error("Detected usage of LinearAlgebra.dot in MatrixFields. Did you mean to `import ClimaCore.MatrixFields: ⋅`?") + else + Base.Broadcast.broadcasted(Base.Broadcast.combine_styles(args_bced...), dot, args_bced...) + end +end diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 3b31dde3b2..65c30f1618 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3432,3 +3432,4 @@ Base.@propagate_inbounds function apply_stencil!( end return field_out end + diff --git a/test/MatrixFields/linearalgebra_dot.jl b/test/MatrixFields/linearalgebra_dot.jl new file mode 100644 index 0000000000..52e8ff77c7 --- /dev/null +++ b/test/MatrixFields/linearalgebra_dot.jl @@ -0,0 +1,29 @@ +# LinearAlgebra exports ⋅ and it is easy to get confused with MatrixFields.⋅. +# This test checks that we have a mechanism in place to inform the user when +# they are using the wrong ⋅. + +import ClimaCore: Domains, Geometry, Meshes, Spaces, MatrixFields +import ClimaCore.MatrixFields: @name + +import LinearAlgebra: ⋅, I + +@testset "LinearAlgebra dot" begin + n_elem_z = 2 + + domain = Domains.IntervalDomain(Geometry.ZPoint(0.0), Geometry.ZPoint(1.0), boundary_names = (:bottom, :top)) + mesh = Meshes.IntervalMesh(domain, nelems = 2) + space = Spaces.FaceFiniteDifferenceSpace(mesh) + + diverg = Operators.DivergenceC2F(; bottom = Operators.SetDivergence(0.0), top = Operators.SetDivergence(0.0)) + grad = Operators.GradientF2C() + + diverg_matrix = MatrixFields.operator_matrix(diverg) + grad_matrix = MatrixFields.operator_matrix(grad) + + name = @name(u) + jacobian = MatrixFields.FieldMatrix( + (@name(u), @name(u)) => similar(zeros(space), ClimaCore.MatrixFields.TridiagonalMatrixRow{Float64}), + ) + + @test_throws ErrorException @. jacobian[name, name] = diverg_matrix() ⋅ grad_matrix() - (I,) +end