Skip to content

Commit

Permalink
update estimators keys in output NamedTuple
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Jul 26, 2024
1 parent 1fe317a commit 7574a20
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/cli.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ function cli_settings()
"--estimator-key"
arg_type = String
help = "Estimator to use to proceed with sieve variance correction."
default = "TMLE"
default = "1"
end

@add_arg_table! s["merge"] begin
Expand Down
22 changes: 7 additions & 15 deletions src/models/registry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,24 @@ function estimators_from_string(;config_string="wtmle-ose", treatment_variables=
## There is an estimator specification, a Q specification and a G specification
q_string = components[2]
g_string = components[3]
Q_continuous = model_from_string(string(q_string, "_REGRESSOR"), treatment_variables)
Q_binary = model_from_string(string(q_string, "_CLASSIFIER"), treatment_variables)
G = model_from_string(string(g_string, "_CLASSIFIER"), treatment_variables; interactions=false)
elseif length(components) == 2
## There is an estimator specification and a single model specification
model_string = components[2]
Q_continuous = model_from_string(string(model_string, "_REGRESSOR"), treatment_variables)
Q_binary = model_from_string(string(model_string, "_CLASSIFIER"), treatment_variables)
G = model_from_string(string(model_string, "_CLASSIFIER"), treatment_variables; interactions=false)
q_string = g_string = components[2]
else
## There is an estimator specification
Q_continuous = model_from_string("GLMNET_REGRESSOR", treatment_variables)
Q_binary = model_from_string("GLMNET_CLASSIFIER", treatment_variables)
G = model_from_string("GLMNET_CLASSIFIER", treatment_variables; interactions=false)
q_string = g_string = "GLMNET"
end
models = TMLE.default_models(
## For the estimation of E[Y|W, T]: continuous outcome
Q_continuous = Q_continuous,
Q_continuous = model_from_string(string(q_string, "_REGRESSOR"), treatment_variables),
## For the estimation of E[Y|W, T]: binary outcome
Q_binary = Q_binary,
Q_binary = model_from_string(string(q_string, "_CLASSIFIER"), treatment_variables),
## For the estimation of p(T| W)
G = G,
G = model_from_string(string(g_string, "_CLASSIFIER"), treatment_variables; interactions=false),
)
# Create Estimators
resampling = RESAMPLING(treatment_variables)
estimators_strings = split(components[1], "-")
estimators = [estimator_from_string(estimator_string, models, resampling) for estimator_string in estimators_strings]
return NamedTuple{Tuple(Symbol.(estimators_strings))}(estimators)
estimator_names = Tuple(Symbol(estimator_string, :_, q_string, :_, g_string) for estimator_string in estimators_strings)
return NamedTuple{estimator_names}(estimators)
end
17 changes: 9 additions & 8 deletions src/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function update_work_lists_with!(result, sample_ids, batch_results, grm_ids, res
end
end

function build_work_list(prefix, grm_ids; estimator_key=:TMLE)
function build_work_list(prefix, grm_ids; estimator_key=1)
dirname_, prefix_ = splitdir(prefix)
dirname__ = dirname_ == "" ? "." : dirname_
hdf5files = filter(
Expand Down Expand Up @@ -189,16 +189,16 @@ end
corrected_stderrors(variances) =
sqrt.(view(maximum(variances, dims=1), 1, :))

with_updated_std(estimate::T, std) where T = T(
with_updated_std(estimate::T, std) where T <: TMLE.Estimate = T(
estimate.estimand,
estimate.estimate,
convert(Float64, std),
estimate.n,
Float64[]
)

with_updated_std(results, stds, estimator_key) =
[NamedTuple{(estimator_key,)}([with_updated_std(result, std)]) for (result, std) in zip(results, stds)]
with_updated_std(results::AbstractVector, stds) =
[NamedTuple{(:SVP,)}([with_updated_std(result, std)]) for (result, std) in zip(results, stds)]


"""
Expand All @@ -208,7 +208,7 @@ with_updated_std(results, stds, estimator_key) =
verbosity=0,
n_estimators=10,
max_tau=0.8,
estimator_key="TMLE"
estimator_key=1
)
Sieve Variance Plateau CLI.
Expand All @@ -232,9 +232,10 @@ function sieve_variance_plateau(input_prefix::String;
verbosity::Int=0,
n_estimators::Int=10,
max_tau::Float64=0.8,
estimator_key::String="TMLE"
estimator_key::String="1"
)
estimator_key = Symbol(estimator_key)
estimator_key_ = tryparse(Int, estimator_key)
estimator_key = estimator_key_ === nothing ? Symbol(estimator_key) : estimator_key_
τs = default_τs(n_estimators;max_τ=max_tau)
grm, grm_ids = readGRM(grm_prefix)
verbosity > 0 && @info "Preparing work list."
Expand All @@ -244,7 +245,7 @@ function sieve_variance_plateau(input_prefix::String;
verbosity > 0 && @info "Computing variance estimates."
variances = compute_variances(influence_curves, grm, τs, n_obs)
std_errors = corrected_stderrors(variances)
results = with_updated_std(results, std_errors, estimator_key)
results = with_updated_std(results, std_errors)
else
variances = Float32[]
end
Expand Down
26 changes: 13 additions & 13 deletions test/models/registry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ using MLJLinearModels
# Default configuration results in GLMNets with interactions of order 2
estimators = TargetedEstimation.estimators_from_string(config_string="wtmle-ose", treatment_variables=Set([:T1, :T2]))
## Check estimators
@test estimators.WTMLE isa TMLEE
@test estimators.WTMLE.weighted === true
@test estimators.WTMLE.resampling === nothing
@test estimators.WTMLE_GLMNET_GLMNET isa TMLEE
@test estimators.WTMLE_GLMNET_GLMNET.weighted === true
@test estimators.WTMLE_GLMNET_GLMNET.resampling === nothing

@test estimators.OSE isa OSE
@test estimators.OSE.resampling === nothing
@test estimators.OSE_GLMNET_GLMNET isa OSE
@test estimators.OSE_GLMNET_GLMNET.resampling === nothing
## Check models
expected_resampling = JointStratifiedCV(
patterns = Regex[r"^T2$", r"^T1$"],
Expand Down Expand Up @@ -47,12 +47,12 @@ end
resampling=StratifiedCV(nfolds=3)
)
## Check estimators
@test estimators.CVTMLE isa TMLEE
@test estimators.CVTMLE.weighted === false
@test estimators.CVTMLE.resampling == expected_resampling
@test estimators.CVTMLE_TUNEDXGBOOST_TUNEDXGBOOST isa TMLEE
@test estimators.CVTMLE_TUNEDXGBOOST_TUNEDXGBOOST.weighted === false
@test estimators.CVTMLE_TUNEDXGBOOST_TUNEDXGBOOST.resampling == expected_resampling

@test estimators.CVOSE isa OSE
@test estimators.CVOSE.resampling == expected_resampling
@test estimators.CVOSE_TUNEDXGBOOST_TUNEDXGBOOST isa OSE
@test estimators.CVOSE_TUNEDXGBOOST_TUNEDXGBOOST.resampling == expected_resampling
## Check models
for estimator in estimators
@test estimator.models.Q_binary_default.probabilistic_tuned_model.model isa XGBoostClassifier
Expand All @@ -65,9 +65,9 @@ end
# 2 model is provided for nuisance functions
estimators = TargetedEstimation.estimators_from_string(config_string="tmle--sl--glm", treatment_variables=["Coco"])
## Check estimators
@test estimators.TMLE isa TMLEE
@test estimators.TMLE.weighted === false
@test estimators.TMLE.resampling === nothing
@test estimators.TMLE_SL_GLM isa TMLEE
@test estimators.TMLE_SL_GLM.weighted === false
@test estimators.TMLE_SL_GLM.resampling === nothing
## Check models
expected_resampling = JointStratifiedCV(
patterns = Regex[r"^Coco$"],
Expand Down
3 changes: 1 addition & 2 deletions test/runner.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ end
results_from_json = TMLE.read_json(output, use_mmap=false)
n_IC_empties = 0
for result in results_from_json
if result[:OSE].IC != []
if result[:OSE_GLM_GLM].IC != []
n_IC_empties += 1
end
end
Expand Down Expand Up @@ -284,7 +284,6 @@ end
end
end


end;

true
20 changes: 10 additions & 10 deletions test/sieve_variance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,12 +307,12 @@ end
src_results = [tmleout1..., tmleout2...]

for svp_result in svp_results
src_result_index = findall(x.TMLE.estimand == svp_result.TMLE.estimand for x in src_results)
src_result_index = findall(x.TMLE.estimand == svp_result.SVP.estimand for x in src_results)
src_result = src_results[only(src_result_index)]
@test src_result.TMLE.std != svp_result.TMLE.std
@test src_result.TMLE.estimate == svp_result.TMLE.estimate
@test src_result.TMLE.n == svp_result.TMLE.n
@test svp_result.TMLE.IC == []
@test src_result.TMLE.std != svp_result.SVP.std
@test src_result.TMLE.estimate == svp_result.SVP.estimate
@test src_result.TMLE.n == svp_result.SVP.n
@test svp_result.SVP.IC == []
end

close(io)
Expand Down Expand Up @@ -350,14 +350,14 @@ end
svp_results = io["results"]
standalone_estimates = svp_results[1:2]
from_composite = svp_results[3:4]
@test standalone_estimates[1].OSE.estimand == from_composite[1].OSE.estimand
@test standalone_estimates[2].OSE.estimand == from_composite[2].OSE.estimand
@test standalone_estimates[1].SVP.estimand == from_composite[1].SVP.estimand
@test standalone_estimates[2].SVP.estimand == from_composite[2].SVP.estimand

# Check std has been updated
for i in 1:2
@test standalone_estimates[i].OSE.estimand == src_results[i].OSE.estimand
@test standalone_estimates[i].OSE.estimate == src_results[i].OSE.estimate
@test standalone_estimates[i].OSE.std != src_results[i].OSE.std
@test standalone_estimates[i].SVP.estimand == src_results[i].OSE.estimand
@test standalone_estimates[i].SVP.estimate == src_results[i].OSE.estimate
@test standalone_estimates[i].SVP.std != src_results[i].OSE.std
end

close(io)
Expand Down
2 changes: 1 addition & 1 deletion test/summary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ include(joinpath(TESTDIR, "testutils.jl"))

@test length(inputs) == 9
for (input, jls_output, hdf5_out, json_output) in zip(inputs, jls_outputs, hdf5_outputs, json_outputs)
@test input.WTMLE.estimand == jls_output.WTMLE.estimand == hdf5_out.WTMLE.estimand == json_output[:WTMLE].estimand
@test input.WTMLE_GLMNET_GLMNET.estimand == jls_output.WTMLE_GLMNET_GLMNET.estimand == hdf5_out.WTMLE_GLMNET_GLMNET.estimand == json_output[:WTMLE_GLMNET_GLMNET].estimand
end
end

Expand Down

0 comments on commit 7574a20

Please sign in to comment.