Skip to content

Commit

Permalink
add code to store 3D in HDF
Browse files Browse the repository at this point in the history
  • Loading branch information
jd-lara committed Sep 29, 2023
1 parent 22d3e72 commit ecf9a0a
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 19 deletions.
31 changes: 26 additions & 5 deletions src/core/dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,43 @@ function make_system_state(
return InMemoryDataset(NaN, timestamp, resolution, 0, 1, columns)
end

function get_dataset_value(s::InMemoryDataset, date::Dates.DateTime)
function get_dataset_value(
s::T,
date::Dates.DateTime,
) where {T <: Union{InMemoryDataset{1}, InMemoryDataset{2}}}
s_index = find_timestamp_index(s.timestamps, date)
if isnothing(s_index)
error("Request time stamp $date not in the state")
end
return s.values[:, s_index]
end

function get_dataset_value(s::InMemoryDataset{3}, date::Dates.DateTime)
s_index = find_timestamp_index(s.timestamps, date)
if isnothing(s_index)
error("Request time stamp $date not in the state")

Check warning on line 128 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
end
return s.values[:, :, s_index]

Check warning on line 130 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L130

Added line #L130 was not covered by tests
end

function get_column_names(k::OptimizationContainerKey, s::InMemoryDataset)
return get_column_names(k, s.values)
end

function get_last_recorded_value(s::InMemoryDataset)
function get_last_recorded_value(s::InMemoryDataset{2})
if get_last_recorded_row(s) == 0
error("The Dataset hasn't been written yet")
end
return s.values[:, get_last_recorded_row(s)]
end

function get_last_recorded_value(s::InMemoryDataset{3})
if get_last_recorded_row(s) == 0
error("The Dataset hasn't been written yet")

Check warning on line 146 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L144-L146

Added lines #L144 - L146 were not covered by tests
end
return s.values[:, :, get_last_recorded_row(s)]

Check warning on line 148 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L148

Added line #L148 was not covered by tests
end

function get_end_of_step_timestamp(s::InMemoryDataset)
return s.timestamps[s.end_of_step_index]
end
Expand All @@ -153,17 +171,20 @@ function get_value_timestamp(s::InMemoryDataset, date::Dates.DateTime)
return s.timestamps[s_index]
end

function set_value!(s::InMemoryDataset, vals::DenseAxisArray{Float64, 2}, index::Int)
# These set_value! methods expect a single time_step value because they are used to update
#the state so the incoming vals will have one dimension less than the DataSet. The exception
# is for vals of Dimension 1 which are still stored in DataSets of dimension 2.
function set_value!(s::InMemoryDataset{2}, vals::DenseAxisArray{Float64, 2}, index::Int)

Check warning on line 177 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L177

Added line #L177 was not covered by tests
s.values[:, index] = vals[:, index]
return
end

function set_value!(s::InMemoryDataset, vals::DenseAxisArray{Float64, 1}, index::Int)
function set_value!(s::InMemoryDataset{2}, vals::DenseAxisArray{Float64, 1}, index::Int)
s.values[:, index] = vals
return
end

function set_value!(s::InMemoryDataset, vals::DenseAxisArray{Float64, 3}, index::Int)
function set_value!(s::InMemoryDataset{3}, vals::DenseAxisArray{Float64, 2}, index::Int)
s.values[:, :, index] = vals
return

Check warning on line 189 in src/core/dataset.jl

View check run for this annotation

Codecov / codecov/patch

src/core/dataset.jl#L187-L189

Added lines #L187 - L189 were not covered by tests
end
Expand Down
1 change: 1 addition & 0 deletions src/core/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const JuMPVariableMatrix = DenseAxisArray{
JuMP.Containers._AxisLookup{Tuple{Int64, Int64}},
},
}
const JuMPFloatMatrix = DenseAxisArray{Float64, 2}
const JuMPFloatArray = DenseAxisArray{Float64}
const JuMPVariableArray = DenseAxisArray{JuMP.VariableRef}

Expand Down
2 changes: 1 addition & 1 deletion src/parameters/update_parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ end

function _fix_parameter_value!(
container::OptimizationContainer,
parameter_array::JuMPFloatArray,
parameter_array::DenseAxisArray{Float64, 2},
parameter_attributes::VariableValueAttributes,
)
affected_variable_keys = parameter_attributes.affected_keys
Expand Down
61 changes: 49 additions & 12 deletions src/simulation/hdf_simulation_store.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,10 @@ function write_result!(
::Dates.DateTime,
data::DenseAxisArray{Float64, 3, <:NTuple{3, Any}},
)
@show "Here"
#=
output_cache = get_output_cache(store.cache, model_name, key)
cur_size = get_size(store.cache)

Check warning on line 596 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L595-L596

Added lines #L595 - L596 were not covered by tests
add_result!(output_cache, index, to_matrix(data), is_full(store.cache, cur_size))

add_result!(output_cache, index, data.data, is_full(store.cache, cur_size))

