Skip to content

Commit

Permalink
Add registration for Base.clamp
Browse files Browse the repository at this point in the history
  • Loading branch information
contradict committed Sep 4, 2024
1 parent fb0265e commit aca03fb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ end
@register_symbolic Base.rand(x)
@register_symbolic Base.randn(x)

@register_symbolic Base.clamp(x, y, z)

function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1})
x, l, h = args
T = promote_type(symtype(x), symtype(l), symtype(h))
z = zero(T)
o = one(T)
ifelse(x<l, z, ifelse(x>h, z, o))
end

@register_symbolic Distributions.pdf(dist,x)
@register_symbolic Distributions.logpdf(dist,x)
@register_symbolic Distributions.cdf(dist,x)
Expand Down
3 changes: 3 additions & 0 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ x = Num.(randn(10))
@test norm(x, 1) == norm(Symbolics.value.(x), 1)
@test norm(x, 1.2) == norm(Symbolics.value.(x), 1.2)

@test clamp.(x, 0, 1) == clamp.(Symbolics.value.(x), 0, 1)
@test isequal(Symbolics.derivative(clamp(a, 0, 1), a), ifelse(a < 0, 0, ifelse(a>1, 0, 1)))

@variables x[1:2]
@test isequal(scalarize(norm(x)), sqrt(abs2(x[1]) + abs2(x[2])))
@test isequal(scalarize(norm(x, Inf)), max(abs(x[1]), abs(x[2])))
Expand Down

0 comments on commit aca03fb

Please sign in to comment.