diff --git a/estimators/ccb/__init__.py b/estimators/ccb/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/estimators/ccb/base.py b/estimators/ccb/base.py new file mode 100644 index 0000000..065d07e --- /dev/null +++ b/estimators/ccb/base.py @@ -0,0 +1,51 @@ +""" Interface for implementation of conditional contextual bandits estimators """ + +from abc import ABC, abstractmethod +from typing import List + +class Estimator(ABC): + """ Interface for implementation of conditional contextual bandits estimators """ + + @abstractmethod + def add_example(self, p_log: List, r: List, p_pred: List, count: float) -> None: + """ + Args: + p_log: List of probability of the logging policy + r: List of reward for choosing an action in the given context + p_pred: List of predicted probability of making decision + count: weight + """ + ... + + @abstractmethod + def get(self) -> float: + """ Calculates the selected estimator + + Returns: + The estimator value + """ + ... + +class Interval(ABC): + """ Interface for implementation of conditional contextual bandits estimators interval """ + + @abstractmethod + def add_example(self, p_log: List, r: List, p_pred: List, count: float) -> None: + """ + Args: + p_log: List of probability of the logging policy + r: List of reward for choosing an action in the given context + p_pred: List of predicted probability of making decision + count: weight + """ + ... + + @abstractmethod + def get(self, alpha: float) -> List: + """ Calculates the CI + Args: + alpha: alpha value + Returns: + Returns the confidence interval as list[float] + """ + ... diff --git a/setup.py b/setup.py index 534df4d..ed5554e 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ "Operating System :: OS Independent", "Topic :: Scientific/Engineering" ], - packages=["estimators", "estimators.bandits", "estimators.slates", "estimators.utils"], + packages=["estimators", "estimators.bandits", "estimators.ccb", "estimators.slates", "estimators.utils"], install_requires= ['scipy>=0.9'], tests_require=['pytest'], python_requires=">=3.6",