Skip to content

Commit

Permalink
Merge pull request #102 from samuelfneumann/main
Browse files Browse the repository at this point in the history
  • Loading branch information
mkschleg authored Jun 28, 2023
2 parents 862e0d4 + e269b50 commit f180060
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/configs/arg_iter_config_sql.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ steps=102902
[sweep_args]
opt2 = [5,6,7,8,9]
opt1 = "1:50"
"opt3+opt4" = [["a", 1], ["b", 2], ["c", 3]]
"opt3+opt4" = [["a", 1], ["b", 2], ["c", 3]]
9 changes: 6 additions & 3 deletions examples/experiment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import Reproduce:
- `opt1::Int`: The first argument. Experiment errors on `opt1 == 2`
- `opt2::Int`: The second argument.
- `opt2::String`: The Third argument. This is required to be a string
- `opt4::Int`: The fourth argument.
- `opt4::Int`: The fourth argument.
"""
opt1 => 1
opt2 => 2
Expand All @@ -41,10 +41,13 @@ function main_experiment(config::Dict, extra_arg = nothing; progress=false, test
throw("Oh No!!!")
end

Dict("mean"=>0.1,
Dict(
"mean"=>0.1,
"vec"=>rand(100),
"mat"=>rand(10, 10),
"vec_vec"=>[rand(10) for _ in 1:10])
"3darr"=>reshape(collect(1:27), 3, 3, 3),
"vec_vec"=>[rand(10) for _ in 1:10],
)
end
end

Expand Down
47 changes: 35 additions & 12 deletions src/save/sql_manager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ end
Create the tables to store the results.
"""
function create_results_tables(dbm::DBManager, results)

tbl_name = get_results_table_name()
if table_exists(dbm, tbl_name)
return
Expand All @@ -64,7 +64,7 @@ function create_results_tables(dbm::DBManager, results)
# add Hash
push!(names, HASH_KEY)
push!(types, get_hash_type())

for k in keys(results)
if results[k] isa AbstractArray
# Do crazy things...
Expand All @@ -78,22 +78,22 @@ function create_results_tables(dbm::DBManager, results)

# push!(names, string(k))
# push!(types, "BOOLEAN NOT NULL DEFAULT 0")

# create_results_subtable(dbm, k, eltype(results[k]))

# elseif results[k] isa DataType && results[k] <: AbstractVector

# push!(names, string(k))
# push!(types, "BOOLEAN NOT NULL DEFAULT 0")

# create_results_subtable(dbm, k, results[k].parameters[1])

else # add to types

nms, dtys = get_sql_schema(string(k), results[k])
append!(names, nms isa String ? [nms] : nms)
append!(types, dtys isa String ? [dtys] : dtys)

end
end

Expand All @@ -106,7 +106,7 @@ function create_results_subtable(dbm::DBManager, key, elt)
if table_exists(dbm, tbl_name)
return
end

create_array_table(dbm, tbl_name, elt)

end
Expand Down Expand Up @@ -135,12 +135,14 @@ function get_array_table_sql_statement(tbl_name, data::AbstractArray)
sql *= "step_$(i) INT UNSIGNED, "
end
sql *= "INDEX (_HASH));"
@show sql
return sql
end

function create_array_table(dbm::DBManager, tbl_name, data)

sql = get_array_table_sql_statement(tbl_name, data)

try
close!(execute(dbm, sql))
catch err
Expand All @@ -163,8 +165,8 @@ end
function save_params(dbm::DBManager, params; filter_keys = String[], use_git_info = true) # returns hash

p_names, p_values = get_sql_names_values(params)


# hash key
pms_hash = hash_params(params; filter_keys=filter_keys)
push!(p_names, HASH_KEY)
Expand Down Expand Up @@ -196,7 +198,7 @@ function save_results(dbm::DBManager, pms_hash, results)

# save to sub table
save_sub_results(dbm, pms_hash, k, results[k])

names *= "$(k)"
values *= "true"

Expand Down Expand Up @@ -230,7 +232,7 @@ function save_results(dbm::DBManager, pms_hash, results)
end

append_row(dbm, get_results_table_name(), names, values)

end


Expand All @@ -255,6 +257,27 @@ function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractMatrix

end

function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractArray)
tbl_name = get_results_subtable_name(key)

indices = map(x -> range(1, x), size(results))

for inds in Iterators.product(indices...)
v = results[inds...]

names = "(" * HASH_KEY * ", "
values = "($(pms_hash), "
for i in 1:length(inds)
names *= "step_$i, "
values *= "$(inds[i]), "
end
names *= "data)"
values *= "$(v))"

append_row(dbm, tbl_name, names, values)
end
end

function save_sub_results(dbm::DBManager, pms_hash, key, results::AbstractVector{V}) where {V<:AbstractVector}
tbl_name = get_results_subtable_name(key)
for (i, vec) in enumerate(results)
Expand Down

0 comments on commit f180060

Please sign in to comment.