-
-
Notifications
You must be signed in to change notification settings - Fork 612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make loss(f,x,y) == loss(f(x), y)
#2090
base: master
Are you sure you want to change the base?
Conversation
docs/src/models/losses.md
Outdated
All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`. | ||
This is convenient for [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`: | ||
|
||
```julia | ||
loss(ŷ, y) # defaults to `mean` | ||
loss(ŷ, y, agg=sum) # use `sum` for reduction | ||
loss(ŷ, y, agg=x->sum(x, dims=2)) # partial reduction | ||
loss(ŷ, y, agg=x->mean(w .* x)) # weighted mean | ||
loss(ŷ, y, agg=identity) # no aggregation. | ||
loss(model, x, y) = loss(model(x), y) | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GH won't let me suggest on this easily, but right now, it almost reads like you need to define the 3-arg loss
to work with train!
(which is the exact opposite intent!). Something like
All loss functions in Flux have a method which takes the model as the first argument, and calculates the prediction `ŷ = model(x)`, and finally the loss `loss(ŷ, y)`. This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`. For a custom loss, you can replicate this as:
```julia
myloss(model, x, y) = myloss(model(x), y)
```
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I wondered this too. In this doc section "loss" is an example of any built-in one.
I wonder if it should use say mse
everywhere, and say "Flux has a method like this already defined:"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, maybe it is clearer to start this section by saying something like "Using Flux.Losses.mse
as an example, ...". Then say, for this specific point,
All loss functions in Flux have a method which takes the model as the first argument, and calculates the loss such that
```julia
Flux.Losses.mse(model, x, y) == Flux.Losses.mse(model(x), y)
```
This is convenient for passing the loss function directly to [`train!`](@ref Flux.train)`(loss, model, [(x,y), (x2,y2), ...], opt)`.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Turns out I was half-done with changing this section locally to work through defining a new one, rather than listing properties of existing ones. See what you think? Agree that if it does discuss existing ones, it should be ==
.
A NEWS entry for this feature would be good too |
""" | ||
$($loss)(model, x, y) | ||
|
||
This method calculates `ŷ = model(x)`. Accepts the same keyword arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kept this docstring short. Not so sure whether or not it will show up in the doc listing, nor whether it should.
Sorry, I have to say that I'm really not a fan of this signature because it excludes a bunch of models while adding one more thing to know for loss function authors. For example, what does Given that the existing |
Yes I agree it's specialised to some uses. It just seems slightly weird to force people to define a function which is the just adjusting the signature to work, not doing any work or making any choices. They are forced to do so now because, in addition, this function closes over the model. So it must be re-defined if you change the model. I suppose it seems especially odd if the "official" documented way is that you must name this trivial function. And perhaps writing always something like this would be less odd:
However, that's still quite a bit of boilerplate to say "use mse". And I know some people find the |
If it were just a matter of clarifying how the |
Right now this is worse, For implicit-Flux, having methods like For explicit-Flux, we could have We could also just make |
Yeah, that's a good argument for having the data = [(x1,y1), (x2,y2), ...]
train!((m, x, y) -> mse(m(x), y), model, data, opt) Most users can directly copy-paste this, and those who have more complex forward passes can either define a separate function or ease into learning the train!((m, x, y) -> mse(m(x), y) + Optimisers.total(norm, m), model, data, opt) |
OK, https://fluxml.ai/Flux.jl/previews/PR2114/training/training/ takes this view that we should just always make an anon. function. It emphasises gradient + update over |
If
train!
stops accepting implicit parameters, as in #2082, then its loss function needs to accept the model as an argument, rather than close over it.This makes all the built-in ones do so, to avoid defining
loss(m,x,y) = mse(m(x), y)
etc. yourself every time.(Defining
loss(x,y) = mse(model(x), y)
every time used to be the idiom for closing over the model, and IMO this is pretty confusing. It means "loss function" means two things. Cleaner to delete this entirely than to update it to a 3-arg version.)PR Checklist