From 2765a596eb70249725ccceeaeaa146697252910c Mon Sep 17 00:00:00 2001 From: Evening Date: Wed, 21 Feb 2024 11:47:31 +0800 Subject: [PATCH] Expose ImageNet Scaling option --- src/frdc/models/inceptionv3.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/frdc/models/inceptionv3.py b/src/frdc/models/inceptionv3.py index 4ee47ed6..3d3b3e49 100644 --- a/src/frdc/models/inceptionv3.py +++ b/src/frdc/models/inceptionv3.py @@ -25,11 +25,18 @@ def __init__( x_scaler: StandardScaler, y_encoder: OrdinalEncoder, ema_lr: float = 0.001, + imagenet_scaling: bool = False, ): """Initialize the InceptionV3 model. Args: - n_classes: The number of output classes + in_channels: The number of input channels. + n_classes: The number of classes. + lr: The learning rate. + x_scaler: The X input StandardScaler. + y_encoder: The Y input OrdinalEncoder. + ema_lr: The learning rate for the EMA model. + imagenet_scaling: Whether to use the adapted ImageNet scaling. Notes: - Min input size: 299 x 299. @@ -129,7 +136,7 @@ def adapt_inception_multi_channel( return inception @staticmethod - def transform_input(x: torch.Tensor) -> torch.Tensor: + def imagenet_scaling(x: torch.Tensor) -> torch.Tensor: """Perform adapted ImageNet normalization on the input tensor. See Also: @@ -181,7 +188,9 @@ def forward(self, x: torch.Tensor): f"Got: {x.shape[2]} x {x.shape[3]}." ) - x = self.transform_input(x) + if self.imagenet_scaling: + x = self.imagenet_scaling(x) + # During training, the auxiliary outputs are used for auxiliary loss, # but during testing, only the main output is used. if self.training: