Skip to content

Commit

Permalink
Merge pull request #79 from SpringMT/feature/unlock-gvl
Browse files Browse the repository at this point in the history
Feature/unlock gvl for streaming compression/decompression
  • Loading branch information
SpringMT authored Apr 13, 2024
2 parents 788f4f5 + 04d74fc commit cc225b4
Show file tree
Hide file tree
Showing 13 changed files with 142 additions and 58 deletions.
13 changes: 11 additions & 2 deletions .github/workflows/ruby.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@

name: Ruby

on: [push, pull_request]
on:
push:
branches:
- main
paths-ignore:
- 'README.md'
pull_request:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
test:

runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
require 'benchmark/ips'
$LOAD_PATH.unshift '../lib'
require 'json'
require 'objspace'
require 'zstd-ruby'
require 'thread'

Expand All @@ -19,7 +16,11 @@
THREADS.times.map {
Thread.new {
while str = queue.pop
Zstd.compress(str)
stream = Zstd::StreamingCompress.new
stream << str
res = stream.flush
stream << str
res << stream.finish
end
}
}.each(&:join)
25 changes: 25 additions & 0 deletions benchmarks/multi_thread_streaming_decomporess.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
$LOAD_PATH.unshift '../lib'
require 'zstd-ruby'
require 'thread'

GUESSES = (ENV['GUESSES'] || 1000).to_i
THREADS = (ENV['THREADS'] || 1).to_i

p GUESSES: GUESSES, THREADS: THREADS

sample_file_name = ARGV[0]
json_string = File.read("./samples/#{sample_file_name}")
target = Zstd.compress(json_string)

queue = Queue.new
GUESSES.times { queue << target }
THREADS.times { queue << nil }
THREADS.times.map {
Thread.new {
while str = queue.pop
stream = Zstd::StreamingDecompress.new
stream.decompress(str)
stream.decompress(str)
end
}
}.each(&:join)
2 changes: 1 addition & 1 deletion benchmarks/zstd_compress_memory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

sample_file_name = ARGV[0]

json_data = JSON.parse(IO.read("./samples/#{sample_file_name}"), symbolize_names: true)
json_data = JSON.parse(File.read("./samples/#{sample_file_name}"), symbolize_names: true)
json_string = json_data.to_json

i = 0
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/zstd_decompress_memory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
p "#{ObjectSpace.memsize_of_all/1000} #{ObjectSpace.count_objects} #{`ps -o rss= -p #{Process.pid}`.to_i}"

sample_file_name = ARGV[0]
json_data = JSON.parse(IO.read("./samples/#{sample_file_name}"), symbolize_names: true)
json_data = JSON.parse(File.read("./samples/#{sample_file_name}"), symbolize_names: true)
json_string = json_data.to_json

i = 0
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/zstd_streaming_compress_memory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

sample_file_name = ARGV[0]

json_string = IO.read("./samples/#{sample_file_name}")
json_string = File.read("./samples/#{sample_file_name}")

i = 0
start_time = Time.now
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/zstd_streaming_decompress_memory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

sample_file_name = ARGV[0]

cstr = IO.read("./results/#{sample_file_name}.zstd")
cstr = File.read("./results/#{sample_file_name}.zstd")
i = 0
start_time = Time.now
while true do
Expand Down
2 changes: 1 addition & 1 deletion examples/sinatra/Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: ../..
specs:
zstd-ruby (1.5.6.1)
zstd-ruby (1.5.6.2)

GEM
remote: https://rubygems.org/
Expand Down
68 changes: 63 additions & 5 deletions ext/zstdruby/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define ZSTD_RUBY_H 1

#include <ruby.h>
#ifdef HAVE_RUBY_THREAD_H
#include <ruby/thread.h>
#endif
#include <stdbool.h>
#include "./libzstd/zstd.h"

static int convert_compression_level(VALUE compression_level_value)
Expand All @@ -12,11 +16,6 @@ static int convert_compression_level(VALUE compression_level_value)
return NUM2INT(compression_level_value);
}

static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp)
{
return ZSTD_compressStream2(ctx, output, input, endOp);
}

static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VALUE kwargs)
{
ID kwargs_keys[2];
Expand Down Expand Up @@ -45,6 +44,36 @@ static void set_compress_params(ZSTD_CCtx* const ctx, VALUE level_from_args, VAL
}
}

struct compress_params {
ZSTD_CCtx* ctx;
ZSTD_outBuffer* output;
ZSTD_inBuffer* input;
ZSTD_EndDirective endOp;
size_t ret;
};

static void* compress_wrapper(void* args)
{
struct compress_params* params = args;
params->ret = ZSTD_compressStream2(params->ctx, params->output, params->input, params->endOp);
return NULL;
}

