Skip to content

Commit

Permalink
feat: support indexing Tuple parameters, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 26, 2024
1 parent ed4bce0 commit de8b19b
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 69 deletions.
8 changes: 6 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ argument version of this function returns the parameter value at index `i`. The
two-argument version of this function will default to returning
`parameter_values(p)[i]`.
If this function is called with an `AbstractArray`, it will return the same array.
If this function is called with an `AbstractArray` or `Tuple`, it will return the same
array/tuple.
"""
function parameter_values end

parameter_values(arr::AbstractArray) = arr
parameter_values(arr::Tuple) = arr

Check warning on line 16 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L16

Added line #L16 was not covered by tests
parameter_values(arr::AbstractArray, i) = arr[i]
parameter_values(arr::Tuple, i) = arr[i]

Check warning on line 18 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L18

Added line #L18 was not covered by tests
parameter_values(prob, i) = parameter_values(parameter_values(prob), i)

"""
Expand Down Expand Up @@ -77,7 +80,8 @@ See: [`parameter_values`](@ref)
"""
function set_parameter! end

function set_parameter!(sys::AbstractArray, val, idx)
# Tuple only included for the error message
function set_parameter!(sys::Union{AbstractArray, Tuple}, val, idx)

Check warning on line 84 in src/parameter_indexing.jl

View check run for this annotation

Codecov / codecov/patch

src/parameter_indexing.jl#L84

Added line #L84 was not covered by tests
sys[idx] = val
end
set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
Expand Down
142 changes: 75 additions & 67 deletions test/parameter_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,78 +17,86 @@ function SymbolicIndexingInterface.finalize_parameters_hook!(fi::FakeIntegrator,
end

sys = SymbolCache([:x, :y, :z], [:a, :b, :c], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, copy(p), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1
for pType in [Vector, Tuple]
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(sys, pType(copy(p)), Ref(0))
new_p = [4.0, 5.0, 6.0]
@test parameter_timeseries(fi) == [0]
for (sym, oldval, newval, check_inference) in [
(:a, p[1], new_p[1], true),
(1, p[1], new_p[1], true),
([:a, :b], p[1:2], new_p[1:2], true),
(1:2, p[1:2], new_p[1:2], true),
((1, 2), Tuple(p[1:2]), Tuple(new_p[1:2]), true),
([:a, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([:a, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((:a, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((:a, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true),
([1, [:b, :c]], [p[1], p[2:3]], [new_p[1], new_p[2:3]], false),
([1, (:b, :c)], [p[1], (p[2], p[3])], [new_p[1], (new_p[2], new_p[3])], false),
((1, [:b, :c]), (p[1], p[2:3]), (new_p[1], new_p[2:3]), true),
((1, (:b, :c)), (p[1], (p[2], p[3])), (new_p[1], (new_p[2], new_p[3])), true)
]
get = getp(sys, sym)
set! = setp(sys, sym)
if check_inference
@inferred get(fi)
end
@test get(fi) == fi.ps[sym]
@test get(fi) == oldval

@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2
if pType === Tuple
@test_throws MethodError set!(fi, newval)
continue
end

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4
@test fi.counter[] == 0
if check_inference
@inferred set!(fi, newval)
else
set!(fi, newval)
end
@test fi.counter[] == 1

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
@test get(fi) == newval
set!(fi, oldval)
@test get(fi) == oldval
@test fi.counter[] == 2

fi.ps[sym] = newval
@test get(fi) == newval
@test fi.counter[] == 3
fi.ps[sym] = oldval
@test get(fi) == oldval
@test fi.counter[] == 4

if check_inference
@inferred get(p)
end
@test get(p) == oldval
if check_inference
@inferred set!(p, newval)
else
set!(p, newval)
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end
@test get(p) == newval
set!(p, oldval)
@test get(p) == oldval
@test fi.counter[] == 4
fi.counter[] = 0
end

for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
for (sym, val) in [
([:a, :b, :c], p),
([:c, :a], p[[3, 1]]),
((:b, :a), p[[2, 1]]),
((1, :c), p[[1, 3]])
]
buffer = zeros(length(sym))
get = getp(sys, sym)
@inferred get(buffer, fi)
@test buffer == val
end
end

struct FakeSolution
Expand Down

0 comments on commit de8b19b

Please sign in to comment.