Skip to content

Commit

Permalink
fix docstrings and remove commented out lines
Browse files Browse the repository at this point in the history
Signed-off-by: kta-intel <[email protected]>
  • Loading branch information
kta-intel committed Nov 18, 2024
1 parent 326069d commit 8a75cc5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 17 deletions.
39 changes: 30 additions & 9 deletions openfl-workspace/xgb_higgs/src/taskrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,24 @@

class XGBoostRunner(XGBoostTaskRunner):
"""
Simple CNN for classification.
A class to run XGBoost training and validation tasks.
PyTorchTaskRunner inherits from nn.module, so you can define your model
in the same way that you would for PyTorch
"""
This class inherits from XGBoostTaskRunner and provides methods to train and validate
an XGBoost model using federated learning.
Attributes:
bst (xgb.Booster): The XGBoost model.
params (dict): Parameters for the XGBoost model.
num_rounds (int): Number of boosting rounds.
"""
def __init__(self, params=None, num_rounds=1, **kwargs):
"""Initialize.
"""
Initialize the XGBoostRunner.
Args:
**kwargs: Additional arguments to pass to the function
params (dict, optional): Parameters for the XGBoost model. Defaults to None.
num_rounds (int, optional): Number of boosting rounds. Defaults to 1.
**kwargs: Additional arguments to pass to the function.
"""
super().__init__(**kwargs)

Expand All @@ -33,7 +39,15 @@ def __init__(self, params=None, num_rounds=1, **kwargs):
self.num_rounds = num_rounds

def train_(self, train_dataloader) -> Metric:
"""Train model."""
"""
Train the XGBoost model.
Args:
train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'.
Returns:
Metric: A Metric object containing the training loss.
"""
dtrain = train_dataloader['dmatrix']
evals = [(dtrain, 'train')]
evals_result = {}
Expand All @@ -45,8 +59,15 @@ def train_(self, train_dataloader) -> Metric:
return Metric(name=self.params['eval_metric'], value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
"""Validate model."""
"""
Validate the XGBoost model.
Args:
validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'.
Returns:
Metric: A Metric object containing the validation accuracy.
"""
dtest = validation_dataloader['dmatrix']
y_test = validation_dataloader['labels']
preds = self.bst.predict(dtest)
Expand Down
34 changes: 26 additions & 8 deletions openfl/federated/task/runner_xgb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

"""XGBoostTaskRunner module."""

# from copy import deepcopy
# from typing import Iterator, Tuple

import json

import numpy as np
Expand All @@ -20,10 +17,16 @@

class XGBoostTaskRunner(TaskRunner):
def __init__(self, **kwargs):
"""Initializes the XGBoostTaskRunner object.
"""
A class to manage XGBoost tasks in a federated learning environment.
Args:
**kwargs: Additional parameters to pass to the functions.
This class inherits from TaskRunner and provides methods to initialize and manage
the global model and required tensor keys for XGBoost tasks.
Attributes:
global_model (xgb.Booster): The global XGBoost model.
required_tensorkeys_for_function (dict): A dictionary to store required tensor keys for each function.
training_round_completed (bool): A flag to indicate if the training round is completed.
"""
super().__init__(**kwargs)
self.global_model = None
Expand Down Expand Up @@ -352,7 +355,15 @@ def save_native(
self.bst.save_model(filepath)

def train_(self, train_dataloader) -> Metric:
"""Train model."""
"""
Train the XGBoost model.
Args:
train_dataloader (dict): A dictionary containing the training data with keys 'dmatrix'.
Returns:
Metric: A Metric object containing the training loss.
"""
dtrain = train_dataloader["dmatrix"]
evals = [(dtrain, "train")]
evals_result = {}
Expand All @@ -371,8 +382,15 @@ def train_(self, train_dataloader) -> Metric:
return Metric(name=self.loss_fn.__name__, value=np.array(loss))

def validate_(self, validation_dataloader) -> Metric:
"""Validate model."""
"""
Validate the XGBoost model.
Args:
validation_dataloader (dict): A dictionary containing the validation data with keys 'dmatrix' and 'labels'.
Returns:
Metric: A Metric object containing the validation accuracy.
"""
dtest = validation_dataloader["dmatrix"]
y_test = validation_dataloader["labels"]
preds = self.bst.predict(dtest)
Expand Down

0 comments on commit 8a75cc5

Please sign in to comment.