Skip to content

Commit

Permalink
up TMLE dep
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 2, 2024
1 parent 7574a20 commit 58156a2
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
16 changes: 11 additions & 5 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.10.3"
julia_version = "1.10.4"
manifest_format = "2.0"
project_hash = "b62f3762a6535538ed6b74bdb53c6ce6255c32d6"

Expand Down Expand Up @@ -156,6 +156,12 @@ git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be"
uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
version = "0.1.0"

[[deps.AutoHashEquals]]
deps = ["Pkg"]
git-tree-sha1 = "daaeb6f7f77b88c072a83a2451801818acb5c63b"
uuid = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
version = "2.1.0"

[[deps.BSON]]
git-tree-sha1 = "4c3e506685c527ac6a54ccc0c8c76fd6f91b42fb"
uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down Expand Up @@ -1633,12 +1639,12 @@ uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c"
version = "7.2.1+1"

[[deps.TMLE]]
deps = ["AbstractDifferentiation", "CategoricalArrays", "Combinatorics", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "PrecompileTools", "Random", "SplitApplyCombine", "Statistics", "TableOperations", "Tables", "Zygote"]
git-tree-sha1 = "7e2a58b276c5468de5ac5194b41d880657919eea"
repo-rev = "agnostic_composed"
deps = ["AbstractDifferentiation", "AutoHashEquals", "CategoricalArrays", "Combinatorics", "Distributions", "GLM", "Graphs", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "MetaGraphsNext", "Missings", "OrderedCollections", "PrecompileTools", "Random", "SplitApplyCombine", "Statistics", "TableOperations", "Tables", "Zygote"]
git-tree-sha1 = "4ae1fe54bc361c2073cbb76f9f9a07c294655eed"
repo-rev = "treatment_values"
repo-url = "https://github.com/TARGENE/TMLE.jl"
uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf"
version = "0.16.1"
version = "0.17.0"
weakdeps = ["JSON", "YAML"]

[deps.TMLE.extensions]
Expand Down
23 changes: 12 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
#####  Read TMLE Estimands Configuration ####
#####################################################################

function convert_treatment_values(treatment_levels::NamedTuple{names, <:Tuple{Vararg{NamedTuple}}}, treatment_types) where names
return [(
case = convert(treatment_types[tn], treatment_levels[tn].case),
control = convert(treatment_types[tn], treatment_levels[tn].control)
)
for tn in names]
convert_estimand_treatment_values(Ψ, treatment_types) = Dict(
T => convert_treatment_values(val, treatment_types[T]) for (T, val) Ψ.treatment_values
)

function convert_treatment_values(treatment_levels::NamedTuple{(:control, :case)}, treatment_type)
return (
control = convert(treatment_type, treatment_levels.control),
case = convert(treatment_type, treatment_levels.case)
)
end

convert_treatment_values(treatment_levels::NamedTuple{names,}, treatment_types) where names =
[convert(treatment_types[tn], treatment_levels[tn]) for tn in names]
convert_treatment_values(treatment_level, treatment_type) = convert(treatment_type, treatment_level)

MissingSCMError() = ArgumentError(string("A Structural Causal Model should be provided in the configuration file in order to identify causal estimands."))

Expand All @@ -38,6 +40,7 @@ end
wrapped_type(x) = x
wrapped_type(x::Type{<:CategoricalValue{T,}}) where T = T
wrapped_type(x::Type{Union{Missing, T}}) where T = wrapped_type(T)

"""
Uses the values found in the dataset to create a new estimand with adjusted values.
"""
Expand All @@ -46,9 +49,7 @@ function fix_treatment_values!(treatment_types::AbstractDict, Ψ, dataset)
for tn in treatment_names
haskey(treatment_types, tn) ? nothing : treatment_types[tn] = wrapped_type(eltype(dataset[!, tn]))
end
new_treatment = NamedTuple{treatment_names}(
convert_treatment_values.treatment_values, treatment_types)
)
new_treatment = convert_estimand_treatment_values(Ψ, treatment_types)
return typeof(Ψ)(
outcome = Ψ.outcome,
treatment_values = new_treatment,
Expand Down
25 changes: 14 additions & 11 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ include(joinpath(TESTDIR, "testutils.jl"))

@testset "Test convert_treatment_values" begin
treatment_types = Dict(:T₁=> Union{Missing, Bool}, :T₂=> Int)
newT = TargetedEstimation.convert_treatment_values((T₁=1,), treatment_types)
@test newT isa Vector{Bool}
@test newT == [1]

newT = TargetedEstimation.convert_treatment_values((T₁=(case=1, control=0.),), treatment_types)
@test newT isa Vector{NamedTuple{(:case, :control), Tuple{Bool, Bool}}}
@test newT == [(case = true, control = false)]
Ψ = CM(;outcome = :Y, treatment_values=Dict(:T₁=>1, :T₂=>false))
newT = TargetedEstimation.convert_estimand_treatment_values(Ψ, treatment_types)
@test newT[:T₁] === true !== 1
@test newT[:T₂] === 0 !== false

newT = TargetedEstimation.convert_treatment_values((T₁=(case=1, control=0.), T₂=(case=true, control=0)), treatment_types)
@test newT isa Vector{NamedTuple{(:case, :control)}}
@test newT == [(case = true, control = false), (case = 1, control = 0)]
Ψ = ATE(;outcome = :Y, treatment_values=Dict(:T₁ => (case=1, control=0.)), )
newT = TargetedEstimation.convert_estimand_treatment_values(Ψ, treatment_types)
@test newT[:T₁] === (control=false, case=true) !== (control=0, case=1)

Ψ = IATE(;outcome = :Y, treatment_values=Dict(:T₁ => (case=1, control=0.), :T₂ => (case=true, control=0)), )
newT = TargetedEstimation.convert_estimand_treatment_values(Ψ, treatment_types)
@test newT[:T₁] === (control=false, case=true) !== (control=0, case=1)
@test newT[:T₂] === (control=0, case=1) !== (control=false, case=true)
end

@testset "Test treatments_from_estimands" begin
Expand Down Expand Up @@ -66,10 +69,10 @@ end
estimands = TargetedEstimation.proofread_estimands(config, dataset)
for estimand in estimands
if haskey(estimand.treatment_values, :T1)
check_type(estimand.treatment_values.T1, Float64)
check_type(estimand.treatment_values[:T1], Float64)
end
if haskey(estimand.treatment_values, :T2)
check_type(estimand.treatment_values.T2, Bool)
check_type(estimand.treatment_values[:T2], Bool)
end
end
# Clean estimands file
Expand Down

0 comments on commit 58156a2

Please sign in to comment.