From 2c4df752893c8d5e78a180aa62ee3df1c39b5111 Mon Sep 17 00:00:00 2001 From: Robin Manhaeve Date: Fri, 5 Jul 2024 09:52:45 +0200 Subject: [PATCH] Fix batching network inputs of mixed type --- src/deepproblog/network.py | 33 +++++++++-------- .../tests/test_neural_predicate.py | 35 ++++++++++++------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/src/deepproblog/network.py b/src/deepproblog/network.py index 00bfcaff..8a72a3bf 100644 --- a/src/deepproblog/network.py +++ b/src/deepproblog/network.py @@ -17,13 +17,13 @@ class Network(object): """Wraps a PyTorch neural network for use with DeepProblog""" def __init__( - self, - network_module: torch.nn.Module, - name: str, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler=None, - k: Optional[int] = None, - batching: bool = False, + self, + network_module: torch.nn.Module, + name: str, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler=None, + k: Optional[int] = None, + batching: bool = False, ): """Create a Network object @@ -121,13 +121,17 @@ def __call__(self, to_evaluate: list) -> list: :return: """ if self.batching: - batched_inputs: List[torch.Tensor] = [ - self.function(*e)[0] for e in to_evaluate - ] - stacked_inputs = torch.stack(batched_inputs) - if self.is_cuda: - stacked_inputs = stacked_inputs.cuda(device=self.device) - evaluated = self.network_module(stacked_inputs) + inputs = (self.function(*e) for e in to_evaluate) + stacked_inputs = list() + for inputs in zip(*inputs): + try: + inputs = torch.stack(inputs) + if self.is_cuda: + inputs.cuda(device=self.device) + except TypeError: + inputs = list(inputs) + stacked_inputs.append(inputs) + evaluated = self.network_module(*stacked_inputs) else: evaluated = [self.network_module(*self.function(*e)) for e in to_evaluate] return evaluated @@ -169,7 +173,6 @@ def get_hyperparameters(self): } return parameters - # class NetworkEvaluation(object): # """ # An object that keeps track of which inputs the neural networks need to be evaluated on. diff --git a/src/deepproblog/tests/test_neural_predicate.py b/src/deepproblog/tests/test_neural_predicate.py index 156a3c03..319af08f 100644 --- a/src/deepproblog/tests/test_neural_predicate.py +++ b/src/deepproblog/tests/test_neural_predicate.py @@ -12,7 +12,7 @@ nn(dummy1,[X],Y,[a,b,c]) :: net1(X,Y). nn(dummy2,[X]) :: net2(X). nn(dummy3,[X],Y) :: net3(X,Y). -nn(dummy4,[X,Y],Z,[a,b]) :: net4(X,Y,Z). +nn(dummy4,[X,Y]) :: net4(X,Y). test1(X1,Y1,X2,Y2) :- net1(X1,Y1), net1(X2,Y2). test2(X1,X2) :- net2(X1), net2(X2). @@ -28,9 +28,19 @@ dummy_values3 = {Term("i1"): [1.0, 2.0, 3.0, 4.0], Term("i2"): [-1.0, 0.0, 1.0]} dummy_net3 = Network(DummyNet(dummy_values3), "dummy3") -dummy_net4 = Network(DummyTensorNet(batching=True), "dummy4", batching=True) -tensors = {(Constant(0),): torch.Tensor([0.2]), (Constant(1),): torch.Tensor([0.8])} +dummy_tensors = {(Term("a"),): torch.Tensor([0.1, 0.2, 0.3, 0.4]), (Term("b"),): torch.Tensor([0.25, 0.25, 0.25, 0.25])} + + +class IndexNet(torch.nn.Module): + + def forward(self, t, index): + # index = int(index) + index = torch.LongTensor([int(i) for i in index]) + return t.index_select(dim=1, index=index) + + +dummy_net4 = Network(IndexNet(), "dummy4", batching=True) @pytest.fixture( @@ -53,7 +63,7 @@ def model(request) -> Model: model = Model(program, [dummy_net1, dummy_net2, dummy_net3, dummy_net4], load=False) engine = request.param["engine_factory"](model) model.set_engine(engine, cache=request.param["cache"]) - model.add_tensor_source('dummy', tensors) + model.add_tensor_source('dummy', dummy_tensors) return model @@ -108,13 +118,12 @@ def test_det_network_substitution(model: Model): assert all(r1.detach().numpy() == [1.0, 2.0, 3.0, 4.0]) assert all(r2.detach().numpy() == [-1.0, 0.0, 1.0]) -def test_double_input(model: Model): - terms = lambda x: Term("net4", - Term("tensor",Term("dummy", Constant(0))), - Term("tensor",Term("dummy", Constant(1))), - x) - results = model.solve([Query(terms(Var("X")))]) - r1 = float(results[0].result[terms(Term("a"))]) - r2 = float(results[0].result[terms(Term("b"))]) +def test_multi_input_network(model: Model): + dummy_tensor = lambda x: Term("tensor", Term("dummy", x)) + q1 = Query(Term("net4", dummy_tensor(Term("a")), Constant(1))) + q2 = Query(Term("net4", dummy_tensor(Term("b")), Constant(2))) + results = model.solve([q1, q2]) + r1 = float(results[0].result[q1.query]) + r2 = float(results[1].result[q2.query]) assert pytest.approx(0.2) == r1 - assert pytest.approx(0.8) == r2 \ No newline at end of file + assert pytest.approx(0.25) == r2 \ No newline at end of file