Skip to content

Commit

Permalink
REF: Localizer class to de-duplicate tzconversion code (pandas-dev#46397
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jbrockmendel authored Apr 18, 2022
1 parent 6d16567 commit 9797c89
Showing 1 changed file with 109 additions and 139 deletions.
248 changes: 109 additions & 139 deletions pandas/_libs/tslibs/vectorized.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ from .np_datetime cimport (
)
from .offsets cimport BaseOffset
from .period cimport get_period_ordinal
from .timestamps cimport (
create_timestamp_from_ts,
normalize_i8_stamp,
)
from .timestamps cimport create_timestamp_from_ts
from .timezones cimport (
get_dst_info,
is_tzlocal,
Expand All @@ -47,6 +44,54 @@ from .tzconversion cimport (
localize_tzinfo_api,
)


cdef const int64_t[::1] _deltas_placeholder = np.array([], dtype=np.int64)


@cython.freelist(16)
@cython.internal
@cython.final
cdef class Localizer:
cdef:
tzinfo tz
bint use_utc, use_fixed, use_tzlocal, use_dst, use_pytz
ndarray trans
Py_ssize_t ntrans
const int64_t[::1] deltas
int64_t delta

@cython.initializedcheck(False)
@cython.boundscheck(False)
def __cinit__(self, tzinfo tz):
self.tz = tz
self.use_utc = self.use_tzlocal = self.use_fixed = False
self.use_dst = self.use_pytz = False
self.ntrans = -1 # placeholder
self.delta = -1 # placeholder
self.deltas = _deltas_placeholder

if is_utc(tz) or tz is None:
self.use_utc = True

elif is_tzlocal(tz) or is_zoneinfo(tz):
self.use_tzlocal = True

else:
trans, deltas, typ = get_dst_info(tz)
self.trans = trans
self.ntrans = trans.shape[0]
self.deltas = deltas

if typ != "pytz" and typ != "dateutil":
# static/fixed; in this case we know that len(delta) == 1
self.use_fixed = True
self.delta = deltas[0]
else:
self.use_dst = True
if typ == "pytz":
self.use_pytz = True


# -------------------------------------------------------------------------


Expand Down Expand Up @@ -87,19 +132,14 @@ def ints_to_pydatetime(
ndarray[object] of type specified by box
"""
cdef:
Py_ssize_t i, ntrans = -1, n = stamps.shape[0]
ndarray[int64_t] trans
int64_t[::1] deltas
Localizer info = Localizer(tz)
int64_t utc_val, local_val
Py_ssize_t pos, i, n = stamps.shape[0]
int64_t* tdata = NULL
intp_t pos
int64_t utc_val, local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
str typ

npy_datetimestruct dts
tzinfo new_tz
ndarray[object] result = np.empty(n, dtype=object)
bint use_pytz = False
bint use_date = False, use_time = False, use_ts = False, use_pydt = False

if box == "date":
Expand All @@ -116,20 +156,8 @@ def ints_to_pydatetime(
"box must be one of 'datetime', 'date', 'time' or 'timestamp'"
)

if is_utc(tz) or tz is None:
use_utc = True
elif is_tzlocal(tz) or is_zoneinfo(tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
tdata = <int64_t*>cnp.PyArray_DATA(trans)
use_pytz = typ == "pytz"
if info.use_dst:
tdata = <int64_t*>cnp.PyArray_DATA(info.trans)

for i in range(n):
utc_val = stamps[i]
Expand All @@ -139,17 +167,17 @@ def ints_to_pydatetime(
result[i] = <object>NaT
continue

if use_utc:
if info.use_utc:
local_val = utc_val
elif use_tzlocal:
elif info.use_tzlocal:
local_val = utc_val + localize_tzinfo_api(utc_val, tz)
elif use_fixed:
local_val = utc_val + delta
elif info.use_fixed:
local_val = utc_val + info.delta
else:
pos = bisect_right_i8(tdata, utc_val, ntrans) - 1
local_val = utc_val + deltas[pos]
pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
local_val = utc_val + info.deltas[pos]

if use_pytz:
if info.use_pytz:
# find right representation of dst etc in pytz timezone
new_tz = tz._tzinfos[tz._transition_info[pos]]

Expand Down Expand Up @@ -191,46 +219,31 @@ cdef inline c_Resolution _reso_stamp(npy_datetimestruct *dts):
@cython.boundscheck(False)
def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:
cdef:
Py_ssize_t i, ntrans = -1, n = stamps.shape[0]
ndarray[int64_t] trans
int64_t[::1] deltas
Localizer info = Localizer(tz)
int64_t utc_val, local_val
Py_ssize_t pos, i, n = stamps.shape[0]
int64_t* tdata = NULL
intp_t pos
int64_t utc_val, local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
str typ

npy_datetimestruct dts
c_Resolution reso = c_Resolution.RESO_DAY, curr_reso

if is_utc(tz) or tz is None:
use_utc = True
elif is_tzlocal(tz) or is_zoneinfo(tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
tdata = <int64_t*>cnp.PyArray_DATA(trans)
if info.use_dst:
tdata = <int64_t*>cnp.PyArray_DATA(info.trans)

for i in range(n):
utc_val = stamps[i]
if utc_val == NPY_NAT:
continue

if use_utc:
if info.use_utc:
local_val = utc_val
elif use_tzlocal:
elif info.use_tzlocal:
local_val = utc_val + localize_tzinfo_api(utc_val, tz)
elif use_fixed:
local_val = utc_val + delta
elif info.use_fixed:
local_val = utc_val + info.delta
else:
pos = bisect_right_i8(tdata, utc_val, ntrans) - 1
local_val = utc_val + deltas[pos]
pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
local_val = utc_val + info.deltas[pos]

dt64_to_dtstruct(local_val, &dts)
curr_reso = _reso_stamp(&dts)
Expand All @@ -242,6 +255,8 @@ def get_resolution(const int64_t[:] stamps, tzinfo tz=None) -> Resolution:

# -------------------------------------------------------------------------


@cython.cdivision(False)
@cython.wraparound(False)
@cython.boundscheck(False)
cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo tz):
Expand All @@ -260,48 +275,33 @@ cpdef ndarray[int64_t] normalize_i8_timestamps(const int64_t[:] stamps, tzinfo t
result : int64 ndarray of converted of normalized nanosecond timestamps
"""
cdef:
Py_ssize_t i, ntrans = -1, n = stamps.shape[0]
ndarray[int64_t] trans
int64_t[::1] deltas
Localizer info = Localizer(tz)
int64_t utc_val, local_val
Py_ssize_t pos, i, n = stamps.shape[0]
int64_t* tdata = NULL
intp_t pos
int64_t utc_val, local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
str typ

int64_t[::1] result = np.empty(n, dtype=np.int64)

if is_utc(tz) or tz is None:
use_utc = True
elif is_tzlocal(tz) or is_zoneinfo(tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
tdata = <int64_t*>cnp.PyArray_DATA(trans)
if info.use_dst:
tdata = <int64_t*>cnp.PyArray_DATA(info.trans)

for i in range(n):
utc_val = stamps[i]
if utc_val == NPY_NAT:
result[i] = NPY_NAT
continue

if use_utc:
if info.use_utc:
local_val = utc_val
elif use_tzlocal:
elif info.use_tzlocal:
local_val = utc_val + localize_tzinfo_api(utc_val, tz)
elif use_fixed:
local_val = utc_val + delta
elif info.use_fixed:
local_val = utc_val + info.delta
else:
pos = bisect_right_i8(tdata, utc_val, ntrans) - 1
local_val = utc_val + deltas[pos]
pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
local_val = utc_val + info.deltas[pos]

result[i] = normalize_i8_stamp(local_val)
result[i] = local_val - (local_val % DAY_NANOS)

return result.base # `.base` to access underlying ndarray

Expand All @@ -324,40 +324,25 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
is_normalized : bool True if all stamps are normalized
"""
cdef:
Py_ssize_t i, ntrans = -1, n = stamps.shape[0]
ndarray[int64_t] trans
int64_t[::1] deltas
Localizer info = Localizer(tz)
int64_t utc_val, local_val
Py_ssize_t pos, i, n = stamps.shape[0]
int64_t* tdata = NULL
intp_t pos
int64_t utc_val, local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
str typ

if is_utc(tz) or tz is None:
use_utc = True
elif is_tzlocal(tz) or is_zoneinfo(tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
tdata = <int64_t*>cnp.PyArray_DATA(trans)

if info.use_dst:
tdata = <int64_t*>cnp.PyArray_DATA(info.trans)

for i in range(n):
utc_val = stamps[i]
if use_utc:
if info.use_utc:
local_val = utc_val
elif use_tzlocal:
elif info.use_tzlocal:
local_val = utc_val + localize_tzinfo_api(utc_val, tz)
elif use_fixed:
local_val = utc_val + delta
elif info.use_fixed:
local_val = utc_val + info.delta
else:
pos = bisect_right_i8(tdata, utc_val, ntrans) - 1
local_val = utc_val + deltas[pos]
pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
local_val = utc_val + info.deltas[pos]

if local_val % DAY_NANOS != 0:
return False
Expand All @@ -372,47 +357,32 @@ def is_date_array_normalized(const int64_t[:] stamps, tzinfo tz=None) -> bool:
@cython.boundscheck(False)
def dt64arr_to_periodarr(const int64_t[:] stamps, int freq, tzinfo tz):
cdef:
Py_ssize_t i, ntrans = -1, n = stamps.shape[0]
ndarray[int64_t] trans
int64_t[::1] deltas
Localizer info = Localizer(tz)
int64_t utc_val, local_val
Py_ssize_t pos, i, n = stamps.shape[0]
int64_t* tdata = NULL
intp_t pos
int64_t utc_val, local_val, delta = NPY_NAT
bint use_utc = False, use_tzlocal = False, use_fixed = False
str typ

npy_datetimestruct dts
int64_t[::1] result = np.empty(n, dtype=np.int64)

if is_utc(tz) or tz is None:
use_utc = True
elif is_tzlocal(tz) or is_zoneinfo(tz):
use_tzlocal = True
else:
trans, deltas, typ = get_dst_info(tz)
ntrans = trans.shape[0]
if typ not in ["pytz", "dateutil"]:
# static/fixed; in this case we know that len(delta) == 1
use_fixed = True
delta = deltas[0]
else:
tdata = <int64_t*>cnp.PyArray_DATA(trans)
if info.use_dst:
tdata = <int64_t*>cnp.PyArray_DATA(info.trans)

for i in range(n):
utc_val = stamps[i]
if utc_val == NPY_NAT:
result[i] = NPY_NAT
continue

if use_utc:
if info.use_utc:
local_val = utc_val
elif use_tzlocal:
elif info.use_tzlocal:
local_val = utc_val + localize_tzinfo_api(utc_val, tz)
elif use_fixed:
local_val = utc_val + delta
elif info.use_fixed:
local_val = utc_val + info.delta
else:
pos = bisect_right_i8(tdata, utc_val, ntrans) - 1
local_val = utc_val + deltas[pos]
pos = bisect_right_i8(tdata, utc_val, info.ntrans) - 1
local_val = utc_val + info.deltas[pos]

dt64_to_dtstruct(local_val, &dts)
result[i] = get_period_ordinal(&dts, freq)
Expand Down

0 comments on commit 9797c89

Please sign in to comment.