Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve training #28

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ChainRulesCore = "1.16"
ComputationalResources = "0.3"
Distributions = "0.25"
Flux = "0.13, 0.14"
MLJFlux = "0.2, 0.3, 0.4.0"
MLJFlux = "0.2, 0.3, 0.4, 0.5"
MLJModelInterface = "1.8"
MLUtils = "0.4"
ProgressMeter = "1.7"
Expand Down
125 changes: 117 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,35 @@ using Flux
using Flux.Losses: logitcrossentropy
using TaijaBase.Samplers: ImproperSGLD

"Base class for joint energy models."
struct JointEnergyModel
chain::Chain
sampler::AbstractSampler
sampling_rule::AbstractSamplingRule
sampling_steps::Int
end

"""
JointEnergyModel(
chain::Union{Chain,Nothing},
sampler::AbstractSampler;
sampling_rule = ImproperSGLD(),
sampling_steps = sampling_rule isa ImproperSGLD ? 10 : 1000,
)

Constructs a `JointEnergyModel` object. The `JointEnergyModel` object is a wrapper around a `Chain` model and a `Sampler` object. The `Sampler` object is used to generate samples from the model's energy function. The `sampling_rule` and `sampling_steps` parameters are used to specify the sampling rule and the number of sampling steps, respectively.

# Arguments

- `chain::Union{Chain,Nothing}`: The `Chain` model.
- `sampler::AbstractSampler`: The `Sampler` object.
- `sampling_rule::AbstractSamplingRule`: The sampling rule to use. Default is `ImproperSGLD()`.
- `sampling_steps::Int`: The number of sampling steps.

# Returns

- `jem::JointEnergyModel`: The `JointEnergyModel` object.
"""
function JointEnergyModel(
chain::Union{Chain,Nothing},
sampler::AbstractSampler;
Expand All @@ -21,14 +43,31 @@ end

Flux.@functor JointEnergyModel

"""
(jem::JointEnergyModel)(x)

Computes the output of the joint energy model.
"""
function (jem::JointEnergyModel)(x)
jem.chain(x)
end

@doc raw"""
class_loss(jem::JointEnergyModel, x, y)

Computes the classification loss.
Computes the classification loss. The (default) classification loss is the cross-entropy loss between the predicted and target labels. The loss is aggregated using the `agg` function.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `x`: The input data.
- `y`: The target data.
- `loss_fun`: The loss function to use.
- `agg`: The aggregation function to use for the loss.

# Returns

- `ℓ`: The classification loss.
"""
function class_loss(jem::JointEnergyModel, x, y; loss_fun = logitcrossentropy, agg = mean)
ŷ = jem(x)
Expand All @@ -38,16 +77,30 @@ function class_loss(jem::JointEnergyModel, x, y; loss_fun = logitcrossentropy, a
end

"""
get_samples(jem::JointEnergyModel, x)
get_samples(jem::JointEnergyModel, x)::Tuple{AbstractArray,AbstractArray}

Gets samples from the sampler buffer. The number of samples is determined by the size of the input data `x` and the buffer. If the batch of input data is larger than the buffer, a subset of the input data is sampled.

Gets samples from the sampler buffer.
# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `x`: The input data.

# Returns

- `x`: The input data.
- `xsample`: The samples from the buffer.
"""
function get_samples(jem::JointEnergyModel, x)
# Determine the size of the batch:
# Either the size of the input data (training batch size) or the total size of the buffer, whichever is smaller.
size_sample =
minimum([size(x)[end], size(jem.sampler.buffer, ndims(jem.sampler.buffer))])
# If the input batch is larger than the buffer, we need to sample a subset of the input data.
if size_sample < size(x)[end]
x = selectdim(x, ndims(x), rand(1:size(x)[end], size_sample))
end
# Get the `size_sample` samples from the buffer that were last added:
xsample = selectdim(jem.sampler.buffer, ndims(jem.sampler.buffer), 1:size_sample)
@assert size(xsample) == size(x)
return x, xsample
Expand All @@ -56,7 +109,16 @@ end
@doc raw"""
gen_loss(jem::JointEnergyModel, x)

Computes the generative loss.
Computes the generative loss. The generative loss is the difference between the energy of the input data and the energy of the generated samples from the replay buffer.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `x`: The input data.

# Returns

- `ℓ`: The generative loss, which is the difference between the energy of the input data and the energy of the generated samples from the replay buffer.
"""
function gen_loss(jem::JointEnergyModel, x, y)
x, xsample = get_samples(jem, x)
Expand All @@ -68,7 +130,17 @@ end
@doc raw"""
reg_loss(jem::JointEnergyModel, x)

Computes the regularization loss.
Computes the regularization loss. The regularization loss is the sum of the squared energies of the input data and the generated samples from the replay buffer. This loss is used to prevent the model from overfitting with respect to the generative loss.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `x`: The input data.
- `y`: The target data.

# Returns

- `ℓ`: The regularization loss, which is the sum of the squared energies of the input data and the generated samples from the replay buffer.
"""
function reg_loss(jem::JointEnergyModel, x, y)
x, xsample = get_samples(jem, x)
Expand All @@ -80,7 +152,23 @@ end
@doc raw"""
loss(jem::JointEnergyModel, x, y; agg=mean)

Computes the total loss.
Computes the total loss. The total loss is a weighted sum of the classification, generative, and regularization losses. The weights are determined by the `α` parameter.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `x`: The input data.
- `y`: The target data.
- `agg`: The aggregation function to use for the loss.
- `α`: The weights for the classification, generative, and regularization losses.
- `use_class_loss`: Whether to use the classification loss.
- `use_gen_loss`: Whether to use the generative loss.
- `use_reg_loss`: Whether to use the regularization loss.
- `class_loss_fun`: The classification loss function to use.

# Returns

- `loss`: The total loss.
"""
function loss(
jem::JointEnergyModel,
Expand Down Expand Up @@ -114,7 +202,17 @@ end
"""
generate_samples(jem::JointEnergyModel, n::Int; kwargs...)

A convenience function for generating samples for a given energy model. If `n` is `missing`, then the sampler's `batch_size` is used.
A convenience function for generating samples for a given energy model. If `n` is `missing`, then the sampler's `batch_size` is used. The `kwargs` are passed to the sampler when it is called.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `n::Int`: The number of samples to generate.
- `kwargs`: Additional keyword arguments to pass to the sampler when it is called.

# Returns

- `samples`: The generated samples.
"""
function generate_samples(jem::JointEnergyModel, n::Int; kwargs...)
n = ismissing(n) ? nothing : n
Expand All @@ -126,9 +224,20 @@ function generate_samples(jem::JointEnergyModel, n::Int; kwargs...)
end

"""
generate_conditional_samples(model, rule::Flux.Optimise.AbstractOptimiser, n::Int, y::Int; kwargs...)
generate_conditional_samples(jem::JointEnergyModel, n::Int, y::Int; kwargs...)

A convenience function for generating conditional samples for a given model, sampler and sampling rule. If `n` is `missing`, then the sampler's `batch_size` is used. The conditioning value `y` needs to be specified.

# Arguments

- `jem::JointEnergyModel`: The joint energy model.
- `n::Int`: The number of samples to generate.
- `y::Int`: The conditioning value.
- `kwargs`: Additional keyword arguments to pass to the sampler when it is called.

# Returns

- `samples`: The generated samples.
"""
function generate_conditional_samples(jem::JointEnergyModel, n::Int, y::Int; kwargs...)
@assert typeof(jem.sampler) <: ConditionalSampler "sampler must be a ConditionalSampler"
Expand Down
12 changes: 7 additions & 5 deletions src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ using CategoricalArrays
using Distributions

"""
ConditionalSampler(
X::AbstractArray, y::AbstractArray;
batch_size::Int,
max_len::Int=10000, prob_buffer::AbstractFloat=0.95
TaijaBase.Samplers.ConditionalSampler(
X::Union{Tables.MatrixTable,AbstractMatrix},
y::Union{CategoricalArray,AbstractMatrix};
batch_size::Int = 1,
max_len::Int = 10000,
prob_buffer::AbstractFloat = 0.95,
)

Outer constructor for `ConditionalSampler`.
Overloads the `ConditionalSampler` constructor to preprocess the input data and return a `ConditionalSampler` object.
"""
function TaijaBase.Samplers.ConditionalSampler(
X::Union{Tables.MatrixTable,AbstractMatrix},
Expand Down
Loading