Skip to content

Commit

Permalink
fix for coverity issue 656283: Unsafe deserialization.
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Nov 8, 2024
1 parent fa3c516 commit 306455c
Showing 1 changed file with 37 additions and 12 deletions.
49 changes: 37 additions & 12 deletions openfl-tutorials/experimental/Privacy_Meter/cifar10_PM.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from cifar10_loader import CIFAR10
import warnings

import hmac
import hashlib
import secrets

warnings.filterwarnings("ignore")

batch_size_train = 32
Expand All @@ -47,6 +51,8 @@
random_seed = 10
torch.manual_seed(random_seed)

# HMAC Keys
model_hmac_key = {}

class Net(nn.Module):
def __init__(self, num_classes=10):
Expand Down Expand Up @@ -173,14 +179,14 @@ def optimizer_to_device(optimizer, device):
raise (ValueError("No dict keys in optimizer state: please check"))


def load_previous_round_model_and_optimizer_and_perform_testing(
def load_previous_round_model_and_optimizer_and_perform_testing_with_hmac(
model, global_model, optimizer, collaborator_name, round_num, device
):
"""
Load pickle file to retrieve the model and optimizer state dictionary
from the previous round for each collaborator
and perform several validation routines with current
round state dictionaries to test the flow loop.
round state dictionaries to test the flow loop and check file integrity using HMAC.
Note: this functionality can be enabled through the command line argument
by setting "--flow_internal_loop_test=True".
Expand All @@ -203,7 +209,19 @@ def load_previous_round_model_and_optimizer_and_perform_testing(
f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num-1}.pickle",
"rb",
) as f:
model_prevround_config = pickle.load(f)
file_content = f.read()
serialized_data = file_content[:-64] # SHA-256 HMAC is 64 hex characters
received_hmac_hex = file_content[-64:].decode()
# Create an HMAC for the serialized data
hmac_obj = hmac.new(model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num-1}.pickle"], serialized_data, hashlib.sha256)
computed_hmac_hex = hmac_obj.hexdigest()
# Verify the HMAC
if hmac.compare_digest(computed_hmac_hex, received_hmac_hex):
# Deserialize the data
model_prevround_config = pickle.loads(serialized_data)
else:
raise ValueError('HMAC verification failed. Data integrity compromised.')

model_prevround.load_state_dict(model_prevround_config["model_state_dict"])
optimizer_prevround.load_state_dict(
model_prevround_config["optim_state_dict"]
Expand Down Expand Up @@ -281,7 +299,7 @@ def load_previous_round_model_and_optimizer_and_perform_testing(
raise (ValueError("No such name of pickle file exists"))


def save_current_round_model_and_optimizer_for_next_round_testing(
def save_current_round_model_and_optimizer_for_next_round_testing_with_hmac(
model, optimizer, collaborator_name, round_num
):
"""
Expand All @@ -291,7 +309,8 @@ def save_current_round_model_and_optimizer_for_next_round_testing(
for later retieving and verifying its correctness.
This provide the user the ability to verify the fields
in the model and optimizer state dictionary and
may provide confidence on the results of privacy auditing.
may provide confidence on the results of privacy auditing
and check file integrity using HMAC.
Note: this functionality can be enabled through the command line
argument by setting "--flow_internal_loop_test=True".
Expand All @@ -305,11 +324,17 @@ def save_current_round_model_and_optimizer_for_next_round_testing(
"model_state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
}
with open(
f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle",
"wb",
) as f:
pickle.dump(model_config, f)
# Serialize the data
serialized_data = pickle.dumps(model_config)
# Create an HMAC for the serialized data
model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle"] = secrets.token_bytes(32)
hmac_obj = hmac.new(model_hmac_key[f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle"], serialized_data, hashlib.sha256)
hmac_hex = hmac_obj.hexdigest()

# Save the serialized data and HMAC to a file
with open(f"Collaborator_{collaborator_name}_model_config_roundnumber_{round_num}.pickle", 'wb') as f:
f.write(serialized_data)
f.write(hmac_hex.encode())


class FederatedFlow(FLSpec):
Expand Down Expand Up @@ -380,7 +405,7 @@ def train(self):
optimizer_to_device(optimizer=self.optimizer, device=self.device)

if self.flow_internal_loop_test:
load_previous_round_model_and_optimizer_and_perform_testing(
load_previous_round_model_and_optimizer_and_perform_testing_with_hmac(
self.model,
self.global_model,
self.optimizer,
Expand All @@ -407,7 +432,7 @@ def train(self):
self.training_completed = True

if self.flow_internal_loop_test:
save_current_round_model_and_optimizer_for_next_round_testing(
save_current_round_model_and_optimizer_for_next_round_testing_with_hmac(
self.model, self.optimizer, self.collaborator_name, self.round_num
)

Expand Down

0 comments on commit 306455c

Please sign in to comment.