Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add truncated normal distribution for torch distributions #2970

Conversation

melopeo
Copy link
Contributor

@melopeo melopeo commented Aug 17, 2023

Description of changes:
Add truncated normal distribution for torch distributions

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

Please tag this pr with at least one of these labels to make our release process faster: BREAKING, new feature, bug fix, other change, dev setup

@melopeo melopeo added the enhancement New feature or request label Aug 17, 2023
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@melopeo why does TruncatedNormal appear twice, both in distributions and distributions/utils?

Also what is the logic? Is this simply truncating samples between the bounds (when sampling) and raising errors for values outside the bounds when computing log density? Then maybe one could have a generic wrapper to apply to any distribution. (nevermind: I see there's more to it)

src/gluonts/torch/distributions/utils/truncated_normal.py Outdated Show resolved Hide resolved
src/gluonts/torch/distributions/utils/truncated_normal.py Outdated Show resolved Hide resolved
@lostella lostella added new feature (one of pr required labels) torch This concerns the PyTorch side of GluonTS labels Aug 18, 2023
@melopeo melopeo force-pushed the add_truncate_normal_distribution_for_torch_distributions branch from e3c8dcc to c10f145 Compare August 22, 2023 12:16
@melopeo melopeo force-pushed the add_truncate_normal_distribution_for_torch_distributions branch from bf5953a to c072d92 Compare August 22, 2023 14:21
Copy link
Contributor

@lostella lostella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thank you!

@melopeo melopeo merged commit db14c0f into awslabs:dev Aug 22, 2023
21 checks passed
@melopeo melopeo deleted the add_truncate_normal_distribution_for_torch_distributions branch August 22, 2023 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request new feature (one of pr required labels) torch This concerns the PyTorch side of GluonTS
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants