-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pairwise losses (MSE, Hinge, Logistic, SoftZeroOne) #32
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass at the super class and utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
The comments for pairwise hinge loss apply to the other ones too.
) | ||
.replace( | ||
"{{explanation}}", | ||
"""\033[A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that actually work? Is it just because the dash -
has to come right after the """
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it works!
Computes pairwise hinge loss between true labels and predicted scores.
This loss function is designed for ranking tasks, where the goal is to
correctly order items within each list. It computes the loss by comparing
pairs of items within each list, penalizing cases where an item with a
higher true label has a lower predicted score than an item with a lower
true label.
For each list of predicted scores `s` in `y_pred` and the corresponding list
of true labels `y` in `y_true`, the loss is computed as follows:
```
loss = sum_{i} sum_{j} I(y_i > y_j) * (s_i - s_j)^2
```
where:
- `y_i` and `y_j` are the true labels of items `i` and `j`, respectively.
- `s_i` and `s_j` are the predicted scores of items `i` and `j`,
respectively.
- `I(y_i > y_j)` is an indicator function that equals 1 if `y_i > y_j`,
and 0 otherwise.
- `(s_i - s_j)^2` is the squared difference between the predicted scores
of items `i` and `j`, which penalizes discrepancies between the
predicted order of items relative to their true order.
Args:
reduction: Type of reduction to apply to the loss. In almost all cases
this should be `"sum_over_batch_size"`. Supported options are
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
`"mean_with_sample_weight"` or `None`. `"sum"` sums the loss,
`"sum_over_batch_size"` and `"mean"` sum the loss and divide by the
sample size, and `"mean_with_sample_weight"` sums the loss and
divides by the sum of the sample weights. `"none"` and `None`
perform no aggregation. Defaults to `"sum_over_batch_size"`.
name: Optional name for the loss instance.
dtype: The dtype of the loss's computations. Defaults to `None`, which
means using `keras.backend.floatx()`. `keras.backend.floatx()` is a
`"float32"` unless set to different value
(via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is
provided, then the `compute_dtype` will be utilized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the thing you pasted here?
Does it work with IDEs (e.g. VSCode for code completion and when you mouse over a symbol)
Does it work on Keras.io (the thing that generates the documentation)?
Those special codes are a terminal concept. So it works when you print
it. But outside of a terminal, it's unlikely to work, there is no equivalent concept in markdown, html...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the thing you pasted here?
print(PairwiseMeanSquaredError.__doc__)
Those special codes are a terminal concept. So it works when you print it. But outside of a terminal, it's unlikely to work, there is no equivalent concept in markdown, html...
Ah, I see
keras_rs/src/losses/pairwise_loss.py
Outdated
|
||
|
||
pairwise_loss_subclass_doc_string = ( | ||
" Computes pairwise hinge loss between true labels and predicted scores." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this simply be in the triple quotes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the reason I put it separately is because black formats it to be on the same line as =
, which in turn makes it go beyond the 80 line limit.
) | ||
.replace( | ||
"{{extra_args}}", | ||
"\033[A", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that work for removing the empty line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah
keras_rs/src/losses/pairwise_loss.py
Outdated
{explanation} | ||
|
||
Args: | ||
{extra_args} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like {extra_args}
is never actually used (always empty).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, keeping it in case we need to add other losses which might have extra args
|
||
|
||
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (s_i - s_j)^2" | ||
explanation = """\033[A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the two \033[A
and just use "
instead """
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did it slightly differently, please take a look
|
||
|
||
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * (1 - sigmoid(s_i - s_j))" | ||
explanation = """\033[A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the two \033[A
and just use "
instead """
.
|
||
|
||
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * max(0, 1 - (s_i - s_j))" | ||
explanation = """\033[A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the two \033[A
and just use "
instead """
.
|
||
|
||
formula = "loss = sum_{i} sum_{j} I(y_i > y_j) * log(1 + exp(-(s_i - s_j)))" | ||
explanation = """\033[A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the two \033[A
and just use "
instead """
.
To be discussed offline:
y_true
. Can usesample_weight
, but have a few questions about that [keep it for now].dim=0
expansion, but check TFR]cls.__doc__
to avoid repetition? [yes]TODO:
ops.stop_gradient()
.