From 651ca04941903eb9ace02e038b8de11a918f5b6d Mon Sep 17 00:00:00 2001 From: pat-alt Date: Wed, 12 Jun 2024 16:44:17 +0200 Subject: [PATCH 1/2] improved docstrings --- src/model.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 28648ec..bec4698 100644 --- a/src/model.jl +++ b/src/model.jl @@ -38,16 +38,25 @@ function class_loss(jem::JointEnergyModel, x, y; loss_fun = logitcrossentropy, a end """ - get_samples(jem::JointEnergyModel, x) + get_samples(jem::JointEnergyModel, x)::Tuple{AbstractArray,AbstractArray} Gets samples from the sampler buffer. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `x`: The input data. """ function get_samples(jem::JointEnergyModel, x) + # Determine the size of the batch: + # Either the size of the input data (training batch size) or the total size of the buffer, whichever is smaller. size_sample = minimum([size(x)[end], size(jem.sampler.buffer, ndims(jem.sampler.buffer))]) + # If the input batch is larger than the buffer, we need to sample a subset of the input data. if size_sample < size(x)[end] x = selectdim(x, ndims(x), rand(1:size(x)[end], size_sample)) end + # Get the `size_sample` samples from the buffer that were last added: xsample = selectdim(jem.sampler.buffer, ndims(jem.sampler.buffer), 1:size_sample) @assert size(xsample) == size(x) return x, xsample From 13017f9ba4b946b40c5f2e98866c1098d2442633 Mon Sep 17 00:00:00 2001 From: pat-alt Date: Mon, 9 Sep 2024 10:24:14 +0200 Subject: [PATCH 2/2] this was still open locally --- Project.toml | 2 +- src/model.jl | 114 +++++++++++++++++++++++++++++++++++++++++++++--- src/samplers.jl | 12 ++--- 3 files changed, 115 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index ab07394..fa42ce8 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ ChainRulesCore = "1.16" ComputationalResources = "0.3" Distributions = "0.25" Flux = "0.13, 0.14" -MLJFlux = "0.2, 0.3, 0.4.0" +MLJFlux = "0.2, 0.3, 0.4, 0.5" MLJModelInterface = "1.8" MLUtils = "0.4" ProgressMeter = "1.7" diff --git a/src/model.jl b/src/model.jl index bec4698..e5d7091 100644 --- a/src/model.jl +++ b/src/model.jl @@ -3,6 +3,7 @@ using Flux using Flux.Losses: logitcrossentropy using TaijaBase.Samplers: ImproperSGLD +"Base class for joint energy models." struct JointEnergyModel chain::Chain sampler::AbstractSampler @@ -10,6 +11,27 @@ struct JointEnergyModel sampling_steps::Int end +""" + JointEnergyModel( + chain::Union{Chain,Nothing}, + sampler::AbstractSampler; + sampling_rule = ImproperSGLD(), + sampling_steps = sampling_rule isa ImproperSGLD ? 10 : 1000, + ) + +Constructs a `JointEnergyModel` object. The `JointEnergyModel` object is a wrapper around a `Chain` model and a `Sampler` object. The `Sampler` object is used to generate samples from the model's energy function. The `sampling_rule` and `sampling_steps` parameters are used to specify the sampling rule and the number of sampling steps, respectively. + +# Arguments + +- `chain::Union{Chain,Nothing}`: The `Chain` model. +- `sampler::AbstractSampler`: The `Sampler` object. +- `sampling_rule::AbstractSamplingRule`: The sampling rule to use. Default is `ImproperSGLD()`. +- `sampling_steps::Int`: The number of sampling steps. + +# Returns + +- `jem::JointEnergyModel`: The `JointEnergyModel` object. +""" function JointEnergyModel( chain::Union{Chain,Nothing}, sampler::AbstractSampler; @@ -21,6 +43,11 @@ end Flux.@functor JointEnergyModel +""" + (jem::JointEnergyModel)(x) + +Computes the output of the joint energy model. +""" function (jem::JointEnergyModel)(x) jem.chain(x) end @@ -28,7 +55,19 @@ end @doc raw""" class_loss(jem::JointEnergyModel, x, y) -Computes the classification loss. +Computes the classification loss. The (default) classification loss is the cross-entropy loss between the predicted and target labels. The loss is aggregated using the `agg` function. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `x`: The input data. +- `y`: The target data. +- `loss_fun`: The loss function to use. +- `agg`: The aggregation function to use for the loss. + +# Returns + +- `ℓ`: The classification loss. """ function class_loss(jem::JointEnergyModel, x, y; loss_fun = logitcrossentropy, agg = mean) ŷ = jem(x) @@ -40,12 +79,17 @@ end """ get_samples(jem::JointEnergyModel, x)::Tuple{AbstractArray,AbstractArray} -Gets samples from the sampler buffer. +Gets samples from the sampler buffer. The number of samples is determined by the size of the input data `x` and the buffer. If the batch of input data is larger than the buffer, a subset of the input data is sampled. # Arguments - `jem::JointEnergyModel`: The joint energy model. - `x`: The input data. + +# Returns + +- `x`: The input data. +- `xsample`: The samples from the buffer. """ function get_samples(jem::JointEnergyModel, x) # Determine the size of the batch: @@ -65,7 +109,16 @@ end @doc raw""" gen_loss(jem::JointEnergyModel, x) -Computes the generative loss. +Computes the generative loss. The generative loss is the difference between the energy of the input data and the energy of the generated samples from the replay buffer. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `x`: The input data. + +# Returns + +- `ℓ`: The generative loss, which is the difference between the energy of the input data and the energy of the generated samples from the replay buffer. """ function gen_loss(jem::JointEnergyModel, x, y) x, xsample = get_samples(jem, x) @@ -77,7 +130,17 @@ end @doc raw""" reg_loss(jem::JointEnergyModel, x) -Computes the regularization loss. +Computes the regularization loss. The regularization loss is the sum of the squared energies of the input data and the generated samples from the replay buffer. This loss is used to prevent the model from overfitting with respect to the generative loss. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `x`: The input data. +- `y`: The target data. + +# Returns + +- `ℓ`: The regularization loss, which is the sum of the squared energies of the input data and the generated samples from the replay buffer. """ function reg_loss(jem::JointEnergyModel, x, y) x, xsample = get_samples(jem, x) @@ -89,7 +152,23 @@ end @doc raw""" loss(jem::JointEnergyModel, x, y; agg=mean) -Computes the total loss. +Computes the total loss. The total loss is a weighted sum of the classification, generative, and regularization losses. The weights are determined by the `α` parameter. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `x`: The input data. +- `y`: The target data. +- `agg`: The aggregation function to use for the loss. +- `α`: The weights for the classification, generative, and regularization losses. +- `use_class_loss`: Whether to use the classification loss. +- `use_gen_loss`: Whether to use the generative loss. +- `use_reg_loss`: Whether to use the regularization loss. +- `class_loss_fun`: The classification loss function to use. + +# Returns + +- `loss`: The total loss. """ function loss( jem::JointEnergyModel, @@ -123,7 +202,17 @@ end """ generate_samples(jem::JointEnergyModel, n::Int; kwargs...) -A convenience function for generating samples for a given energy model. If `n` is `missing`, then the sampler's `batch_size` is used. +A convenience function for generating samples for a given energy model. If `n` is `missing`, then the sampler's `batch_size` is used. The `kwargs` are passed to the sampler when it is called. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `n::Int`: The number of samples to generate. +- `kwargs`: Additional keyword arguments to pass to the sampler when it is called. + +# Returns + +- `samples`: The generated samples. """ function generate_samples(jem::JointEnergyModel, n::Int; kwargs...) n = ismissing(n) ? nothing : n @@ -135,9 +224,20 @@ function generate_samples(jem::JointEnergyModel, n::Int; kwargs...) end """ - generate_conditional_samples(model, rule::Flux.Optimise.AbstractOptimiser, n::Int, y::Int; kwargs...) + generate_conditional_samples(jem::JointEnergyModel, n::Int, y::Int; kwargs...) A convenience function for generating conditional samples for a given model, sampler and sampling rule. If `n` is `missing`, then the sampler's `batch_size` is used. The conditioning value `y` needs to be specified. + +# Arguments + +- `jem::JointEnergyModel`: The joint energy model. +- `n::Int`: The number of samples to generate. +- `y::Int`: The conditioning value. +- `kwargs`: Additional keyword arguments to pass to the sampler when it is called. + +# Returns + +- `samples`: The generated samples. """ function generate_conditional_samples(jem::JointEnergyModel, n::Int, y::Int; kwargs...) @assert typeof(jem.sampler) <: ConditionalSampler "sampler must be a ConditionalSampler" diff --git a/src/samplers.jl b/src/samplers.jl index e48a06b..64488dc 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -2,13 +2,15 @@ using CategoricalArrays using Distributions """ - ConditionalSampler( - X::AbstractArray, y::AbstractArray; - batch_size::Int, - max_len::Int=10000, prob_buffer::AbstractFloat=0.95 + TaijaBase.Samplers.ConditionalSampler( + X::Union{Tables.MatrixTable,AbstractMatrix}, + y::Union{CategoricalArray,AbstractMatrix}; + batch_size::Int = 1, + max_len::Int = 10000, + prob_buffer::AbstractFloat = 0.95, ) -Outer constructor for `ConditionalSampler`. +Overloads the `ConditionalSampler` constructor to preprocess the input data and return a `ConditionalSampler` object. """ function TaijaBase.Samplers.ConditionalSampler( X::Union{Tables.MatrixTable,AbstractMatrix},