Skip to content

Commit

Permalink
test: compile functions in tests correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 14, 2024
1 parent b1aee66 commit 5dd9c16
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
12 changes: 8 additions & 4 deletions test/reactant/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,12 @@
@test celoss_ls(ŷ, y) @jit(celoss_ls(ŷ_ra, y_ra))

celoss_lp = CrossEntropyLoss(; logits=Val(true))
@test celoss_lp(log.(ŷ), y) @jit(celoss_lp(log.(ŷ_ra), y_ra))
logit_celoss_lp = (ŷ, y) -> celoss_lp(log.(ŷ), y)
@test logit_celoss_lp(ŷ, y) @jit(logit_celoss_lp(ŷ_ra, y_ra))

celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1)
@test celoss_lp_ls(log.(ŷ), y) @jit(celoss_lp_ls(log.(ŷ_ra), y_ra))
logit_celoss_lp_ls = (ŷ, y) -> celoss_lp_ls(log.(ŷ), y)
@test logit_celoss_lp_ls(ŷ, y) @jit(logit_celoss_lp_ls(ŷ_ra, y_ra))
end

@testset "Binary CrossEntropyLoss" begin
Expand All @@ -86,11 +88,13 @@
@test bceloss_ls(ŷ, y) @jit(bceloss_ls(ŷ_ra, y_ra))

bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true))
@test bceloss_lp(log.(ŷ), y) @jit(bceloss_lp(log.(ŷ_ra), y_ra))
logit_bceloss_lp = (ŷ, y) -> bceloss_lp(log.(ŷ), y)
@test logit_bceloss_lp(ŷ, y) @jit(logit_bceloss_lp(ŷ_ra, y_ra))

bceloss_lp_ls = BinaryCrossEntropyLoss(;
logits=Val(true), label_smoothing=0.1)
@test bceloss_lp_ls(log.(ŷ), y) @jit(bceloss_lp_ls(log.(ŷ_ra), y_ra))
logit_bceloss_lp_ls = (ŷ, y) -> bceloss_lp_ls(log.(ŷ), y)
@test logit_bceloss_lp_ls(ŷ, y) @jit(logit_bceloss_lp_ls(ŷ_ra, y_ra))
end

@testset "BinaryFocalLoss" begin
Expand Down
19 changes: 10 additions & 9 deletions test/reactant/training_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,22 @@
ps, st = Lux.setup(StableRNG(1234), model) |> xdev

x_ra = randn(Float32, 2, 32) |> xdev
y_ra = rand(Float32, 2, 32) |> xdev

inference_fn = @compile model(x_ra, ps, Lux.testmode(st))
inference_loss_fn = (xᵢ, yᵢ, mode, ps, st) -> begin
ŷᵢ, _ = model(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end
inference_loss_fn_compiled = @compile inference_loss_fn(
x_ra, y_ra, model, ps, st
)

x = [rand(Float32, 2, 32) for _ in 1:32]
y = [xᵢ .^ 2 for xᵢ in x]

dataloader = DeviceIterator(xdev, zip(x, y))

total_initial_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, ps, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end
total_initial_loss = mapreduce(inference_loss_fn_compiled, +, dataloader)

train_state = Training.TrainState(model, ps, st, Adam(0.01f0))

Expand All @@ -51,10 +55,7 @@
end
end

total_final_loss = mapreduce(+, dataloader) do (xᵢ, yᵢ)
ŷᵢ, _ = inference_fn(xᵢ, train_state.parameters, Lux.testmode(st))
return MSELoss()(ŷᵢ, yᵢ)
end
total_final_loss = mapreduce(inference_loss_fn_compiled, +, dataloader)

@test total_final_loss < 100 * total_initial_loss
end
Expand Down

0 comments on commit 5dd9c16

Please sign in to comment.