Skip to content

Commit

Permalink
GH718 Part 2
Browse files Browse the repository at this point in the history
  • Loading branch information
loicdiridollou committed Nov 13, 2024
1 parent 7e0ba97 commit 4332ecf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
24 changes: 10 additions & 14 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@ from pandas.core.generic import NDFrame
from pandas.core.groupby.generic import SeriesGroupBy
from pandas.core.groupby.groupby import BaseGroupBy
from pandas.core.indexers import BaseIndexer
from pandas.core.indexes.accessors import (
CombinedDatetimelikeProperties,
PeriodProperties,
TimedeltaProperties,
TimestampProperties,
)
from pandas.core.indexes.accessors import CombinedDatetimelikeProperties
from pandas.core.indexes.base import Index
from pandas.core.indexes.category import CategoricalIndex
from pandas.core.indexes.datetimes import DatetimeIndex
Expand Down Expand Up @@ -1508,6 +1503,11 @@ class Series(IndexOpsMixin[S1], NDFrame):
# just failed to generate these so I couldn't match
# them up.
@overload
def __add__(
self: Series[Timestamp],
other: TimedeltaSeries | np.timedelta64 | timedelta | BaseOffset,
) -> Series[Timestamp]: ...
@overload
def __add__(self, other: S1 | Self) -> Self: ...
@overload
def __add__(
Expand Down Expand Up @@ -1561,6 +1561,10 @@ class Series(IndexOpsMixin[S1], NDFrame):
def __radd__(self, other: S1 | Series[S1]) -> Self: ...
@overload
def __radd__(self, other: num | _str | _ListLike | Series) -> Series: ...
@overload
def __radd__(
self: Series[Timedelta], other: datetime | Timestamp | Series[Timestamp]
) -> Series[Timestamp]: ...
# ignore needed for mypy as we want different results based on the arguments
@overload # type: ignore[override]
def __rand__( # pyright: ignore[reportOverlappingOverload]
Expand Down Expand Up @@ -2081,9 +2085,6 @@ class Series(IndexOpsMixin[S1], NDFrame):
) -> Self: ...

class TimestampSeries(Series[Timestamp]):
@property
def dt(self) -> TimestampProperties: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __add__(self, other: TimedeltaSeries | np.timedelta64 | timedelta | BaseOffset) -> TimestampSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __radd__(self, other: TimedeltaSeries | np.timedelta64 | timedelta) -> TimestampSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
@overload # type: ignore[override]
def __sub__(
Expand Down Expand Up @@ -2137,7 +2138,6 @@ class TimedeltaSeries(Series[Timedelta]):
def __add__( # pyright: ignore[reportIncompatibleMethodOverride]
self, other: timedelta | Timedelta | np.timedelta64
) -> TimedeltaSeries: ...
def __radd__(self, other: datetime | Timestamp | TimestampSeries) -> TimestampSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __mul__( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
self, other: num | Sequence[num] | Series[int] | Series[float]
) -> TimedeltaSeries: ...
Expand Down Expand Up @@ -2193,8 +2193,6 @@ class TimedeltaSeries(Series[Timedelta]):
| Sequence[timedelta]
),
) -> Series[int]: ...
@property
def dt(self) -> TimedeltaProperties: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def mean( # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
self,
axis: AxisIndex | None = ...,
Expand Down Expand Up @@ -2223,8 +2221,6 @@ class TimedeltaSeries(Series[Timedelta]):
def diff(self, periods: int = ...) -> TimedeltaSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]

class PeriodSeries(Series[Period]):
@property
def dt(self) -> PeriodProperties: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def __sub__(self, other: PeriodSeries) -> OffsetSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]
def diff(self, periods: int = ...) -> OffsetSeries: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride]

Expand Down
4 changes: 2 additions & 2 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3032,7 +3032,7 @@ def test_timedeltaseries_operators() -> None:
pd.Timedelta,
)
check(
assert_type(datetime.datetime.now() + series, TimestampSeries),
assert_type(datetime.datetime.now() + series, "pd.Series[pd.Timestamp]"),
pd.Series,
pd.Timestamp,
)
Expand All @@ -3046,7 +3046,7 @@ def test_timedeltaseries_operators() -> None:
def test_timestamp_series() -> None:
series = pd.Series([pd.Timestamp(2024, 4, 4)])
check(
assert_type(series + YearEnd(0), TimestampSeries),
assert_type(series + YearEnd(0), "pd.Series[pd.Timestamp]"),
TimestampSeries,
pd.Timestamp,
)
Expand Down
37 changes: 19 additions & 18 deletions tests/test_timefuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def test_fail_on_adding_two_timestamps() -> None:
s1 = pd.Series(pd.to_datetime(["2022-05-01", "2022-06-01"]))
s2 = pd.Series(pd.to_datetime(["2022-05-15", "2022-06-15"]))
if TYPE_CHECKING_INVALID_USAGE:
ssum: pd.Series = s1 + s2 # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
ssum: pd.Series = s1 + s2
ts = pd.Timestamp("2022-06-30")
tsum: pd.Series = s1 + ts # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
tsum: pd.Series = s1 + ts


