diff --git a/ibis/expr/types/core.py b/ibis/expr/types/core.py index 1ab39d9d0725..a7fd748af5fd 100644 --- a/ibis/expr/types/core.py +++ b/ibis/expr/types/core.py @@ -737,3 +737,13 @@ def _binop(op_class: type[ops.Binary], left: ir.Value, right: ir.Value) -> ir.Va return NotImplemented else: return node.to_expr() + + +def _is_null_literal(value: Any) -> bool: + """Detect whether `value` will be treated by ibis as a null literal.""" + if value is None: + return True + if isinstance(value, Expr): + op = value.op() + return isinstance(op, ops.Literal) and op.value is None + return False diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 4a554e17c273..cf83acec709f 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -13,7 +13,7 @@ from ibis.common.deferred import Deferred, _, deferrable from ibis.common.grounds import Singleton from ibis.expr.rewrites import rewrite_window_input -from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin +from ibis.expr.types.core import Expr, _binop, _FixedTextJupyterMixin, _is_null_literal from ibis.expr.types.pretty import to_rich from ibis.util import deprecated, warn_deprecated @@ -1160,13 +1160,13 @@ def __hash__(self) -> int: return super().__hash__() def __eq__(self, other: Value) -> ir.BooleanValue: - if other is None: - return _binop(ops.IdenticalTo, self, other) + if _is_null_literal(other): + return self.isnull() return _binop(ops.Equals, self, other) def __ne__(self, other: Value) -> ir.BooleanValue: - if other is None: - return ~self.__eq__(other) + if _is_null_literal(other): + return self.notnull() return _binop(ops.NotEquals, self, other) def __ge__(self, other: Value) -> ir.BooleanValue: diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 5ded2434bfbf..88ac5d4884c1 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -350,6 +350,20 @@ def test_notnull(table): assert isinstance(expr.op(), ops.NotNull) +@pytest.mark.parametrize( + "value", + [ + param(lambda: None, id="none"), + param(lambda: ibis.NA, id="NA"), + param(lambda: ibis.literal(None, type="int32"), id="typed-null"), + ], +) +def test_null_eq_and_ne(table, value): + other = value() + assert (table.a == other).equals(table.a.isnull()) + assert (table.a != other).equals(table.a.notnull()) + + @pytest.mark.parametrize("column", ["e", "f"], ids=["float32", "double"]) def test_isnan_isinf_column(table, column): expr = table[column].isnan()