diff --git a/examples/configs/arg_iter_config_sql.toml b/examples/configs/arg_iter_config_sql.toml index 044e6af..7fc06fe 100644 --- a/examples/configs/arg_iter_config_sql.toml +++ b/examples/configs/arg_iter_config_sql.toml @@ -17,4 +17,4 @@ steps=102902 [sweep_args] opt2 = [5,6,7,8,9] opt1 = "1:50" -"opt3+opt4" = [["a", 1], ["b", 2], ["c", 3]] \ No newline at end of file +"opt3+opt4" = [["a", 1], ["b", 2], ["c", 3]] diff --git a/examples/experiment.jl b/examples/experiment.jl index 076159e..8123f3a 100644 --- a/examples/experiment.jl +++ b/examples/experiment.jl @@ -17,7 +17,7 @@ import Reproduce: - `opt1::Int`: The first argument. Experiment errors on `opt1 == 2` - `opt2::Int`: The second argument. - `opt2::String`: The Third argument. This is required to be a string - - `opt4::Int`: The fourth argument. + - `opt4::Int`: The fourth argument. """ opt1 => 1 opt2 => 2 @@ -41,10 +41,13 @@ function main_experiment(config::Dict, extra_arg = nothing; progress=false, test throw("Oh No!!!") end - Dict("mean"=>0.1, + Dict( + "mean"=>0.1, "vec"=>rand(100), "mat"=>rand(10, 10), - "vec_vec"=>[rand(10) for _ in 1:10]) + "3darr"=>reshape(collect(1:27), 3, 3, 3), + "vec_vec"=>[rand(10) for _ in 1:10], + ) end end diff --git a/src/save/sql_manager.jl b/src/save/sql_manager.jl index b505786..682fa4c 100644 --- a/src/save/sql_manager.jl +++ b/src/save/sql_manager.jl @@ -51,7 +51,7 @@ end Create the tables to store the results. """ function create_results_tables(dbm::DBManager, results) - + tbl_name = get_results_table_name() if table_exists(dbm, tbl_name) return @@ -64,7 +64,7 @@ function create_results_tables(dbm::DBManager, results) # add Hash push!(names, HASH_KEY) push!(types, get_hash_type()) - + for k in keys(results) if results[k] isa AbstractArray # Do crazy things... @@ -78,22 +78,22 @@ function create_results_tables(dbm::DBManager, results) # push!(names, string(k)) # push!(types, "BOOLEAN NOT NULL DEFAULT 0") - + # create_results_subtable(dbm, k, eltype(results[k])) # elseif results[k] isa DataType && results[k] <: AbstractVector # push!(names, string(k)) # push!(types, "BOOLEAN NOT NULL DEFAULT 0") - + # create_results_subtable(dbm, k, results[k].parameters[1]) - + else # add to types nms, dtys = get_sql_schema(string(k), results[k]) append!(names, nms isa String ? [nms] : nms) append!(types, dtys isa String ? [dtys] : dtys) - + end end @@ -106,7 +106,7 @@ function create_results_subtable(dbm::DBManager, key, elt) if table_exists(dbm, tbl_name) return end - + create_array_table(dbm, tbl_name, elt) end @@ -135,12 +135,14 @@ function get_array_table_sql_statement(tbl_name, data::AbstractArray) sql *= "step_$(i) INT UNSIGNED, " end sql *= "INDEX (_HASH));" + @show sql + return sql end function create_array_table(dbm::DBManager, tbl_name, data) sql = get_array_table_sql_statement(tbl_name, data) - + try close!(execute(dbm, sql)) catch err @@ -163,8 +165,8 @@ end function save_params(dbm::DBManager, params; filter_keys = String[], use_git_info = true) # returns hash p_names, p_values = get_sql_names_values(params) - - + + # hash key pms_hash = hash_params(params; filter_keys=filter_keys) push!(p_names, HASH_KEY) @@ -196,7 +198,7 @@ function save_results(dbm::DBManager, pms_hash, results) # save to sub table save_sub_results(dbm, pms_hash, k, results[k]) - + names *= "$(k)" values *= "true" @@ -230,7 +232,7 @@ function save_results(dbm::DBManager, pms_hash, results) end append_row(dbm, get_results_table_name(), names, values) - + end @@ -255,6 +257,27 @@ function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractMatrix end +function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractArray) + tbl_name = get_results_subtable_name(key) + + indices = map(x -> range(1, x), size(results)) + + for inds in Iterators.product(indices...) + v = results[inds...] + + names = "(" * HASH_KEY * ", " + values = "($(pms_hash), " + for i in 1:length(inds) + names *= "step_$i, " + values *= "$(inds[i]), " + end + names *= "data)" + values *= "$(v))" + + append_row(dbm, tbl_name, names, values) + end +end + function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractVector{V}) where {V<:AbstractVector} tbl_name = get_results_subtable_name(key) for (i, vec) in enumerate(results)