Skip to content

Commit

Permalink
JSON tokenizer memory optimizations (#16978)
Browse files Browse the repository at this point in the history
The full push-down automata that tokenizes the input JSON string, as well as the bracket-brace FST over-estimates the total buffer size required for the translated output and indices. This PR splits the `transduce` calls for both FSTs into two invocations. The first invocation estimates the size of the translated buffer and the translated indices, and the second call performs the DFA run.

Authors:
  - Shruti Shivakumar (https://github.com/shrshi)
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Karthikeyan (https://github.com/karthikeyann)
  - Basit Ayantunde (https://github.com/lamarrr)

URL: #16978
  • Loading branch information
shrshi authored Oct 23, 2024
1 parent 27c0c9d commit cff1296
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions cpp/src/io/json/nested_json_gpu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1448,10 +1448,6 @@ void get_stack_context(device_span<SymbolT const> json_in,
// Number of stack operations in the input (i.e., number of '{', '}', '[', ']' outside of quotes)
cudf::detail::device_scalar<SymbolOffsetT> d_num_stack_ops(stream);

// Sequence of stack symbols and their position in the original input (sparse representation)
rmm::device_uvector<StackSymbolT> stack_ops{json_in.size(), stream};
rmm::device_uvector<SymbolOffsetT> stack_op_indices{json_in.size(), stream};

// Prepare finite-state transducer that only selects '{', '}', '[', ']' outside of quotes
constexpr auto max_translation_table_size =
to_stack_op::NUM_SYMBOL_GROUPS * to_stack_op::TT_NUM_STATES;
Expand All @@ -1468,11 +1464,26 @@ void get_stack_context(device_span<SymbolT const> json_in,

// "Search" for relevant occurrence of brackets and braces that indicate the beginning/end
// of structs/lists
// Run FST to estimate the sizes of translated buffers
json_to_stack_ops_fst.Transduce(json_in.begin(),
static_cast<SymbolOffsetT>(json_in.size()),
thrust::make_discard_iterator(),
thrust::make_discard_iterator(),
d_num_stack_ops.data(),
to_stack_op::start_state,
stream);

auto stack_ops_bufsize = d_num_stack_ops.value(stream);
// Sequence of stack symbols and their position in the original input (sparse representation)
rmm::device_uvector<StackSymbolT> stack_ops{stack_ops_bufsize, stream};
rmm::device_uvector<SymbolOffsetT> stack_op_indices{stack_ops_bufsize, stream};

// Run bracket-brace FST to retrieve starting positions of structs and lists
json_to_stack_ops_fst.Transduce(json_in.begin(),
static_cast<SymbolOffsetT>(json_in.size()),
stack_ops.data(),
stack_op_indices.data(),
d_num_stack_ops.data(),
thrust::make_discard_iterator(),
to_stack_op::start_state,
stream);

Expand Down Expand Up @@ -1508,6 +1519,7 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> pr
device_span<SymbolOffsetT const> token_indices,
rmm::cuda_stream_view stream)
{
CUDF_FUNC_RANGE();
// Instantiate FST for post-processing the token stream to remove all tokens that belong to an
// invalid JSON line
token_filter::UnwrapTokenFromSymbolOp sgid_op{};
Expand Down Expand Up @@ -1643,21 +1655,28 @@ std::pair<rmm::device_uvector<PdaTokenT>, rmm::device_uvector<SymbolOffsetT>> ge
// see a JSON-line delimiter as the very first item
SymbolOffsetT const delimiter_offset =
(format == tokenizer_pda::json_format_cfg_t::JSON_LINES_RECOVER ? 1 : 0);
rmm::device_uvector<PdaTokenT> tokens{max_token_out_count + delimiter_offset, stream, mr};
rmm::device_uvector<SymbolOffsetT> tokens_indices{
max_token_out_count + delimiter_offset, stream, mr};

// Run FST to estimate the size of output buffers
json_to_tokens_fst.Transduce(zip_in,
static_cast<SymbolOffsetT>(json_in.size()),
tokens.data() + delimiter_offset,
tokens_indices.data() + delimiter_offset,
thrust::make_discard_iterator(),
thrust::make_discard_iterator(),
num_written_tokens.data(),
tokenizer_pda::start_state,
stream);

auto const num_total_tokens = num_written_tokens.value(stream) + delimiter_offset;
tokens.resize(num_total_tokens, stream);
tokens_indices.resize(num_total_tokens, stream);
rmm::device_uvector<PdaTokenT> tokens{num_total_tokens, stream, mr};
rmm::device_uvector<SymbolOffsetT> tokens_indices{num_total_tokens, stream, mr};

// Run FST to translate the input JSON string into tokens and indices at which they occur
json_to_tokens_fst.Transduce(zip_in,
static_cast<SymbolOffsetT>(json_in.size()),
tokens.data() + delimiter_offset,
tokens_indices.data() + delimiter_offset,
thrust::make_discard_iterator(),
tokenizer_pda::start_state,
stream);

if (delimiter_offset == 1) {
tokens.set_element(0, token_t::LineEnd, stream);
Expand Down

0 comments on commit cff1296

Please sign in to comment.