Skip to content

Commit

Permalink
Did more work on arrowArray visitor WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
Giorgi Lomia committed Oct 4, 2021
1 parent eb13e98 commit 25583b4
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 98 deletions.
8 changes: 8 additions & 0 deletions libsupport/include/katana/ArrowVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,14 @@ using AcceptAllArrowTypes = std::tuple<
arrow::LargeStringType, arrow::StructType, arrow::ListType,
arrow::LargeListType, arrow::NullType>;

using AcceptAllFlatTypes = std::tuple<
arrow::Int8Type, arrow::UInt8Type, arrow::Int16Type, arrow::UInt16Type,
arrow::Int32Type, arrow::UInt32Type, arrow::Int64Type, arrow::UInt64Type,
arrow::FloatType, arrow::DoubleType, arrow::FloatType, arrow::DoubleType,
arrow::BooleanType, arrow::Date32Type, arrow::Date64Type, arrow::Time32Type,
arrow::Time64Type, arrow::TimestampType, arrow::StringType,
arrow::LargeStringType, arrow::NullType>;

template <typename... Args>
using tuple_cat_t = decltype(std::tuple_cat(std::declval<Args>()...));

Expand Down
178 changes: 80 additions & 98 deletions tools/graph-stats/graph-memory-stats.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <string>
#include <unordered_map>

#include <arrow/array/builder_base.h>
#include <arrow/type_traits.h>

#include "katana/ArrowVisitor.h"
Expand All @@ -43,137 +42,122 @@ using map_string_element = std::unordered_map<std::string, std::string>;
using memory_map = std::unordered_map<
std::string, std::variant<map_element, map_string_element>>;

inline std::shared_ptr<arrow::DataType>
GetArrowType(const arrow::Scalar& scalar) {
return scalar.type;
}

inline std::shared_ptr<arrow::DataType>
GetArrowType(const arrow::Array& array) {
return array.type();
}

inline std::shared_ptr<arrow::DataType>
GetArrowType(const arrow::ArrayBuilder* builder) {
return builder->type();
}

