From aca03fbd04bade6690f26ea98337cb545da42720 Mon Sep 17 00:00:00 2001 From: contradict Date: Wed, 4 Sep 2024 15:01:29 -0700 Subject: [PATCH] Add registration for `Base.clamp` --- src/extra_functions.jl | 10 ++++++++++ test/overloads.jl | 3 +++ 2 files changed, 13 insertions(+) diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 53fc0d2f7..0a78b96f7 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -68,6 +68,16 @@ end @register_symbolic Base.rand(x) @register_symbolic Base.randn(x) +@register_symbolic Base.clamp(x, y, z) + +function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1}) + x, l, h = args + T = promote_type(symtype(x), symtype(l), symtype(h)) + z = zero(T) + o = one(T) + ifelse(xh, z, o)) +end + @register_symbolic Distributions.pdf(dist,x) @register_symbolic Distributions.logpdf(dist,x) @register_symbolic Distributions.cdf(dist,x) diff --git a/test/overloads.jl b/test/overloads.jl index 0bfbc0a90..be893876e 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -186,6 +186,9 @@ x = Num.(randn(10)) @test norm(x, 1) == norm(Symbolics.value.(x), 1) @test norm(x, 1.2) == norm(Symbolics.value.(x), 1.2) +@test clamp.(x, 0, 1) == clamp.(Symbolics.value.(x), 0, 1) +@test isequal(Symbolics.derivative(clamp(a, 0, 1), a), ifelse(a < 0, 0, ifelse(a>1, 0, 1))) + @variables x[1:2] @test isequal(scalarize(norm(x)), sqrt(abs2(x[1]) + abs2(x[2]))) @test isequal(scalarize(norm(x, Inf)), max(abs(x[1]), abs(x[2])))