Skip to content

Commit

Permalink
feat: add StreamWriter and StreamReader
Browse files Browse the repository at this point in the history
  • Loading branch information
SpringMT committed Apr 3, 2024
1 parent 5395b01 commit 5624ae4
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 40 deletions.
60 changes: 46 additions & 14 deletions ext/zstdruby/streaming_compress.c
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)

struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);

const char* output_data = RSTRING_PTR(sc->buf);
VALUE result = rb_str_new(0, 0);
while (input.pos < input.size) {
Expand All @@ -139,27 +140,54 @@ rb_streaming_compress_compress(VALUE obj, VALUE src)
}

static VALUE
rb_streaming_compress_addstr(VALUE obj, VALUE src)
rb_streaming_compress_write(int argc, VALUE *argv, VALUE obj)
{
StringValue(src);
const char* input_data = RSTRING_PTR(src);
size_t input_size = RSTRING_LEN(src);
ZSTD_inBuffer input = { input_data, input_size, 0 };

size_t total = 0;
VALUE result = rb_str_new(0, 0);
struct streaming_compress_t* sc;
TypedData_Get_Struct(obj, struct streaming_compress_t, &streaming_compress_type, sc);
const char* output_data = RSTRING_PTR(sc->buf);

while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };
size_t const result = ZSTD_compressStream2(sc->ctx, &output, &input, ZSTD_e_continue);
if (ZSTD_isError(result)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(result));
ZSTD_outBuffer output = { (void*)output_data, sc->buf_size, 0 };

while (argc-- > 0) {
VALUE str = *argv++;
StringValue(str);
const char* input_data = RSTRING_PTR(str);
size_t input_size = RSTRING_LEN(str);
ZSTD_inBuffer input = { input_data, input_size, 0 };

while (input.pos < input.size) {
size_t const ret = ZSTD_compressStream2(sc->ctx, &output, &input, ZSTD_e_continue);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
}
total += RSTRING_LEN(str);
}
}
return obj;
return SIZET2NUM(total);
}

/*
* Document-method: <<
* Same as IO.
*/
#define rb_streaming_compress_addstr rb_io_addstr
/*
* Document-method: printf
* Same as IO.
*/
#define rb_streaming_compress_printf rb_io_printf
/*
* Document-method: print
* Same as IO.
*/
#define rb_streaming_compress_print rb_io_print
/*
* Document-method: puts
* Same as IO.
*/
#define rb_streaming_compress_puts rb_io_puts

static VALUE
rb_streaming_compress_flush(VALUE obj)
{
Expand All @@ -186,12 +214,16 @@ zstd_ruby_streaming_compress_init(void)
rb_define_alloc_func(cStreamingCompress, rb_streaming_compress_allocate);
rb_define_method(cStreamingCompress, "initialize", rb_streaming_compress_initialize, -1);
rb_define_method(cStreamingCompress, "compress", rb_streaming_compress_compress, 1);
rb_define_method(cStreamingCompress, "write", rb_streaming_compress_write, -1);
rb_define_method(cStreamingCompress, "<<", rb_streaming_compress_addstr, 1);
rb_define_method(cStreamingCompress, "printf", rb_streaming_compress_printf, -1);
rb_define_method(cStreamingCompress, "print", rb_streaming_compress_print, -1);
rb_define_method(cStreamingCompress, "puts", rb_streaming_compress_puts, -1);

rb_define_method(cStreamingCompress, "flush", rb_streaming_compress_flush, 0);
rb_define_method(cStreamingCompress, "finish", rb_streaming_compress_finish, 0);

rb_define_const(cStreamingCompress, "CONTINUE", INT2FIX(ZSTD_e_continue));
rb_define_const(cStreamingCompress, "FLUSH", INT2FIX(ZSTD_e_flush));
rb_define_const(cStreamingCompress, "END", INT2FIX(ZSTD_e_end));
}

26 changes: 1 addition & 25 deletions ext/zstdruby/streaming_decompress.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,35 +101,13 @@ rb_streaming_decompress_decompress(VALUE obj, VALUE src)
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
size_t const ret = ZSTD_decompressStream(sd->ctx, &output, &input);
if (ZSTD_isError(ret)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(ret));
rb_raise(rb_eRuntimeError, "decompress error error code: %s", ZSTD_getErrorName(ret));
}
rb_str_cat(result, output.dst, output.pos);
}
return result;
}

