Skip to content

Commit

Permalink
fix: didn't expand logprs
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Dec 22, 2024
1 parent 892275d commit d4f6143
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer
if m.p.train_feature
[lpr_eq * empirical_feature_logpr, empirical_feature_logpr]
else
lpr_eq = Dice.expand_logprs(l, lpr_eq)
[lpr_eq * compute(a, lpr_eq), empirical_feature_logpr]
end
else
Expand All @@ -463,7 +464,8 @@ function produce_loss(rs::RunState, m::FeatureSpecEntropyLossMgr, epoch::Integer
)
push!(m.num_meeting, num_meeting / length(samples))

loss = Dice.expand_logprs(l, loss) / length(samples)
# loss = Dice.expand_logprs(l, loss) / length(samples)
loss = loss / length(samples)
m.current_loss = loss
m.current_actual_loss = actual_loss
m.current_samples = samples
Expand Down

0 comments on commit d4f6143

Please sign in to comment.