diff --git a/sonar/models/mutox/builder.py b/sonar/models/mutox/builder.py index f021c85..7e6577f 100644 --- a/sonar/models/mutox/builder.py +++ b/sonar/models/mutox/builder.py @@ -40,25 +40,29 @@ def __init__( self.config = config self.device, self.dtype = device, dtype - def build_model(self, activation=nn.ReLU()) -> MutoxClassifier: + def build_model(self) -> MutoxClassifier: model_h1 = nn.Sequential( nn.Dropout(0.01), nn.Linear(self.config.input_size, 512), ) model_h2 = nn.Sequential( - activation, + nn.ReLU(), nn.Linear(512, 128), ) + model_h3 = nn.Sequential( + nn.ReLU(), + nn.Linear(128, 1), + ) + model_all = nn.Sequential( model_h1, model_h2, + model_h3, ) - return MutoxClassifier( - model_all, - ).to( + return MutoxClassifier(model_all,).to( device=self.device, dtype=self.dtype, ) diff --git a/sonar/models/mutox/classifier.py b/sonar/models/mutox/classifier.py index ada2efe..9ae7ebe 100644 --- a/sonar/models/mutox/classifier.py +++ b/sonar/models/mutox/classifier.py @@ -22,12 +22,13 @@ def __init__( self.model_all = model_all def forward(self, inputs: torch.Tensor, output_prob: bool = False) -> torch.Tensor: + outputs = self.model_all(inputs) + if output_prob: - self.model_all.add_module("sigmoid", nn.Sigmoid()) - else: - self.model_all.add_module("linear", nn.Linear(128, 1)) + outputs = torch.sigmoid(outputs) + + return outputs - return self.model_all(inputs) @dataclass