From 3773515fb9b3a629eef866b9e2d522b213761f60 Mon Sep 17 00:00:00 2001 From: kcelia Date: Fri, 17 May 2024 09:45:21 +0200 Subject: [PATCH] chore: restore cifar_utils.py --- .../cifar_brevitas_finetuning/cifar_utils.py | 77 ++++--------------- 1 file changed, 13 insertions(+), 64 deletions(-) diff --git a/use_case_examples/cifar/cifar_brevitas_finetuning/cifar_utils.py b/use_case_examples/cifar/cifar_brevitas_finetuning/cifar_utils.py index 9b6dfc84e..641ec530b 100644 --- a/use_case_examples/cifar/cifar_brevitas_finetuning/cifar_utils.py +++ b/use_case_examples/cifar/cifar_brevitas_finetuning/cifar_utils.py @@ -91,30 +91,6 @@ ] ), }, - "MNIST": { - "dataset": datasets.MNIST, - "mean": (0.5), - "std": (0.5), - "train_transform": transforms.Compose( - [ - transforms.Pad(1, padding_mode="edge"), - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)), - # transforms.RandomRotation(5, fill=(1,)), - transforms.GaussianBlur(kernel_size=(3, 3)), - # transforms.RandomHorizontalFlip(0.5), - # transforms.Resize(20), - ] - ), - "test_transform": transforms.Compose( - [ - transforms.Pad(1, padding_mode="edge"), - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)), - # transforms.Resize(20), - ] - ), - }, } @@ -164,14 +140,11 @@ def get_dataloader( Tuple[DataLoader, DataLoader]: Training and test data loaders. """ - g = None - - if param["seed"]: - g = torch.Generator() - g.manual_seed(param["seed"]) - np.random.seed(param["seed"]) - torch.manual_seed(param["seed"]) - random.seed(param["seed"]) + g = torch.Generator() + g.manual_seed(param["seed"]) + np.random.seed(param["seed"]) + torch.manual_seed(param["seed"]) + random.seed(param["seed"]) max_examples = param.get("dataset_size", None) train_dataset = get_torchvision_dataset( @@ -360,10 +333,8 @@ def train( nn.Module: the trained model. """ - if param["seed"]: - - torch.manual_seed(param["seed"]) - random.seed(param["seed"]) + torch.manual_seed(param["seed"]) + random.seed(param["seed"]) model = model.to(device) @@ -372,10 +343,6 @@ def train( optimizer, milestones=param["milestones"], gamma=param["gamma"] ) - # Save the state_dict - dir = Path(".") / param["dir"] / param["training"] - dir.mkdir(parents=True, exist_ok=True) - # To avoid breaking up the tqdm bar with tqdm(total=param["epochs"], file=sys.stdout) as pbar: @@ -428,17 +395,12 @@ def train( ) pbar.update(step) - torch.save( - model.state_dict(), - f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt", - ) - - print("Save in:", f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt") - + # Save the state_dict + dir = Path(".") / param["dir"] / param["training"] + dir.mkdir(parents=True, exist_ok=True) torch.save( model.state_dict(), f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt" ) - import pickle as pkl with open(f"{dir}/{param['dataset_name']}_history.pkl", "wb") as f: pkl.dump(param, f) @@ -479,24 +441,12 @@ def torch_inference( return np.mean(np.vstack(correct), dtype="float64") -def fhe_compatibility( - model: Callable, - data: DataLoader, - rounding_threshold_bits: Optional[int] = None, - show_mlir: bool = False, - output_onnx_file: str = "test.onnx", -) -> Callable: +def fhe_compatibility(model: Callable, data: DataLoader) -> Callable: """Test if the model is FHE-compatible. Args: model (Callable): The Brevitas model. data (DataLoader): The data loader. - rounding_threshold_bits (Optiona[int]): if not None, every accumulators in the model are - rounded down to the given bits of precision. - show_mlir (bool): if set, the MLIR produced by the converter and which is going - to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo. - output_onnx_file (str): temporary file to store ONNX model. If None a temporary file - is generated. Returns: Callable: Quantized model. @@ -506,9 +456,8 @@ def fhe_compatibility( model.to("cpu"), # Training torch_inputset=data, - show_mlir=show_mlir, - output_onnx_file=output_onnx_file, - rounding_threshold_bits=rounding_threshold_bits, + show_mlir=False, + output_onnx_file="test.onnx", ) return qmodel