Skip to content

Commit

Permalink
Improve LinearBandit documentation
Browse files Browse the repository at this point in the history
Summary: Improve LinearBandit documentation

Reviewed By: jb3618columbia

Differential Revision: D62759074

fbshipit-source-id: ee361128334c2ce327d373c8a6c57b74df18998b
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Oct 23, 2024
1 parent 90720b6 commit 69d7cc4
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions pearl/policy_learners/contextual_bandits/linear_bandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,29 @@

class LinearBandit(ContextualBanditBase):
"""
Policy Learner for Contextual Bandit with Linear Policy
Policy Learner for Contextual Bandit with Linear Policy.
This class implements a policy learner for a contextual bandit problem where the policy is
linear. It supports learning through linear regression and can apply discounting to observations
based on the number of weighted data points processed. The learner also supports exploration
modules for acting based on learned policies.
Attributes:
model (LinearRegression): Linear regression model used for learning.
apply_discounting_interval (float): Interval for applying discounting to the data points.
last_sum_weight_when_discounted (float): The counter for the last data point when discounting was applied.
Args:
feature_dim (int): Dimension of the feature space.
exploration_module (Optional[ExplorationModule]): Module for exploring actions.
l2_reg_lambda (float): L2 regularization parameter for the linear regression model.
gamma (float): Discount factor for discounting observations.
apply_discounting_interval (float): number of (weighted observations) for applying discounting to the data points.
Set to 0.0 to disable.
force_pinv (bool): If True, use pseudo-inverse for matrix inversion in the linear model.
training_rounds (int): Number of training rounds.
batch_size (int): Size of the batches used during training.
action_representation_module (Optional[ActionRepresentationModule]): Module for representing actions.
"""

def __init__(
Expand All @@ -46,9 +68,8 @@ def __init__(
exploration_module: Optional[ExplorationModule] = None,
l2_reg_lambda: float = 1.0,
gamma: float = 1.0,
apply_discounting_interval: float = 0.0, # discounting will be applied after this many
# observations (weighted) are processed. set to 0 to disable
force_pinv: bool = False, # If True, use pseudo inverse instead of regular inverse for `A`
apply_discounting_interval: float = 0.0,
force_pinv: bool = False,
training_rounds: int = 100,
batch_size: int = 128,
action_representation_module: Optional[ActionRepresentationModule] = None,
Expand Down

0 comments on commit 69d7cc4

Please sign in to comment.