Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
JayjeetAtGithub committed Aug 12, 2024
1 parent 4080c1c commit bddda71
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions cpp/examples/tpch/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <cudf/transform.hpp>
#include <cudf/unary.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/cuda_memory_resource.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
Expand Down Expand Up @@ -536,24 +537,37 @@ void print_hardware_stats()
std::cout << std::endl;
}

cudf::io::source_info get_host_buff_data_source(std::unique_ptr<cudf::table> table,
std::vector<std::string> const& col_names)
cudf::io::source_info get_device_source(std::unique_ptr<cudf::table> table,
std::vector<std::string> const& col_names)
{
CUDF_FUNC_RANGE();
auto const stream = cudf::get_default_stream();

// Prepare the table metadata
cudf::io::table_metadata metadata;
std::vector<cudf::io::column_name_info> col_name_infos;
for (auto& col_name : col_names) {
col_name_infos.push_back(cudf::io::column_name_info(col_name));
}
metadata.schema_info = col_name_infos;
auto const table_input_metadata = cudf::io::table_input_metadata{metadata};

// Declare a host and device buffer
std::vector<char> h_buffer;
rmm::device_uvector<std::byte> d_buffer{0, stream};

// Write parquet data to host buffer
auto builder =
cudf::io::parquet_writer_options::builder(cudf::io::sink_info(&h_buffer), table->view());
builder.metadata(table_input_metadata);
auto const options = builder.build();
cudf::io::write_parquet(options);
return std::move(cudf::io::source_info(h_buffer.data(), h_buffer.size()));

// Copy host buffer to device buffer
d_buffer.resize(h_buffer.size(), stream);
CUDF_CUDA_TRY(cudaMemcpyAsync(
d_buffer.data(), h_buffer.data(), h_buffer.size(), cudaMemcpyDefault, stream.value()));
return cudf::io::source_info(d_buffer);
}

std::unordered_map<std::string, cudf::io::source_info> generate_data_sources(
Expand All @@ -579,22 +593,17 @@ std::unordered_map<std::string, cudf::io::source_info> generate_data_sources(

auto region = cudf::datagen::generate_region(cudf::get_default_stream(),
rmm::mr::get_current_device_resource());
std::cout << "X";
sources["orders"] = get_host_buff_data_source(std::move(orders), cudf::datagen::schema::ORDERS);
sources["lineitem"] =
get_host_buff_data_source(std::move(lineitem), cudf::datagen::schema::LINEITEM);
sources["part"] = get_host_buff_data_source(std::move(part), cudf::datagen::schema::PART);
sources["partsupp"] =
get_host_buff_data_source(std::move(partsupp), cudf::datagen::schema::PARTSUPP);
sources["supplier"] =
get_host_buff_data_source(std::move(supplier), cudf::datagen::schema::SUPPLIER);
sources["customer"] =
get_host_buff_data_source(std::move(customer), cudf::datagen::schema::CUSTOMER);
sources["nation"] = get_host_buff_data_source(std::move(nation), cudf::datagen::schema::NATION);
sources["region"] = get_host_buff_data_source(std::move(region), cudf::datagen::schema::REGION);
auto x = sources["region"].host_buffers().size();
std::cout << x;

// sources["orders"] = get_device_source(std::move(orders), cudf::datagen::schema::ORDERS);
sources["lineitem"] =
std::move(get_device_source(std::move(lineitem), cudf::datagen::schema::LINEITEM));
// sources["part"] = get_device_source(std::move(part), cudf::datagen::schema::PART);
// sources["partsupp"] = get_device_source(std::move(partsupp),
// cudf::datagen::schema::PARTSUPP); sources["supplier"] =
// get_device_source(std::move(supplier), cudf::datagen::schema::SUPPLIER); sources["customer"]
// = get_device_source(std::move(customer), cudf::datagen::schema::CUSTOMER); sources["nation"]
// = get_device_source(std::move(nation), cudf::datagen::schema::NATION); sources["region"] =
// get_device_source(std::move(region), cudf::datagen::schema::REGION);
} else {
sources["orders"] = cudf::io::source_info(dataset_source + "/orders.parquet");
sources["lineitem"] = cudf::io::source_info(dataset_source + "/lineitem.parquet");
Expand Down

0 comments on commit bddda71

Please sign in to comment.