Skip to content

Commit

Permalink
clean up taskrunner
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent 238448f commit 16cd7e1
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 54 deletions.
10 changes: 5 additions & 5 deletions openfl-workspace/xgb_higgs/src/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, params=None, num_rounds=1, **kwargs):
self.params = params
self.num_rounds = num_rounds

def train_(self, train_dataloader) -> Metric:
def train_(self, data) -> Metric:
"""
Train the XGBoost model.
Expand All @@ -48,7 +48,7 @@ def train_(self, train_dataloader) -> Metric:
Returns:
Metric: A Metric object containing the training loss.
"""
dtrain = train_dataloader['dmatrix']
dtrain = data['dmatrix']
evals = [(dtrain, 'train')]
evals_result = {}

Expand All @@ -58,7 +58,7 @@ def train_(self, train_dataloader) -> Metric:
loss = evals_result['train']['logloss'][-1]
return Metric(name=self.params['eval_metric'], value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
def validate_(self, data) -> Metric:
"""
Validate the XGBoost model.
Expand All @@ -68,8 +68,8 @@ def validate_(self, validation_dataloader) -> Metric:
Returns:
Metric: A Metric object containing the validation accuracy.
"""
dtest = validation_dataloader['dmatrix']
y_test = validation_dataloader['labels']
dtest = data['dmatrix']
y_test = data['labels']
preds = self.bst.predict(dtest)
y_pred_binary = np.where(preds > 0.5, 1, 0)
acc = accuracy_score(y_test, y_pred_binary)
Expand Down
63 changes: 14 additions & 49 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ def __init__(self, **kwargs):
Attributes:
global_model (xgb.Booster): The global XGBoost model.
required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function.
training_round_completed (bool): A flag to indicate if the training round is completed.
"""
super().__init__(**kwargs)
self.global_model = None
self.required_tensorkeys_for_function = {}
self.training_round_completed = False

def rebuild_model(self, input_tensor_dict):
"""
Expand Down Expand Up @@ -73,7 +71,7 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
local_output_dict (dict): Tensors to maintain in the local
TensorDB.
"""
loader = self.data_loader.get_valid_dmatrix()
data = self.data_loader.get_valid_dmatrix()

# during agg validation, self.bst will still be None. during local validation, it will have a value - no need to rebuild
if self.bst is None:
Expand All @@ -85,7 +83,7 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
# TODO: this is not robust, especially if using a loss metric
metric = Metric(name="accuracy", value=np.array(0))
else:
metric = self.validate_(loader)
metric = self.validate_(data)

origin = col_name
suffix = "validate"
Expand All @@ -95,7 +93,7 @@ def validate_task(self, col_name, round_num, input_tensor_dict, **kwargs):
suffix += "_agg"
tags = ("metric",)
tags = change_tags(tags, add_field=suffix)
# TODO figure out a better way to pass in metric for this pytorch

# validate function
output_tensor_dict = {TensorKey(metric.name, origin, round_num, True, tags): metric.value}

Expand Down Expand Up @@ -125,8 +123,8 @@ def train_task(
TensorDB.
"""
self.rebuild_model(input_tensor_dict)
loader = self.data_loader.get_train_dmatrix()
metric = self.train_(loader)
data = self.data_loader.get_train_dmatrix()
metric = self.train_(data)
# Output metric tensors (scalar)
origin = col_name
tags = ("trained",)
Expand Down Expand Up @@ -167,26 +165,6 @@ def train_task(
**next_local_tensorkey_model_dict,
}

# Update the required tensors if they need to be pulled from the
# aggregator
# TODO this logic can break if different collaborators have different
# roles between rounds.
# For example, if a collaborator only performs validation in the first
# round but training in the second, it has no way of knowing the
# optimizer state tensor names to request from the aggregator because
# these are only created after training occurs. A work around could
# involve doing a single epoch of training on random data to get the
# optimizer names, and then throwing away the model.
if self.opt_treatment == "CONTINUE_GLOBAL":
self.initialize_tensorkeys_for_functions()

# This will signal that the optimizer values are now present,
# and can be loaded when the model is rebuilt
self.training_round_completed = True

# Return global_tensor_dict, local_tensor_dict
# import pdb; pdb.set_trace()
# TODO it is still decodable from here with .tobytes().decode('utf-8')
return global_tensor_dict, local_tensor_dict

def get_tensor_dict(self, with_opt_vars=False):
Expand Down Expand Up @@ -266,30 +244,17 @@ def initialize_tensorkeys_for_functions(self, with_opt_vars=False):
Custom tensors should be added to this function.
Args:
with_opt_vars (bool): Flag to check if optimizer variables are
included. Defaults to False.
with_opt_vars (bool): with_opt_vars (bool): N/A for XGBoost (Default=False).
Returns:
None
"""
# TODO there should be a way to programmatically iterate through
# all of the methods in the class and declare the tensors.
# For now this is done manually

output_model_dict = self.get_tensor_dict(with_opt_vars=with_opt_vars)
output_model_dict = self.get_tensor_dict()
global_model_dict, local_model_dict = split_tensor_dict_for_holdouts(
self.logger, output_model_dict, **self.tensor_dict_split_fn_kwargs
)
if not with_opt_vars:
global_model_dict_val = global_model_dict
local_model_dict_val = local_model_dict
else:
output_model_dict = self.get_tensor_dict(with_opt_vars=False)
global_model_dict_val, local_model_dict_val = split_tensor_dict_for_holdouts(
self.logger,
output_model_dict,
**self.tensor_dict_split_fn_kwargs,
)
global_model_dict_val = global_model_dict
local_model_dict_val = local_model_dict

self.required_tensorkeys_for_function["train_task"] = [
TensorKey(tensor_name, "GLOBAL", 0, False, ("model",))
Expand Down Expand Up @@ -354,7 +319,7 @@ def save_native(
"""
self.bst.save_model(filepath)

def train_(self, train_dataloader) -> Metric:
def train_(self, data) -> Metric:
"""
Train the XGBoost model.
Expand All @@ -364,7 +329,7 @@ def train_(self, train_dataloader) -> Metric:
Returns:
Metric: A Metric object containing the training loss.
"""
dtrain = train_dataloader["dmatrix"]
dtrain = data["dmatrix"]
evals = [(dtrain, "train")]
evals_result = {}

Expand All @@ -381,7 +346,7 @@ def train_(self, train_dataloader) -> Metric:
loss = evals_result["train"]["logloss"][-1]
return Metric(name=self.loss_fn.__name__, value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
def validate_(self, data) -> Metric:
"""
Validate the XGBoost model.
Expand All @@ -391,8 +356,8 @@ def validate_(self, validation_dataloader) -> Metric:
Returns:
Metric: A Metric object containing the validation accuracy.
"""
dtest = validation_dataloader["dmatrix"]
y_test = validation_dataloader["labels"]
dtest = data["dmatrix"]
y_test = data["labels"]
preds = self.bst.predict(dtest)
y_pred_binary = np.where(preds > 0.5, 1, 0)
acc = accuracy_score(y_test, y_pred_binary)
Expand Down

0 comments on commit 16cd7e1

Please sign in to comment.