Skip to content

Commit

Permalink
Merge pull request #42 from msamsami/use-uv-refactor-workflows
Browse files Browse the repository at this point in the history
patch: use `uv` for package management, refactor github workflows, fix bug in handling complex arrays
  • Loading branch information
msamsami authored Jan 17, 2025
2 parents 504f7b5 + 61d33d0 commit 2770e16
Show file tree
Hide file tree
Showing 20 changed files with 1,200 additions and 155 deletions.
20 changes: 0 additions & 20 deletions .github/workflows/black.yml

This file was deleted.

30 changes: 14 additions & 16 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# This workflow will upload a Python Package using Twine when a release is created

name: build

on:
Expand All @@ -10,31 +8,31 @@ permissions:
contents: read

jobs:
test:
uses: ./.github/workflows/run-tests.yml
ci:
uses: ./.github/workflows/ci.yml

publish:
runs-on: ubuntu-latest
needs: test
needs: ci

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: '3.8'

- name: Set up uv
uses: astral-sh/setup-uv@v5

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
pip install -r requirements.txt
run: uv sync

- name: Build the package
run: python -m build
- name: Build
run: uv build

- name: Publish the package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
- name: Publish
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
29 changes: 16 additions & 13 deletions .github/workflows/run-tests.yml → .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,36 +1,39 @@
# This workflow will run the unit tests on a PR

name: Run Tests
name: CI

on:
pull_request:
branches: [ main ]
workflow_call:
workflow_dispatch:

permissions:
contents: read

jobs:
build-and-test:
test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Set up uv
uses: astral-sh/setup-uv@v5

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
pip install -r requirements.txt
pip install -r requirements-dev.txt
run: uv sync

- name: Run ruff linter
run: uv run ruff check --output-format=github .

- name: Run black format check
run: uv run black --check --diff .

- name: Run tests
run: |
pytest
run: uv run pytest -vv
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
BSD License
===========

Copyright (c) 2024 by Mehdi Samsami.
Copyright (c) 2025 by Mehdi Samsami.
All rights reserved.

Redistribution and use in source and binary forms, with or without
Expand Down
18 changes: 3 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

<div align="center">

![Lastest Release](https://img.shields.io/badge/release-v0.5.0-green)
![Lastest Release](https://img.shields.io/badge/release-v0.5.1-green)
[![PyPI Version](https://img.shields.io/pypi/v/wnb)](https://pypi.org/project/wnb/)
![Python Versions](https://img.shields.io/badge/python-3.8%20%7C%203.9%20%7C%203.10%20%7C%203.11%20%7C%203.12%20%7C%203.13-blue)<br>
![GitHub Workflow Status (build)](https://github.com/msamsami/wnb/actions/workflows/build.yml/badge.svg)
Expand Down Expand Up @@ -36,21 +36,9 @@ Ensure that Python 3.8 or higher is installed on your machine before installing
pip install wnb
```

### Poetry
### uv
```bash
poetry add wnb
```

### GitHub
```bash
# Clone the repository
git clone https://github.com/msamsami/wnb.git

# Navigate into the project directory
cd wnb

# Install the package
pip install .
uv add wnb
```

## Getting started ⚡️
Expand Down
20 changes: 9 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ dependencies = [
"typing-extensions>=4.8.0; python_full_version < '3.11'",
]

[project.urls]
Homepage = "https://github.com/msamsami/wnb"
Source = "https://github.com/msamsami/wnb"

[project.optional-dependencies]
[dependency-groups]
dev = [
"pytest>=7.0.0",
"black>=24.8.0",
"tqdm",
"pre-commit>=3.5.0",
"isort",
"pre-commit>=3.5.0",
"pytest>=7.0.0",
"ruff>=0.9.2",
"tqdm",
]

[project.urls]
Homepage = "https://github.com/msamsami/wnb"
Source = "https://github.com/msamsami/wnb"

[tool.hatch.version]
path = "wnb/__init__.py"

Expand All @@ -63,9 +64,6 @@ packages = ["wnb"]
[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.sdist]
include = ["/README.md", "/wnb"]

[tool.pytest.ini_options]
testpaths = ["tests"]
filterwarnings = ["ignore"]
Expand Down
7 changes: 4 additions & 3 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pytest>=7.0.0
black>=24.8.0
tqdm
pre-commit>=3.5.0
isort
pre-commit>=3.5.0
pytest>=7.0.0
ruff>=0.9.2
tqdm
54 changes: 0 additions & 54 deletions setup.py

This file was deleted.

5 changes: 4 additions & 1 deletion tests/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
try:
from tqdm import tqdm
except ModuleNotFoundError:
tqdm = lambda iterable, *args, **kwargs: iterable

def tqdm(iterable, *args, **kwargs):
return iterable


warnings.filterwarnings("ignore")

Expand Down
15 changes: 15 additions & 0 deletions tests/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def test_normal_pdf():
assert_array_almost_equal(norm_wnb(X), norm_scipy.pdf(X), decimal=10)


@pytest.mark.parametrize("epsilon", [1e-10, 1e-9, 1e-6, 1e-3])
def test_normal_with_epsilon(epsilon: float):
"""
Test whether epsilon is correctly applied for `NormalDist`.
"""
norm_1 = NormalDist(mu=1, sigma=0)
norm_2 = NormalDist(mu=1, sigma=0, epsilon=epsilon)
norm_3 = NormalDist(mu=1, sigma=np.sqrt(epsilon))
assert norm_1.sigma == norm_2.sigma == 0
assert norm_3.sigma == np.sqrt(epsilon)
X = np.random.uniform(-100, 100, size=10000)
assert np.isnan(norm_1(X)).all()
assert_array_almost_equal(norm_2(X), norm_3(X), decimal=10)


def test_lognormal_pdf(random_uniform):
"""
Test whether pdf method of `LognormalDist` returns the same result as pdf method of `scipy.stats.lognorm`.
Expand Down
34 changes: 34 additions & 0 deletions tests/test_gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,37 @@ def test_gnb_invalid_dist():
msg = r"Distribution .* is not supported"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)


def test_gnb_var_smoothing():
"""
Test whether var_smoothing parameter properly affects the variances of normal distributions.
"""
X = np.array([[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]]) # First feature has variance 2.0
y = np.array([1, 1, 2, 2, 2])

clf1 = GeneralNB(var_smoothing=0.0)
clf1.fit(X, y)

clf2 = GeneralNB(var_smoothing=1.0)
clf2.fit(X, y)

test_point = np.array([[2.5, 0]])
prob1 = clf1.predict_proba(test_point)
prob2 = clf2.predict_proba(test_point)

assert not np.allclose(prob1, prob2)
assert clf1.epsilon_ == 0.0
assert clf2.epsilon_ > clf1.epsilon_


def test_gnb_var_smoothing_non_numeric():
"""
Test that var_smoothing is ignored for non-numeric features.
"""
X = np.array([["a", 1], ["b", 2], ["a", 2], ["b", 1]])
y = np.array([1, 1, 2, 2])

clf = GeneralNB(distributions=[D.CATEGORICAL, D.CATEGORICAL], var_smoothing=1e-6)
clf.fit(X, y)
assert clf.epsilon_ == 0
Loading

0 comments on commit 2770e16

Please sign in to comment.