Check warning on line 598 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L598

Added line #L598 was not covered by tests

if get_dirty_size(output_cache) >= get_min_flush_size(store.cache)
discard = !should_keep_in_cache(output_cache)

Check warning on line 601 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L600-L601

Added lines #L600 - L601 were not covered by tests
Expand All @@ -614,7 +613,6 @@ function write_result!(
#end

@debug "write_result" get_size(store.cache) encode_key_as_string(key)
=#
return

Check warning on line 616 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L615-L616

Added lines #L615 - L616 were not covered by tests
end

Expand All @@ -627,10 +625,10 @@ function write_result!(
key::OptimizationContainerKey,
index::EmulationModelIndexType,
simulation_time::Dates.DateTime,
data::Matrix{Float64},
data::Array{Float64},
)
dataset = _get_em_dataset(store, key)
_write_dataset!(dataset.values, data, index:index)
_write_dataset!(dataset.values, data, index)
set_last_recorded_row!(dataset, index)
set_update_timestamp!(dataset, simulation_time)
return
Expand All @@ -642,7 +640,21 @@ function write_result!(
key::OptimizationContainerKey,
index::EmulationModelIndexType,
simulation_time::Dates.DateTime,
data,
data::DenseAxisArray{Float64, 2},
)
data_array = Array{Float64, 3}(undef, size(data)[1], size(data)[2], 1)
data_array[:, :, 1] = data
write_result!(store, model_name, key, index, simulation_time, data_array)
return

Check warning on line 648 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L645-L648

Added lines #L645 - L648 were not covered by tests
end

function write_result!(
store::HdfSimulationStore,
model_name::Symbol,
key::OptimizationContainerKey,
index::EmulationModelIndexType,
simulation_time::Dates.DateTime,
data::DenseAxisArray{Float64, 1},
)
write_result!(store, model_name, key, index, simulation_time, to_matrix(data))
return
Expand Down Expand Up @@ -984,7 +996,7 @@ function _read_length(::Type{OptimizerStats}, store::HdfSimulationStore)
end

function _write_dataset!(
dataset,
dataset::HDF5.Dataset,
array::Matrix{Float64},
row_range::UnitRange{Int64},
::Val{3},
Expand All @@ -995,7 +1007,7 @@ function _write_dataset!(
end

function _write_dataset!(
dataset,
dataset::HDF5.Dataset,
array::Matrix{Float64},
row_range::UnitRange{Int64},
::Val{2},
Expand All @@ -1005,13 +1017,38 @@ function _write_dataset!(
return
end

function _write_dataset!(dataset, array::Matrix{Float64}, row_range::UnitRange{Int64})
_write_dataset!(dataset, array, row_range, Val{ndims(dataset)}())
function _write_dataset!(

Check warning on line 1020 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L1020

Added line #L1020 was not covered by tests
dataset::HDF5.Dataset,
array::Array{Float64, 3},
row_range::UnitRange{Int64},
::Val{3},
)
dataset[row_range, :, :] = array
@debug "wrote dataset" dataset row_range
return

Check warning on line 1028 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L1026-L1028

Added lines #L1026 - L1028 were not covered by tests
end

function _write_dataset!(dataset::HDF5.Dataset, array::Array{Float64}, index::Int)
_write_dataset!(dataset, array, index:index, Val{ndims(dataset)}())
return
end

function _write_dataset!(dataset, array::Array{Float64, 3}, row_range::UnitRange{Int64})
function _write_dataset!(
dataset::HDF5.Dataset,
array::Array{Float64, 3},
row_range::UnitRange{Int64},
)
dataset[:, :, row_range] = array
@debug "wrote dataset" dataset row_range
return
end

function _write_dataset!(

Check warning on line 1046 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L1046

Added line #L1046 was not covered by tests
dataset::HDF5.Dataset,
array::Array{Float64, 4},
row_range::UnitRange{Int64},
)
dataset[:, :, :, row_range] = array
@debug "wrote dataset" dataset row_range
return

Check warning on line 1053 in src/simulation/hdf_simulation_store.jl

View check run for this annotation

Codecov / codecov/patch

src/simulation/hdf_simulation_store.jl#L1051-L1053

Added lines #L1051 - L1053 were not covered by tests
end
3 changes: 2 additions & 1 deletion src/simulation/simulation_state.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ function update_system_state!(
set_update_timestamp!(system_dataset, ts)
# Keep coordination between fields. System state is an array of size 1
system_dataset.timestamps[1] = ts
set_dataset_values!(state, key, 1, get_dataset_value(decision_dataset, simulation_time))
data_set_value = get_dataset_value(decision_dataset, simulation_time)
set_dataset_values!(state, key, 1, data_set_value)
# This value shouldn't be other than one and after one execution is no-op.
set_last_recorded_row!(system_dataset, 1)
return
Expand Down

0 comments on commit ecf9a0a

Please sign in to comment.