+
+ @property
+ defcandidate_models(self)->Dict[str,torch.nn.Module]:
+"""A set of candidate models."""
+ return{
+ "resnet10":mnn.ResNet10(num_classes=self.n_class),
+ "resnet18":mnn.ResNet18(num_classes=self.n_class),
+ }
+
+ @property
+ defdoi(self)->List[str]:
+"""DOI(s) related to the dataset."""
+ return["10.1109/cvpr.2009.5206848"]
+
+ @property
+ defurl(self)->str:
+"""URL for downloading the original dataset."""
+ return"http://cs231n.stanford.edu/tiny-imagenet-200.zip"
+
+ @property
+ deflabel_map(self)->dict:
+"""Label map for the dataset."""
+ return{
+ idx:self._wnid2label.get(label,label)
+ foridx,labelinenumerate(self._dataset_info["features"]["label"]["names"])
+ }
+
+
[docs]defview_image(self,client_idx:int,image_idx:int)->None:
+"""View a single image.
+
+ Parameters
+ ----------
+ client_idx : int
+ Index of the client on which the image is located.
+ image_idx : int
+ Index of the image in the client.
+
+ Returns
+ -------
+ None
+
+ """
+ importmatplotlib.pyplotasplt
+
+ ifclient_idx>=self.num_clients:
+ raiseValueError(f"client_idx must be less than {self.num_clients}, got {client_idx}")
+
+ total_num_images=len(self.indices["train"][client_idx])+len(self.indices["test"][client_idx])
+ ifimage_idx>=total_num_images:
+ raiseValueError(f"image_idx must be less than {total_num_images}, got {image_idx}")
+ ifimage_idx<len(self.indices["train"][client_idx]):
+ image=self._train_data_dict[self._IMGAE][self.indices["train"][client_idx][image_idx]]
+ label=self._train_data_dict[self._LABEL][self.indices["train"][client_idx][image_idx]]
+ image_idx=self.indices["train"][client_idx][image_idx]
+ else:
+ image_idx-=len(self.indices["train"][client_idx])
+ image=self._test_data_dict[self._IMGAE][self.indices["test"][client_idx][image_idx]]
+ label=self._test_data_dict[self._LABEL][self.indices["test"][client_idx][image_idx]]
+ image_idx=self.indices["test"][client_idx][image_idx]
+ # image: channel first to channel last
+ image=image.transpose(1,2,0)
+ plt.imshow(image)
+ plt.title(f"image_idx: {image_idx}, label: {label} ({self.label_map[int(label)]}")
+ plt.show()
+
+
[docs]defrandom_grid_view(self,nrow:int,ncol:int,save_path:Optional[Union[str,Path]]=None)->None:
+"""Select randomly `nrow` x `ncol` images from the dataset
+ and plot them in a grid.
+
+ Parameters
+ ----------
+ nrow : int
+ Number of rows in the grid.
+ ncol : int
+ Number of columns in the grid.
+ save_path : Union[str, Path], optional
+ Path to save the figure. If ``None``, do not save the figure.
+
+ Returns
+ -------
+ None
+
+ """
+ importmatplotlib.pyplotasplt
+
+ rng=np.random.default_rng()
+
+ fig,axes=plt.subplots(nrow,ncol,figsize=(ncol*1,nrow*1))
+ selected=[]
+ foriinrange(nrow):
+ forjinrange(ncol):
+ whileTrue:
+ client_idx=rng.integers(self.num_clients)
+ image_idx=rng.integers(len(self.indices["train"][client_idx]))
+ if(client_idx,image_idx)notinselected:
+ selected.append((client_idx,image_idx))
+ break
+ image=self._train_data_dict[self._IMGAE][self.indices["train"][client_idx][image_idx]]
+ axes[i,j].imshow(image.transpose(1,2,0))
+ axes[i,j].axis("off")
+ ifsave_pathisnotNone:
+ fig.savefig(save_path,bbox_inches="tight",dpi=600)
+ plt.tight_layout()
+ plt.show()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/_modules/fl_sim/models/nn.html b/_modules/fl_sim/models/nn.html
index 114b7bb..0caed7f 100644
--- a/_modules/fl_sim/models/nn.html
+++ b/_modules/fl_sim/models/nn.html
@@ -185,6 +185,7 @@
# ``Detected call of `lr_scheduler.step()` before `optimizer.step()`.``.# The risk is one has to check that scheduler.step() is called after# optimizer.step() in the training loop by himself.
- optimizer.step._with_counter=True
+ ifpackaging.version.parse(torch_version)<packaging.version.parse("2.4.0"):
+ optimizer.step._with_counter=True
+ else:
+ # NOTE: new in torch 2.4.0,
+ # the check by `optimizer.step._with_counter` is replaced by
+ # `optimizer.step._wrapped_by_lr_sched`
+ optimizer.step._wrapped_by_lr_sched=Truereturnoptimizertry:
@@ -554,7 +563,10 @@
[docs]classCSVLogger(BaseLogger):
-"""Logger that logs to a CSV file.
-
- Parameters
- ----------
- algorithm, dataset, model : str
- Used to form the prefix of the log file.
- log_dir : str or pathlib.Path, optional
- Directory to save the log file
- log_suffix : str, optional
- Suffix of the log file.
- verbose : int, default 1
- The verbosity level.
- Not used in this logger,
- but is kept for compatibility with other loggers.
-
- """
-
- __name__="CSVLogger"
-
- def__init__(
- self,
- algorithm:str,
- dataset:str,
- model:str,
- log_dir:Optional[Union[str,Path]]=None,
- log_suffix:Optional[str]=None,
- verbose:int=1,
- )->None:
- assertall([isinstance(x,str)forxin[algorithm,dataset,model]]),"algorithm, dataset, model must be str"
- self.log_prefix=re.sub("[\\s]+","_",f"{algorithm}-{dataset}-{model}")
- self._log_dir=self.set_log_dir(log_dir)
- iflog_suffixisNone:
- self.log_suffix=""
- else:
- self.log_suffix=f"_{log_suffix}"
- self.log_file=f"{self.log_prefix}_{get_date_str()}{self.log_suffix}.csv"
- self.logger=pd.DataFrame()
- self.step=-1
- self._flushed=True
-
-
[docs]defreset(self)->None:
-"""Reset the logger.
-
- Close the current logger and create a new one,
- with new log file name.
- """
- self.close()
- self.log_file=f"{self.log_prefix}_{get_date_str()}{self.log_suffix}.csv"
- self.logger=pd.DataFrame()
- self.step=-1
- self._flushed=True
[docs]@classmethod
- deffrom_config(cls,config:Dict[str,Any])->"CSVLogger":
-"""Create a :class:`CSVLogger` instance from a configuration.
-
- Parameters
- ----------
- config : dict
- Configuration for the logger. The following keys are used:
-
- - ``"algorithm"``: :obj:`str`,
- name of the algorithm.
- - ``"dataset"``: :obj:`str`,
- name of the dataset.
- - ``"model"``: :obj:`str`,
- name of the model.
- - ``"log_dir"``: :obj:`str` or :class:`pathlib.Path`, optional,
- directory to save the log file.
- - ``"log_suffix"``: :obj:`str`, optional,
- suffix of the log file.
-
- Returns
- -------
- CSVLogger
- A :class:`CSVLogger` instance.
-
- """
- returncls(**config)
[docs]classJsonLogger(BaseLogger):"""Logger that logs to a JSON file, or a yaml file.
@@ -1097,19 +970,6 @@
Source code for fl_sim.utils.loggers
))
- def_add_csv_logger(self)->None:
-"""Add a :class:`CSVLogger` instance to the manager."""
- self.loggers.append(
- CSVLogger(
- self._algorith,
- self._dataset,
- self._model,
- self._log_dir,
- self._log_suffix,
- self._verbose,
- )
- )
-
def_add_json_logger(self,fmt:str="json")->None:"""Add a :class:`JsonLogger` instance to the manager."""self.loggers.append(
@@ -1202,8 +1062,6 @@
Source code for fl_sim.utils.loggers
suffix of the log files. - ``"txt_logger"``: :obj:`bool`, optional, whether to add a :class:`TxtLogger` instance.
- - ``"csv_logger"``: :obj:`bool`, optional,
- whether to add a :class:`CSVLogger` instance. - ``"json_logger"``: :obj:`bool`, optional, whether to add a :class:`JsonLogger` instance. - ``"fmt"``: {"json", "yaml"}, optional,
@@ -1228,11 +1086,6 @@
Source code for fl_sim.utils.loggers
)ifconfig.get("txt_logger",True):lm._add_txt_logger()
- ifconfig.get("csv_logger",False):
- # for federated learning, csv logger has too many empty values,
- # resulting in a very large csv file,
- # hence it is not recommended to use csv logger.
- lm._add_csv_logger()ifconfig.get("json_logger",True):lm._add_json_logger(fmt=config.get("fmt",get_kwargs(JsonLogger)["fmt"]))returnlm
The Tiny ImageNet dataset is a subset of the ImageNet dataset. It consists of 200 classes, each with 500 training
+images and 50 validation images and 50 test images. The images are downsampled to 64x64 pixels.
+
The original dataset [1] contains the test images while the hugingface dataset [3] does not contain the test images.
+We use the hugingface dataset [3] for simplicity, and treat the validation set as the test set.
+
+
Parameters:
+
+
datadir (Union[pathlib.Path, str], optional) – Directory to store data.
+If None, use default directory.
+
num_clients (int, default 100) – Number of clients.
+
alpha (float, default 0.5) – Concentration parameter for the Dirichlet distribution.
+
transform (Union[str, Callable], default "none") – Transform to apply to data. Conventions:
+"none" means no transform, using TensorDataset.
+
seed (int, default 0) – Random seed for data partitioning.
+
**extra_config (dict, optional) – Extra configurations.
Get local dataloader at client client_idx or get the global dataloader.
+
+
Parameters:
+
+
train_bs (int, optional) – Batch size for training dataloader.
+If None, use default batch size.
+
test_bs (int, optional) – Batch size for testing dataloader.
+If None, use default batch size.
+
client_idx (int, optional) – Index of the client to get dataloader.
+If None, get the dataloader containing all data.
+Usually used for centralized training.