Skip to content

Commit

Permalink
Add constant distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
vankesteren committed Feb 23, 2024
1 parent 26152da commit 1b664eb
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 8 deletions.
64 changes: 64 additions & 0 deletions metasyncontrib/disclosure/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"Module for constant distributions"
import polars as pl
from metasyncontrib.disclosure.base import metadist_disclosure
from metasyn.distribution.constant import (
ConstantDistribution,
DiscreteConstantDistribution,
StringConstantDistribution,
DateTimeConstantDistribution,
TimeConstantDistribution,
DateConstantDistribution,
)


def disclosure_constant(cls):
"""Decorator that overrides _fit method for constant distributions."""
def _fit(values: pl.Series, n_avg=11):
# if unique, just get that value if it occurs at least n_avg times
if values.n_unique() == 1 & values.len() >= n_avg:
return cls(values.unique()[0])

# otherwise get most common value
val_counts = values.value_counts(sort=True)
value = val_counts[0,0]
count = val_counts[0,1]

if count >= n_avg:
return cls(value)

return cls.default_distribution()

setattr(cls, "_fit", _fit)
return cls


@metadist_disclosure()
@disclosure_constant
class DisclosureConstant(ConstantDistribution):
"Disclosure controlled ConstantDistribution"

@metadist_disclosure()
@disclosure_constant
class DisclosureDiscreteConstant(DiscreteConstantDistribution):
"Disclosure controlled DiscreteConstantDistribution"

@metadist_disclosure()
@disclosure_constant
class DisclosureStringConstant(StringConstantDistribution):
"Disclosure controlled StringConstantDistribution"

@metadist_disclosure()
@disclosure_constant
class DisclosureDateTimeConstant(DateTimeConstantDistribution):
"Disclosure controlled DateTimeConstantDistribution"

@metadist_disclosure()
@disclosure_constant
class DisclosureTimeConstant(TimeConstantDistribution):
"Disclosure controlled TimeConstantDistribution"

@metadist_disclosure()
@disclosure_constant
class DisclosureDateConstant(DateConstantDistribution):
"Disclosure controlled DateConstantDistribution"

38 changes: 32 additions & 6 deletions metasyncontrib/disclosure/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,31 @@
from __future__ import annotations
from metasyn.provider import BaseDistributionProvider

from metasyncontrib.disclosure.continuous import DisclosureUniform, DisclosureTruncatedNormal
from metasyncontrib.disclosure.continuous import (
DisclosureUniform,
DisclosureTruncatedNormal,
)
from metasyncontrib.disclosure.continuous import DisclosureNormal
from metasyncontrib.disclosure.continuous import DisclosureLogNormal
from metasyncontrib.disclosure.continuous import DisclosureExponential
from metasyncontrib.disclosure.discrete import DisclosureDiscreteUniform, DisclosureUniqueKey
from metasyncontrib.disclosure.discrete import (
DisclosureDiscreteUniform,
DisclosureUniqueKey,
)
from metasyncontrib.disclosure.discrete import DisclosurePoisson
from metasyncontrib.disclosure.string import DisclosureFaker, DisclosureUniqueFaker
from metasyncontrib.disclosure.categorical import DisclosureMultinoulli
from metasyncontrib.disclosure.datetime import DisclosureDate
from metasyncontrib.disclosure.datetime import DisclosureDateTime
from metasyncontrib.disclosure.datetime import DisclosureTime
from metasyncontrib.disclosure.constant import (
DisclosureConstant,
DisclosureDiscreteConstant,
DisclosureStringConstant,
DisclosureDateTimeConstant,
DisclosureTimeConstant,
DisclosureDateConstant,
)


class DisclosureProvider(BaseDistributionProvider):
Expand All @@ -26,12 +40,24 @@ class DisclosureProvider(BaseDistributionProvider):
name = "metasyn-disclosure"
version = "1.0"
distributions = [
DisclosureUniform, DisclosureTruncatedNormal, DisclosureNormal,
DisclosureLogNormal, DisclosureExponential,
DisclosureDiscreteUniform, DisclosureUniqueKey, DisclosurePoisson,
DisclosureUniform,
DisclosureTruncatedNormal,
DisclosureNormal,
DisclosureLogNormal,
DisclosureExponential,
DisclosureDiscreteUniform,
DisclosureUniqueKey,
DisclosurePoisson,
DisclosureMultinoulli,
DisclosureFaker, DisclosureUniqueFaker,
DisclosureFaker,
DisclosureUniqueFaker,
DisclosureDate,
DisclosureTime,
DisclosureDateTime,
DisclosureConstant,
DisclosureDiscreteConstant,
DisclosureStringConstant,
DisclosureDateTimeConstant,
DisclosureTimeConstant,
DisclosureDateConstant,
]
37 changes: 37 additions & 0 deletions tests/test_constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pytest import mark

from metasyn.distribution.constant import (
ConstantDistribution,
DiscreteConstantDistribution,
StringConstantDistribution,
DateTimeConstantDistribution,
TimeConstantDistribution,
DateConstantDistribution,
)

from metasyncontrib.disclosure.constant import (
DisclosureConstant,
DisclosureDiscreteConstant,
DisclosureStringConstant,
DisclosureDateTimeConstant,
DisclosureTimeConstant,
DisclosureDateConstant,
)

@mark.parametrize(
"dist_builtin, dist_disclosure, value",
[
(ConstantDistribution, DisclosureConstant, 8.0),
(DiscreteConstantDistribution, DisclosureDiscreteConstant, 8),
(StringConstantDistribution, DisclosureStringConstant, "Secretvalue"),
(DateTimeConstantDistribution, DisclosureDateTimeConstant, "2024-02-23T12:08:38+00:00"),
(TimeConstantDistribution, DisclosureTimeConstant, "12:08:38"),
(DateConstantDistribution, DisclosureDateConstant, "2024-02-23")
]
)
def test_constant(dist_builtin, dist_disclosure, value):
dist = dist_builtin(value)
data = [dist.draw() for _ in range(21)]

assert dist_disclosure.fit(data, n_avg = 22)._param_dict().get("value") != value
assert dist_disclosure.fit(data, n_avg = 11)._param_dict().get("value") == value
4 changes: 2 additions & 2 deletions tests/test_other_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ def test_datetime(class_norm, class_disc):
series = pl.Series([dist_norm.draw() for _ in range(100)])
dist_norm = class_norm.fit(series)
dist_disc = class_disc.fit(series)
assert dist_norm.start < dist_disc.start
assert dist_norm.end > dist_disc.end
assert dist_norm.lower < dist_disc.lower
assert dist_norm.upper > dist_disc.upper
if not isinstance(dist_norm, DateUniformDistribution):
assert dist_norm.precision == dist_disc.precision

Expand Down

0 comments on commit 1b664eb

Please sign in to comment.