Skip to content

Commit

Permalink
Handle a corner case.
Browse files Browse the repository at this point in the history
  • Loading branch information
Paul Prescod committed Dec 24, 2022
1 parent 68fbd48 commit a4b5bad
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 33 deletions.
66 changes: 38 additions & 28 deletions snowfakery/standard_plugins/_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(self, repeat, iterable):


def parts(
total: int,
min_: int = 1,
max_: Optional[int] = None,
requested_step: float = 1,
user_total: int,
user_min: int = 1,
user_max: Optional[int] = None,
user_step: float = 1,
rand: Optional[Random] = None,
) -> List[Union[int, float]]:
"""Split a number into a randomized set of 'pieces'.
Expand All @@ -64,22 +64,24 @@ def parts(
will be inconsistent with them. e.g. if `total` is not a multiple
of `step`.
"""
max_ = max_ or total
max_ = user_max or user_total
rand = rand or Random()

if requested_step < 1:
if user_step < 1:
allowed_steps = [0.01, 0.5, 0.1, 0.20, 0.25, 0.50]
assert (
requested_step in allowed_steps
), f"`step` must be one of {', '.join(str(f) for f in allowed_steps)}, not {requested_step}"
user_step in allowed_steps
), f"`step` must be one of {', '.join(str(f) for f in allowed_steps)}, not {user_step}"
# multiply up into the integer range so we don't need to do float math
total = int(total / requested_step)
total = int(user_total / user_step)
step = 1
min_ = int(min_ / requested_step)
max_ = int(max_ / requested_step)
min_ = int(user_min / user_step)
max_ = int(max_ / user_step)
else:
step = int(requested_step)
assert step == requested_step, f"`step` should be an integer, not {step}"
step = int(user_step)
min_ = user_min
total = user_total
assert step == user_step, f"`step` should be an integer, not {step}"

pieces = []

Expand All @@ -88,47 +90,55 @@ def parts(
smallest = max(min_, step)
if remaining < smallest:
# mutates pieces
handle_last_bit(pieces, rand, remaining, min_, max_)
success = handle_last_bit(pieces, rand, remaining, min_, max_)
# our constraints must have been impossible to fulfill
assert (
success
), f"No way to match all constraints: total: {user_total}, min: {user_min}, max: {user_max}, step: {user_step}"

else:
pieces.append(generate_piece(pieces, rand, smallest, remaining, max_, step))
pieces.append(generate_piece(rand, smallest, remaining, max_, step))

assert sum(pieces) == total, pieces
assert 0 not in pieces, pieces

if requested_step != step:
pieces = [round(p * requested_step, 2) for p in pieces]
if user_step != step:
pieces = [round(p * user_step, 2) for p in pieces]
return pieces


def handle_last_bit(
pieces: List[int], rand: Random, remaining: int, min_: int, max_: int
):
) -> bool:
"""If the piece is big enough, add it.
Otherwise, try to add it to another piece."""

if remaining > min_:
pos = rand.randint(0, len(pieces))
pieces.insert(pos, remaining)
return
return True

# try to add it to some other piece
for i, val in enumerate(pieces):
if val + remaining <= max_:
pieces[i] += remaining
remaining = 0
return
return True

# just insert it despite it being too small...our
# constraints must have been impossible to fulfill
if remaining:
pos = rand.randint(0, len(pieces))
pieces.insert(pos, remaining)
# No other piece has enough room...so
# split it up among several other pieces
for i, val in enumerate(pieces):
chunk = min(max_ - pieces[i], remaining)
remaining -= chunk
pieces[i] = max_
assert remaining >= 0
if remaining == 0:
return True

return False


def generate_piece(
pieces: List[int], rand: Random, smallest: int, remaining: int, max_: int, step: int
):
def generate_piece(rand: Random, smallest: int, remaining: int, max_: int, step: int):
part = rand.randint(smallest, min(remaining, max_))
round_up = part + step - (part % step)
if round_up <= min(remaining, max_) and rand.randint(0, 1):
Expand Down
32 changes: 27 additions & 5 deletions tests/tests_math_partition.py → tests/test_math_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from snowfakery.data_gen_exceptions import DataGenError

REPS = 1
SEEDS = [randint(0, 2 ** 32) for r in range(REPS)]


@pytest.mark.parametrize("seed", [randint(0, 2 ** 32) for r in range(REPS)])
class TestSummation:
@pytest.mark.parametrize("seed", SEEDS)
class TestMathPartition:
def test_example(self, generated_rows, seed):
generate_data(
"examples/math_partition/math_partition_simple.recipe.yml", seed=seed
Expand All @@ -19,8 +20,11 @@ def test_example(self, generated_rows, seed):
c["Amount__c"] for c in children
), (parents, children)

def test_example_pennies(self, generated_rows, seed):
generate_data("examples/math_partition/sum_pennies.recipe.yml", seed=seed)
regression_seeds = [824956277]

@pytest.mark.parametrize("seed2", regression_seeds + SEEDS)
def test_example_pennies(self, generated_rows, seed, seed2):
generate_data("examples/math_partition/sum_pennies.recipe.yml", seed=seed2)
objs = generated_rows.table_values("Values")
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
p["Amount"] for p in objs
Expand All @@ -31,7 +35,7 @@ def test_example_pennies_param(self, generated_rows, seed, step: int):
generate_data(
"examples/math_partition/sum_pennies_param.recipe.yml",
user_options={"step": step},
seed=1,
seed=seed,
)
objs = generated_rows.table_values("Values")
assert round(sum(p["Amount"] for p in objs)) == 100, sum(
Expand Down Expand Up @@ -151,3 +155,21 @@ def test_bad_step(self, generated_rows, seed):
"""
with pytest.raises(DataGenError, match="step.*0.3"):
generate_data(StringIO(yaml), seed=seed)

def test_inconsistent_constraints(self, generated_rows, seed):
yaml = """
- plugin: snowfakery.standard_plugins.Math
- object: Obj
for_each:
var: child_value
value:
Math.random_partition:
total: 10
min: 8
max: 8
step: 5
fields:
Amount: ${{child_value}}
"""
with pytest.raises(DataGenError, match="constraints"):
generate_data(StringIO(yaml), seed=seed)

0 comments on commit a4b5bad

Please sign in to comment.