Skip to content

Commit

Permalink
Merge pull request #83 from SpringMT/fix/compression-crash
Browse files Browse the repository at this point in the history
fix: compression crash
  • Loading branch information
SpringMT authored Apr 16, 2024
2 parents aeb79df + 8dce322 commit dc8c929
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
12 changes: 6 additions & 6 deletions ext/zstdruby/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,29 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL
}
}

struct compress_params {
struct stream_compress_params {
ZSTD_CCtx* ctx;
ZSTD_outBuffer* output;
ZSTD_inBuffer* input;
ZSTD_EndDirective endOp;
size_t ret;
};

static void* compress_wrapper(void* args)
static void* stream_compress_wrapper(void* args)
{
struct compress_params* params = args;
struct stream_compress_params* params = args;
params->ret = ZSTD_compressStream2(params->ctx, params->output, params->input, params->endOp);
return NULL;
}

static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl)
static size_t zstd_stream_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl)
{
#ifdef HAVE_RUBY_THREAD_H
if (gvl) {
return ZSTD_compressStream2(ctx, output, input, endOp);
} else {
struct compress_params params = { ctx, output, input, endOp };
rb_thread_call_without_gvl(compress_wrapper, &params, NULL, NULL);
struct stream_compress_params params = { ctx, output, input, endOp };
rb_thread_call_without_gvl(stream_compress_wrapper, &params, NULL, NULL);
return params.ret;
}
#else
Expand Down
6 changes: 3 additions & 3 deletions ext/zstdruby/streaming_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp)
do {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };

size_t const ret = zstd_compress(sc->ctx, &output, &input, endOp, false);
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret));
}
Expand All @@ -130,7 +130,7 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
VALUE result = rb_str_new(0, 0);
while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
Expand All @@ -157,7 +157,7 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)

while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
Expand Down
21 changes: 9 additions & 12 deletions ext/zstdruby/zstdruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,19 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
StringValue(input_value);
char* input_data = RSTRING_PTR(input_value);
size_t input_size = RSTRING_LEN(input_value);
ZSTD_inBuffer input = { input_data, input_size, 0 };
// ZSTD_compressBound causes SEGV under multi-thread
size_t max_compressed_size = ZSTD_compressBound(input_size);
VALUE buf = rb_str_new(NULL, max_compressed_size);
char* output_data = RSTRING_PTR(buf);
ZSTD_outBuffer output = { (void*)output_data, max_compressed_size, 0 };

size_t const ret = zstd_compress(ctx, &output, &input, ZSTD_e_end, true);
size_t const max_compressed_size = ZSTD_compressBound(input_size);
VALUE output = rb_str_new(NULL, max_compressed_size);
const char* output_data = RSTRING_PTR(output);

size_t const ret = ZSTD_compress2(ctx,(void*)output_data, max_compressed_size, (void*)input_data, input_size);
if (ZSTD_isError(ret)) {
ZSTD_freeCCtx(ctx);
rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(ret));
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
VALUE result = rb_str_new(0, 0);
rb_str_cat(result, output.dst, output.pos);
rb_str_resize(output, ret);

ZSTD_freeCCtx(ctx);
return result;
return output;
}

static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)
Expand Down
6 changes: 6 additions & 0 deletions spec/zstd-ruby_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
expect(compressed_default.length).to be < compressed_with_arg.length
end

it 'should compress large bytes' do
large_string = Random.bytes(1<<17 + 15)
compressed = Zstd.compress(large_string)
expect(Zstd.decompress(compressed)).to eq(large_string)
end

it 'should raise exception with unsupported object' do
expect { Zstd.compress(Object.new) }.to raise_error(TypeError)
end
Expand Down

0 comments on commit dc8c929

Please sign in to comment.