-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add truncated normal distribution for torch distributions #2970
Add truncated normal distribution for torch distributions #2970
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@melopeo why does TruncatedNormal
appear twice, both in distributions
and distributions/utils
?
Also what is the logic? Is this simply truncating samples between the bounds (when sampling) and raising errors for values outside the bounds when computing log density? Then maybe one could have a generic wrapper to apply to any distribution. (nevermind: I see there's more to it)
e3c8dcc
to
c10f145
Compare
bf5953a
to
c072d92
Compare
Co-authored-by: Lorenzo Stella <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thank you!
Description of changes:
Add truncated normal distribution for torch distributions
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup