From 9dc4dad6e10e57d646be26a413f6d892478fd170 Mon Sep 17 00:00:00 2001 From: yes Date: Tue, 26 Nov 2024 04:39:12 -0800 Subject: [PATCH] keras and tf updated Signed-off-by: yes --- .../keras_cnn_mnist/plan/plan.yaml | 14 +-- .../keras_cnn_mnist/requirements.txt | 3 +- .../{tfmnist_inmemory.py => dataloader.py} | 0 .../src/{keras_cnn.py => taskrunner.py} | 27 ++---- openfl/federated/task/runner_keras.py | 95 ++++++++----------- 5 files changed, 61 insertions(+), 78 deletions(-) rename openfl-workspace/keras_cnn_mnist/src/{tfmnist_inmemory.py => dataloader.py} (100%) rename openfl-workspace/keras_cnn_mnist/src/{keras_cnn.py => taskrunner.py} (75%) diff --git a/openfl-workspace/keras_cnn_mnist/plan/plan.yaml b/openfl-workspace/keras_cnn_mnist/plan/plan.yaml index e1c661343e..8ec4f9b79c 100644 --- a/openfl-workspace/keras_cnn_mnist/plan/plan.yaml +++ b/openfl-workspace/keras_cnn_mnist/plan/plan.yaml @@ -1,13 +1,13 @@ -# Copyright (C) 2020-2021 Intel Corporation +# Copyright (C) 2020-2024 Intel Corporation # Licensed subject to the terms of the separately executed evaluation license agreement between Intel Corporation and you. aggregator : defaults : plan/defaults/aggregator.yaml template : openfl.component.Aggregator settings : - init_state_path : save/keras_cnn_mnist_init.pbuf - best_state_path : save/keras_cnn_mnist_best.pbuf - last_state_path : save/keras_cnn_mnist_last.pbuf + init_state_path : save/init.pbuf + best_state_path : save/best.pbuf + last_state_path : save/last.pbuf rounds_to_train : 10 collaborator : @@ -19,7 +19,7 @@ collaborator : data_loader : defaults : plan/defaults/data_loader.yaml - template : src.tfmnist_inmemory.TensorFlowMNISTInMemory + template : src.dataloader.TensorFlowMNISTInMemory settings : collaborator_count : 2 data_group_name : mnist @@ -27,7 +27,7 @@ data_loader : task_runner : defaults : plan/defaults/task_runner.yaml - template : src.keras_cnn.KerasCNN + template : src.taskrunner.KerasCNN network : defaults : plan/defaults/network.yaml @@ -39,4 +39,4 @@ tasks : defaults : plan/defaults/tasks_keras.yaml compression_pipeline : - defaults : plan/defaults/compression_pipeline.yaml + defaults : plan/defaults/compression_pipeline.yaml \ No newline at end of file diff --git a/openfl-workspace/keras_cnn_mnist/requirements.txt b/openfl-workspace/keras_cnn_mnist/requirements.txt index af80212eeb..74efde6582 100644 --- a/openfl-workspace/keras_cnn_mnist/requirements.txt +++ b/openfl-workspace/keras_cnn_mnist/requirements.txt @@ -1 +1,2 @@ -tensorflow==2.13 +tensorflow==2.18.0 +keras==3.6.0 \ No newline at end of file diff --git a/openfl-workspace/keras_cnn_mnist/src/tfmnist_inmemory.py b/openfl-workspace/keras_cnn_mnist/src/dataloader.py similarity index 100% rename from openfl-workspace/keras_cnn_mnist/src/tfmnist_inmemory.py rename to openfl-workspace/keras_cnn_mnist/src/dataloader.py diff --git a/openfl-workspace/keras_cnn_mnist/src/keras_cnn.py b/openfl-workspace/keras_cnn_mnist/src/taskrunner.py similarity index 75% rename from openfl-workspace/keras_cnn_mnist/src/keras_cnn.py rename to openfl-workspace/keras_cnn_mnist/src/taskrunner.py index 35a71f7734..b56a472c75 100644 --- a/openfl-workspace/keras_cnn_mnist/src/keras_cnn.py +++ b/openfl-workspace/keras_cnn_mnist/src/taskrunner.py @@ -3,11 +3,10 @@ """You may copy this file as the starting point of your own model.""" -import tensorflow.keras as ke -from tensorflow.keras import Sequential -from tensorflow.keras.layers import Conv2D -from tensorflow.keras.layers import Dense -from tensorflow.keras.layers import Flatten +from keras.models import Sequential +from keras.layers import Conv2D +from keras.layers import Dense +from keras.layers import Flatten from openfl.federated import KerasTaskRunner @@ -50,7 +49,7 @@ def build_model(self, num_classes (int): The number of classes of the dataset Returns: - tensorflow.python.keras.engine.sequential.Sequential: The model defined in Keras + keras.models.Sequential: The model defined in Keras """ model = Sequential() @@ -72,14 +71,8 @@ def build_model(self, model.add(Dense(num_classes, activation='softmax')) - model.compile(loss=ke.losses.categorical_crossentropy, - optimizer=ke.optimizers.legacy.Adam(), - metrics=['accuracy']) - - # initialize the optimizer variables - opt_vars = model.optimizer.variables() - - for v in opt_vars: - v.initializer.run(session=self.sess) - - return model + model.compile(loss="categorical_crossentropy", + optimizer="adam", + metrics=["accuracy"]) + + return model \ No newline at end of file diff --git a/openfl/federated/task/runner_keras.py b/openfl/federated/task/runner_keras.py index df56fd97e1..f85114f786 100644 --- a/openfl/federated/task/runner_keras.py +++ b/openfl/federated/task/runner_keras.py @@ -18,7 +18,7 @@ with catch_warnings(): simplefilter(action="ignore") import tensorflow as tf - import tensorflow.keras as ke + import keras as ke class KerasTaskRunner(TaskRunner): @@ -39,14 +39,13 @@ def __init__(self, **kwargs): """ super().__init__(**kwargs) - self.model = ke.Model() + self.model = ke.models.Model() self.model_tensor_names = [] # this is a map of all of the required tensors for each of the public # functions in KerasTaskRunner self.required_tensorkeys_for_function = {} - ke.backend.clear_session() def rebuild_model(self, round_num, input_tensor_dict, validation=False): """Parse tensor names and update weights of model. Handles the @@ -59,6 +58,7 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False): to False. """ if self.opt_treatment == "RESET": + # TODO issue while reseting the optimizer variables self.reset_opt_vars() self.set_tensor_dict(input_tensor_dict, with_opt_vars=False) elif round_num > 0 and self.opt_treatment == "CONTINUE_GLOBAL" and not validation: @@ -161,6 +161,7 @@ def train( if self.opt_treatment == "CONTINUE_GLOBAL": self.initialize_tensorkeys_for_functions(with_opt_vars=True) + self.update_tensorkeys_for_functions() return global_tensor_dict, local_tensor_dict def train_iteration(self, batch_generator, metrics: list = None, **kwargs): @@ -181,16 +182,16 @@ def train_iteration(self, batch_generator, metrics: list = None, **kwargs): # initialization (build_model). # If metrics are added (i.e. not a subset of what was originally # defined) then the model must be recompiled. - model_metrics_names = self.model.metrics_names + results = self.model.get_metrics_result() # TODO if there are new metrics in the flplan that were not included # in the originally # compiled model, that behavior is not currently handled. for param in metrics: - if param not in model_metrics_names: + if param not in results: raise ValueError( f"KerasTaskRunner does not support specifying new metrics. " - f"Param_metrics = {metrics}, model_metrics_names = {model_metrics_names}" + f"Param_metrics = {metrics}" ) history = self.model.fit(batch_generator, verbose=1, **kwargs) @@ -223,20 +224,16 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): self.rebuild_model(round_num, input_tensor_dict, validation=True) param_metrics = kwargs["metrics"] - vals = self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1) - model_metrics_names = self.model.metrics_names - if type(vals) is not list: - vals = [vals] - ret_dict = dict(zip(model_metrics_names, vals)) - + self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1) + results = self.model.get_metrics_result() # TODO if there are new metrics in the flplan that were not included in # the originally compiled model, that behavior is not currently # handled. for param in param_metrics: - if param not in model_metrics_names: + if param not in results: raise ValueError( f"KerasTaskRunner does not support specifying new metrics. " - f"Param_metrics = {param_metrics}, model_metrics_names = {model_metrics_names}" + f"Param_metrics = {param_metrics}" ) origin = col_name @@ -248,7 +245,7 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs): tags = ("metric",) tags = change_tags(tags, add_field=suffix) output_tensor_dict = { - TensorKey(metric, origin, round_num, True, tags): np.array(ret_dict[metric]) + TensorKey(metric, origin, round_num, True, tags): np.array(results[metric]) for metric in param_metrics } @@ -281,7 +278,10 @@ def _get_weights_names(obj): Returns: weight_names (list): The weight name list. """ - weight_names = [weight.name for weight in obj.weights] + if isinstance(obj, ke.optimizers.Optimizer): + weight_names = [weight.name for weight in obj.variables] + else: + weight_names = [layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights] return weight_names @staticmethod @@ -298,10 +298,17 @@ def _get_weights_dict(obj, suffix=""): weights_dict (dict): The weight dictionary. """ weights_dict = {} - weight_names = [weight.name for weight in obj.weights] - weight_values = obj.get_weights() - for name, value in zip(weight_names, weight_values): - weights_dict[name + suffix] = value + if isinstance(obj, ke.optimizers.Optimizer): + weight_names = [weight.name for weight in obj.variables] + weights_dict = {weight_names[i] + suffix: weight.numpy() for i, weight in enumerate(obj.variables)} + else: + weight_names = [layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights] + weight_name_index = 0 + for layer in obj.layers: + if weight_name_index < len(weight_names) and len(layer.get_weights()) > 0: + for weight in layer.get_weights(): + weights_dict[weight_names[weight_name_index] + suffix] = weight + weight_name_index += 1 return weights_dict @staticmethod @@ -313,8 +320,12 @@ def _set_weights_dict(obj, weights_dict): the weights. weights_dict (dict): The weight dictionary. """ - weight_names = [weight.name for weight in obj.weights] - weight_values = [weights_dict[name] for name in weight_names] + if isinstance(obj, ke.optimizers.Optimizer): + weight_names = [weight.name for weight in obj.variables] + weight_values = [weights_dict[name] for name in weight_names] + else: + weight_names = [layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights] + weight_values = [weights_dict[name] for name in weight_names] obj.set_weights(weight_values) def get_tensor_dict(self, with_opt_vars, suffix=""): @@ -348,45 +359,23 @@ def set_tensor_dict(self, tensor_dict, with_opt_vars): if with_opt_vars is False: # It is possible to pass in opt variables from the input tensor # dict. This will make sure that the correct layers are updated - model_weight_names = [weight.name for weight in self.model.weights] + model_weight_names = self._get_weights_names(self.model) model_weights_dict = {name: tensor_dict[name] for name in model_weight_names} self._set_weights_dict(self.model, model_weights_dict) else: - model_weight_names = [weight.name for weight in self.model.weights] + model_weight_names = self._get_weights_names(self.model) model_weights_dict = {name: tensor_dict[name] for name in model_weight_names} - opt_weight_names = [weight.name for weight in self.model.optimizer.weights] + opt_weight_names = self._get_weights_names(self.model.optimizer) opt_weights_dict = {name: tensor_dict[name] for name in opt_weight_names} self._set_weights_dict(self.model, model_weights_dict) self._set_weights_dict(self.model.optimizer, opt_weights_dict) def reset_opt_vars(self): """Resets the optimizer variables.""" - for var in self.model.optimizer.variables(): + for var in self.model.optimizer.variables: var.assign(tf.zeros_like(var)) self.logger.debug("Optimizer variables reset") - def set_required_tensorkeys_for_function(self, func_name, tensor_key, **kwargs): - """ - Set the required tensors for specified function that could be called as part of a task. - - By default, this is just all of the layers and optimizer of the model. - Custom tensors should be added to this function. - - Args: - func_name (str): The function name. - tensor_key (TensorKey): The tensor key. - **kwargs: Any function arguments. - """ - # TODO there should be a way to programmatically iterate through all - # of the methods in the class and declare the tensors. - # For now this is done manually - - if func_name == "validate": - # Should produce 'apply=global' or 'apply=local' - local_model = "apply" + kwargs["apply"] - self.required_tensorkeys_for_function[func_name][local_model].append(tensor_key) - else: - self.required_tensorkeys_for_function[func_name].append(tensor_key) def get_required_tensorkeys_for_function(self, func_name, **kwargs): """Get the required tensors for specified function that could be called @@ -423,17 +412,17 @@ def update_tensorkeys_for_functions(self): tensor_names = model_layer_names + opt_names self.logger.debug("Updating model tensor names: %s", tensor_names) self.required_tensorkeys_for_function["train"] = [ - TensorKey(tensor_name, "GLOBAL", 0, ("model",)) for tensor_name in tensor_names + TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) for tensor_name in tensor_names ] # Validation may be performed on local or aggregated (global) model, # so there is an extra lookup dimension for kwargs self.required_tensorkeys_for_function["validate"] = {} - self.required_tensorkeys_for_function["validate"]["local_model=True"] = [ - TensorKey(tensor_name, "LOCAL", 0, ("trained",)) for tensor_name in tensor_names + self.required_tensorkeys_for_function["validate"]["apply=local"] = [ + TensorKey(tensor_name, "LOCAL", 0, False, ("trained",)) for tensor_name in tensor_names ] - self.required_tensorkeys_for_function["validate"]["local_model=False"] = [ - TensorKey(tensor_name, "GLOBAL", 0, ("model",)) for tensor_name in tensor_names + self.required_tensorkeys_for_function["validate"]["apply=global"] = [ + TensorKey(tensor_name, "GLOBAL", 0, False, ("model",)) for tensor_name in tensor_names ] def initialize_tensorkeys_for_functions(self, with_opt_vars=False):