From 0ec14fa8f14ff5fb76a013dc111ea8f06f95f7a2 Mon Sep 17 00:00:00 2001 From: javier-garcia-tilburg <114025442+javier-garcia-tilburg@users.noreply.github.com> Date: Mon, 25 Nov 2024 22:11:00 +0100 Subject: [PATCH] Fix Gaussian integrator on jax Gaussian and GaussLegendre integrators throw an error on jax if anp.prod is called on a list instead of an array. See #214 --- torchquad/integration/gaussian.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchquad/integration/gaussian.py b/torchquad/integration/gaussian.py index 0d3c9612..17051681 100644 --- a/torchquad/integration/gaussian.py +++ b/torchquad/integration/gaussian.py @@ -65,7 +65,7 @@ def _weights(self, N, dim, backend, requires_grad=False): ).ravel() else: return anp.prod( - anp.meshgrid(*([weights] * dim), like=backend), axis=0 + anp.stack(anp.meshgrid(*([weights] * dim), like=backend)), axis=0 ).ravel() def _roots(self, N, backend, requires_grad=False):