Skip to content

Commit

Permalink
remove robust flag
Browse files Browse the repository at this point in the history
  • Loading branch information
Karl Wessel committed Nov 15, 2024
1 parent 7fba695 commit 3bf685a
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 33 deletions.
66 changes: 40 additions & 26 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,25 @@ function recursive_hasoperator(op, O)
end
end

function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
if robust || occurrences == nothing
"""
executediff(D, arg, simplify=false; occurrences=nothing)
Apply the passed Differential D on the passed argument.
This function differs to `expand_derivatives` in that in only expands the
passed differential and not any other Differentials it encounters.
# Arguments
- `D::Differential`: The differential to apply
- `arg::Symbolic`: The symbolic expression to apply the differential on.
- `simplify::Bool=false`: Whether to simplify the resulting expression using
[`SymbolicUtils.simplify`](@ref).
- `occurrences=nothing`: Information about the occurrences of the independent
variable in the argument of the derivative. This is used internally for
optimization purposes.
"""
function executediff(D, arg, simplify=false; occurrences=nothing)
if occurrences == nothing
occurrences = occursin_info(D.x, arg)
end

Expand All @@ -166,15 +183,15 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
return D(arg) # base case if any argument is directly equal to the i.v.
else
return sum(inner_args, init=0) do a
return executediff(Differential(a), arg; robust) *
executediff(D, a; robust)
return executediff(Differential(a), arg) *
executediff(D, a)
end
end
elseif op === (IfElse.ifelse)
args = arguments(arg)
O = op(args[1],
executediff(D, args[2], simplify; robust, occurrences=arguments(occurrences)[2]),
executediff(D, args[3], simplify; robust, occurrences=arguments(occurrences)[3]))
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]),
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3]))
return O
elseif isa(op, Differential)
# The recursive expand_derivatives was not able to remove
Expand All @@ -184,20 +201,21 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
if isequal(op.x, D.x)
return D(arg)
else
inner = executediff(D, arguments(arg)[1], false; robust)
inner = executediff(D, arguments(arg)[1], false)
# if the inner expression is not expandable either, return
if iscall(inner) && operation(inner) isa Differential
return D(arg)
else
return expand_derivatives(op(inner), simplify; robust) # TODO
# otherwise give the nested Differential another try
return executediff(op, inner, simplify)
end
end
elseif isa(op, Integral)
if isa(op.domain.domain, AbstractInterval)
domain = op.domain.domain
a, b = DomainSets.endpoints(domain)
c = 0
inner_function = expand_derivatives(arguments(arg)[1]; robust) # TODO
inner_function = arguments(arg)[1]
if iscall(value(a))
t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => value(a)))
t2 = D(a)
Expand All @@ -208,7 +226,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
t2 = D(b)
c += t1*t2
end
inner = executediff(D, arguments(arg)[1]; robust)
inner = executediff(D, arguments(arg)[1])
c += op(inner)
return value(c)
end
Expand All @@ -220,7 +238,7 @@ function executediff(D, arg, simplify=false; robust=false, occurrences=nothing)
c = 0

for i in 1:l
t2 = executediff(D, inner_args[i],false; robust, occurrences=arguments(occurrences)[i])
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i])

x = if _iszero(t2)
t2
Expand Down Expand Up @@ -265,9 +283,6 @@ and other derivative rules to expand any derivatives it encounters.
- `O::Symbolic`: The symbolic expression to expand.
- `simplify::Bool=false`: Whether to simplify the resulting expression using
[`SymbolicUtils.simplify`](@ref).
- `occurrences=nothing`: Information about the occurrences of the independent
variable in the argument of the derivative. This is used internally for
optimization purposes.
# Examples
```jldoctest
Expand All @@ -283,30 +298,29 @@ julia> dfx=expand_derivatives(Dx(f))
(k*((2abs(x - y)) / y - 2z)*IfElse.ifelse(signbit(x - y), -1, 1)) / y
```
"""
function expand_derivatives(O::Symbolic, simplify=false; robust=false, occurrences=nothing)
function expand_derivatives(O::Symbolic, simplify=false)
if iscall(O) && isa(operation(O), Differential)
arg = only(arguments(O))
arg = expand_derivatives(arg, false; robust)
arg = expand_derivatives(arg, false)
return executediff(operation(O), arg, simplify)
elseif iscall(O) && isa(operation(O), Integral)
return operation(O)(expand_derivatives(arguments(O)[1]; robust))
return operation(O)(expand_derivatives(arguments(O)[1]))
elseif !hasderiv(O)
return O
else
args = map(a->expand_derivatives(a, false; robust), arguments(O))
args = map(a->expand_derivatives(a, false), arguments(O))
O1 = operation(O)(args...)
return simplify ? SymbolicUtils.simplify(O1) : O1
end

executediff(operation(O), arg, simplify; robust, occurrences)
end
function expand_derivatives(n::Num, simplify=false; robust=false, occurrences=nothing)
wrap(expand_derivatives(value(n), simplify; robust, occurrences))
function expand_derivatives(n::Num, simplify=false)
wrap(expand_derivatives(value(n), simplify))
end
function expand_derivatives(n::Complex{Num}, simplify=false; robust=false, occurrences=nothing)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; robust, occurrences),
expand_derivatives(imag(n), simplify; robust, occurrences)))
function expand_derivatives(n::Complex{Num}, simplify=false)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify),
expand_derivatives(imag(n), simplify)))
end
expand_derivatives(x, simplify=false; robust=false, occurrences=nothing) = x
expand_derivatives(x, simplify=false) = x

_iszero(x) = false
_isone(x) = false
Expand Down
13 changes: 6 additions & 7 deletions test/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,8 @@ let
D = Differential(t)
expr = b - ((D(b))^2) * D(D(b))
expr2 = D(expr)
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
@test isequal(expand_derivatives(expr2; robust=true), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
@test isequal(expand_derivatives(expr2; robust=true), expand_derivatives(expr2))
@test isequal(expand_derivatives(expr), expr)
@test isequal(expand_derivatives(expr2), D(b) - (D(b)^2)*D(D(D(b))) - 2D(b)*(D(D(b))^2))
end

# 1126
Expand All @@ -370,13 +369,13 @@ let
expr_gen = (fun) -> D(D(((-D(D(fun))) / g(y))))

expr = expr_gen(g(y))
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
# just make sure that no errors are thrown in the following, the results are to complicated to compare
expand_derivatives(expr)
expr = expr_gen(h(y))
@test isequal(expand_derivatives(expr), expand_derivatives(expr; robust=true))
expand_derivatives(expr)

expected = substitute(expand_derivatives(expr; robust=true), h(y) => f(y))
expr = expr_gen(f(y))
@test isequal(expand(expand_derivatives(expr)), expand(expand_derivatives(expr; robust=true)))
expand_derivatives(expr)
end

# Check `is_derivative` function
Expand Down

0 comments on commit 3bf685a

Please sign in to comment.