From 5637d8fecd58b8c0f42775a6ac16b37e9452b0e1 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 19 Aug 2024 15:29:45 -0500 Subject: [PATCH] feat(exasol): implement `cov`/`corr` --- ibis/backends/sql/compilers/exasol.py | 15 ++++++++++++++- ibis/backends/tests/test_aggregation.py | 4 ++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index 21a9e1ce0d0c7..8b3bd7fb713c8 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -44,7 +44,6 @@ class ExasolCompiler(SQLGlotCompiler): ops.ArrayUnion, ops.ArrayZip, ops.BitwiseNot, - ops.Covariance, ops.CumeDist, ops.DateAdd, ops.DateSub, @@ -120,6 +119,20 @@ def visit_NonNullLiteral(self, op, *, value, dtype): def visit_Date(self, op, *, arg): return self.cast(arg, dt.date) + def visit_Correlation(self, op, *, left, right, how, where): + if how == "sample": + raise com.UnsupportedOperationError( + "Exasol only implements `pop` correlation coefficient" + ) + + if (left_type := op.left.dtype).is_boolean(): + left = self.cast(left, dt.Int32(nullable=left_type.nullable)) + + if (right_type := op.right.dtype).is_boolean(): + right = self.cast(right, dt.Int32(nullable=right_type.nullable)) + + return self.agg.corr(left, right, where=where) + def visit_GroupConcat(self, op, *, arg, sep, where, order_by): if where is not None: arg = self.if_(where, arg, NULL) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 40e2a5ca2501e..cdaa72d5fb6e1 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1013,7 +1013,7 @@ def test_quantile( raises=com.OperationNotDefinedError, ), pytest.mark.notyet( - ["postgres", "duckdb", "snowflake", "risingwave"], + ["postgres", "duckdb", "snowflake", "risingwave", "exasol"], raises=com.UnsupportedOperationError, reason="backend only implements population correlation coefficient", ), @@ -1114,7 +1114,7 @@ def test_quantile( ), ], ) -@pytest.mark.notimpl(["mssql", "exasol"], raises=com.OperationNotDefinedError) +@pytest.mark.notimpl(["mssql"], raises=com.OperationNotDefinedError) def test_corr_cov( con, batting,