Skip to content

Commit

Permalink
test: use @jit for simplified testing code
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 14, 2024
1 parent 40f305b commit b1aee66
Showing 1 changed file with 27 additions and 57 deletions.
84 changes: 27 additions & 57 deletions test/reactant/loss_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,8 @@
fn1(x) = LuxOps.xlogx.(x)
fn2(x, y) = LuxOps.xlogy.(x, y)

fn1_compiled = @compile fn1(x_ra)
@test fn1(x) fn1_compiled(x_ra)

fn2_compiled = @compile fn2(x_ra, y_ra)
@test fn2(x, y) fn2_compiled(x_ra, y_ra)
@test fn1(x) @jit(fn1(x_ra))
@test fn2(x, y) @jit(fn2(x_ra, y_ra))
end

@testset "Regression Loss" begin
Expand All @@ -43,14 +40,9 @@
loss_sum = eval(Symbol(loss * "Loss"))(; agg=sum)
loss_sum2 = eval(Symbol(loss * "Loss"))(; agg=(args...) -> sum(args...))

loss_mean_compiled = @compile loss_mean(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) loss_mean_compiled(ŷ_ra, y_ra)

loss_sum_compiled = @compile loss_sum(ŷ_ra, y_ra)
@test loss_sum(ŷ, y) loss_sum_compiled(ŷ_ra, y_ra)

loss_sum2_compiled = @compile loss_sum2(ŷ_ra, y_ra)
@test loss_sum2(ŷ, y) loss_sum2_compiled(ŷ_ra, y_ra)
@test loss_mean(ŷ, y) @jit(loss_mean(ŷ_ra, y_ra))
@test loss_sum(ŷ, y) @jit(loss_sum(ŷ_ra, y_ra))
@test loss_sum2(ŷ, y) @jit(loss_sum2(ŷ_ra, y_ra))
end

@testset "MSLE" begin
Expand All @@ -61,8 +53,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

loss_msle = MSLELoss()
loss_msle_compiled = @compile loss_msle(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) loss_msle_compiled(ŷ_ra, y_ra)
@test loss_msle(ŷ, y) @jit(loss_msle(ŷ_ra, y_ra))
end
end

Expand All @@ -75,39 +66,31 @@

@testset "CrossEntropyLoss" begin
celoss = CrossEntropyLoss()
celoss_compiled = @compile celoss(ŷ_ra, y_ra)
@test celoss(ŷ, y) celoss_compiled(ŷ_ra, y_ra)
@test celoss(ŷ, y) @jit(celoss(ŷ_ra, y_ra))

celoss_ls = CrossEntropyLoss(; label_smoothing=0.1)
celoss_ls_compiled = @compile celoss_ls(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) celoss_ls_compiled(ŷ_ra, y_ra)
@test celoss_ls(ŷ, y) @jit(celoss_ls(ŷ_ra, y_ra))

celoss_lp = CrossEntropyLoss(; logits=Val(true))
celoss_lp_compiled = @compile celoss_lp(log.(ŷ_ra), y_ra)
@test celoss_lp(log.(ŷ), y) celoss_lp_compiled(log.(ŷ_ra), y_ra)
@test celoss_lp(log.(ŷ), y) @jit(celoss_lp(log.(ŷ_ra), y_ra))

celoss_lp_ls = CrossEntropyLoss(; logits=Val(true), label_smoothing=0.1)
celoss_lp_ls_compiled = @compile celoss_lp_ls(log.(ŷ_ra), y_ra)
@test celoss_lp_ls(log.(ŷ), y) celoss_lp_ls_compiled(log.(ŷ_ra), y_ra)
@test celoss_lp_ls(log.(ŷ), y) @jit(celoss_lp_ls(log.(ŷ_ra), y_ra))
end

@testset "Binary CrossEntropyLoss" begin
bceloss = BinaryCrossEntropyLoss()
bceloss_compiled = @compile bceloss(ŷ_ra, y_ra)
@test bceloss(ŷ, y) bceloss_compiled(ŷ_ra, y_ra)
@test bceloss(ŷ, y) @jit(bceloss(ŷ_ra, y_ra))

bceloss_ls = BinaryCrossEntropyLoss(; label_smoothing=0.1)
bceloss_ls_compiled = @compile bceloss_ls(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) bceloss_ls_compiled(ŷ_ra, y_ra)
@test bceloss_ls(ŷ, y) @jit(bceloss_ls(ŷ_ra, y_ra))