static size_t zstd_compress(ZSTD_CCtx* const ctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, ZSTD_EndDirective endOp, bool gvl)
{
#ifdef HAVE_RUBY_THREAD_H
if (gvl) {
return ZSTD_compressStream2(ctx, output, input, endOp);
} else {
struct compress_params params = { ctx, output, input, endOp };
rb_thread_call_without_gvl(compress_wrapper, &params, NULL, NULL);
return params.ret;
}
#else
return ZSTD_compressStream2(ctx, output, input, endOp);
#endif
}

static void set_decompress_params(ZSTD_DCtx* const dctx, VALUE kwargs)
{
ID kwargs_keys[1];
Expand All @@ -63,4 +92,33 @@ static void set_decompress_params(ZSTD_DCtx* const dctx, VALUE kwargs)
}
}

struct decompress_params {
ZSTD_DCtx* dctx;
ZSTD_outBuffer* output;
ZSTD_inBuffer* input;
size_t ret;
};

static void* decompress_wrapper(void* args)
{
struct decompress_params* params = args;
params->ret = ZSTD_decompressStream(params->dctx, params->output, params->input);
return NULL;
}

static size_t zstd_decompress(ZSTD_DCtx* const dctx, ZSTD_outBuffer* output, ZSTD_inBuffer* input, bool gvl)
{
#ifdef HAVE_RUBY_THREAD_H
if (gvl) {
return ZSTD_decompressStream(dctx, output, input);
} else {
struct decompress_params params = { dctx, output, input };
rb_thread_call_without_gvl(decompress_wrapper, &params, NULL, NULL);
return params.ret;
}
#else
return ZSTD_decompressStream(dctx, output, input);
#endif
}

#endif /* ZSTD_RUBY_H */
2 changes: 1 addition & 1 deletion ext/zstdruby/extconf.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

have_func('rb_gc_mark_movable')

$CFLAGS = '-I. -O3 -std=c99 -DZSTD_STATIC_LINKING_ONLY'
$CFLAGS = '-I. -O3 -std=c99 -DZSTD_STATIC_LINKING_ONLY -DZSTD_MULTITHREAD -pthread -DDEBUGLEVEL=0'
$CPPFLAGS += " -fdeclspec" if CONFIG['CXX'] =~ /clang/

Dir.chdir File.expand_path('..', __FILE__) do
Expand Down
6 changes: 3 additions & 3 deletions ext/zstdruby/streaming_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ no_compress(struct streaming_compress_t* sc, ZSTD_EndDirective endOp)
do {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };

size_t const ret = zstd_compress(sc->ctx, &output, &input, endOp);
size_t const ret = zstd_compress(sc->ctx, &output, &input, endOp, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "flush error error code: %s", ZSTD_getErrorName(ret));
}
Expand All @@ -130,7 +130,7 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
VALUE result = rb_str_new(0, 0);
while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue);
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
Expand All @@ -157,7 +157,7 @@ rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)

while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue);
size_t const ret = zstd_compress(sc->ctx, &output, &input, ZSTD_e_continue, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
Expand Down
2 changes: 1 addition & 1 deletion ext/zstdruby/streaming_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src)
VALUE result = rb_str_new(0, 0);
while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
size_t const ret = ZSTD_decompressStream(sd->dctx, &output, &input);
size_t const ret = zstd_decompress(sd->dctx, &output, &input, false);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret));
}
Expand Down
65 changes: 28 additions & 37 deletions ext/zstdruby/zstdruby.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ static VALUE rb_compress(int argc, VALUE *argv, VALUE self)
char* input_data = RSTRING_PTR(input_value);
size_t input_size = RSTRING_LEN(input_value);
ZSTD_inBuffer input = { input_data, input_size, 0 };
// ZSTD_compressBound causes SEGV under multi-thread
size_t max_compressed_size = ZSTD_compressBound(input_size);
VALUE buf = rb_str_new(NULL, max_compressed_size);
char* output_data = RSTRING_PTR(buf);
ZSTD_outBuffer output = { (void*)output_data, max_compressed_size, 0 };

size_t const ret = zstd_compress(ctx, &output, &input, ZSTD_e_end);
size_t const ret = zstd_compress(ctx, &output, &input, ZSTD_e_end, true);
if (ZSTD_isError(ret)) {
ZSTD_freeCCtx(ctx);
rb_raise(rb_eRuntimeError, "%s: %s", "compress failed", ZSTD_getErrorName(ret));
Expand Down Expand Up @@ -87,19 +88,8 @@ static VALUE rb_compress_using_dict(int argc, VALUE *argv, VALUE self)
}


