From f00b8ee1959f354950660b0b0c784d4f48f420c6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 20 Jul 2024 11:21:46 -0400 Subject: [PATCH] Expose array sort (#764) --- python/datafusion/functions.py | 22 ++++++++++++++++++++++ python/datafusion/tests/test_functions.py | 8 ++++++++ src/functions.rs | 4 ++++ 3 files changed, 34 insertions(+) diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index ad77712ed..46d2a2f04 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -1126,6 +1126,28 @@ def list_replace_all(array: Expr, from_val: Expr, to_val: Expr) -> Expr: return array_replace_all(array, from_val, to_val) +def array_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr: + """Sort an array. + + Args: + array: The input array to sort. + descending: If True, sorts in descending order. + null_first: If True, nulls will be returned at the beginning of the array. + """ + desc = "DESC" if descending else "ASC" + nulls_first = "NULLS FIRST" if null_first else "NULLS LAST" + return Expr( + f.array_sort( + array.expr, Expr.literal(desc).expr, Expr.literal(nulls_first).expr + ) + ) + + +def list_sort(array: Expr, descending: bool = False, null_first: bool = False) -> Expr: + """This is an alias for ``array_sort``.""" + return array_sort(array, descending=descending, null_first=null_first) + + def array_slice( array: Expr, begin: Expr, end: Expr, stride: Expr | None = None ) -> Expr: diff --git a/python/datafusion/tests/test_functions.py b/python/datafusion/tests/test_functions.py index 2384b6ab8..25d7de14a 100644 --- a/python/datafusion/tests/test_functions.py +++ b/python/datafusion/tests/test_functions.py @@ -453,6 +453,14 @@ def py_flatten(arr): lambda col: f.list_replace_all(col, literal(3.0), literal(4.0)), lambda data: [py_arr_replace(arr, 3.0, 4.0) for arr in data], ], + [ + lambda col: f.array_sort(col, descending=True, null_first=True), + lambda data: [np.sort(arr)[::-1] for arr in data], + ], + [ + lambda col: f.list_sort(col, descending=False, null_first=False), + lambda data: [np.sort(arr) for arr in data], + ], [ lambda col: f.array_slice(col, literal(2), literal(4)), lambda data: [arr[1:4] for arr in data], diff --git a/src/functions.rs b/src/functions.rs index 1e75dd5fb..74eb48a62 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -673,6 +673,8 @@ array_fn!(array_replace_n, array from to max); array_fn!(list_replace_n, array_replace_n, array from to max); array_fn!(array_replace_all, array from to); array_fn!(list_replace_all, array_replace_all, array from to); +array_fn!(array_sort, array desc null_first); +array_fn!(list_sort, array_sort, array desc null_first); array_fn!(array_intersect, first_array second_array); array_fn!(list_intersect, array_intersect, first_array second_array); array_fn!(array_union, array1 array2); @@ -936,6 +938,8 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(list_replace_n))?; m.add_wrapped(wrap_pyfunction!(array_replace_all))?; m.add_wrapped(wrap_pyfunction!(list_replace_all))?; + m.add_wrapped(wrap_pyfunction!(array_sort))?; + m.add_wrapped(wrap_pyfunction!(list_sort))?; m.add_wrapped(wrap_pyfunction!(array_slice))?; m.add_wrapped(wrap_pyfunction!(list_slice))?; m.add_wrapped(wrap_pyfunction!(flatten))?;