Skip to content

Commit

Permalink
skipped rebuild model for round 0
Browse files Browse the repository at this point in the history
Signed-off-by: yes <[email protected]>
  • Loading branch information
tanwarsh committed Nov 27, 2024
1 parent 9c294e1 commit ba9924d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/interactive-tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ jobs:
- name: Interactive API - tensorflow_mnist
run: |
python setup.py build_grpc
pip install tensorflow==2.13
pip install tensorflow==2.18.0
python -m tests.github.interactive_api_director.experiments.tensorflow_mnist.run
12 changes: 7 additions & 5 deletions openfl/federated/task/runner_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from warnings import catch_warnings, simplefilter

import numpy as np
import copy

from openfl.federated.task.runner import TaskRunner
from openfl.utilities import Metric, TensorKey, change_tags
Expand Down Expand Up @@ -58,9 +59,8 @@ def rebuild_model(self, round_num, input_tensor_dict, validation=False):
to False.
"""
if self.opt_treatment == "RESET":
# TODO issue while reseting the optimizer variables
self.reset_opt_vars()
self.set_tensor_dict(input_tensor_dict, with_opt_vars=False)
self.set_tensor_dict(input_tensor_dict, with_opt_vars=True)
elif round_num > 0 and self.opt_treatment == "CONTINUE_GLOBAL" and not validation:
self.set_tensor_dict(input_tensor_dict, with_opt_vars=True)
else:
Expand Down Expand Up @@ -98,7 +98,8 @@ def train(
raise KeyError("metrics must be defined")

# rebuild model with updated weights
self.rebuild_model(round_num, input_tensor_dict)
if round_num > 0:
self.rebuild_model(round_num, input_tensor_dict)
for epoch in range(epochs):
self.logger.info("Run %s epoch of %s round", epoch, round_num)
results = self.train_iteration(
Expand Down Expand Up @@ -221,7 +222,8 @@ def validate(self, col_name, round_num, input_tensor_dict, **kwargs):
else:
batch_size = 1

self.rebuild_model(round_num, input_tensor_dict, validation=True)
if round_num > 0:
self.rebuild_model(round_num, input_tensor_dict, validation=True)
param_metrics = kwargs["metrics"]

self.model.evaluate(self.data_loader.get_valid_loader(batch_size), verbose=1)
Expand Down Expand Up @@ -300,7 +302,7 @@ def _get_weights_dict(obj, suffix=""):
weights_dict = {}
if isinstance(obj, ke.optimizers.Optimizer):
weight_names = [weight.name for weight in obj.variables]
weights_dict = {weight_names[i] + suffix: weight.numpy() for i, weight in enumerate(obj.variables)}
weights_dict = {weight_names[i] + suffix: weight.numpy() for i, weight in enumerate(copy.deepcopy(obj.variables))}
else:
weight_names = [layer.name + "/" + weight.name for layer in obj.layers for weight in layer.weights]
weight_name_index = 0
Expand Down

0 comments on commit ba9924d

Please sign in to comment.