Skip to content

Commit

Permalink
Fix Gaussian integrator on jax
Browse files Browse the repository at this point in the history
Gaussian and GaussLegendre integrators throw an error on jax if anp.prod is called on a list instead of an array. See esa#214
  • Loading branch information
javier-garcia-tilburg authored Nov 25, 2024
1 parent bf8ed5c commit 0ec14fa
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchquad/integration/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _weights(self, N, dim, backend, requires_grad=False):
).ravel()
else:
return anp.prod(
anp.meshgrid(*([weights] * dim), like=backend), axis=0
anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0
).ravel()

def _roots(self, N, backend, requires_grad=False):
Expand Down

0 comments on commit 0ec14fa

Please sign in to comment.