Skip to content

Commit 8881603

Browse files
committed
fix: respect both multiple_of and minimum/maximum constraints
Previously, `generate_constrained_number()` would potentially generate invalid numbers when `multiple_of` is not None and exactly one of either `minimum` or `maximum` is not None, since it would just return `multiple_of` without respecting the upper or lower bound. This significantly changes the implementation of the code to correctly handle this code. The `generate_constrained_number()` method has been completely removed, being replaced with a `generate_constrained_multiple_of()` function. A major difference between the old function and the new function is that the new one does not accept a `method` argument for generating random numbers. This is because in the new function, we always use `create_random_integer()`, since the problem reduces to generating a random integer multiplier. The high-level algorithm behind `generate_constrained_multiple_of()` is that we need to constrain the random integer generator to generate numbers such that when they are multiplied with `multiple_of`, they still fit within the original bounds constraints. This simplify involves dividing the original bounds by `multiple_of`, with some special handling for negative `multiple_of` numbers as well as carefully chosen rounding behavior. We also need to make some changes to other functions. `get_increment()` needs to take an additional argument for the actual value that the increment is for. This is because floating-point numbers can't use a static increment or else it might get rounded away if the numbers are too large. Python fortunately provides a `math.ulp()` function for computing this for a given float value, so we make use of that function. We still use the original `float_info.epsilon` constant as a lower bound on the increment, though, since in the case that the value is too close to zero, we still need to make sure that the increment doesn't disappear when used against other numbers. Finally, we rename and modify `passes_pydantic_multiple_validator()` to `is_almost_multiple_of()`, modifying its implementation to defer the casting of values to `float()` to minimize rounding errors. This specifically affects Decimal numbers, where casting to float too early causes too much loss of precision. A significant number of changes were made to the tests as well, since the original tests missed the bug being fixed here. Each of the integer, floating-point, and decimal tests has been updated to assert that the result is actually within the minimum and maximum constraints. In addition, we remove some unnecessary sorting of the randomly generated test input values, since this was unnecessarily constraining `multiple_of` to be greater than or less than the minimum and maximum values. This was causing a lot of the scenarios involving negative values to be skipped. Lastly, the floating-point and decimal tests need additional constraints to avoid unrealistic extreme values from hitting precision issues. This was done by adding a number of constraints on the number of significant digits in the input numbers and on the relative magnitudes of the input numbers.
1 parent c4e3d91 commit 8881603

File tree

5 files changed

+172
-130
lines changed

5 files changed

+172
-130
lines changed

polyfactory/value_generators/constrained_numbers.py

+62-51
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from __future__ import annotations
22

33
from decimal import Decimal
4+
from math import ceil, floor, ulp
45
from sys import float_info
5-
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
6+
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
67

