From 132e922cbadc318db0e89b21f59885a260a9dd33 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sat, 14 Oct 2023 16:37:34 +0000 Subject: [PATCH] fixing many mypy issues --- baselines/fjord/__init__.py | 1 - baselines/fjord/fjord/client.py | 6 ++++-- baselines/fjord/fjord/dataset.py | 2 +- baselines/fjord/fjord/main.py | 6 +++--- baselines/fjord/fjord/models.py | 20 ++++++++++++------- baselines/fjord/fjord/od/layers/batch_norm.py | 4 ++-- baselines/fjord/fjord/od/layers/conv.py | 6 +++--- baselines/fjord/fjord/od/layers/linear.py | 2 +- baselines/fjord/fjord/od/models/utils.py | 15 ++++++-------- .../fjord/fjord/od/samplers/base_sampler.py | 4 ++-- baselines/fjord/fjord/od/samplers/fixed_od.py | 2 +- baselines/fjord/fjord/strategy.py | 8 ++++---- baselines/fjord/fjord/utils/logger.py | 4 ++-- 13 files changed, 42 insertions(+), 38 deletions(-) delete mode 100644 baselines/fjord/__init__.py diff --git a/baselines/fjord/__init__.py b/baselines/fjord/__init__.py deleted file mode 100644 index 5725b08adac4..000000000000 --- a/baselines/fjord/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""FjORD baseline package.""" diff --git a/baselines/fjord/fjord/client.py b/baselines/fjord/fjord/client.py index b7d2f3a412d6..0a9c3768a7b2 100644 --- a/baselines/fjord/fjord/client.py +++ b/baselines/fjord/fjord/client.py @@ -101,10 +101,12 @@ def get_agg_config( # Define Flower client -class FjORDClient(fl.client.NumPyClient): +class FjORDClient( + fl.client.NumPyClient +): # pylint: disable=too-many-instance-attributes """Flower client training on CIFAR-10.""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, cid: int, model_name: str, diff --git a/baselines/fjord/fjord/dataset.py b/baselines/fjord/fjord/dataset.py index b7e7ea2622f7..adedbf5d542a 100644 --- a/baselines/fjord/fjord/dataset.py +++ b/baselines/fjord/fjord/dataset.py @@ -75,7 +75,7 @@ def __len__(self): class FLCifar10(CIFAR10): """CIFAR10 Federated Dataset.""" - def __init__( + def __init__( # pylint: disable=too-many-arguments self, root: str, train: Optional[bool] = True, diff --git a/baselines/fjord/fjord/main.py b/baselines/fjord/fjord/main.py index df2cc60bdabc..8e607938f198 100644 --- a/baselines/fjord/fjord/main.py +++ b/baselines/fjord/fjord/main.py @@ -57,7 +57,7 @@ def get_eval_fn( def evaluate( server_round: int, parameters: fl.common.NDArrays, - config: Dict[str, fl.common.Scalar], + config: Dict[str, fl.common.Scalar], # pylint: disable=unused-argument ) -> Optional[Tuple[float, Dict[str, fl.common.Scalar]]]: if server_round and (server_round % args.evaluate_every == 0): net = get_net(args.model, args.p_s, device) @@ -82,7 +82,7 @@ def evaluate( return evaluate -def get_client_fn( +def get_client_fn( # pylint: disable=too-many-arguments args: Any, model_path: str, cid_to_max_p: Dict[int, float], @@ -137,7 +137,7 @@ def __init__(self, cid_to_max_p: Dict[int, float]) -> None: Args: :param cid_to_max_p: Dictionary mapping client id to max p-value """ - super(FjORDBalancedClientManager, self).__init__() + super().__init__() self.cid_to_max_p = cid_to_max_p self.p_s = sorted(set(self.cid_to_max_p.values())) diff --git a/baselines/fjord/fjord/models.py b/baselines/fjord/fjord/models.py index fbf106b58a98..bb9dfbf76935 100644 --- a/baselines/fjord/fjord/models.py +++ b/baselines/fjord/fjord/models.py @@ -25,8 +25,10 @@ class BasicBlock(nn.Module): expansion = 1 - def __init__(self, od, p_s, in_planes, planes, stride=1): - super(BasicBlock, self).__init__() + def __init__( + self, od, p_s, in_planes, planes, stride=1 + ): # pylint: disable=too-many-arguments + super().__init__() self.od = od self.conv1 = create_conv_layer( od, @@ -87,7 +89,7 @@ def forward(self, x, sampler): # Adapted from: # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py -class ResNet(nn.Module): +class ResNet(nn.Module): # pylint: disable=too-many-instance-attributes """ResNet in PyTorch. Reference: @@ -95,8 +97,10 @@ class ResNet(nn.Module): Deep Residual Learning for Image Recognition. arXiv:1512.03385 """ - def __init__(self, od, p_s, block, num_blocks, num_classes=10): - super(ResNet, self).__init__() + def __init__( + self, od, p_s, block, num_blocks, num_classes=10 + ): # pylint: disable=too-many-arguments + super().__init__() self.od = od self.in_planes = 64 @@ -110,7 +114,9 @@ def __init__(self, od, p_s, block, num_blocks, num_classes=10): self.layer4 = self._make_layer(od, p_s, block, 512, num_blocks[3], stride=2) self.linear = create_linear_layer(od, False, 512 * block.expansion, num_classes) - def _make_layer(self, od, p_s, block, planes, num_blocks, stride): + def _make_layer( + self, od, p_s, block, planes, num_blocks, stride + ): # pylint: disable=too-many-arguments strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: @@ -179,7 +185,7 @@ def get_net( return net -def train( +def train( # pylint: disable=too-many-locals, too-many-arguments net: Module, trainloader: DataLoader, know_distill: bool, diff --git a/baselines/fjord/fjord/od/layers/batch_norm.py b/baselines/fjord/fjord/od/layers/batch_norm.py index f0bf33da845c..510bc76d192c 100644 --- a/baselines/fjord/fjord/od/layers/batch_norm.py +++ b/baselines/fjord/fjord/od/layers/batch_norm.py @@ -8,7 +8,7 @@ __all__ = ["ODBatchNorm2d"] -class ODBatchNorm2d(nn.Module): +class ODBatchNorm2d(nn.Module): # pylint: disable=too-many-instance-attributes """Ordered Dropout BatchNorm2d.""" def __init__( @@ -19,7 +19,7 @@ def __init__( *args, **kwargs, ) -> None: - super(ODBatchNorm2d, self).__init__() + super().__init__() self.p_s = p_s self.is_od = False # no sampling is happening here self.num_features = num_features diff --git a/baselines/fjord/fjord/od/layers/conv.py b/baselines/fjord/fjord/od/layers/conv.py index 6ea75cda3f0f..66724096c57f 100644 --- a/baselines/fjord/fjord/od/layers/conv.py +++ b/baselines/fjord/fjord/od/layers/conv.py @@ -48,7 +48,7 @@ class ODConv1d(nn.Conv1d): def __init__(self, is_od: bool = True, *args, **kwargs) -> None: self.is_od = is_od - super(ODConv1d, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.width = self.out_channels self.last_input_dim = None self.last_output_dim = None @@ -75,7 +75,7 @@ class ODConv2d(nn.Conv2d): def __init__(self, is_od: bool = True, *args, **kwargs) -> None: self.is_od = is_od - super(ODConv2d, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.width = self.out_channels self.last_input_dim = None self.last_output_dim = None @@ -102,7 +102,7 @@ class ODConv3d(nn.Conv3d): def __init__(self, is_od: bool = True, *args, **kwargs) -> None: self.is_od = is_od - super(ODConv3d, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.width = self.out_channels self.last_input_dim = None self.last_output_dim = None diff --git a/baselines/fjord/fjord/od/layers/linear.py b/baselines/fjord/fjord/od/layers/linear.py index 7455e0a14a45..4655ac324a43 100644 --- a/baselines/fjord/fjord/od/layers/linear.py +++ b/baselines/fjord/fjord/od/layers/linear.py @@ -15,7 +15,7 @@ class ODLinear(nn.Linear): """Ordered Dropout Linear.""" def __init__(self, is_od: bool = True, *args, **kwargs) -> None: - super(ODLinear, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.is_od = is_od self.width = self.out_features self.last_input_dim = None diff --git a/baselines/fjord/fjord/od/models/utils.py b/baselines/fjord/fjord/od/models/utils.py index fb83a0c42400..d88f4a2ed216 100644 --- a/baselines/fjord/fjord/od/models/utils.py +++ b/baselines/fjord/fjord/od/models/utils.py @@ -15,8 +15,8 @@ def create_linear_layer(od, is_od, *args, **kwargs): """ if od: return ODLinear(is_od, *args, **kwargs) - else: - return nn.Linear(*args, **kwargs) + + return nn.Linear(*args, **kwargs) def create_conv_layer(od, is_od, *args, **kwargs): @@ -30,8 +30,8 @@ def create_conv_layer(od, is_od, *args, **kwargs): """ if od: return ODConv2d(is_od, *args, **kwargs) - else: - return nn.Conv2d(*args, **kwargs) + + return nn.Conv2d(*args, **kwargs) def create_bn_layer(od, p_s, *args, **kwargs): @@ -45,16 +45,13 @@ def create_bn_layer(od, p_s, *args, **kwargs): """ if od: return ODBatchNorm2d(p_s, *args, **kwargs) - else: - return nn.BatchNorm2d(*args, **kwargs) + + return nn.BatchNorm2d(*args, **kwargs) class SequentialWithSampler(nn.Sequential): """Implements sequential model with sampler.""" - def __init__(self, *args, **kwargs): - super(SequentialWithSampler, self).__init__(*args, **kwargs) - def forward(self, x, sampler=None): """Forward method for custom Sequential. diff --git a/baselines/fjord/fjord/od/samplers/base_sampler.py b/baselines/fjord/fjord/od/samplers/base_sampler.py index 902fe8420c8c..6458e5dba887 100644 --- a/baselines/fjord/fjord/od/samplers/base_sampler.py +++ b/baselines/fjord/fjord/od/samplers/base_sampler.py @@ -45,5 +45,5 @@ def __call__(self): """Call sampler.""" if self.with_layer: return next(self.width_samples), next(self.layer_samples) - else: - return next(self.width_samples) + + return next(self.width_samples) diff --git a/baselines/fjord/fjord/od/samplers/fixed_od.py b/baselines/fjord/fjord/od/samplers/fixed_od.py index a1b53b8e2bbe..b90912a7b5c2 100644 --- a/baselines/fjord/fjord/od/samplers/fixed_od.py +++ b/baselines/fjord/fjord/od/samplers/fixed_od.py @@ -15,7 +15,7 @@ class ODSampler(BaseSampler): """ def __init__(self, p_s: List[float], max_p: float, *args, **kwargs) -> None: - super(ODSampler, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.p_s = np.array([p for p in p_s if p <= max_p]) self.max_p = max_p diff --git a/baselines/fjord/fjord/strategy.py b/baselines/fjord/fjord/strategy.py index 8772b66d16d3..11ec5e00817c 100644 --- a/baselines/fjord/fjord/strategy.py +++ b/baselines/fjord/fjord/strategy.py @@ -59,7 +59,7 @@ def get_p_layer_updates( return layer_updates_p, num_examples_p -def fjord_average( +def fjord_average( # pylint: disable=too-many-arguments i: int, layer_updates: List[np.ndarray], num_examples: List[int], @@ -93,9 +93,9 @@ def fjord_average( ) if len(layer_updates_p) == 0: return update - else: - assert num_examples_p > 0 - return reduce(np.add, layer_updates_p) / num_examples_p + + assert num_examples_p > 0 + return reduce(np.add, layer_updates_p) / num_examples_p elif fjord_config["layer"][i] in ["ODLinear", "ODConv2d", "ODBatchNorm2d"]: # perform nested updates for p in p_s[::-1]: diff --git a/baselines/fjord/fjord/utils/logger.py b/baselines/fjord/fjord/utils/logger.py index 03ec7891d8b0..a9d92f1c1a38 100644 --- a/baselines/fjord/fjord/utils/logger.py +++ b/baselines/fjord/fjord/utils/logger.py @@ -49,8 +49,8 @@ def get(cls, logger_name="default"): """ if logger_name in cls.registered_loggers: return cls.registered_loggers[logger_name] - else: - return cls(logger_name) + + return cls(logger_name) def __init__(self, logger_name="default"): """Initialise logger not previously registered.