diff --git a/ext/zstdruby/common.h b/ext/zstdruby/common.h index 5971242..7eebe36 100644 --- a/ext/zstdruby/common.h +++ b/ext/zstdruby/common.h @@ -44,7 +44,7 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL } } -struct compress_params { +struct stream_compress_params { ZSTD_CCtx* ctx; ZSTD_outBuffer* output; ZSTD_inBuffer* input; @@ -52,21 +52,21 @@ struct compress_params { size_t ret; }; -static void* compress_wrapper(void* args) +static void* stream_compress_wrapper(void* args) { - struct compress_params* params = args; + struct stream_compress_params* params = args; params->ret = ZSTD_compressStream2(params->ctx, params->output, params->input, params->endOp); return NULL; } -static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl) +static size_t zstd_stream_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl) { #ifdef HAVE_RUBY_THREAD_H if (gvl) { return ZSTD_compressStream2(ctx, output, input, endOp); } else { - struct compress_params params = { ctx, output, input, endOp }; - rb_thread_call_without_gvl(compress_wrapper, ¶ms, NULL, NULL); + struct stream_compress_params params = { ctx, output, input, endOp }; + rb_thread_call_without_gvl(stream_compress_wrapper, ¶ms, NULL, NULL); return params.ret; } #else diff --git a/ext/zstdruby/streaming_compress.c b/ext/zstdruby/streaming_compress.c index 6628510..af01a3a 100644 --- a/ext/zstdruby/streaming_compress.c +++ b/ext/zstdruby/streaming_compress.c @@ -106,7 +106,7 @@ no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp) do { ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; - size_t const ret = zstd_compress(sc->ctx, &output, &input, endOp, false); + size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, endOp, false); if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret)); } @@ -130,7 +130,7 @@ rb_streaming_compress_compress(VALUE obj, VALUE src) VALUE result = rb_str_new(0, 0); while (input.pos < input.size) { ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; - size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); + size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret)); } @@ -157,7 +157,7 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj) while (input.pos < input.size) { ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 }; - size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); + size_t const ret = zstd_stream_compress(sc->ctx, &output, &input, ZSTD_e_continue, false); if (ZSTD_isError(ret)) { rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret)); } diff --git a/ext/zstdruby/zstdruby.c b/ext/zstdruby/zstdruby.c index 3b5051b..b6b4d8f 100644 --- a/ext/zstdruby/zstdruby.c +++ b/ext/zstdruby/zstdruby.c @@ -25,22 +25,19 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self) StringValue(input_value); char* input_data = RSTRING_PTR(input_value); size_t input_size = RSTRING_LEN(input_value); - ZSTD_inBuffer input = { input_data, input_size, 0 }; - // ZSTD_compressBound causes SEGV under multi-thread - size_t max_compressed_size = ZSTD_compressBound(input_size); - VALUE buf = rb_str_new(NULL, max_compressed_size); - char* output_data = RSTRING_PTR(buf); - ZSTD_outBuffer output = { (void*)output_data, max_compressed_size, 0 }; - size_t const ret = zstd_compress(ctx, &output, &input, ZSTD_e_end, true); + size_t const max_compressed_size = ZSTD_compressBound(input_size); + VALUE output = rb_str_new(NULL, max_compressed_size); + const char* output_data = RSTRING_PTR(output); + + size_t const ret = ZSTD_compress2(ctx,(void*)output_data, max_compressed_size, (void*)input_data, input_size); if (ZSTD_isError(ret)) { - ZSTD_freeCCtx(ctx); - rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(ret)); + rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret)); } - VALUE result = rb_str_new(0, 0); - rb_str_cat(result, output.dst, output.pos); + rb_str_resize(output, ret); + ZSTD_freeCCtx(ctx); - return result; + return output; } static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self) diff --git a/spec/zstd-ruby_spec.rb b/spec/zstd-ruby_spec.rb index 5569a7b..8f80f65 100644 --- a/spec/zstd-ruby_spec.rb +++ b/spec/zstd-ruby_spec.rb @@ -40,6 +40,12 @@ expect(compressed_default.length).to be < compressed_with_arg.length end + it 'should compress large bytes' do + large_string = Random.bytes(1<<17 + 15) + compressed = Zstd.compress(large_string) + expect(Zstd.decompress(compressed)).to eq(large_string) + end + it 'should raise exception with unsupported object' do expect { Zstd.compress(Object.new) }.to raise_error(TypeError) end