78
from polyfactory.exceptions import ParameterException
89
from polyfactory.value_generators.primitives import create_random_decimal, create_random_float, create_random_integer
@@ -99,8 +100,8 @@ def is_multiply_of_multiple_of_in_range(
99100
return False
100101

101102

102-
def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
103-
"""Determine whether a given value passes the pydantic multiple_of validation.
103+
def is_almost_multiple_of(value: T, multiple_of: T) -> bool:
104+
"""Determine whether a given ``value`` is a close enough to a multiple of ``multiple_of``.
104105
105106
:param value: A numeric value.
106107
:param multiple_of: Another numeric value.
@@ -110,23 +111,33 @@ def passes_pydantic_multiple_validator(value: T, multiple_of: T) -> bool:
110111
"""
111112
if multiple_of == 0:
112113
return True
113-
mod = float(value) / float(multiple_of) % 1
114-
return almost_equal_floats(mod, 0.0) or almost_equal_floats(mod, 1.0)
114+
mod = value % multiple_of
115+
return almost_equal_floats(float(mod), 0.0) or almost_equal_floats(float(abs(mod)), float(abs(multiple_of)))
115116

116117

117-
def get_increment(t_type: type[T]) -> T:
118+
def get_increment(value: T, t_type: type[T]) -> T:
118119
"""Get a small increment base to add to constrained values, i.e. lt/gt entries.
119120
120-
:param t_type: A value of type T.
121+
:param value: A value of type T.
122+
:param t_type: The type of ``value``.
121123
122124
:returns: An increment T.
123125
"""
124-
values: dict[Any, Any] = {
125-
int: 1,
126-
float: float_info.epsilon,
127-
Decimal: Decimal("0.001"),
128-
}
129-
return cast("T", values[t_type])
126+
# See https://github.com/python/mypy/issues/17045 for why the redundant casts are ignored.
127+
if t_type == int:
128+
return cast("T", 1)
129+
if t_type == float:
130+
# When ``value`` is large in magnitude, we need to choose an increment that is large enough
131+
# to not be rounded away, but when ``value`` small in magnitude, we need to prevent the
132+
# incerement from vanishing. ``float_info.epsilon`` is defined as the smallest delta that
133+
# can be represented between 1.0 and the next largest number, but it's not sufficient for
134+
# larger values. ``ulp(x)`` will return smallest delta that can be added to ``x``.
135+
return cast("T", max(ulp(value), float_info.epsilon)) # type: ignore[redundant-cast]
136+
if t_type == Decimal:
137+
return cast("T", Decimal("0.001")) # type: ignore[redundant-cast]
138+
139+
msg = f"invalid t_type: {t_type}"
140+
raise AssertionError(msg)
130141

131142

132143
def get_value_or_none(
@@ -147,14 +158,14 @@ def get_value_or_none(
147158
if ge is not None:
148159
minimum_value = ge
149160
elif gt is not None:
150-
minimum_value = gt + get_increment(t_type)
161+
minimum_value = gt + get_increment(gt, t_type)
151162
else:
152163
minimum_value = None
153164

154165
if le is not None:
155166
maximum_value = le
156167
elif lt is not None:
157-
maximum_value = lt - get_increment(t_type)
168+
maximum_value = lt - get_increment(lt, t_type)
158169
else:
159170
maximum_value = None
160171
return minimum_value, maximum_value
@@ -210,33 +221,36 @@ def get_constrained_number_range(
210221
return minimum, maximum
211222

212223

213-
def generate_constrained_number(
224+
def generate_constrained_multiple_of(
214225
random: Random,
215226
minimum: T | None,
216227
maximum: T | None,
217-
multiple_of: T | None,
218-
method: "NumberGeneratorProtocol[T]",
228+
multiple_of: T,
219229
) -> T:
220-
"""Generate a constrained number, output depends on the passed in callbacks.
230+
"""Generate a constrained multiple of ``multiple_of``.
221231
222232
:param random: An instance of random.
223233
:param minimum: A minimum value.
224234
:param maximum: A maximum value.
225235
:param multiple_of: A multiple of value.
226-
:param method: A function that generates numbers of type T.
227236
228237
:returns: A value of type T.
229238
"""
230-
if minimum is None or maximum is None:
231-
return multiple_of if multiple_of is not None else method(random=random)
232-
if multiple_of is None:
233-
return method(random=random, minimum=minimum, maximum=maximum)
234-
if multiple_of >= minimum:
235-
return multiple_of
236-
result = minimum
237-
while not passes_pydantic_multiple_validator(result, multiple_of):
238-
result = round(method(random=random, minimum=minimum, maximum=maximum) / multiple_of) * multiple_of
239-
return result
239+
240+
# Regardless of the type of ``multiple_of``, we can generate a valid multiple of it by
241+
# multiplying it with any integer, which we call a multiplier. We will randomly generate the
242+
# multiplier as a random integer, but we need to translate the original bounds, if any, to the
243+
# correct bounds on the multiplier so that the resulting product will meet the original
244+
# constraints.
245+
246+
if multiple_of < 0:
247+
minimum, maximum = maximum, minimum
248+
249+
multiplier_min = ceil(minimum / multiple_of) if minimum is not None else None
250+
multiplier_max = floor(maximum / multiple_of) if maximum is not None else None
251+
multiplier = create_random_integer(random=random, minimum=multiplier_min, maximum=multiplier_max)
252+
253+
return multiplier * multiple_of
240254

241255

242256
def handle_constrained_int(
@@ -269,13 +283,11 @@ def handle_constrained_int(
269283
multiple_of=multiple_of,
270284
random=random,
271285
)
272-
return generate_constrained_number(
273-
random=random,
274-
minimum=minimum,
275-
maximum=maximum,
276-
multiple_of=multiple_of,
277-
method=create_random_integer,
278-
)
286+
287+
if multiple_of is None:
288+
return create_random_integer(random=random, minimum=minimum, maximum=maximum)
289+
290+
return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)
279291

280292

281293
def handle_constrained_float(
@@ -308,13 +320,10 @@ def handle_constrained_float(
308320
random=random,
309321
)
310322

311-
return generate_constrained_number(
312-
random=random,
313-
minimum=minimum,
314-
maximum=maximum,
315-
multiple_of=multiple_of,
316-
method=create_random_float,
317-
)
323+
if multiple_of is None:
324+
return create_random_float(random=random, minimum=minimum, maximum=maximum)
325+
326+
return generate_constrained_multiple_of(random=random, minimum=minimum, maximum=maximum, multiple_of=multiple_of)
318327

319328

320329
def validate_max_digits(
@@ -422,13 +431,15 @@ def handle_constrained_decimal(
422431
if max_digits is not None:
423432
validate_max_digits(max_digits=max_digits, minimum=minimum, decimal_places=decimal_places)
424433

425-
generated_decimal = generate_constrained_number(
426-
random=random,
427-
minimum=minimum,
428-
maximum=maximum,
429-
multiple_of=multiple_of,
430-
method=create_random_decimal,
431-
)
434+
if multiple_of is None:
435+
generated_decimal = create_random_decimal(random=random, minimum=minimum, maximum=maximum)
436+
else:
437+
generated_decimal = generate_constrained_multiple_of(
438+
random=random,
439+
minimum=minimum,
440+
maximum=maximum,
441+
multiple_of=multiple_of,
442+
)
432443

433444
if max_digits is not None or decimal_places is not None:
434445
return handle_decimal_length(

tests/constraints/test_decimal_constraints.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Optional, cast
44

55
import pytest
6-
from hypothesis import given
6+
from hypothesis import assume, given
77
from hypothesis.strategies import decimals, integers
88

99
from pydantic import BaseModel, condecimal
@@ -13,11 +13,24 @@
1313
from polyfactory.value_generators.constrained_numbers import (
1414
handle_constrained_decimal,
1515
handle_decimal_length,
16+
is_almost_multiple_of,
1617
is_multiply_of_multiple_of_in_range,
17-
passes_pydantic_multiple_validator,
1818
)
1919

2020

21+
def assume_max_digits(x: Decimal, max_digits: int) -> None:
22+
"""
23+
Signal to Hypothesis that ``x`` should have at most ``max_digits`` significant digits.
24+
25+
This is different than the ``decimals()`` strategy function's ``places`` keyword argument, which
26+
only counts the digits after the decimal point when the number is written without an exponent.
27+
28+
E.g. 12.51 has 4 significant digits but 2 decimal places.
29+
"""
30+
31+
assume(len(x.as_tuple().digits) <= max_digits)
32+
33+
2134
def test_handle_constrained_decimal_without_constraints() -> None:
2235
result = handle_constrained_decimal(
2336
random=Random(),
@@ -162,7 +175,7 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
162175
random=Random(),
163176
multiple_of=multiple_of,
164177
)
165-
assert passes_pydantic_multiple_validator(result, multiple_of)
178+
assert is_almost_multiple_of(result, multiple_of)
166179
else:
167180
with pytest.raises(ParameterException):
168181
handle_constrained_decimal(
@@ -185,15 +198,17 @@ def test_handle_constrained_decimal_handles_multiple_of(multiple_of: Decimal) ->
185198
max_value=1000000000,
186199
),
187200
)
188-
def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, val2: Decimal) -> None:
189-
multiple_of, max_value = sorted([val1, val2])
201+
def test_handle_constrained_decimal_handles_multiple_of_with_lt(max_value: Decimal, multiple_of: Decimal) -> None:
190202
if multiple_of != Decimal("0"):
203+
assume_max_digits(max_value, 10)
204+
assume_max_digits(multiple_of, 10)
191205
result = handle_constrained_decimal(
192206
random=Random(),
193207
multiple_of=multiple_of,
194208
lt=max_value,
195209
)
196-
assert passes_pydantic_multiple_validator(result, multiple_of)
210+
assert result < max_value
211+
assert is_almost_multiple_of(result, multiple_of)
197212
else:
198213
with pytest.raises(ParameterException):
199214
handle_constrained_decimal(
@@ -217,15 +232,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_lt(val1: Decimal, v
217232
max_value=1000000000,
218233
),
219234
)
220-
def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, val2: Decimal) -> None:
221-
multiple_of, max_value = sorted([val1, val2])
235+
def test_handle_constrained_decimal_handles_multiple_of_with_le(max_value: Decimal, multiple_of: Decimal) -> None:
222236
if multiple_of != Decimal("0"):
237+
assume_max_digits(max_value, 10)
238+
assume_max_digits(multiple_of, 10)
223239
result = handle_constrained_decimal(
224240
random=Random(),
225241
multiple_of=multiple_of,
226242
le=max_value,
227243
)
228-
assert passes_pydantic_multiple_validator(result, multiple_of)
244+
assert result <= max_value
245+
assert is_almost_multiple_of(result, multiple_of)
229246
else:
230247
with pytest.raises(ParameterException):
231248
handle_constrained_decimal(
@@ -249,15 +266,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_le(val1: Decimal, v
249266
max_value=1000000000,
250267
),
251268
)
252-
def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, val2: Decimal) -> None:
253-
min_value, multiple_of = sorted([val1, val2])
269+
def test_handle_constrained_decimal_handles_multiple_of_with_ge(min_value: Decimal, multiple_of: Decimal) -> None:
254270
if multiple_of != Decimal("0"):
271+
assume_max_digits(min_value, 10)
272+
assume_max_digits(multiple_of, 10)
255273
result = handle_constrained_decimal(
256274
random=Random(),
257275
multiple_of=multiple_of,
258276
ge=min_value,
259277
)
260-
assert passes_pydantic_multiple_validator(result, multiple_of)
278+
assert min_value <= result
279+
assert is_almost_multiple_of(result, multiple_of)
261280
else:
262281
with pytest.raises(ParameterException):
263282
handle_constrained_decimal(
@@ -281,15 +300,17 @@ def test_handle_constrained_decimal_handles_multiple_of_with_ge(val1: Decimal, v
281300
max_value=1000000000,
282301
),
283302
)
284-
def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, val2: Decimal) -> None:
285-
min_value, multiple_of = sorted([val1, val2])
303+
def test_handle_constrained_decimal_handles_multiple_of_with_gt(min_value: Decimal, multiple_of: Decimal) -> None:
286304
if multiple_of != Decimal("0"):
305+
assume_max_digits(min_value, 10)
306+
assume_max_digits(multiple_of, 10)
287307
result = handle_constrained_decimal(
288308
random=Random(),
289309
multiple_of=multiple_of,
290310
gt=min_value,
291311
)
292-
assert passes_pydantic_multiple_validator(result, multiple_of)
312+
assert min_value < result
313+
assert is_almost_multiple_of(result, multiple_of)
293314
else:
294315
with pytest.raises(ParameterException):
295316
handle_constrained_decimal(
@@ -322,21 +343,25 @@ def test_handle_constrained_decimal_handles_multiple_of_with_gt(val1: Decimal, v
322343
def test_handle_constrained_decimal_handles_multiple_of_with_ge_and_le(
323344
val1: Decimal,
324345
val2: Decimal,
325-
val3: Decimal,
346+
multiple_of: Decimal,
326347
) -> None:
327-
min_value, multiple_of, max_value = sorted([val1, val2, val3])
348+
min_value, max_value = sorted([val1, val2])
328349
if multiple_of != Decimal("0") and is_multiply_of_multiple_of_in_range(
329350
minimum=min_value,
330351
maximum=max_value,
331352
multiple_of=multiple_of,
332353
):
354+
assume_max_digits(min_value, 10)
355+
assume_max_digits(max_value, 10)
356+
assume_max_digits(multiple_of, 10)
333357
result = handle_constrained_decimal(
334358
random=Random(),
335359
multiple_of=multiple_of,
336360
ge=min_value,
337361
le=max_value,
338362
)
339-
assert passes_pydantic_multiple_validator(result, multiple_of)
363+
assert min_value <= result <= max_value
364+
assert is_almost_multiple_of(result, multiple_of)
340365
else:
341366
with pytest.raises(ParameterException):
342367
handle_constrained_decimal(

0 commit comments

Comments
 (0)