Skip to content

Commit

Permalink
try fix the sampling issue related to missing levels
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 29, 2024
1 parent 6eb8173 commit 4d649f2
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 57 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{}
4 changes: 1 addition & 3 deletions src/samplers/density_estimate_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ function sample_roots_and_treatments(sampler::DensityEstimateSampler, origin_dat
)
# Sample Roots
sampled_dataset = sample_from(origin_dataset, sampler.roots;
n=n,
min_occurences=min_occurences,
variables_to_check=[],
n=n,
max_attempts=max_attempts,
verbosity=verbosity
)
Expand Down
30 changes: 20 additions & 10 deletions src/samplers/null_sampler.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""
The Permutation-Null-Sampler keeps the marginal distributions of each variable in the original dataset
The NullSampler keeps the marginal distributions of each variable in the original dataset
intact while disrupting the causal relationships between them. This is done by:
1. Sampling from (W, C)
2. Permuting each T
3. Permuting Y
2. Sampling from each T independently
3. Sampling from each Y independently
"""
struct NullSampler
confounders_and_covariates
Expand All @@ -29,6 +29,21 @@ function NullSampler(outcome, treatments;
return NullSampler(variables)
end

"""
sample_from(sampler::NullSampler, origin_dataset;
n=100,
min_occurences=10,
max_attempts=1000,
verbosity=1
)
The procedure tries to:
1. Sample jointly from (W, C):
- The levels of sampled factor variables should match the levels in the original data.
2. Sample independently for each T and Y:
- The levels of sampled factor variables should match the levels in the original data.
- The lowest populated sampled level of each factor variable should have more than `min_occurences` samples.
each for a maximum of `max_attempts`.
"""
function sample_from(sampler::NullSampler, origin_dataset;
n=100,
min_occurences=10,
Expand All @@ -37,21 +52,16 @@ function sample_from(sampler::NullSampler, origin_dataset;
)
sampled_dataset = sample_from(origin_dataset, collect(sampler.confounders_and_covariates);
n=n,
min_occurences=min_occurences,
max_attempts=max_attempts,
verbosity=verbosity
)
# Independently sample the rest of variables
variables_to_check = [var for var in sampler.other_variables if isfactor(origin_dataset[!, var])]
for variable in sampler.other_variables
sampled_variable_df = sample_from(origin_dataset, [variable];
sampled_dataset[!, variable] = sample_from(origin_dataset, variable;
n=n,
variables_to_check=variables_to_check,
min_occurences=min_occurences,
max_attempts=max_attempts,
verbosity=verbosity
)
sampled_dataset[!, variable] = sampled_variable_df[!, variable]
)
end
return sampled_dataset
end
Expand Down
103 changes: 74 additions & 29 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,72 +79,118 @@ confounders_and_covariates(variables) = vcat(collect(variables.confounders), col

countuniques(dataset, colname) = DataFrames.combine(groupby(dataset, colname, skipmissing=true), nrow)

function dataset_is_too_extreme(sampled_dataset, origin_dataset, variables_to_check; min_occurences=10)
for var in variables_to_check
# Check all levels are present in the smapled dataset
sampled_levels = Set(skipmissing(sampled_dataset[!, var]))
origin_levels = Set(skipmissing(origin_dataset[!, var]))
if sampled_levels != origin_levels
return true, string("Missing levels for variable: ", var)
end
# Check all levels occur at least `min_occurences` of times
n_uniques = countuniques(sampled_dataset, var)
if minimum(n_uniques.nrow) < min_occurences
return true, string("Not enough occurences for variable: ", var)
levels_missing(sampled_vector, origin_vector) = Set(skipmissing(sampled_vector)) != Set(skipmissing(origin_vector))

"""
Checks that the multiclass variables have all their levels present in the sampled dataset
"""
function check_sampled_levels(sampled_dataset, origin_dataset, variables_to_check)
for variable in variables_to_check
if levels_missing(sampled_dataset[!, variable], origin_dataset[!, variable])
return true, string(variable)
end
end
return false, ""
end

isfactor(col; nlevels=5) = length(levels(col; skipmissing=true)) < nlevels
ismulticlass(col) = autotype(col, :few_to_finite) <: Union{Missing, <:Multiclass}

multiclass_variables(origin_dataset, variables) =
filter(v -> ismulticlass(origin_dataset[!, v]), variables)

"""
sample_from(origin_dataset::DataFrame, variables;
n=100,
min_occurences=10,
max_attempts=1000,
verbosity = 1
)
This method jointly samples with replacement n samples of `variables` from `origin_dataset` after dropping missing values.
It ensures that each level of each sampled factor variable is present at least `min_occurences` of times.
Otherwise a new sampling attempt is made and up to `max_attempts`.
Tries to jointly sample non-missing values of `variables` from the `origin_dataset` for a maximum of `max_attempts`.
Each sampled dataset is checked as follows:
1. The levels of sampled factor variables should match the levels in the original data.
"""
function sample_from(origin_dataset::DataFrame, variables;
n=100,
variables_to_check=[],
min_occurences=10,
variables_to_check=multiclass_variables(origin_dataset, variables),
max_attempts=1000,
verbosity = 1
)
variables = collect(variables)
variables_to_check = intersect(variables, variables_to_check)
nomissing = dropmissing(origin_dataset[!, variables])
too_extreme, msg = Simulations.dataset_is_too_extreme(nomissing, origin_dataset, variables_to_check; min_occurences=min_occurences)
if too_extreme
levels_are_missing, msg = Simulations.check_sampled_levels(nomissing, origin_dataset, variables_to_check)
if levels_are_missing
msg = string(
"Filtering of missing values resulted in a too extreme dataset. In particular: ", msg,
"Filtering of missing values resulted in missing levels for variable ", msg,
".\n Consider lowering or setting the `call_threshold` to `nothing`."
)
throw(ErrorException(msg))
end
# Resample until the dataset is not too extreme
for attempt in 1:max_attempts
sample_rows = StatsBase.sample(1:nrow(nomissing), n, replace=true)
sampled_dataset = nomissing[sample_rows, variables]
too_extreme, msg = dataset_is_too_extreme(sampled_dataset, nomissing, variables_to_check; min_occurences=min_occurences)
if !too_extreme
levels_are_missing, msg = check_sampled_levels(sampled_dataset, nomissing, variables_to_check)
if !levels_are_missing
return sampled_dataset
end
verbosity > 0 && @info(string("Sampled dataset after attempt ", attempt, " was too extreme. In particular: ", msg, ".\n Retrying."))
verbosity > 0 && @info(string("Sampled dataset after attempt ", attempt, " had missing levels. In particular: ", msg, ".\n Retrying."))
end
msg = string(
"Could not sample a dataset which wasn't too extreme after: ", max_attempts,
" attempts. Possible solutions: increase `sample_size`, change your simulation estimands of increase `max_attempts`."
"Could not sample a dataset with all variables' levels in `variables_to_check` after: ", max_attempts,
" attempts. Possible solutions: increase `sample_size`, change your simulation estimands or increase `max_attempts`."
)
throw(ErrorException(msg))
end

function sampled_vector_has_enough_occurences(sampled_vector, origin_vector;
min_occurences=10
)
n_uniques = countmap(sampled_vector)
if length(n_uniques) != length(levels(origin_vector))
return false, "missing levels."
end
if minimum(values(n_uniques)) < min_occurences
return false, "not enough occurrences for each level."
end
return true, ""
end

"""
sample_from(origin_vector::AbstractVector;
n=100,
min_occurences=10,
max_attempts=1000,
)
Tries to sample non-missing values from a vector for a maximum of `max_attempts`.
Each sampled vector is checked as follows:
1. The levels of sampled factor variables should match the levels in the original data.
2. The lowest populated sampled level of each factor variable should have more than `min_occurences` samples.
"""
function sample_from(origin_dataset::DataFrame, variable::Union{Symbol, AbstractString};
n=100,
min_occurences=10,
max_attempts=1000,
verbosity = 1
)
origin_vector = collect(skipmissing(origin_dataset[!, variable]))
for attempt in 1:max_attempts
sampled_vector = StatsBase.sample(origin_vector, n, replace=true)
# If binary of multiclass: check levels and occurences
if length(levels(sampled_vector)) == 2 || ismulticlass(sampled_vector)
has_enough_occurences, msg = Simulations.sampled_vector_has_enough_occurences(sampled_vector, origin_vector;
min_occurences=min_occurences
)
if !has_enough_occurences
verbosity > 0 && @info(string("The sampled vector for variable ", variable, " had ", msg, "\nRetrying."))
continue
end
end
return sampled_vector
end
throw(ErrorException(string("Could not sample variable ", variable, " because it either did not have enough occurences or some levels were missing after ", max_attempts,
" attempts.\nConsider increasing the sample size or changing your estimands.")))
end

variables_from_args(outcome, treatments, confounders, outcome_extra_covariates) = (
outcome = Symbol(outcome),
treatments = Symbol.(Tuple(treatments)),
Expand All @@ -159,7 +205,6 @@ transpose_table(X) = Float32.(Tables.matrix(X, transpose=true))
transpose_table(estimator, X) =
transpose_table(MLJBase.transform(estimator.encoder, X))


function get_conditional_densities_variables(estimands)
conditional_densities_variables = Set{Pair}([])
for Ψ in estimands
Expand Down
1 change: 1 addition & 0 deletions test/null_simulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ include(joinpath(TESTDIR, "testutils.jl"))
rngs = [1, 2]
sample_sizes = [100, 200]
n_bootstraps = 2
(output_index, (estimator, rng, sample_size)) = first(enumerate(Iterators.product(estimators, rngs, sample_sizes)))
for (output_index, (estimator, rng, sample_size)) enumerate(Iterators.product(estimators, rngs, sample_sizes))
out = joinpath(outdir, string("null_estimation_results_", output_index, ".hdf5"))
copy!(ARGS, [
Expand Down
2 changes: 1 addition & 1 deletion test/samplers/null_sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ include(joinpath(TESTDIR, "testutils.jl"))
@test true_effects[estimands[3]] == [0, 0]
# Raise if min_occurences not statisfied
msg = string(
"Could not sample a dataset which wasn't too extreme after: 1 attempts. Possible solutions: increase `sample_size`, change your simulation estimands of increase `max_attempts`."
"Could not sample variable Ybin because it either did not have enough occurences or some levels were missing after 1 attempts.\nConsider increasing the sample size or changing your estimands."
)
@test_throws ErrorException(msg) sample_from(sampler, origin_dataset, n=10, min_occurences = 10, max_attempts=1, verbosity=0)

Expand Down
53 changes: 39 additions & 14 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,27 +127,52 @@ end
C = ["AA", "AC", "CC", "CC", "AA", "AC", "AC"],
D = 1:7
)
# Dropping missing results in CC not present in the dataset
error_msg = string(
"Filtering of missing values resulted in a too extreme dataset. In particular: Missing levels for variable: B.",
"\n Consider lowering or setting the `call_threshold` to `nothing`."
)
@test_throws ErrorException(error_msg) sample_from(origin_dataset, [:A, :B]; variables_to_check=[:A, :B], n=2, min_occurences=0, verbosity=0)
# if min_occurences = 10, A won't have enough occurences and will raise first
error_msg = string(
"Filtering of missing values resulted in a too extreme dataset. In particular: Not enough occurences for variable: A.",
"\n Consider lowering or setting the `call_threshold` to `nothing`."
)
@test_throws ErrorException(error_msg) sample_from(origin_dataset, [:A, :B]; variables_to_check=[:A, :B], n=2, min_occurences=10, verbosity=0)
# This will work
# Test Sample from a DataFrame
## No problem here
variables = [:A, :C, :D]
sampled_dataset = sample_from(origin_dataset, variables; n=4, variables_to_check=[:A, :C], min_occurences=0, verbosity=0)
sampled_dataset = sample_from(origin_dataset, variables; n=4, verbosity=0)
all_rows = collect(eachrow(origin_dataset[!, variables]))
for row in eachrow(sampled_dataset)
@test row all_rows
end
@test length(unique(sampled_dataset.A)) == 3
@test length(unique(sampled_dataset.C)) == 3
## Dropping missing results in CC not present in the dataset for B
error_msg = string(
"Filtering of missing values resulted in missing levels for variable B.",
"\n Consider lowering or setting the `call_threshold` to `nothing`."
)
@test_throws ErrorException(error_msg) sample_from(origin_dataset, [:A, :B]; n=2, verbosity=0)

# Now sample from a column Vector
## Continuous variables always work
@test sample_from(origin_dataset, :D;
n=10,
min_occurences=10,
max_attempts=10,
verbosity=0
) isa Vector{Int64}
# Reasonnable min_occurences
@test sample_from(origin_dataset, "B";
n=100,
min_occurences=1,
max_attempts=10,
verbosity=0
) isa Vector{String}
## Raises because too many required occurences
@test_throws ErrorException sample_from(origin_dataset, :B;
n=100,
min_occurences=100,
max_attempts=1,
verbosity = 1
)
## Raises because not enough levels
@test_throws ErrorException sample_from(origin_dataset, "B";
n=2,
min_occurences=0,
max_attempts=10,
verbosity=1
)
end

@testset "Test coerce_parents_and_outcome!" begin
Expand Down

0 comments on commit 4d649f2

Please sign in to comment.