Skip to content

Commit

Permalink
add back datetime arithmetic, ignore mssql strcmp
Browse files Browse the repository at this point in the history
  • Loading branch information
finn-rudolph committed Nov 14, 2024
1 parent 9e2ec89 commit 18f2d53
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 8 deletions.
11 changes: 10 additions & 1 deletion src/pydiverse/transform/_internal/ops/ops/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from pydiverse.transform._internal.ops.signature import Signature
from pydiverse.transform._internal.tree.types import (
NUMERIC,
Date,
Datetime,
Decimal,
Duration,
Float,
Int,
String,
Expand All @@ -14,10 +17,16 @@
"__add__",
*(Signature(dtype, dtype, return_type=dtype) for dtype in NUMERIC),
Signature(String(), String(), return_type=String()),
Signature(Duration(), Duration(), return_type=Duration()),
)

sub = Operator(
"__sub__", *(Signature(dtype, dtype, return_type=dtype) for dtype in NUMERIC)
"__sub__",
*(Signature(dtype, dtype, return_type=dtype) for dtype in NUMERIC),
Signature(Datetime(), Datetime(), return_type=Duration()),
Signature(Date(), Date(), return_type=Duration()),
Signature(Datetime(), Date(), return_type=Duration()),
Signature(Date(), Datetime(), return_type=Duration()),
)

mul = Operator(
Expand Down
38 changes: 38 additions & 0 deletions src/pydiverse/transform/_internal/tree/col_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def __add__(self: ColExpr[Decimal], rhs: ColExpr[Decimal]) -> ColExpr[Decimal]:
@overload
def __add__(self: ColExpr[String], rhs: ColExpr[String]) -> ColExpr[String]: ...

@overload
def __add__(
self: ColExpr[Duration], rhs: ColExpr[Duration]
) -> ColExpr[Duration]: ...

def __add__(self: ColExpr, rhs: ColExpr) -> ColExpr:
return ColFn(ops.add, self, rhs)

Expand All @@ -144,6 +149,11 @@ def __radd__(self: ColExpr[Decimal], rhs: ColExpr[Decimal]) -> ColExpr[Decimal]:
@overload
def __radd__(self: ColExpr[String], rhs: ColExpr[String]) -> ColExpr[String]: ...

@overload
def __radd__(
self: ColExpr[Duration], rhs: ColExpr[Duration]
) -> ColExpr[Duration]: ...

def __radd__(self: ColExpr, rhs: ColExpr) -> ColExpr:
return ColFn(ops.add, rhs, self)

Expand Down Expand Up @@ -591,6 +601,20 @@ def __sub__(self: ColExpr[Float], rhs: ColExpr[Float]) -> ColExpr[Float]: ...
@overload
def __sub__(self: ColExpr[Decimal], rhs: ColExpr[Decimal]) -> ColExpr[Decimal]: ...

@overload
def __sub__(
self: ColExpr[Datetime], rhs: ColExpr[Datetime]
) -> ColExpr[Duration]: ...

@overload
def __sub__(self: ColExpr[Date], rhs: ColExpr[Date]) -> ColExpr[Duration]: ...

@overload
def __sub__(self: ColExpr[Datetime], rhs: ColExpr[Date]) -> ColExpr[Duration]: ...

@overload
def __sub__(self: ColExpr[Date], rhs: ColExpr[Datetime]) -> ColExpr[Duration]: ...

def __sub__(self: ColExpr, rhs: ColExpr) -> ColExpr:
return ColFn(ops.sub, self, rhs)

Expand All @@ -603,6 +627,20 @@ def __rsub__(self: ColExpr[Float], rhs: ColExpr[Float]) -> ColExpr[Float]: ...
@overload
def __rsub__(self: ColExpr[Decimal], rhs: ColExpr[Decimal]) -> ColExpr[Decimal]: ...

@overload
def __rsub__(
self: ColExpr[Datetime], rhs: ColExpr[Datetime]
) -> ColExpr[Duration]: ...

@overload
def __rsub__(self: ColExpr[Date], rhs: ColExpr[Date]) -> ColExpr[Duration]: ...

@overload
def __rsub__(self: ColExpr[Datetime], rhs: ColExpr[Date]) -> ColExpr[Duration]: ...

@overload
def __rsub__(self: ColExpr[Date], rhs: ColExpr[Datetime]) -> ColExpr[Duration]: ...

def __rsub__(self: ColExpr, rhs: ColExpr) -> ColExpr:
return ColFn(ops.sub, rhs, self)

Expand Down
14 changes: 7 additions & 7 deletions tests/test_backend_equivalence/test_ops/test_case_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,13 @@ def test_summarize_case(df4):
C.col1,
)
>> summarize(
# x=C.col2.max().map(
# {
# 0: C.col1.min(),
# 1: C.col2.mean() + 0.5,
# 2: 2,
# }
# ),
x=C.col2.max().map(
{
0: C.col1.min(),
1: C.col2.mean() + 0.5,
2: 2,
}
),
y=pdt.when(C.col2.max() > 2)
.then(1)
.when(C.col2.max() < 2)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_backend_equivalence/test_ops/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pydiverse.transform import C
from pydiverse.transform._internal.pipe.verbs import mutate
from pydiverse.transform._internal.tree.col_expr import LiteralCol
from tests.fixtures.backend import skip_backends
from tests.util import assert_result_equal


Expand All @@ -29,6 +30,7 @@ def test_row_number(df4):
)


@skip_backends("mssql")
def test_min(df4):
assert_result_equal(
df4,
Expand All @@ -46,6 +48,7 @@ def test_min(df4):
)


@skip_backends("mssql")
def test_max(df4):
assert_result_equal(
df4,
Expand Down

0 comments on commit 18f2d53

Please sign in to comment.