bceloss_lp = BinaryCrossEntropyLoss(; logits=Val(true))
bceloss_lp_compiled = @compile bceloss_lp(log.(ŷ_ra), y_ra)
@test bceloss_lp(log.(ŷ), y) bceloss_lp_compiled(log.(ŷ_ra), y_ra)
@test bceloss_lp(log.(ŷ), y) @jit(bceloss_lp(log.(ŷ_ra), y_ra))

bceloss_lp_ls = BinaryCrossEntropyLoss(;
logits=Val(true), label_smoothing=0.1)
bceloss_lp_ls_compiled = @compile bceloss_lp_ls(log.(ŷ_ra), y_ra)
@test bceloss_lp_ls(log.(ŷ), y) bceloss_lp_ls_compiled(log.(ŷ_ra), y_ra)
@test bceloss_lp_ls(log.(ŷ), y) @jit(bceloss_lp_ls(log.(ŷ_ra), y_ra))
end

@testset "BinaryFocalLoss" begin
Expand All @@ -120,8 +103,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

bfl = BinaryFocalLoss()
bfl_compiled = @compile bfl(ŷ_ra, y_ra)
@test bfl(ŷ, y) bfl_compiled(ŷ_ra, y_ra)
@test bfl(ŷ, y) @jit(bfl(ŷ_ra, y_ra))
end

@testset "FocalLoss" begin
Expand All @@ -134,8 +116,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

fl = FocalLoss()
fl_compiled = @compile fl(ŷ_ra, y_ra)
@test fl(ŷ, y) fl_compiled(ŷ_ra, y_ra)
@test fl(ŷ, y) @jit(fl(ŷ_ra, y_ra))
end
end

Expand All @@ -148,8 +129,7 @@
ŷ_ra = Reactant.to_rarray(ŷ)

kldl = KLDivergenceLoss()
kldl_compiled = @compile kldl(ŷ_ra, y_ra)
@test kldl(ŷ, y) kldl_compiled(ŷ_ra, y_ra)
@test kldl(ŷ, y) @jit(kldl(ŷ_ra, y_ra))
end

@testset "HingeLoss" begin
Expand All @@ -160,12 +140,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = HingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))

hl = HingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))
end

@testset "SquaredHingeLoss" begin
Expand All @@ -176,12 +154,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

hl = SquaredHingeLoss()
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))

hl = SquaredHingeLoss(; agg=mean)
hl_compiled = @compile hl(ŷ_ra, y_ra)
@test hl(ŷ, y) hl_compiled(ŷ_ra, y_ra)
@test hl(ŷ, y) @jit(hl(ŷ_ra, y_ra))
end

@testset "PoissonLoss" begin
Expand All @@ -192,12 +168,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

pl = PoissonLoss()
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) @jit(pl(ŷ_ra, y_ra))

pl = PoissonLoss(; agg=mean)
pl_compiled = @compile pl(ŷ_ra, y_ra)
@test pl(ŷ, y) pl_compiled(ŷ_ra, y_ra)
@test pl(ŷ, y) @jit(pl(ŷ_ra, y_ra))
end

@testset "DiceCoeffLoss" begin
Expand All @@ -208,12 +182,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

dl = DiceCoeffLoss()
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) @jit(dl(ŷ_ra, y_ra))

dl = DiceCoeffLoss(; agg=mean)
dl_compiled = @compile dl(ŷ_ra, y_ra)
@test dl(ŷ, y) dl_compiled(ŷ_ra, y_ra)
@test dl(ŷ, y) @jit(dl(ŷ_ra, y_ra))
end

@testset "Siamese Contrastive Loss" begin
Expand All @@ -228,12 +200,10 @@
ŷ_ra = Reactant.to_rarray(ŷ)

sl = SiameseContrastiveLoss()
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) @jit(sl(ŷ_ra, y_ra))

sl = SiameseContrastiveLoss(; agg=mean)
sl_compiled = @compile sl(ŷ_ra, y_ra)
@test sl(ŷ, y) sl_compiled(ŷ_ra, y_ra)
@test sl(ŷ, y) @jit(sl(ŷ_ra, y_ra))
end
end
end
Expand Down

0 comments on commit b1aee66

Please sign in to comment.