Skip to content

Commit

Permalink
add unit tests for epsilon in normal dist and var_smoothing in generalnb
Browse files Browse the repository at this point in the history
  • Loading branch information
msamsami committed Jan 17, 2025
1 parent e79ee0f commit 61d33d0
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/test_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,21 @@ def test_normal_pdf():
assert_array_almost_equal(norm_wnb(X), norm_scipy.pdf(X), decimal=10)


@pytest.mark.parametrize("epsilon", [1e-10, 1e-9, 1e-6, 1e-3])
def test_normal_with_epsilon(epsilon: float):
"""
Test whether epsilon is correctly applied for `NormalDist`.
"""
norm_1 = NormalDist(mu=1, sigma=0)
norm_2 = NormalDist(mu=1, sigma=0, epsilon=epsilon)
norm_3 = NormalDist(mu=1, sigma=np.sqrt(epsilon))
assert norm_1.sigma == norm_2.sigma == 0
assert norm_3.sigma == np.sqrt(epsilon)
X = np.random.uniform(-100, 100, size=10000)
assert np.isnan(norm_1(X)).all()
assert_array_almost_equal(norm_2(X), norm_3(X), decimal=10)


def test_lognormal_pdf(random_uniform):
"""
Test whether pdf method of `LognormalDist` returns the same result as pdf method of `scipy.stats.lognorm`.
Expand Down
34 changes: 34 additions & 0 deletions tests/test_gnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,37 @@ def test_gnb_invalid_dist():
msg = r"Distribution .* is not supported"
with pytest.raises(ValueError, match=msg):
clf.fit(X, y)


def test_gnb_var_smoothing():
"""
Test whether var_smoothing parameter properly affects the variances of normal distributions.
"""
X = np.array([[1, 0], [2, 0], [3, 0], [4, 0], [5, 0]]) # First feature has variance 2.0
y = np.array([1, 1, 2, 2, 2])

clf1 = GeneralNB(var_smoothing=0.0)
clf1.fit(X, y)

clf2 = GeneralNB(var_smoothing=1.0)
clf2.fit(X, y)

test_point = np.array([[2.5, 0]])
prob1 = clf1.predict_proba(test_point)
prob2 = clf2.predict_proba(test_point)

assert not np.allclose(prob1, prob2)
assert clf1.epsilon_ == 0.0
assert clf2.epsilon_ > clf1.epsilon_


def test_gnb_var_smoothing_non_numeric():
"""
Test that var_smoothing is ignored for non-numeric features.
"""
X = np.array([["a", 1], ["b", 2], ["a", 2], ["b", 1]])
y = np.array([1, 1, 2, 2])

clf = GeneralNB(distributions=[D.CATEGORICAL, D.CATEGORICAL], var_smoothing=1e-6)
clf.fit(X, y)
assert clf.epsilon_ == 0

0 comments on commit 61d33d0

Please sign in to comment.