Skip to content

Commit

Permalink
Merge pull request #172 from daniel-thom/assert-op
Browse files Browse the repository at this point in the history
Add assert_op macro
  • Loading branch information
daniel-thom authored Nov 13, 2020
2 parents f14f135 + 0aba3d2 commit aadcc52
Show file tree
Hide file tree
Showing 15 changed files with 88 additions and 28 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/pr_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: ['1', 'nightly']
julia-version: ['1']
julia-arch: [x64]
os: [ubuntu-latest, windows-latest, macOS-latest]

Expand All @@ -24,7 +24,6 @@ jobs:
env:
PYTHON: ""
- uses: julia-actions/julia-runtest@latest
continue-on-error: ${{ matrix.julia-version == 'nightly' }}
env:
PYTHON: ""
- uses: julia-actions/julia-processcoverage@v1
Expand Down
19 changes: 16 additions & 3 deletions docs/src/style.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,22 @@ One exception is the case where one file has all single-line functions.
an exception and terminate the application.
* Do not use try/catch to handle retrieving a potentially-missing key from a
dictionary.
* Use @assert statements to guard against programming errors. Do not use them
after detecting bad user input. Note that they may be compiled out in release
builds.

## Asserts
* Use `@assert` statements to guard against programming errors. Do not use them
after detecting bad user input. An assert tripping should indicate that there
is a bug in the code. Note that they may be compiled out in optimized builds in
the future.
* Consider using `InfrastructureSystems.@assert_op` instead of the standard
`@assert` because it will automatically print the value of the expression.
Unlike the standard `@assert` the Julia compiler will never exclude
`@assert_op` in optimized builds.

```julia
julia> a = 3; b = 4;
julia> @assert_op a == b
ERROR: AssertionError: 3 == 4
```

## Globals

Expand Down
1 change: 1 addition & 0 deletions src/InfrastructureSystems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ get_internal(value::InfrastructureSystemsComponent) = value.internal

