diff --git a/pearl/policy_learners/contextual_bandits/linear_bandit.py b/pearl/policy_learners/contextual_bandits/linear_bandit.py index a9c385bb..75d23ce0 100644 --- a/pearl/policy_learners/contextual_bandits/linear_bandit.py +++ b/pearl/policy_learners/contextual_bandits/linear_bandit.py @@ -37,7 +37,29 @@ class LinearBandit(ContextualBanditBase): """ - Policy Learner for Contextual Bandit with Linear Policy + Policy Learner for Contextual Bandit with Linear Policy. + + This class implements a policy learner for a contextual bandit problem where the policy is + linear. It supports learning through linear regression and can apply discounting to observations + based on the number of weighted data points processed. The learner also supports exploration + modules for acting based on learned policies. + + Attributes: + model (LinearRegression): Linear regression model used for learning. + apply_discounting_interval (float): Interval for applying discounting to the data points. + last_sum_weight_when_discounted (float): The counter for the last data point when discounting was applied. + + Args: + feature_dim (int): Dimension of the feature space. + exploration_module (Optional[ExplorationModule]): Module for exploring actions. + l2_reg_lambda (float): L2 regularization parameter for the linear regression model. + gamma (float): Discount factor for discounting observations. + apply_discounting_interval (float): number of (weighted observations) for applying discounting to the data points. + Set to 0.0 to disable. + force_pinv (bool): If True, use pseudo-inverse for matrix inversion in the linear model. + training_rounds (int): Number of training rounds. + batch_size (int): Size of the batches used during training. + action_representation_module (Optional[ActionRepresentationModule]): Module for representing actions. """ def __init__( @@ -46,9 +68,8 @@ def __init__( exploration_module: Optional[ExplorationModule] = None, l2_reg_lambda: float = 1.0, gamma: float = 1.0, - apply_discounting_interval: float = 0.0, # discounting will be applied after this many - # observations (weighted) are processed. set to 0 to disable - force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A` + apply_discounting_interval: float = 0.0, + force_pinv: bool = False, training_rounds: int = 100, batch_size: int = 128, action_representation_module: Optional[ActionRepresentationModule] = None,