Skip to content

Commit

Permalink
Merge pull request #33 from msamsami/maint-utils-improve
Browse files Browse the repository at this point in the history
Add utility functions + minor improvements
  • Loading branch information
msamsami authored Aug 3, 2024
2 parents 697696d + 9710a85 commit 542a90c
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 56 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
<img src="https://raw.githubusercontent.com/msamsami/weighted-naive-bayes/main/docs/logo.png" alt="wnb logo" width="275" />
</div>

<div align="center"> <b>General and weighted naive Bayes classifiers</b> </div> <br>
<div align="center"> <b>General and weighted naive Bayes classifiers</b> </div>
<div align="center">Scikit-learn-compatible</div> <br>

<div align="center">

![Lastest Release](https://img.shields.io/badge/release-v0.2.4-green)
![Lastest Release](https://img.shields.io/badge/release-v0.2.5-green)
[![PyPI Version](https://img.shields.io/pypi/v/wnb)](https://pypi.org/project/wnb/)
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12-blue)<br>
![GitHub Workflow Status (build)](https://github.com/msamsami/weighted-naive-bayes/actions/workflows/python-publish.yml/badge.svg)
Expand Down
28 changes: 14 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ keywords = [
"probabilistic",
]
classifiers = [
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development :: Libraries :: Python Modules",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: BSD License",
]
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
"Topic :: Software Development :: Libraries :: Python Modules",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"License :: OSI Approved :: BSD License",
]
requires-python = ">=3.8,<3.13"
dependencies = [
"pandas>=1.4.1",
Expand All @@ -43,7 +43,7 @@ dependencies = [

[project.optional-dependencies]
dev = [
"pytest>=7.3.1",
"pytest>=7.0.0",
"black>=24.4.2",
"tqdm>=4.65.0",
"pre-commit>=3.7.1",
Expand Down
2 changes: 1 addition & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ numpy<2.0.0
scipy>=1.8.0
scikit-learn>=1.0.2
typing-extensions>=4.8.0
pytest>=7.3.1
pytest>=7.0.0
black>=24.4.2
tqdm>=4.65.0
pre-commit>=3.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =

[options.extras_require]
dev =
pytest>=7.3.1
pytest>=7.0.0
black>=24.4.2
tqdm>=4.65.0
pre-commit>=3.7.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
],
extras_require={
"dev": [
"pytest>=7.3.1",
"pytest>=7.0.0",
"black>=24.4.2",
"tqdm>=4.65.0",
"pre-commit>=3.7.1",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,6 @@ def test_gnb_invalid_dist():
"""
clf = GeneralNB(distributions=["Normal", "Borel"])

msg = "Distribution 'Borel' is not supported"
msg = r"Distribution .* is not supported"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)
2 changes: 1 addition & 1 deletion wnb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Python library for the implementations of general and weighted naive Bayes (WNB) classifiers.
"""

__version__ = "0.2.4"
__version__ = "0.2.5"
__author__ = "Mehdi Samsami"


Expand Down
41 changes: 11 additions & 30 deletions wnb/gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from typing_extensions import Self

from ._typing import ArrayLike, DistibutionLike, Float, MatrixLike
from .dist import AllDistributions, NonNumericDistributions
from .dist import NonNumericDistributions
from .enums import Distribution
from .utils import get_dist_class, is_dist_supported

__all__ = [
"GeneralNB",
Expand Down Expand Up @@ -182,9 +183,9 @@ def _prepare_parameters(self):

self.class_prior_ = self.priors

# Convert to NumPy array if input priors is in a list
if type(self.class_prior_) is list:
self.class_prior_ = np.array(self.class_prior_)
# Convert to NumPy array if input priors is in a list/tuple/set
if isinstance(self.class_prior_, (list, tuple, set)):
self.class_prior_ = np.array(list(self.class_prior_))

# Set distributions if not specified
if self.distributions is None:
Expand All @@ -199,34 +200,14 @@ def _prepare_parameters(self):
)

# Check that all specified distributions are supported
for dist in self.distributions:
if not (
isinstance(dist, Distribution)
or dist in Distribution.__members__.values()
or (hasattr(dist, "from_data") and hasattr(dist, "__call__"))
):
raise ValueError(f"Distribution '{dist}' is not supported.")
for i, dist in enumerate(self.distributions):
if not is_dist_supported(dist):
raise ValueError(
f"Distribution '{dist}' at index {i} is not supported."
)

self.distributions_ = self.distributions

@staticmethod
def _get_dist_object(name_or_obj):
if (
isinstance(name_or_obj, Distribution)
or name_or_obj in Distribution.__members__.values()
):
return AllDistributions[name_or_obj]
elif isinstance(name_or_obj, str) and name_or_obj.upper() in [
d.name for d in Distribution
]:
return AllDistributions[Distribution.__members__[name_or_obj.upper()]]
elif isinstance(name_or_obj, str) and name_or_obj.title() in [
d.value for d in Distribution
]:
return AllDistributions[Distribution(name_or_obj.title())]
else:
return name_or_obj

def fit(self, X: MatrixLike, y: ArrayLike) -> Self:
"""Fits general Naive Bayes classifier according to X, y.
Expand Down Expand Up @@ -261,7 +242,7 @@ def fit(self, X: MatrixLike, y: ArrayLike) -> Self:

self.likelihood_params_ = {
c: [
self._get_dist_object(self.distributions_[i]).from_data(
get_dist_class(self.distributions_[i]).from_data(
X[y == c, i], alpha=self.alpha
)
for i in range(self.n_features_in_)
Expand Down
10 changes: 5 additions & 5 deletions wnb/gwnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def _prepare_parameters(self, X, y):
else:
self.class_prior_ = self.priors

# Convert to NumPy array if input priors is in a list
if type(self.class_prior_) is list:
self.class_prior_ = np.array(self.class_prior_)
# Convert to NumPy array if input priors is in a list/tuple/set
if isinstance(self.class_prior_, (list, tuple, set)):
self.class_prior_ = np.array(list(self.class_prior_))

# Update if no error weights is provided
if self.error_weights is None:
Expand Down Expand Up @@ -309,7 +309,7 @@ def fit(self, X: MatrixLike, y: ArrayLike) -> Self:
self.n_iter_ = 0
for self.n_iter_ in range(self.max_iter):
# Predict on X
y_hat = self.__predict(X)
y_hat = self._predict(X)

# Calculate cost
self.cost_hist_[self.n_iter_], _lambda = self._calculate_cost(
Expand Down Expand Up @@ -412,7 +412,7 @@ def _calculate_grad_slow(self, X, _lambda):
_grad += _lambda[i] * _log_p
return _grad

def __predict(self, X):
def _predict(self, X):
p_hat = self.predict_log_proba(X)
return np.argmax(p_hat, axis=1)

Expand Down
49 changes: 49 additions & 0 deletions wnb/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import contextlib
from typing import Optional

from ._base import DistMixin
from ._typing import DistibutionLike
from .dist import AllDistributions
from .enums import Distribution

__all__ = ["is_dist_supported", "get_dist_class"]


def is_dist_supported(dist: DistibutionLike) -> bool:
with contextlib.suppress(TypeError):
issubclass(dist, DistMixin)
return True

if (
isinstance(dist, Distribution)
or dist in Distribution.__members__.values()
or all(
hasattr(dist, attr_name)
for attr_name in ["from_data", "support", "__call__"]
)
):
return True

return False


def get_dist_class(name_or_type: DistibutionLike) -> Optional[DistMixin]:
with contextlib.suppress(TypeError):
issubclass(name_or_type, DistMixin)
return name_or_type

if (
isinstance(name_or_type, Distribution)
or name_or_type in Distribution.__members__.values()
):
return AllDistributions[name_or_type]
elif isinstance(name_or_type, str) and name_or_type.upper() in [
d.name for d in Distribution
]:
return AllDistributions[Distribution.__members__[name_or_type.upper()]]
elif isinstance(name_or_type, str) and name_or_type.title() in [
d.value for d in Distribution
]:
return AllDistributions[Distribution(name_or_type.title())]
else:
return

0 comments on commit 542a90c

Please sign in to comment.