diff --git a/src/main/cpp/src/datetime_truncate.cu b/src/main/cpp/src/datetime_truncate.cu index 8eb5d092a..3444a6b12 100644 --- a/src/main/cpp/src/datetime_truncate.cu +++ b/src/main/cpp/src/datetime_truncate.cu @@ -40,9 +40,13 @@ namespace detail { namespace { /** - * @brief Mark the date/time component to truncate. + * @brief The date/time format to perform truncation. + * + * The format must match the descriptions in the Apache Spark documentation: + * - https://spark.apache.org/docs/latest/api/sql/index.html#trunc + * - https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc */ -enum class truncate_component : uint8_t { +enum class truncation_format : uint8_t { YEAR, YYYY, YY, @@ -63,10 +67,10 @@ enum class truncate_component : uint8_t { __device__ char to_upper(unsigned char const c) { return ('a' <= c && c <= 'z') ? c ^ 0x20 : c; } -// Parse the component to truncate from a string. -__device__ truncate_component parse_component(cudf::string_view const format) +// Parse the truncation format from a string. +__device__ truncation_format parse_format(cudf::string_view const format) { - // This must be kept in sync with the `truncate_component` enum. + // This must be kept in sync with the `truncation_format` enum. char const* components[] = {"YEAR", "YYYY", "YY", @@ -83,11 +87,9 @@ __device__ truncate_component parse_component(cudf::string_view const format) "MILLISECOND", "MICROSECOND"}; // Manually calculate sizes of the strings since `strlen` is not available in device code. - cudf::size_type const comp_sizes[] = {4, 4, 2, 7, 5, 2, 3, 4, 3, 2, 4, 6, 6, 11, 11}; - - auto constexpr num_components = std::size(components); + cudf::size_type constexpr comp_sizes[] = {4, 4, 2, 7, 5, 2, 3, 4, 3, 2, 4, 6, 6, 11, 11}; + auto constexpr num_components = std::size(components); - // auto const num_components = sizeof(components) / sizeof(components[0]); auto const* in_ptr = reinterpret_cast(format.data()); auto const in_len = format.size_bytes(); for (std::size_t comp_idx = 0; comp_idx < num_components; ++comp_idx) { @@ -100,21 +102,23 @@ __device__ truncate_component parse_component(cudf::string_view const format) break; } } - if (equal) { return static_cast(comp_idx); } + if (equal) { return static_cast(comp_idx); } } - return truncate_component::INVALID; + return truncation_format::INVALID; } +// Truncate the given month to the first month of the quarter. __device__ inline uint32_t trunc_quarter_month(uint32_t month) { auto const zero_based_month = month - 1u; return (zero_based_month / 3u) * 3u + 1u; } +// Truncate the given day to the previous Monday. __device__ inline cuda::std::chrono::sys_days trunc_to_monday( cuda::std::chrono::sys_days const days_since_epoch) { - // Have to define our constant as `cuda::std::chrono::Monday` is not available in device code. + // Define our `MONDAY` constant as `cuda::std::chrono::Monday` is not available in device code. // [0, 6] => [Sun, Sat] auto constexpr MONDAY = cuda::std::chrono::weekday{1}; auto const weekday = cuda::std::chrono::weekday{days_since_epoch}; @@ -132,22 +136,22 @@ template __device__ inline thrust::optional trunc_date( cuda::std::chrono::sys_days const days_since_epoch, cuda::std::chrono::year_month_day const ymd, - truncate_component const trunc_comp) + truncation_format const trunc_comp) { using namespace cuda::std::chrono; switch (trunc_comp) { - case truncate_component::YEAR: - case truncate_component::YYYY: - case truncate_component::YY: + case truncation_format::YEAR: + case truncation_format::YYYY: + case truncation_format::YY: return Timestamp{sys_days{year_month_day{ymd.year(), month{1}, day{1}}}}; - case truncate_component::QUARTER: + case truncation_format::QUARTER: return Timestamp{sys_days{year_month_day{ ymd.year(), month{trunc_quarter_month(static_cast(ymd.month()))}, day{1}}}}; - case truncate_component::MONTH: - case truncate_component::MM: - case truncate_component::MON: + case truncation_format::MONTH: + case truncation_format::MM: + case truncation_format::MON: return Timestamp{sys_days{year_month_day{ymd.year(), ymd.month(), day{1}}}}; - case truncate_component::WEEK: return Timestamp{trunc_to_monday(days_since_epoch)}; + case truncation_format::WEEK: return Timestamp{trunc_to_monday(days_since_epoch)}; default: return thrust::nullopt; } } @@ -166,8 +170,8 @@ struct truncate_date_fn { } auto const fmt = format.element(format_idx); - auto const trunc_comp = parse_component(fmt); - if (trunc_comp == truncate_component::INVALID) { return {Timestamp{}, false}; } + auto const trunc_comp = parse_format(fmt); + if (trunc_comp == truncation_format::INVALID) { return {Timestamp{}, false}; } using namespace cuda::std::chrono; auto const ts = datetime.element(datetime_idx); @@ -193,16 +197,15 @@ struct truncate_timestamp_fn { } auto const fmt = format.element(format_idx); - auto const trunc_comp = parse_component(fmt); - if (trunc_comp == truncate_component::INVALID) { return {Timestamp{}, false}; } + auto const trunc_comp = parse_format(fmt); + if (trunc_comp == truncation_format::INVALID) { return {Timestamp{}, false}; } - using namespace cuda::std::chrono; auto const ts = datetime.element(datetime_idx); - // No truncation needed for microsecond timestamps. - if (trunc_comp == truncate_component::MICROSECOND) { return {ts, true}; } + if (trunc_comp == truncation_format::MICROSECOND) { return {ts, true}; } // The components that are common for both date and timestamp: year, quarter, month, week. + using namespace cuda::std::chrono; auto const days_since_epoch = floor(ts); auto const ymd = year_month_day(days_since_epoch); if (auto const try_trunc_date = trunc_date(days_since_epoch, ymd, trunc_comp); @@ -214,17 +217,17 @@ struct truncate_timestamp_fn { if (time_since_midnight.count() < 0) { time_since_midnight += days(1); } switch (trunc_comp) { - case truncate_component::DAY: - case truncate_component::DD: return {Timestamp{sys_days{ymd}}, true}; - case truncate_component::HOUR: + case truncation_format::DAY: + case truncation_format::DD: return {Timestamp{sys_days{ymd}}, true}; + case truncation_format::HOUR: return {Timestamp{sys_days{ymd} + floor(time_since_midnight)}, true}; - case truncate_component::MINUTE: + case truncation_format::MINUTE: return {Timestamp{sys_days{ymd} + floor(time_since_midnight)}, true}; - case truncate_component::SECOND: + case truncation_format::SECOND: return {Timestamp{sys_days{ymd} + floor(time_since_midnight)}, true}; - case truncate_component::MILLISECOND: + case truncation_format::MILLISECOND: return {Timestamp{sys_days{ymd} + floor(time_since_midnight)}, true}; - default: CUDF_UNREACHABLE("Unhandled truncating component."); + default: CUDF_UNREACHABLE("Unhandled truncation format."); } } }; @@ -238,6 +241,8 @@ std::unique_ptr truncate(cudf::column_view const& datetime, CUDF_EXPECTS( type_id == cudf::type_id::TIMESTAMP_DAYS || type_id == cudf::type_id::TIMESTAMP_MICROSECONDS, "The input must be either day or microsecond timestamps."); + CUDF_EXPECTS(format.type().id() == cudf::type_id::STRING, + "The format column must be of string type."); auto const size = std::max(datetime.size(), format.size()); if (datetime.size() == 0 || format.size() == 0 || datetime.size() == datetime.null_count() ||