From 313ebd20508e787d8f52379d06fa9cd40267b4eb Mon Sep 17 00:00:00 2001 From: Bora Uyar <bora.uyar@mdc-berlin.de> Date: Mon, 4 Mar 2024 20:03:32 +0100 Subject: [PATCH] condition no longer needed --- flexynesis/models/direct_pred_gcnn.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/flexynesis/models/direct_pred_gcnn.py b/flexynesis/models/direct_pred_gcnn.py index b6b3c92..cb14f0d 100644 --- a/flexynesis/models/direct_pred_gcnn.py +++ b/flexynesis/models/direct_pred_gcnn.py @@ -61,28 +61,17 @@ def __init__( # NOTE: For now we use matrices, so number of node input features is 1. input_dims = [1 for _ in range(len(layers))] - if self.gnn_conv_type == 'GCNN': - self.encoders = nn.ModuleList( - [ - GCNN( + self.encoders = nn.ModuleList( + [ + GNNs( input_dim=input_dims[i], hidden_dim=int(self.config["hidden_dim"]), # int because of pyg output_dim=self.config["latent_dim"], + act = self.config['activation'], + conv = self.gnn_conv_type ) - for i in range(len(layers)) - ]) - else: - self.encoders = nn.ModuleList( - [ - GNNs( - input_dim=input_dims[i], - hidden_dim=int(self.config["hidden_dim"]), # int because of pyg - output_dim=self.config["latent_dim"], - act = self.config['activation'], - conv = self.gnn_conv_type - ) - for i in range(len(layers)) - ]) + for i in range(len(layers)) + ]) # Init output layers self.MLPs = nn.ModuleDict()