From 8a90383e1da459249f68b1a5499bf08f98b2acc3 Mon Sep 17 00:00:00 2001 From: nick7 Date: Tue, 28 Nov 2023 14:19:23 +0100 Subject: [PATCH] fix branch input size for multi-channel inputs Signed-off-by: nick7 --- src/torchphysics/models/deeponet/branchnets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchphysics/models/deeponet/branchnets.py b/src/torchphysics/models/deeponet/branchnets.py index 05d34d0a..e8c30c5f 100644 --- a/src/torchphysics/models/deeponet/branchnets.py +++ b/src/torchphysics/models/deeponet/branchnets.py @@ -25,7 +25,7 @@ def __init__(self, function_space, discretization_sampler): super().__init__(function_space, output_space=None) self.output_neurons = 0 self.discretization_sampler = discretization_sampler - self.input_dim = len(self.discretization_sampler) + self.input_dim = len(self.discretization_sampler) * function_space.output_space.dim self.current_out = torch.empty(0) def finalize(self, output_space, output_neurons):