Skip to content

Commit

Permalink
chore: restore cifar_utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kcelia committed May 17, 2024
1 parent e300bd9 commit 3773515
Showing 1 changed file with 13 additions and 64 deletions.
77 changes: 13 additions & 64 deletions use_case_examples/cifar/cifar_brevitas_finetuning/cifar_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,30 +91,6 @@
]
),
},
"MNIST": {
"dataset": datasets.MNIST,
"mean": (0.5),
"std": (0.5),
"train_transform": transforms.Compose(
[
transforms.Pad(1, padding_mode="edge"),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
# transforms.RandomRotation(5, fill=(1,)),
transforms.GaussianBlur(kernel_size=(3, 3)),
# transforms.RandomHorizontalFlip(0.5),
# transforms.Resize(20),
]
),
"test_transform": transforms.Compose(
[
transforms.Pad(1, padding_mode="edge"),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
# transforms.Resize(20),
]
),
},
}


Expand Down Expand Up @@ -164,14 +140,11 @@ def get_dataloader(
Tuple[DataLoader, DataLoader]: Training and test data loaders.
"""

g = None

if param["seed"]:
g = torch.Generator()
g.manual_seed(param["seed"])
np.random.seed(param["seed"])
torch.manual_seed(param["seed"])
random.seed(param["seed"])
g = torch.Generator()
g.manual_seed(param["seed"])
np.random.seed(param["seed"])
torch.manual_seed(param["seed"])
random.seed(param["seed"])

max_examples = param.get("dataset_size", None)
train_dataset = get_torchvision_dataset(
Expand Down Expand Up @@ -360,10 +333,8 @@ def train(
nn.Module: the trained model.
"""

if param["seed"]:

torch.manual_seed(param["seed"])
random.seed(param["seed"])
torch.manual_seed(param["seed"])
random.seed(param["seed"])

model = model.to(device)

Expand All @@ -372,10 +343,6 @@ def train(
optimizer, milestones=param["milestones"], gamma=param["gamma"]
)

# Save the state_dict
dir = Path(".") / param["dir"] / param["training"]
dir.mkdir(parents=True, exist_ok=True)

# To avoid breaking up the tqdm bar
with tqdm(total=param["epochs"], file=sys.stdout) as pbar:

Expand Down Expand Up @@ -428,17 +395,12 @@ def train(
)
pbar.update(step)

torch.save(
model.state_dict(),
f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt",
)

print("Save in:", f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt")

# Save the state_dict
dir = Path(".") / param["dir"] / param["training"]
dir.mkdir(parents=True, exist_ok=True)
torch.save(
model.state_dict(), f"{dir}/{param['dataset_name']}_{param['training']}_state_dict.pt"
)
import pickle as pkl

with open(f"{dir}/{param['dataset_name']}_history.pkl", "wb") as f:
pkl.dump(param, f)
Expand Down Expand Up @@ -479,24 +441,12 @@ def torch_inference(
return np.mean(np.vstack(correct), dtype="float64")


def fhe_compatibility(
model: Callable,
data: DataLoader,
rounding_threshold_bits: Optional[int] = None,
show_mlir: bool = False,
output_onnx_file: str = "test.onnx",
) -> Callable:
def fhe_compatibility(model: Callable, data: DataLoader) -> Callable:
"""Test if the model is FHE-compatible.
Args:
model (Callable): The Brevitas model.
data (DataLoader): The data loader.
rounding_threshold_bits (Optiona[int]): if not None, every accumulators in the model are
rounded down to the given bits of precision.
show_mlir (bool): if set, the MLIR produced by the converter and which is going
to be sent to the compiler backend is shown on the screen, e.g., for debugging or demo.
output_onnx_file (str): temporary file to store ONNX model. If None a temporary file
is generated.
Returns:
Callable: Quantized model.
Expand All @@ -506,9 +456,8 @@ def fhe_compatibility(
model.to("cpu"),
# Training
torch_inputset=data,
show_mlir=show_mlir,
output_onnx_file=output_onnx_file,
rounding_threshold_bits=rounding_threshold_bits,
show_mlir=False,
output_onnx_file="test.onnx",
)

return qmodel
Expand Down

0 comments on commit 3773515

Please sign in to comment.