static VALUE decompress_buffered(const char* input_data, size_t input_size)
static VALUE decompress_buffered(ZSTD_DCtx* dctx, const char* input_data, size_t input_size)
{
ZSTD_DStream* const dstream = ZSTD_createDStream();
if (dstream == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDStream failed");
}

size_t initResult = ZSTD_initDStream(dstream);
if (ZSTD_isError(initResult)) {
ZSTD_freeDStream(dstream);
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_initDStream failed", ZSTD_getErrorName(initResult));
}

VALUE output_string = rb_str_new(NULL, 0);
ZSTD_outBuffer output = { NULL, 0, 0 };

Expand All @@ -109,15 +99,14 @@ static VALUE decompress_buffered(const char* input_data, size_t input_size)
rb_str_resize(output_string, output.size);
output.dst = RSTRING_PTR(output_string);

size_t readHint = ZSTD_decompressStream(dstream, &output, &input);
if (ZSTD_isError(readHint)) {
ZSTD_freeDStream(dstream);
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(readHint));
size_t ret = zstd_decompress(dctx, &output, &input, true);
if (ZSTD_isError(ret)) {
ZSTD_freeDCtx(dctx);
rb_raise(rb_eRuntimeError, "%s: %s", "ZSTD_decompressStream failed", ZSTD_getErrorName(ret));
}
}

ZSTD_freeDStream(dstream);
rb_str_resize(output_string, output.pos);
ZSTD_freeDCtx(dctx);
return output_string;
}

Expand All @@ -129,6 +118,11 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
StringValue(input_value);
char* input_data = RSTRING_PTR(input_value);
size_t input_size = RSTRING_LEN(input_value);
ZSTD_DCtx* const dctx = ZSTD_createDCtx();
if (dctx == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
}
set_decompress_params(dctx, kwargs);

unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
Expand All @@ -137,15 +131,9 @@ static VALUE rb_decompress(int argc, VALUE *argv, VALUE self)
// ZSTD_decompressStream may be called multiple times when ZSTD_CONTENTSIZE_UNKNOWN, causing slowness.
// Therefore, we will not standardize on ZSTD_decompressStream
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
return decompress_buffered(input_data, input_size);
return decompress_buffered(dctx, input_data, input_size);
}

ZSTD_DCtx* const dctx = ZSTD_createDCtx();
if (dctx == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
}
set_decompress_params(dctx, kwargs);

VALUE output = rb_str_new(NULL, uncompressed_size);
char* output_data = RSTRING_PTR(output);

Expand All @@ -167,35 +155,38 @@ static VALUE rb_decompress_using_dict(int argc, VALUE *argv, VALUE self)
StringValue(input_value);
char* input_data = RSTRING_PTR(input_value);
size_t input_size = RSTRING_LEN(input_value);
unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
}
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
return decompress_buffered(input_data, input_size);
}
VALUE output = rb_str_new(NULL, uncompressed_size);
char* output_data = RSTRING_PTR(output);

char* dict_buffer = RSTRING_PTR(dict);
size_t dict_size = RSTRING_LEN(dict);
ZSTD_DDict* const ddict = ZSTD_createDDict(dict_buffer, dict_size);
if (ddict == NULL) {
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDDict failed");
}

unsigned const expected_dict_id = ZSTD_getDictID_fromDDict(ddict);
unsigned const actual_dict_id = ZSTD_getDictID_fromFrame(input_data, input_size);
if (expected_dict_id != actual_dict_id) {
ZSTD_freeDDict(ddict);
rb_raise(rb_eRuntimeError, "%s: %s", "DictID mismatch", ZSTD_getErrorName(uncompressed_size));
rb_raise(rb_eRuntimeError, "DictID mismatch");
}

ZSTD_DCtx* const ctx = ZSTD_createDCtx();
if (ctx == NULL) {
ZSTD_freeDDict(ddict);
rb_raise(rb_eRuntimeError, "%s", "ZSTD_createDCtx failed");
}

unsigned long long const uncompressed_size = ZSTD_getFrameContentSize(input_data, input_size);
if (uncompressed_size == ZSTD_CONTENTSIZE_ERROR) {
ZSTD_freeDDict(ddict);
ZSTD_freeDCtx(ctx);
rb_raise(rb_eRuntimeError, "%s: %s", "not compressed by zstd", ZSTD_getErrorName(uncompressed_size));
}
if (uncompressed_size == ZSTD_CONTENTSIZE_UNKNOWN) {
return decompress_buffered(ctx, input_data, input_size);
}

VALUE output = rb_str_new(NULL, uncompressed_size);
char* output_data = RSTRING_PTR(output);
size_t const decompress_size = ZSTD_decompress_usingDDict(ctx, output_data, uncompressed_size, input_data, input_size, ddict);
if (ZSTD_isError(decompress_size)) {
ZSTD_freeDDict(ddict);
Expand Down

0 comments on commit cc225b4

Please sign in to comment.