struct Visitor : public katana::ArrowVisitor {
using ResultType = katana::Result<int64_t>;
using AcceptTypes = std::tuple<
arrow::Int8Type, arrow::UInt8Type, arrow::Int16Type, arrow::UInt16Type,
arrow::Int32Type, arrow::UInt32Type, arrow::Int64Type, arrow::UInt64Type,
arrow::FloatType, arrow::DoubleType, arrow::FloatType, arrow::DoubleType,
arrow::BooleanType, arrow::Date32Type, arrow::Date64Type,
arrow::Time32Type, arrow::Time64Type, arrow::TimestampType,
arrow::StringType, arrow::LargeStringType, arrow::StructType,
arrow::NullType>;

template <typename ArrowType, typename ScalarType>
using AcceptTypes = std::tuple<katana::AcceptAllFlatTypes>;

template <typename ArrowType, typename ArrayType>
arrow::enable_if_null<ArrowType, ResultType> Call(const ArrayType& scalars) {
std::cout << scalars.total_values_length() << "\n";
return 0;
}

template <typename ArrowType, typename ArrayType>
std::enable_if_t<
arrow::is_number_type<ArrowType>::value ||
arrow::is_boolean_type<ArrowType>::value ||
arrow::is_temporal_type<ArrowType>::value,
ResultType>
Call(const ScalarType& scalar) {
return scalar.value;
Call(const ArrayType& scalars) {
// ResultType width = 0;
std::cout << scalars.total_values_length() << "\n";
return 0;
}

template <typename ArrowType, typename ScalarType>
template <typename ArrowType, typename ArrayType>
arrow::enable_if_string_like<ArrowType, ResultType> Call(
const ScalarType& scalar) {
const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
auto res = (arrow::util::string_view)(*typed_scalar->value);
// TODO (giorgi): make this KATANA_CHECKED
// if (!res.ok()) {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
// res);
// }
return res;
const ArrayType& scalars) {
std::cout << scalars.total_values_length() << "\n";

return 0;
}

ResultType AcceptFailed(const arrow::Scalar& scalar) {
template <typename Param>
ResultType AcceptFailed(Param&& param) {
return KATANA_ERROR(
katana::ErrorCode::ArrowError, "no matching type {}",
scalar.type->name());
"Instant functions do not accept {}", GetArrowType(param)->ToString());
}
};

// struct ToArrayVisitor : public katana::ArrowVisitor {
// // Internal data and constructor
// const std::shared_ptr<arrow::Array> scalars;
// ToArrayVisitor(const std::shared_ptr<arrow::Array> input) : scalars(input) {}

// using ResultType = katana::Result<std::shared_ptr<arrow::Array>>;

// struct Visitor : public katana::ArrowVisitor {
// const std::shared_ptr<arrow::Scalar>& scalar;
// Visitor(const std::shared_ptr<arrow::Scalar>& input) : scalar(input) {}
// using ResultType = katana::Result<int64_t>;
// using AcceptTypes = std::tuple<katana::AcceptAllArrowTypes>;

// template <typename ArrowType, typename BuilderType>
// arrow::enable_if_null<ArrowType, ResultType> Call(BuilderType* builder) {
// return KATANA_CHECKED(builder->Finish());
// template <typename ArrowType, typename WidthType>
// arrow::enable_if_null<ArrowType, ResultType> Call(
// const WidthType& width_tracker) {
// width_tracker = 0;
// return width_tracker;
// }

// template <typename ArrowType, typename BuilderType>
// template <typename ArrowType, typename WidthType>
// std::enable_if_t<
// arrow::is_number_type<ArrowType>::value ||
// arrow::is_boolean_type<ArrowType>::value ||
// arrow::is_temporal_type<ArrowType>::value,
// ResultType>
// Call(BuilderType* builder) {
// Call(const WidthType& width_tracker) {
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;

// KATANA_CHECKED(builder->Reserve(scalars->length()));
// for (auto j = 0; j < scalars->length(); j++) {
// auto scalar = *scalars->GetScalar(j);
// if (scalar != nullptr && scalar->is_valid) {
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// builder->UnsafeAppend(typed_scalar->value);
// } else {
// builder->UnsafeAppendNull();
// }
// if (scalar != nullptr && scalar->is_valid) {
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// return typed_scalar->value;
// } else {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
// }
// return KATANA_CHECKED(builder->Finish());
// }

// template <typename ArrowType, typename BuilderType>
// template <typename ArrowType, typename WidthType>
// arrow::enable_if_string_like<ArrowType, ResultType> Call(
// BuilderType* builder) {
// const WidthType& width_tracker) {
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
// // same as above, but with string_view and Append instead of UnsafeAppend
// for (auto j = 0; j < scalars->length(); j++) {
// auto scalar = *scalars->GetScalar(j);
// if (scalar != nullptr && scalar->is_valid) {
// // ->value->ToString() works, scalar->ToString() yields "..."
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// if (auto res = builder->Append(
// (arrow::util::string_view)(*typed_scalar->value));
// !res.ok()) {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError, "arrow builder failed append: {}",
// res);
// }
// } else {
// if (auto res = builder->AppendNull(); !res.ok()) {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError,
// "arrow builder failed append null: {}", res);
// }
// }
// if (scalar != nullptr && scalar->is_valid) {
// // ->value->ToString() works, scalar->ToString() yields "..."
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// auto res = (arrow::util::string_view)(*typed_scalar->value);
// return res;
// } else {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError, "arrow visitor failed to read: NULL");
// }
// return KATANA_CHECKED(builder->Finish());
// }

// template <typename ArrowType, typename BuilderType>
// template <typename ArrowType, typename WidthType>
// std::enable_if_t<
// arrow::is_list_type<ArrowType>::value ||
// arrow::is_struct_type<ArrowType>::value,
// ResultType>
// Call(BuilderType* builder) {
// Call(const WidthType& width_tracker) {
// using ScalarType = typename arrow::TypeTraits<ArrowType>::ScalarType;
// // use a visitor to traverse more complex types
// katana::AppendScalarToBuilder visitor(builder);
// for (auto j = 0; j < scalars->length(); j++) {
// auto scalar = *scalars->GetScalar(j);
// if (scalar != nullptr && scalar->is_valid) {
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
// } else {
// KATANA_CHECKED(builder->AppendNull());
// }
// Visitor visitor(scalar);
// if (scalar != nullptr && scalar->is_valid) {
// const ScalarType* typed_scalar = static_cast<ScalarType*>(scalar.get());
// KATANA_CHECKED(visitor.Call<ArrowType>(*typed_scalar));
// }
// return KATANA_CHECKED(builder->Finish());
// }

// ResultType AcceptFailed(const arrow::ArrayBuilder* builder) {
// ResultType AcceptFailed(const arrow::Scalar& scalar) {
// return KATANA_ERROR(
// katana::ErrorCode::ArrowError, "no matching type {}",
// builder->type()->name());
// scalar.type->name());
// }
// };

Expand Down Expand Up @@ -202,19 +186,17 @@ PrintStringMapping(const std::unordered_map<std::string, std::string>& u) {
std::cout << "\n";
}

katana::Result<std::shared_ptr<arrow::Array>>
void
RunVisit(const std::shared_ptr<arrow::Array> scalars) {
Visitor v;
int64_t total = 0;
for (auto j = 0; j < scalars->length(); j++) {
auto s = *scalars->GetScalar(j);
auto res = katana::VisitArrow(v, *s);
KATANA_LOG_VASSERT(res, "unexpected errror {}", res.error());
total += res.value();
}
Visitor v;
arrow::Array* arr = scalars.get();
auto res = katana::VisitArrow(v, *arr);
KATANA_LOG_VASSERT(res, "unexpected errror {}", res.error());
total += res.value();

KATANA_LOG_VASSERT(
total == scalars->length(), "{} != {}", total, scalars->length());
// KATANA_LOG_VASSERT(
// total == scalars->length(), "{} != {}", total, scalars->length());
}

void
Expand Down Expand Up @@ -258,7 +240,7 @@ GatherMemoryAllocation(
alloc_size = 0;
prop_size = 0;
auto bit_width = arrow::bit_width(dtype->id());
auto visited_arr = RunVisit(prop_field);
RunVisit(prop_field);
for (auto j = 0; j < prop_field->length(); j++) {
if (prop_field->IsValid(j)) {
auto scal_ptr = *prop_field->GetScalar(j);
Expand Down

0 comments on commit 25583b4

Please sign in to comment.