static VALUE
rb_streaming_decompress_addstr(VALUE obj, VALUE src)
{
StringValue(src);
const char* input_data = RSTRING_PTR(src);
size_t input_size = RSTRING_LEN(src);
ZSTD_inBuffer input = { input_data, input_size, 0 };

struct streaming_decompress_t* sd;
TypedData_Get_Struct(obj, struct streaming_decompress_t, &streaming_decompress_type, sd);
const char* output_data = RSTRING_PTR(sd->buf);

while (input.pos < input.size) {
ZSTD_outBuffer output = { (void*)output_data, sd->buf_size, 0 };
size_t const result = ZSTD_decompressStream(sd->ctx, &output, &input);
if (ZSTD_isError(result)) {
rb_raise(rb_eRuntimeError, "compress error error code: %s", ZSTD_getErrorName(result));
}
}
return obj;
}

extern VALUE rb_mZstd, cStreamingDecompress;
void
zstd_ruby_streaming_decompress_init(void)
Expand All @@ -138,6 +116,4 @@ zstd_ruby_streaming_decompress_init(void)
rb_define_alloc_func(cStreamingDecompress, rb_streaming_decompress_allocate);
rb_define_method(cStreamingDecompress, "initialize", rb_streaming_decompress_initialize, 0);
rb_define_method(cStreamingDecompress, "decompress", rb_streaming_decompress_decompress, 1);
rb_define_method(cStreamingDecompress, "<<", rb_streaming_decompress_addstr, 1);
}

22 changes: 22 additions & 0 deletions lib/zstd-ruby/stream_reader.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
module Zstd
# @todo Exprimental
class StreamReader
def initialize(io)
@io = io
@stream = Zstd::StreamingDecompress.new
end

def read(length)
if @io.eof?
raise StandardError, "EOF"
end
data = @io.read(length)
@stream.decompress(data)
end

def close
@io.write(@stream.finish)
@io.close
end
end
end
23 changes: 23 additions & 0 deletions lib/zstd-ruby/stream_writer.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
module Zstd
# @todo Exprimental
class StreamWriter
def initialize(io, level: nil)
@io = io
@stream = Zstd::StreamingCompress.new(level)
end

def write(*data)
@stream.write(*data)
@io.write(@stream.flush)
end

def finish
@io.write(@stream.finish)
end

def close
@io.write(@stream.finish)
@io.close
end
end
end
23 changes: 23 additions & 0 deletions spec/zstd-ruby-stream_reader_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
require "spec_helper"
require 'zstd-ruby'
require 'zstd-ruby/stream_writer'
require 'zstd-ruby/stream_reader'
require 'pry'

RSpec.describe Zstd::StreamReader do
describe 'read' do
it 'shoud work' do
io = StringIO.new
writer = Zstd::StreamWriter.new(io)
writer.write("abc")
writer.write("def")
writer.finish
io.rewind

reader = Zstd::StreamReader.new(io)
expect(reader.read(10)).to eq('a')
expect(reader.read(10)).to eq('bcdef')
expect(reader.read(10)).to eq('')
end
end
end
17 changes: 17 additions & 0 deletions spec/zstd-ruby-stream_writer_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
require "spec_helper"
require 'zstd-ruby'
require 'zstd-ruby/stream_writer'

RSpec.describe Zstd::StreamWriter do
describe 'write' do
it 'shoud work' do
io = StringIO.new
stream = Zstd::StreamWriter.new(io)
stream.write("abc")
stream.write("def")
stream.finish
io.rewind
expect(Zstd.decompress(io.read)).to eq('abcdef')
end
end
end
1 change: 0 additions & 1 deletion spec/zstd-ruby-streaming-compress_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,3 @@
end
end
end

7 changes: 7 additions & 0 deletions spec/zstd-ruby_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def to_str
expect(decompressed).to eq('')
end

it 'should work for non-ascii string' do
compressed = Zstd.compress('あああ')
expect(compressed.bytesize).to eq(18)
decompressed = Zstd.decompress(compressed)
expect(decompressed.force_encoding('UTF-8')).to eq('あああ')
end

it 'should raise exception with unsupported object' do
expect { Zstd.decompress(Object.new) }.to raise_error(TypeError)
end
Expand Down
1 change: 1 addition & 0 deletions zstd-ruby.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ Gem::Specification.new do |spec|
spec.add_development_dependency "rake", "~> 13.0"
spec.add_development_dependency "rake-compiler", '~> 1'
spec.add_development_dependency "rspec", "~> 3.0"
spec.add_development_dependency "pry"
end

0 comments on commit 5624ae4

Please sign in to comment.