diff --git a/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py b/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py index a6fdb0524e..3ec55c125b 100644 --- a/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py +++ b/openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py @@ -35,6 +35,10 @@ from cifar10_loader import CIFAR10 import warnings +import hmac +import hashlib +import secrets + warnings.filterwarnings("ignore") batch_size_train = 32 @@ -47,6 +51,8 @@ random_seed = 10 torch.manual_seed(random_seed) +# HMAC Keys +model_hmac_key = {} class Net(nn.Module): def __init__(self, num_classes=10): @@ -173,14 +179,14 @@ def optimizer_to_device(optimizer, device): raise (ValueError("No dict keys in optimizer state: please check")) -def load_previous_round_model_and_optimizer_and_perform_testing( +def load_previous_round_model_and_optimizer_and_perform_testing_with_hmac( model, global_model, optimizer, collaborator_name, round_num, device ): """ Load pickle file to retrieve the model and optimizer state dictionary from the previous round for each collaborator and perform several validation routines with current - round state dictionaries to test the flow loop. + round state dictionaries to test the flow loop and check file integrity using HMAC. Note: this functionality can be enabled through the command line argument by setting "--flow_internal_loop_test=True". @@ -203,7 +209,19 @@ def load_previous_round_model_and_optimizer_and_perform_testing( f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num-1}.pickle", "rb", ) as f: - model_prevround_config = pickle.load(f) + file_content = f.read() + serialized_data = file_content[:-64] # SHA-256 HMAC is 64 hex characters + received_hmac_hex = file_content[-64:].decode() + # Create an HMAC for the serialized data + hmac_obj = hmac.new(model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num-1}.pickle"], serialized_data, hashlib.sha256) + computed_hmac_hex = hmac_obj.hexdigest() + # Verify the HMAC + if hmac.compare_digest(computed_hmac_hex, received_hmac_hex): + # Deserialize the data + model_prevround_config = pickle.loads(serialized_data) + else: + raise ValueError('HMAC verification failed. Data integrity compromised.') + model_prevround.load_state_dict(model_prevround_config["model_state_dict"]) optimizer_prevround.load_state_dict( model_prevround_config["optim_state_dict"] @@ -281,7 +299,7 @@ def load_previous_round_model_and_optimizer_and_perform_testing( raise (ValueError("No such name of pickle file exists")) -def save_current_round_model_and_optimizer_for_next_round_testing( +def save_current_round_model_and_optimizer_for_next_round_testing_with_hmac( model, optimizer, collaborator_name, round_num ): """ @@ -291,7 +309,8 @@ def save_current_round_model_and_optimizer_for_next_round_testing( for later retieving and verifying its correctness. This provide the user the ability to verify the fields in the model and optimizer state dictionary and - may provide confidence on the results of privacy auditing. + may provide confidence on the results of privacy auditing + and check file integrity using HMAC. Note: this functionality can be enabled through the command line argument by setting "--flow_internal_loop_test=True". @@ -305,11 +324,17 @@ def save_current_round_model_and_optimizer_for_next_round_testing( "model_state_dict": model.state_dict(), "optim_state_dict": optimizer.state_dict(), } - with open( - f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle", - "wb", - ) as f: - pickle.dump(model_config, f) + # Serialize the data + serialized_data = pickle.dumps(model_config) + # Create an HMAC for the serialized data + model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle"] = secrets.token_bytes(32) + hmac_obj = hmac.new(model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle"], serialized_data, hashlib.sha256) + hmac_hex = hmac_obj.hexdigest() + + # Save the serialized data and HMAC to a file + with open(f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle", 'wb') as f: + f.write(serialized_data) + f.write(hmac_hex.encode()) class FederatedFlow(FLSpec): @@ -380,7 +405,7 @@ def train(self): optimizer_to_device(optimizer=self.optimizer, device=self.device) if self.flow_internal_loop_test: - load_previous_round_model_and_optimizer_and_perform_testing( + load_previous_round_model_and_optimizer_and_perform_testing_with_hmac( self.model, self.global_model, self.optimizer, @@ -407,7 +432,7 @@ def train(self): self.training_completed = True if self.flow_internal_loop_test: - save_current_round_model_and_optimizer_for_next_round_testing( + save_current_round_model_and_optimizer_for_next_round_testing_with_hmac( self.model, self.optimizer, self.collaborator_name, self.round_num )