From bce86efd73fc5a24850407f43bdd9a566bc8700e Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Mon, 29 Jul 2024 23:08:43 -0400 Subject: [PATCH] fix zeros in hypot and norm --- ext/AccessorsLinearAlgebraExt.jl | 5 ++++- src/functionlenses.jl | 5 ++++- test/test_functionlenses.jl | 4 ++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/ext/AccessorsLinearAlgebraExt.jl b/ext/AccessorsLinearAlgebraExt.jl index ca47b95..9ab4a23 100644 --- a/ext/AccessorsLinearAlgebraExt.jl +++ b/ext/AccessorsLinearAlgebraExt.jl @@ -4,7 +4,10 @@ import Accessors: set, @set using LinearAlgebra: norm, normalize, diag, diagind set(arr, ::typeof(normalize), val) = norm(arr) * val -set(arr, ::typeof(norm), val) = map(Base.Fix2(*, val / norm(arr)), arr) # should we check val is positive? +function set(arr, ::typeof(norm), val) + omul = iszero(val) ? one(norm(arr)) : norm(arr) + map(Base.Fix2(*, val / omul), arr) +end set(A, ::typeof(diag), val) = @set A[diagind(A)] = val diff --git a/src/functionlenses.jl b/src/functionlenses.jl index 2d3b0d5..fd865b7 100644 --- a/src/functionlenses.jl +++ b/src/functionlenses.jl @@ -150,7 +150,10 @@ set(x::AbstractString, f::Base.Fix1{typeof(parse), Type{T}}, y::T) where {T} = s set(f, ::typeof(inverse), invf) = setinverse(f, invf) set(obj, ::typeof(Base.splat(atan)), val) = @set Tuple(obj) = hypot(obj...) .* sincos(val) -set(obj, ::typeof(Base.splat(hypot)), val) = map(Base.Fix2(*, val / hypot(obj...)), obj) +function set(obj, ::typeof(Base.splat(hypot)), val) + omul = iszero(val) ? one(hypot(obj...)) : hypot(obj...) + map(Base.Fix2(*, val / omul), obj) +end ################################################################################ ##### strings diff --git a/test/test_functionlenses.jl b/test/test_functionlenses.jl index c9c9947..4500d41 100644 --- a/test/test_functionlenses.jl +++ b/test/test_functionlenses.jl @@ -268,7 +268,9 @@ end @test @set(abs(-2u"m") = 1u"m") === -1u"m" @test @set(abs(x) = 10) ≈ 6 + 8im @test @set(angle(x) = π/2) ≈ 5im + @test set(0, abs, 0) == 0 @test set(0, abs, 10) == 10 + @test set(0+0im, abs, 0) == 0+0im @test set(0+0im, abs, 10) == 10 @test set(0+1e-100im, abs, 10) == 10im @test_throws DomainError @set(abs(x) = -10) @@ -300,7 +302,9 @@ end @test set((3, 4), norm, 10) === (6., 8.) @test set((a=3, b=4), norm, 10) === (a=6., b=8.) test_getset_laws(norm, (3, 4), 10, 12) + test_getset_laws(norm, (0, 0), 0, 0) test_getset_laws(Base.splat(hypot), (3, 4), 10, 12) + test_getset_laws(Base.splat(hypot), (0, 0), 0, 0.) test_getset_laws(!(@optic _.a), (a=true,), false, true) test_getset_laws(!(@optic _[1]), (a=true,), false, true)