Skip to content

Commit bde2512

Browse files
authored
🐛 fix deserialization of pendulum durations (#296)
1 parent 990912c commit bde2512

File tree

2 files changed

+104
-4
lines changed

2 files changed

+104
-4
lines changed

pydantic_extra_types/pendulum_dt.py

+79-4
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,65 @@ def __get_pydantic_core_schema__(cls, source: type[Any], handler: GetCoreSchemaH
186186
Returns:
187187
A Pydantic CoreSchema with the Duration validation.
188188
"""
189-
return core_schema.no_info_wrap_validator_function(cls._validate, core_schema.timedelta_schema())
189+
return core_schema.no_info_wrap_validator_function(
190+
cls._validate,
191+
core_schema.timedelta_schema(),
192+
serialization=core_schema.plain_serializer_function_ser_schema(
193+
lambda instance: instance.to_iso8601_string()
194+
),
195+
)
196+
197+
def to_iso8601_string(self) -> str:
198+
"""
199+
Convert a Duration object to an ISO 8601 string.
200+
201+
In addition to the standard ISO 8601 format, this method also supports the representation of fractions of a second and negative durations.
202+
203+
Args:
204+
duration (Duration): The Duration object.
205+
206+
Returns:
207+
str: The ISO 8601 string representation of the duration.
208+
"""
209+
# Extracting components from the Duration object
210+
years = self.years
211+
months = self.months
212+
days = self._days
213+
hours = self.hours
214+
minutes = self.minutes
215+
seconds = self.remaining_seconds
216+
milliseconds = self.microseconds // 1000
217+
microseconds = self.microseconds % 1000
218+
219+
# Constructing the ISO 8601 duration string
220+
iso_duration = 'P'
221+
if years or months or days:
222+
if years:
223+
iso_duration += f'{years}Y'
224+
if months:
225+
iso_duration += f'{months}M'
226+
if days:
227+
iso_duration += f'{days}D'
228+
229+
if hours or minutes or seconds or milliseconds or microseconds:
230+
iso_duration += 'T'
231+
if hours:
232+
iso_duration += f'{hours}H'
233+
if minutes:
234+
iso_duration += f'{minutes}M'
235+
if seconds or milliseconds or microseconds:
236+
iso_duration += f'{seconds}'
237+
if milliseconds or microseconds:
238+
iso_duration += f'.{milliseconds:03d}'
239+
if microseconds:
240+
iso_duration += f'{microseconds:03d}'
241+
iso_duration += 'S'
242+
243+
# Prefix with '-' if the duration is negative
244+
if self.total_seconds() < 0:
245+
iso_duration = '-' + iso_duration
246+
247+
return iso_duration
190248

191249
@classmethod
192250
def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler) -> Duration:
@@ -219,10 +277,27 @@ def _validate(cls, value: Any, handler: core_schema.ValidatorFunctionWrapHandler
219277
microseconds=value.microseconds,
220278
)
221279

280+
assert isinstance(value, str)
222281
try:
223-
parsed = parse(value, exact=True)
224-
if not isinstance(parsed, timedelta):
282+
# https://github.com/python-pendulum/pendulum/issues/532
283+
if value.startswith('-'):
284+
parsed = parse(value.lstrip('-'), exact=True)
285+
else:
286+
parsed = parse(value, exact=True)
287+
if not isinstance(parsed, _Duration):
225288
raise ValueError(f'value is not a valid duration it is a {type(parsed)}')
226-
return Duration(seconds=parsed.total_seconds())
289+
if value.startswith('-'):
290+
parsed = -parsed
291+
292+
return Duration(
293+
years=parsed.years,
294+
months=parsed.months,
295+
weeks=parsed.weeks,
296+
days=parsed.remaining_days,
297+
hours=parsed.hours,
298+
minutes=parsed.minutes,
299+
seconds=parsed.remaining_seconds,
300+
microseconds=parsed.microseconds,
301+
)
227302
except Exception as exc:
228303
raise PydanticCustomError('value_error', 'value is not a valid duration') from exc

tests/test_pendulum_dt.py

+25
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,31 @@ def test_pendulum_duration_from_serialized(delta_t_str):
283283
assert isinstance(model.delta_t, pendulum.Duration)
284284

285285

286+
@pytest.mark.parametrize(
287+
'duration',
288+
[
289+
(Duration(months=1)),
290+
(Duration(weeks=1)),
291+
(Duration(milliseconds=1)),
292+
(Duration(microseconds=1)),
293+
(Duration(days=1)),
294+
(Duration(hours=1)),
295+
(Duration(minutes=1)),
296+
(Duration(seconds=1)),
297+
(Duration(months=2, days=5)),
298+
(Duration(weeks=3, hours=12)),
299+
(Duration(days=10, minutes=30)),
300+
(Duration(weeks=1, days=2, hours=3)),
301+
(Duration(seconds=30, milliseconds=500)),
302+
],
303+
)
304+
def test_pendulum_duration_serialization_roundtrip(duration):
305+
adapter = TypeAdapter(Duration)
306+
serialized = adapter.dump_python(duration)
307+
deserialized = TypeAdapter.validate_python(adapter, serialized)
308+
assert deserialized == duration
309+
310+
286311
def get_invalid_dt_common():
287312
return [
288313
None,

0 commit comments

Comments
 (0)