Skip to content

Commit

Permalink
Fix categorical warnings and datetime dists (#20)
Browse files Browse the repository at this point in the history
* Fix categorical warnings and datetime dists

* Breaks backward compatibility

---------

Co-authored-by: Raoul Schram <[email protected]>
  • Loading branch information
qubixes and qubixes authored Dec 20, 2023
1 parent d73ff96 commit cfdb407
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ jobs:
include:
- metasyn-version: "git+https://github.com/sodascience/metasyn.git@main"
python-version: "3.11"
- metasyn-version: "metasyn==0.6.0"
python-version: "3.11"
# - metasyn-version: "metasyn==0.6.0"
# python-version: "3.11"

steps:
- uses: actions/checkout@v2
Expand Down
1 change: 1 addition & 0 deletions metasyncontrib/disclosure/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ def _fit(cls, values: pl.Series, n_avg: int = 11):
probs = dist.probs[dist.probs >= n_avg/len(values)]
if len(probs) == 0 or probs.max() >= 0.9:
return cls.default_distribution()
probs /= probs.sum()
return cls(labels, probs)
12 changes: 6 additions & 6 deletions metasyncontrib/disclosure/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@

import polars as pl

from metasyn.distribution.datetime import UniformDateTimeDistribution
from metasyn.distribution.datetime import UniformTimeDistribution
from metasyn.distribution.datetime import UniformDateDistribution
from metasyn.distribution.datetime import DateTimeUniformDistribution
from metasyn.distribution.datetime import TimeUniformDistribution
from metasyn.distribution.datetime import DateUniformDistribution
# from metasyncontrib.disclosure.base import BaseDisclosureDistribution
from metasyncontrib.disclosure.utils import micro_aggregate
from metasyncontrib.disclosure.base import metadist_disclosure


@metadist_disclosure()
class DisclosureDateTime(UniformDateTimeDistribution):
class DisclosureDateTime(DateTimeUniformDistribution):
"""Disclosure implementation for the datetime distribution."""

@classmethod
Expand All @@ -24,7 +24,7 @@ def _fit(cls, values: pl.Series, n_avg: int = 11) -> DisclosureDateTime:


@metadist_disclosure()
class DisclosureTime(UniformTimeDistribution):
class DisclosureTime(TimeUniformDistribution):
"""Disclosure implementation for the time distribution."""

@classmethod
Expand All @@ -40,7 +40,7 @@ def _fit(cls, values: pl.Series, n_avg: int = 11):


@metadist_disclosure()
class DisclosureDate(UniformDateDistribution):
class DisclosureDate(DateUniformDistribution):
"""Disclosure implementation for the date distribution."""

@classmethod
Expand Down
14 changes: 7 additions & 7 deletions tests/test_other_dist.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import polars as pl
from metasyn.distribution.datetime import UniformDateDistribution
from metasyn.distribution.datetime import UniformDateTimeDistribution
from metasyn.distribution.datetime import UniformTimeDistribution
from metasyn.distribution.datetime import DateUniformDistribution
from metasyn.distribution.datetime import DateTimeUniformDistribution
from metasyn.distribution.datetime import TimeUniformDistribution
from metasyncontrib.disclosure.datetime import DisclosureDate, DisclosureDateTime, DisclosureTime
from pytest import mark
from metasyn.distribution.categorical import MultinoulliDistribution
Expand All @@ -11,9 +11,9 @@

@mark.parametrize(
"class_norm,class_disc",
[(UniformDateDistribution, DisclosureDate),
(UniformDateTimeDistribution, DisclosureDateTime),
(UniformTimeDistribution, DisclosureTime)]
[(DateUniformDistribution, DisclosureDate),
(DateTimeUniformDistribution, DisclosureDateTime),
(TimeUniformDistribution, DisclosureTime)]
)
def test_datetime(class_norm, class_disc):
dist_norm = class_norm.default_distribution()
Expand All @@ -22,7 +22,7 @@ def test_datetime(class_norm, class_disc):
dist_disc = class_disc.fit(series)
assert dist_norm.start < dist_disc.start
assert dist_norm.end > dist_disc.end
if not isinstance(dist_norm, UniformDateDistribution):
if not isinstance(dist_norm, DateUniformDistribution):
assert dist_norm.precision == dist_disc.precision


Expand Down

0 comments on commit cfdb407

Please sign in to comment.