def test_dtindex_tzinfo() -> None:
Expand Down Expand Up @@ -434,27 +434,27 @@ def test_series_dt_accessors() -> None:
check(assert_type(s0.dt.normalize(), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.strftime("%Y"), "pd.Series[str]"), pd.Series, str)
check(
assert_type(s0.dt.round("D", nonexistent=dt.timedelta(1)), "TimestampSeries"),
assert_type(s0.dt.round("D", nonexistent=dt.timedelta(1)), pd.Series),
pd.Series,
pd.Timestamp,
)
check(
assert_type(s0.dt.floor("D", nonexistent=dt.timedelta(1)), "TimestampSeries"),
assert_type(s0.dt.floor("D", nonexistent=dt.timedelta(1)), pd.Series),
pd.Series,
pd.Timestamp,
)
check(
assert_type(s0.dt.ceil("D", nonexistent=dt.timedelta(1)), "TimestampSeries"),
assert_type(s0.dt.ceil("D", nonexistent=dt.timedelta(1)), pd.Series),
pd.Series,
pd.Timestamp,
)
check(assert_type(s0.dt.month_name(), "pd.Series[str]"), pd.Series, str)
check(assert_type(s0.dt.day_name(), "pd.Series[str]"), pd.Series, str)
check(assert_type(s0.dt.unit, TimeUnit), str)
check(assert_type(s0.dt.as_unit("s"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ms"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("us"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ns"), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("s"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ms"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("us"), pd.Series), pd.Series, pd.Timestamp)
check(assert_type(s0.dt.as_unit("ns"), pd.Series), pd.Series, pd.Timestamp)

i1 = pd.period_range(start="2022-06-01", periods=10)

Expand All @@ -464,9 +464,9 @@ def test_series_dt_accessors() -> None:

s1 = pd.Series(i1)

check(assert_type(s1.dt.qyear, "pd.Series[int]"), pd.Series, np.integer)
check(assert_type(s1.dt.start_time, "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type(s1.dt.end_time, "TimestampSeries"), pd.Series, pd.Timestamp)
# check(assert_type(s1.dt.qyear, "pd.Series[int]"), pd.Series, np.integer)
# check(assert_type(s1.dt.start_time, "TimestampSeries"), pd.Series, pd.Timestamp)
# check(assert_type(s1.dt.end_time, "TimestampSeries"), pd.Series, pd.Timestamp)

i2 = pd.timedelta_range(start="1 day", periods=10)
check(assert_type(i2, pd.TimedeltaIndex), pd.TimedeltaIndex)
Expand All @@ -488,10 +488,10 @@ def test_series_dt_accessors() -> None:
check(assert_type(s2.dt.to_pytimedelta(), np.ndarray), np.ndarray)
check(assert_type(s2.dt.total_seconds(), "pd.Series[float]"), pd.Series, float)
check(assert_type(s2.dt.unit, TimeUnit), str)
check(assert_type(s2.dt.as_unit("s"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ms"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("us"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ns"), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("s"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ms"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("us"), pd.Series), pd.Series, pd.Timedelta)
check(assert_type(s2.dt.as_unit("ns"), pd.Series), pd.Series, pd.Timedelta)

# Checks for general Series other than TimestampSeries and TimedeltaSeries

Expand Down Expand Up @@ -1263,14 +1263,15 @@ def test_timedelta64_and_arithmatic_operator() -> None:
s4 = s1.astype(object)
check(assert_type(s4 - td1, "TimestampSeries"), pd.Series, pd.Timestamp)

s1 = cast("pd.Series[pd.Timestamp]", s1) # type: ignore[assignment]
td = np.timedelta64(1, "D")
check(assert_type((s1 - td), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type((s1 + td), "TimestampSeries"), pd.Series, pd.Timestamp)
check(assert_type((s1 + td), "pd.Series[pd.Timestamp]"), pd.Series, pd.Timestamp)
check(assert_type((s3 - td), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type((s3 + td), "TimedeltaSeries"), pd.Series, pd.Timedelta)
check(assert_type((s3 / td), "pd.Series[float]"), pd.Series, float)
if TYPE_CHECKING_INVALID_USAGE:
r1 = s1 * td # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
r1 = s1 * td # type: ignore[operator]
r2 = s1 / td # type: ignore[operator] # pyright: ignore[reportOperatorIssue]
r3 = s3 * td # type: ignore[operator] # pyright: ignore[reportOperatorIssue]

Expand Down

0 comments on commit 4332ecf

Please sign in to comment.