Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minimal symbolic clean up #268

Merged
merged 19 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix Num typepiracy
  • Loading branch information
oameye committed Oct 13, 2024
commit a716f5850f37c83e260dd1b9bea40f705755e424
39 changes: 20 additions & 19 deletions src/Symbolics/Symbolics_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,15 @@ function simplify_complex(x::BasicSymbolic)
end
end

# sometimes, expressions get stored as Complex{Num} with no way to decode what real(x) and imag(x)
# this overloads the Num constructor to return a Num if x.re and x.im have similar arguments
function Num(x::Complex{Num})::Num
if x.re.val isa Float64 && x.im.val isa Float64
return Num(x.re.val)
else
if isequal(x.re.val.arguments, x.im.val.arguments)
Num(first(x.re.val.arguments))
else
error("Cannot convert Complex{Num} " * string(x) * " to Num")
end
end
end
# ^ This function commits type-piracy with Symbolics.jl. We should change this.

"""
$(TYPEDSIGNATURES)

Perform substitutions in `rules` on `x`.
`include_derivatives=true` also includes all derivatives of the variables of the keys of `rules`.
"""
function substitute_all(
x::T, rules::Dict; include_derivatives=true
)::T where {T<:Union{Equation,Num}}
subtype=Union{Num,Equation,BasicSymbolic}
function substitute_all(x::subtype, rules::Dict; include_derivatives=true)
if include_derivatives
rules = merge(
rules,
Expand All @@ -71,7 +56,24 @@ function substitute_all(dict::Dict, rules::Dict)::Dict
end
Collections = Union{Dict,Pair,Vector,OrderedDict}
substitute_all(v::AbstractArray, rules) = [substitute_all(x, rules) for x in v]
substitute_all(x::Union{Num,Equation}, rules::Collections) = substitute_all(x, Dict(rules))
substitute_all(x::subtype, rules::Collections) = substitute_all(x, Dict(rules))
# Collections = Union{Dict,OrderedDict}
# function substitute_all(x, rules::Collections; include_derivatives=true)
# if include_derivatives
# rules = merge(
# rules,
# Dict([Differential(var) => Differential(rules[var]) for var in keys(rules)]),
# )
# end
# return substitute(x, rules)
# end
# "Variable substitution - dictionary"
# function substitute_all(dict::Dict, rules::Dict)::Dict
# new_keys = substitute_all.(keys(dict), rules)
# new_values = substitute_all.(values(dict), rules)
# return Dict(zip(new_keys, new_values))
# end
# substitute_all(v::AbstractArray, rules::Collections) = [substitute_all(x, rules) for x in v]


get_independent(x::Num, t::Num) = get_independent(x.val, t)
Expand Down Expand Up @@ -125,6 +127,5 @@ end
is_harmonic(x::Equation, t::Num) = is_harmonic(x.lhs, t) && is_harmonic(x.rhs, t)
is_harmonic(x, t) = is_harmonic(Num(x), Num(t))


"Return true if `f` is a function of `var`."
is_function(f, var) = any(isequal.(get_variables(f), var))
1 change: 0 additions & 1 deletion src/Symbolics/drop_powers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ function drop_powers(expr::Num, vars::Vector{Num}, deg::Int)
removal = Dict([ϵ^d => Num(0) for d in deg:max_deg])
res = substitute_all(substitute_all(subs_expr, removal), Dict(ϵ => Num(1)))
return Symbolics.expand(res)
#res isa Complex ? Num(res.re.val.arguments[1]) : res
end

function drop_powers(expr::Vector{Num}, var::Vector{Num}, deg::Int)
Expand Down
6 changes: 3 additions & 3 deletions src/Symbolics/fourier.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ function trig_to_exp(x::Num)
end

result = Symbolics.substitute(x, Dict(rules))
#result = result isa Complex ? Num(first(result.re.val.arguments)) : result
result = Num(result)
return result
return convert_to_Num(result)
end
convert_to_Num(x::Complex{Num})::Num = Num(first(x.re.val.arguments))
convert_to_Num(x::Num)::Num = x

function exp_to_trig(x::BasicSymbolic)
if isadd(x) || isdiv(x) || ismul(x)
Expand Down
4 changes: 2 additions & 2 deletions src/transform_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ function to_lab_frame(soln, res, times)::Vector{AbstractFloat}
timetrace = zeros(length(times))

for var in res.problem.eom.variables
val = unwrap(substitute_all(_remove_brackets(var), soln))
ω = unwrap(substitute_all(var.ω, soln))
val = real(substitute_all(unwrap(_remove_brackets(var)), soln))
ω = real(unwrap(substitute_all(var.ω, soln)))
if var.type == "u"
timetrace .+= val * cos.(ω * times)
elseif var.type == "v"
Expand Down
35 changes: 17 additions & 18 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@ using Random
const SEED = 0xd8e5d8df
Random.seed!(SEED)

# @testset "Code quality" begin
# using ExplicitImports, Aqua
# ignore_deps = [:Random, :LinearAlgebra, :Printf, :Test, :Pkg]
@testset "Code quality" begin
using ExplicitImports, Aqua
ignore_deps = [:Random, :LinearAlgebra, :Printf, :Test, :Pkg]

# @test check_no_stale_explicit_imports(HarmonicBalance) == nothing
# @test check_all_explicit_imports_via_owners(HarmonicBalance) == nothing
# Aqua.test_ambiguities(HarmonicBalance)
# Aqua.test_all(
# HarmonicBalance;
# deps_compat=(
# ignore=ignore_deps,
# check_extras=(ignore=ignore_deps,),
# check_weakdeps=(ignore=ignore_deps,),
# ),
# piracies=(treat_as_own=[HarmonicBalance.Num],),
# ambiguities=false,
# )
# end
@test check_no_stale_explicit_imports(HarmonicBalance) == nothing
@test check_all_explicit_imports_via_owners(HarmonicBalance) == nothing
Aqua.test_ambiguities(HarmonicBalance)
Aqua.test_all(
HarmonicBalance;
deps_compat=(
ignore=ignore_deps,
check_extras=(ignore=ignore_deps,),
check_weakdeps=(ignore=ignore_deps,),
),
ambiguities=false,
)
end

@testset "Code linting" begin
using JET
Expand All @@ -34,7 +33,7 @@ end
end

@testset "Symbolics customised" begin
include("Symbolics.jl")
include("symbolics.jl")
end

@testset "IO" begin
Expand Down
Loading