Skip to content

Commit

Permalink
fix zeros in hypot and norm
Browse files Browse the repository at this point in the history
  • Loading branch information
aplavin committed Jul 30, 2024
1 parent 4c297d1 commit bce86ef
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
5 changes: 4 additions & 1 deletion ext/AccessorsLinearAlgebraExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import Accessors: set, @set
using LinearAlgebra: norm, normalize, diag, diagind

set(arr, ::typeof(normalize), val) = norm(arr) * val
set(arr, ::typeof(norm), val) = map(Base.Fix2(*, val / norm(arr)), arr) # should we check val is positive?
function set(arr, ::typeof(norm), val)
omul = iszero(val) ? one(norm(arr)) : norm(arr)
map(Base.Fix2(*, val / omul), arr)
end

set(A, ::typeof(diag), val) = @set A[diagind(A)] = val

Expand Down
5 changes: 4 additions & 1 deletion src/functionlenses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ set(x::AbstractString, f::Base.Fix1{typeof(parse), Type{T}}, y::T) where {T} = s
set(f, ::typeof(inverse), invf) = setinverse(f, invf)

set(obj, ::typeof(Base.splat(atan)), val) = @set Tuple(obj) = hypot(obj...) .* sincos(val)
set(obj, ::typeof(Base.splat(hypot)), val) = map(Base.Fix2(*, val / hypot(obj...)), obj)
function set(obj, ::typeof(Base.splat(hypot)), val)
omul = iszero(val) ? one(hypot(obj...)) : hypot(obj...)
map(Base.Fix2(*, val / omul), obj)
end

################################################################################
##### strings
Expand Down
4 changes: 4 additions & 0 deletions test/test_functionlenses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,9 @@ end
@test @set(abs(-2u"m") = 1u"m") === -1u"m"
@test @set(abs(x) = 10) 6 + 8im
@test @set(angle(x) = π/2) 5im
@test set(0, abs, 0) == 0
@test set(0, abs, 10) == 10
@test set(0+0im, abs, 0) == 0+0im
@test set(0+0im, abs, 10) == 10
@test set(0+1e-100im, abs, 10) == 10im
@test_throws DomainError @set(abs(x) = -10)
Expand Down Expand Up @@ -300,7 +302,9 @@ end
@test set((3, 4), norm, 10) === (6., 8.)
@test set((a=3, b=4), norm, 10) === (a=6., b=8.)
test_getset_laws(norm, (3, 4), 10, 12)
test_getset_laws(norm, (0, 0), 0, 0)
test_getset_laws(Base.splat(hypot), (3, 4), 10, 12)
test_getset_laws(Base.splat(hypot), (0, 0), 0, 0.)

test_getset_laws(!(@optic _.a), (a=true,), false, true)
test_getset_laws(!(@optic _[1]), (a=true,), false, true)
Expand Down

0 comments on commit bce86ef

Please sign in to comment.