Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Nghia Truong <[email protected]>
  • Loading branch information
ttnghia committed Dec 7, 2024
1 parent 2fd0282 commit fbc127b
Showing 1 changed file with 40 additions and 35 deletions.
75 changes: 40 additions & 35 deletions src/main/cpp/src/datetime_truncate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ namespace detail {

namespace {
/**
* @brief Mark the date/time component to truncate.
* @brief The date/time format to perform truncation.
*
* The format must match the descriptions in the Apache Spark documentation:
* - https://spark.apache.org/docs/latest/api/sql/index.html#trunc
* - https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc
*/
enum class truncate_component : uint8_t {
enum class truncation_format : uint8_t {
YEAR,
YYYY,
YY,
Expand All @@ -63,10 +67,10 @@ enum class truncate_component : uint8_t {

__device__ char to_upper(unsigned char const c) { return ('a' <= c && c <= 'z') ? c ^ 0x20 : c; }

// Parse the component to truncate from a string.
__device__ truncate_component parse_component(cudf::string_view const format)
// Parse the truncation format from a string.
__device__ truncation_format parse_format(cudf::string_view const format)
{
// This must be kept in sync with the `truncate_component` enum.
// This must be kept in sync with the `truncation_format` enum.
char const* components[] = {"YEAR",
"YYYY",
"YY",
Expand All @@ -83,11 +87,9 @@ __device__ truncate_component parse_component(cudf::string_view const format)
"MILLISECOND",
"MICROSECOND"};
// Manually calculate sizes of the strings since `strlen` is not available in device code.
cudf::size_type const comp_sizes[] = {4, 4, 2, 7, 5, 2, 3, 4, 3, 2, 4, 6, 6, 11, 11};

auto constexpr num_components = std::size(components);
cudf::size_type constexpr comp_sizes[] = {4, 4, 2, 7, 5, 2, 3, 4, 3, 2, 4, 6, 6, 11, 11};
auto constexpr num_components = std::size(components);

// auto const num_components = sizeof(components) / sizeof(components[0]);
auto const* in_ptr = reinterpret_cast<unsigned char const*>(format.data());
auto const in_len = format.size_bytes();
for (std::size_t comp_idx = 0; comp_idx < num_components; ++comp_idx) {
Expand All @@ -100,21 +102,23 @@ __device__ truncate_component parse_component(cudf::string_view const format)
break;
}
}
if (equal) { return static_cast<truncate_component>(comp_idx); }
if (equal) { return static_cast<truncation_format>(comp_idx); }
}
return truncate_component::INVALID;
return truncation_format::INVALID;
}

// Truncate the given month to the first month of the quarter.
__device__ inline uint32_t trunc_quarter_month(uint32_t month)
{
auto const zero_based_month = month - 1u;
return (zero_based_month / 3u) * 3u + 1u;
}

// Truncate the given day to the previous Monday.
__device__ inline cuda::std::chrono::sys_days trunc_to_monday(
cuda::std::chrono::sys_days const days_since_epoch)
{
// Have to define our constant as `cuda::std::chrono::Monday` is not available in device code.
// Define our `MONDAY` constant as `cuda::std::chrono::Monday` is not available in device code.
// [0, 6] => [Sun, Sat]
auto constexpr MONDAY = cuda::std::chrono::weekday{1};
auto const weekday = cuda::std::chrono::weekday{days_since_epoch};
Expand All @@ -132,22 +136,22 @@ template <typename Timestamp>
__device__ inline thrust::optional<Timestamp> trunc_date(
cuda::std::chrono::sys_days const days_since_epoch,
cuda::std::chrono::year_month_day const ymd,
truncate_component const trunc_comp)
truncation_format const trunc_comp)
{
using namespace cuda::std::chrono;
switch (trunc_comp) {
case truncate_component::YEAR:
case truncate_component::YYYY:
case truncate_component::YY:
case truncation_format::YEAR:
case truncation_format::YYYY:
case truncation_format::YY:
return Timestamp{sys_days{year_month_day{ymd.year(), month{1}, day{1}}}};
case truncate_component::QUARTER:
case truncation_format::QUARTER:
return Timestamp{sys_days{year_month_day{
ymd.year(), month{trunc_quarter_month(static_cast<uint32_t>(ymd.month()))}, day{1}}}};
case truncate_component::MONTH:
case truncate_component::MM:
case truncate_component::MON:
case truncation_format::MONTH:
case truncation_format::MM:
case truncation_format::MON:
return Timestamp{sys_days{year_month_day{ymd.year(), ymd.month(), day{1}}}};
case truncate_component::WEEK: return Timestamp{trunc_to_monday(days_since_epoch)};
case truncation_format::WEEK: return Timestamp{trunc_to_monday(days_since_epoch)};
default: return thrust::nullopt;
}
}
Expand All @@ -166,8 +170,8 @@ struct truncate_date_fn {
}

auto const fmt = format.element<cudf::string_view>(format_idx);
auto const trunc_comp = parse_component(fmt);
if (trunc_comp == truncate_component::INVALID) { return {Timestamp{}, false}; }
auto const trunc_comp = parse_format(fmt);
if (trunc_comp == truncation_format::INVALID) { return {Timestamp{}, false}; }

using namespace cuda::std::chrono;
auto const ts = datetime.element<Timestamp>(datetime_idx);
Expand All @@ -193,16 +197,15 @@ struct truncate_timestamp_fn {
}

auto const fmt = format.element<cudf::string_view>(format_idx);
auto const trunc_comp = parse_component(fmt);
if (trunc_comp == truncate_component::INVALID) { return {Timestamp{}, false}; }
auto const trunc_comp = parse_format(fmt);
if (trunc_comp == truncation_format::INVALID) { return {Timestamp{}, false}; }

using namespace cuda::std::chrono;
auto const ts = datetime.element<Timestamp>(datetime_idx);

// No truncation needed for microsecond timestamps.
if (trunc_comp == truncate_component::MICROSECOND) { return {ts, true}; }
if (trunc_comp == truncation_format::MICROSECOND) { return {ts, true}; }

// The components that are common for both date and timestamp: year, quarter, month, week.
using namespace cuda::std::chrono;
auto const days_since_epoch = floor<days>(ts);
auto const ymd = year_month_day(days_since_epoch);
if (auto const try_trunc_date = trunc_date<Timestamp>(days_since_epoch, ymd, trunc_comp);
Expand All @@ -214,17 +217,17 @@ struct truncate_timestamp_fn {
if (time_since_midnight.count() < 0) { time_since_midnight += days(1); }

switch (trunc_comp) {
case truncate_component::DAY:
case truncate_component::DD: return {Timestamp{sys_days{ymd}}, true};
case truncate_component::HOUR:
case truncation_format::DAY:
case truncation_format::DD: return {Timestamp{sys_days{ymd}}, true};
case truncation_format::HOUR:
return {Timestamp{sys_days{ymd} + floor<hours>(time_since_midnight)}, true};
case truncate_component::MINUTE:
case truncation_format::MINUTE:
return {Timestamp{sys_days{ymd} + floor<minutes>(time_since_midnight)}, true};
case truncate_component::SECOND:
case truncation_format::SECOND:
return {Timestamp{sys_days{ymd} + floor<seconds>(time_since_midnight)}, true};
case truncate_component::MILLISECOND:
case truncation_format::MILLISECOND:
return {Timestamp{sys_days{ymd} + floor<milliseconds>(time_since_midnight)}, true};
default: CUDF_UNREACHABLE("Unhandled truncating component.");
default: CUDF_UNREACHABLE("Unhandled truncation format.");
}
}
};
Expand All @@ -238,6 +241,8 @@ std::unique_ptr<cudf::column> truncate(cudf::column_view const& datetime,
CUDF_EXPECTS(
type_id == cudf::type_id::TIMESTAMP_DAYS || type_id == cudf::type_id::TIMESTAMP_MICROSECONDS,
"The input must be either day or microsecond timestamps.");
CUDF_EXPECTS(format.type().id() == cudf::type_id::STRING,
"The format column must be of string type.");

auto const size = std::max(datetime.size(), format.size());
if (datetime.size() == 0 || format.size() == 0 || datetime.size() == datetime.null_count() ||
Expand Down

0 comments on commit fbc127b

Please sign in to comment.