From 8a75cc57b176604e874af5b8fa288b86e4780881 Mon Sep 17 00:00:00 2001 From: kta-intel Date: Mon, 18 Nov 2024 08:06:43 -0800 Subject: [PATCH] fix docstrings and remove commented out lines Signed-off-by: kta-intel --- openfl-workspace/xgb_higgs/src/taskrunner.py | 39 +++++++++++++++----- openfl/federated/task/runner_xgb.py | 34 +++++++++++++---- 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/openfl-workspace/xgb_higgs/src/taskrunner.py b/openfl-workspace/xgb_higgs/src/taskrunner.py index 410c4f49c9..520e303be2 100644 --- a/openfl-workspace/xgb_higgs/src/taskrunner.py +++ b/openfl-workspace/xgb_higgs/src/taskrunner.py @@ -13,18 +13,24 @@ class XGBoostRunner(XGBoostTaskRunner): """ - Simple CNN for classification. + A class to run XGBoost training and validation tasks. - PyTorchTaskRunner inherits from nn.module, so you can define your model - in the same way that you would for PyTorch - """ + This class inherits from XGBoostTaskRunner and provides methods to train and validate + an XGBoost model using federated learning. + Attributes: + bst (xgb.Booster): The XGBoost model. + params (dict): Parameters for the XGBoost model. + num_rounds (int): Number of boosting rounds. + """ def __init__(self, params=None, num_rounds=1, **kwargs): - """Initialize. + """ + Initialize the XGBoostRunner. Args: - **kwargs: Additional arguments to pass to the function - + params (dict, optional): Parameters for the XGBoost model. Defaults to None. + num_rounds (int, optional): Number of boosting rounds. Defaults to 1. + **kwargs: Additional arguments to pass to the function. """ super().__init__(**kwargs) @@ -33,7 +39,15 @@ def __init__(self, params=None, num_rounds=1, **kwargs): self.num_rounds = num_rounds def train_(self, train_dataloader) -> Metric: - """Train model.""" + """ + Train the XGBoost model. + + Args: + train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'. + + Returns: + Metric: A Metric object containing the training loss. + """ dtrain = train_dataloader['dmatrix'] evals = [(dtrain, 'train')] evals_result = {} @@ -45,8 +59,15 @@ def train_(self, train_dataloader) -> Metric: return Metric(name=self.params['eval_metric'], value=np.array(loss)) def validate_(self, validation_dataloader) -> Metric: - """Validate model.""" + """ + Validate the XGBoost model. + Args: + validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'. + + Returns: + Metric: A Metric object containing the validation accuracy. + """ dtest = validation_dataloader['dmatrix'] y_test = validation_dataloader['labels'] preds = self.bst.predict(dtest) diff --git a/openfl/federated/task/runner_xgb.py b/openfl/federated/task/runner_xgb.py index cdc2972a87..782a5b8213 100644 --- a/openfl/federated/task/runner_xgb.py +++ b/openfl/federated/task/runner_xgb.py @@ -4,9 +4,6 @@ """XGBoostTaskRunner module.""" -# from copy import deepcopy -# from typing import Iterator, Tuple - import json import numpy as np @@ -20,10 +17,16 @@ class XGBoostTaskRunner(TaskRunner): def __init__(self, **kwargs): - """Initializes the XGBoostTaskRunner object. + """ + A class to manage XGBoost tasks in a federated learning environment. - Args: - **kwargs: Additional parameters to pass to the functions. + This class inherits from TaskRunner and provides methods to initialize and manage + the global model and required tensor keys for XGBoost tasks. + + Attributes: + global_model (xgb.Booster): The global XGBoost model. + required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function. + training_round_completed (bool): A flag to indicate if the training round is completed. """ super().__init__(**kwargs) self.global_model = None @@ -352,7 +355,15 @@ def save_native( self.bst.save_model(filepath) def train_(self, train_dataloader) -> Metric: - """Train model.""" + """ + Train the XGBoost model. + + Args: + train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'. + + Returns: + Metric: A Metric object containing the training loss. + """ dtrain = train_dataloader["dmatrix"] evals = [(dtrain, "train")] evals_result = {} @@ -371,8 +382,15 @@ def train_(self, train_dataloader) -> Metric: return Metric(name=self.loss_fn.__name__, value=np.array(loss)) def validate_(self, validation_dataloader) -> Metric: - """Validate model.""" + """ + Validate the XGBoost model. + Args: + validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'. + + Returns: + Metric: A Metric object containing the validation accuracy. + """ dtest = validation_dataloader["dmatrix"] y_test = validation_dataloader["labels"] preds = self.bst.predict(dtest)