include("common.jl")
include("internal.jl")
include("utils/assert_op.jl")
include("utils/recorder_events.jl")
include("utils/flatten_iterator_wrapper.jl")
include("utils/generate_structs.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/component.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function _get_columns(start_time, count, ts_metadata::ForecastMetadata)
if window_count > 1
index = Int(offset / interval) + 1
else
@assert interval == Dates.Millisecond(0)
@assert_op interval == Dates.Millisecond(0)
index = 1
end
if count === nothing
Expand Down
2 changes: 1 addition & 1 deletion src/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ function get_components(
iter = FlattenIteratorWrapper(T, _components)
end

@assert eltype(iter) == T
@assert_op eltype(iter) == T
return iter
end

Expand Down
2 changes: 1 addition & 1 deletion src/deterministic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ get_window(f::Deterministic, initial_time; len = nothing) =

function make_time_array(forecast::Deterministic)
# Artificial limitation to reduce scope.
@assert get_count(forecast) == 1
@assert_op get_count(forecast) == 1
timestamps = range(
get_initial_timestamp(forecast);
step = get_resolution(forecast),
Expand Down
4 changes: 2 additions & 2 deletions src/deterministic_single_time_series.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ function _translate_deterministic_offsets(
s_index = (columns.start - 1) * interval_offset + 1
e_index = (columns.stop - 1) * interval_offset + horizon
@debug "translated offsets" horizon columns s_index e_index last_index
@assert s_index <= last_index "s_index = $s_index last_index = $last_index"
@assert e_index <= last_index "e_index = $e_index last_index = $last_index"
@assert_op s_index <= last_index
@assert_op e_index <= last_index
return UnitRange(s_index, e_index)
end
2 changes: 1 addition & 1 deletion src/forecasts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function get_window_common(
if ndims(data) == 2
# This is necessary because the Deterministic and Probabilistic are 3D Arrays
# We need to do this to make the data a 2D TimeArray. In a get_window the data is always count = 1
@assert size(data)[1] <= len
@assert_op size(data)[1] <= len
data = @view data[1:len, :]
else
data = @view data[1:len]
Expand Down
16 changes: 8 additions & 8 deletions src/hdf5_time_series_storage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ function deserialize_time_series(
uuid = get_time_series_uuid(ts_metadata)
path = _get_time_series_path(root, uuid)
attributes = _read_time_series_attributes(storage, path, rows, T)
@assert attributes["type"] == T
@assert_op attributes["type"] == T
@debug "deserializing a StaticTimeSeries" T
data_type = attributes["data_type"]
data = get_hdf_array(path["data"], data_type, rows)
Expand Down Expand Up @@ -337,7 +337,7 @@ function deserialize_time_series(
end

attributes = _read_time_series_attributes(storage, path, rows, T)
@assert attributes["type"] == T
@assert_op attributes["type"] == T
@debug "deserializing a Forecast" T
data_type = attributes["data_type"]
data = get_hdf_array(path["data"], data_type, attributes, rows, columns)
Expand Down Expand Up @@ -510,8 +510,8 @@ function deserialize_time_series(
uuid = get_time_series_uuid(ts_metadata)
path = _get_time_series_path(root, uuid)
attributes = _read_time_series_attributes(storage, path, rows, T)
@assert attributes["type"] == T
@assert length(attributes["dataset_size"]) == 3
@assert_op attributes["type"] == T
@assert_op length(attributes["dataset_size"]) == 3
@debug "deserializing a Forecast" T
data = SortedDict{Dates.DateTime, Matrix{attributes["data_type"]}}()
initial_timestamp = attributes["start_time"]
Expand Down Expand Up @@ -553,8 +553,8 @@ function deserialize_time_series(
uuid = get_time_series_uuid(ts_metadata)
path = _get_time_series_path(root, uuid)
attributes = _read_time_series_attributes(storage, path, rows, T)
@assert attributes["type"] == T
@assert length(attributes["dataset_size"]) == 3
@assert_op attributes["type"] == T
@assert_op length(attributes["dataset_size"]) == 3
@debug "deserializing a Forecast" T
data = SortedDict{Dates.DateTime, Matrix{attributes["data_type"]}}()
initial_timestamp = attributes["start_time"]
Expand Down Expand Up @@ -624,7 +624,7 @@ function _append_item!(path::HDF5.HDF5Group, name::AbstractString, value::Abstra
push!(values, value)

ret = HDF5.o_delete(path, name)
@assert ret == 0
@assert_op ret == 0

path[name] = values
@debug "Appended $value to $name" values
Expand All @@ -646,7 +646,7 @@ function _remove_item!(path::HDF5.HDF5Group, name::AbstractString, value::Abstra
end

ret = HDF5.o_delete(path, name)
@assert ret == 0
@assert_op ret == 0

if isempty(values)
is_empty = true
Expand Down
4 changes: 2 additions & 2 deletions src/time_series_formats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ function get_unique_timestamps(::Type{T}, file::CSV.File) where {T <: TimeSeries
end
end

@assert length(timestamps) > 0
@assert_op length(timestamps) > 0
for timestamp in timestamps[2:end]
@assert timestamp["count"] == timestamps[1]["count"]
@assert_op timestamp["count"] == timestamps[1]["count"]
end

return timestamps
Expand Down
37 changes: 37 additions & 0 deletions src/utils/assert_op.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Throw an `AssertionError` if conditions like `op(exp1, exp2)` are `false`, where `op` is a conditional infix operator.
# Examples
```
julia> a = 3; b = 4;
julia> @assert_op a == b
ERROR: AssertionError: 3 == 4
julia> @assert_op a + 3 > b + 4
ERROR: AssertionError: 6 > 8
```
"""
macro assert_op(expr)
assert_op(expr)
end

function assert_op(expr::Expr)
# Only special case expressions of the form `expr1 == expr2`
if length(expr.args) == 3 && expr.head == :call
return assert_op(expr.args[1], expr.args[2], expr.args[3])
else
return :(@assert $(expr))
end
end

function assert_op(op, exp1, exp2)
return :(
if !$op($(esc(exp1)), $(esc(exp2)))
val1 = $(esc(exp1))
val2 = $(esc(exp2))
op_str = $(esc(op))
throw(AssertionError("$val1 $op_str $val2"))
end
)
end
10 changes: 5 additions & 5 deletions src/utils/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ function get_initial_times(
if count == 0
return []
elseif count == 1
@assert interval == Dates.Second(0) "interval=$interval"
@assert_op interval == Dates.Second(0)
return range(initial_timestamp; stop = initial_timestamp, step = Dates.Second(1))
end
@assert interval != Dates.Second(0) "initial_timestamp=$initial_timestamp interval=$interval count=$count"
Expand Down Expand Up @@ -397,7 +397,7 @@ end
function transform_array_for_hdf(data::SortedDict{Dates.DateTime, Vector{POLYNOMIAL}})
lin_cost = hcat(values(data)...)
rows, cols = size(lin_cost)
@assert length(first(lin_cost)) == 2
@assert_op length(first(lin_cost)) == 2
t_lin_cost = Array{Float64}(undef, rows, cols, 2)
for r in 1:rows, c in 1:cols
tuple = lin_cost[r, c]
Expand All @@ -410,7 +410,7 @@ end

function transform_array_for_hdf(data::Vector{POLYNOMIAL})
rows = length(data)
@assert length(first(data)) == 2
@assert_op length(first(data)) == 2
t_lin_cost = Array{Float64}(undef, rows, 1, 2)
for r in 1:rows
tuple = data[r]
Expand All @@ -425,7 +425,7 @@ function transform_array_for_hdf(data::SortedDict{Dates.DateTime, Vector{PWL}})
quad_cost = hcat(values(data)...)
rows, cols = size(quad_cost)
tuple_length = length(first(quad_cost))
@assert length(first(first(quad_cost))) == 2
@assert_op length(first(first(quad_cost))) == 2
t_quad_cost = Array{Float64}(undef, rows, cols, 2, tuple_length)
for r in 1:rows, c in 1:cols
tuple_array = quad_cost[r, c]
Expand All @@ -441,7 +441,7 @@ end
function transform_array_for_hdf(data::Vector{PWL})
rows = length(data)
tuple_length = length(first(data))
@assert length(first(first(data))) == 2
@assert_op length(first(first(data))) == 2
t_quad_cost = Array{Float64}(undef, rows, 1, 2, tuple_length)
for r in 1:rows
tuple_array = data[r, 1]
Expand Down
2 changes: 1 addition & 1 deletion src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ function check_limits(
field_value,
) where {T <: Union{Nothing, NamedTuple}}
# Validates up/down, min/max, from/to named tuples.
@assert length(field_value) == 2
@assert_op length(field_value) == 2
result1 = check_limits_impl(valid_info, field_value[1])
result2 = check_limits_impl(valid_info, field_value[2])
return result1 && result2
Expand Down
2 changes: 1 addition & 1 deletion test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function create_system_data(; with_time_series = false, time_series_in_memory =
file,
)
time_series = get_all_time_series(data)
@assert length(time_series) > 0
IS.@assert_op length(time_series) > 0
end

return data
Expand Down
10 changes: 10 additions & 0 deletions test/test_assert_op.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
@testset "Test assert_op" begin
a = 2
b = 2
IS.@assert_op a == b
IS.@assert_op a + 2 == b + 2
IS.@assert_op isequal(a + 2, b + 2)

@test_throws AssertionError IS.@assert_op a + 3 == b + 2
@test_throws AssertionError IS.@assert_op isequal(a + 3, b + 2)
end

0 comments on commit aadcc52

Please sign in to comment.