diff options
Diffstat (limited to 'src/serialization.cpp')
-rw-r--r-- | src/serialization.cpp | 135 |
1 files changed, 129 insertions, 6 deletions
diff --git a/src/serialization.cpp b/src/serialization.cpp index 310604f54..b6ce3b37f 100644 --- a/src/serialization.cpp +++ b/src/serialization.cpp @@ -21,7 +21,8 @@ with this program; if not, write to the Free Software Foundation, Inc., #include "util/serialize.h" -#include "zlib.h" +#include <zlib.h> +#include <zstd.h> /* report a zlib or i/o error */ void zerr(int ret) @@ -197,27 +198,133 @@ void decompressZlib(std::istream &is, std::ostream &os, size_t limit) inflateEnd(&z); } -void compress(const SharedBuffer<u8> &data, std::ostream &os, u8 version) +struct ZSTD_Deleter { + void operator() (ZSTD_CStream* cstream) { + ZSTD_freeCStream(cstream); + } + + void operator() (ZSTD_DStream* dstream) { + ZSTD_freeDStream(dstream); + } +}; + +void compressZstd(const u8 *data, size_t data_size, std::ostream &os, int level) +{ + // reusing the context is recommended for performance + // it will destroyed when the thread ends + thread_local std::unique_ptr<ZSTD_CStream, ZSTD_Deleter> stream(ZSTD_createCStream()); + + ZSTD_initCStream(stream.get(), level); + + const size_t bufsize = 16384; + char output_buffer[bufsize]; + + ZSTD_inBuffer input = { data, data_size, 0 }; + ZSTD_outBuffer output = { output_buffer, bufsize, 0 }; + + while (input.pos < input.size) { + size_t ret = ZSTD_compressStream(stream.get(), &output, &input); + if (ZSTD_isError(ret)) { + dstream << ZSTD_getErrorName(ret) << std::endl; + throw SerializationError("compressZstd: failed"); + } + if (output.pos) { + os.write(output_buffer, output.pos); + output.pos = 0; + } + } + + size_t ret; + do { + ret = ZSTD_endStream(stream.get(), &output); + if (ZSTD_isError(ret)) { + dstream << ZSTD_getErrorName(ret) << std::endl; + throw SerializationError("compressZstd: failed"); + } + if (output.pos) { + os.write(output_buffer, output.pos); + output.pos = 0; + } + } while (ret != 0); + +} + +void compressZstd(const std::string &data, std::ostream &os, int level) { + compressZstd((u8*)data.c_str(), data.size(), os, level); +} + +void decompressZstd(std::istream &is, std::ostream &os) +{ + // reusing the context is recommended for performance + // it will destroyed when the thread ends + thread_local std::unique_ptr<ZSTD_DStream, ZSTD_Deleter> stream(ZSTD_createDStream()); + + ZSTD_initDStream(stream.get()); + + const size_t bufsize = 16384; + char output_buffer[bufsize]; + char input_buffer[bufsize]; + + ZSTD_outBuffer output = { output_buffer, bufsize, 0 }; + ZSTD_inBuffer input = { input_buffer, 0, 0 }; + size_t ret; + do + { + if (input.size == input.pos) { + is.read(input_buffer, bufsize); + input.size = is.gcount(); + input.pos = 0; + } + + ret = ZSTD_decompressStream(stream.get(), &output, &input); + if (ZSTD_isError(ret)) { + dstream << ZSTD_getErrorName(ret) << std::endl; + throw SerializationError("decompressZstd: failed"); + } + if (output.pos) { + os.write(output_buffer, output.pos); + output.pos = 0; + } + } while (ret != 0); + + // Unget all the data that ZSTD_decompressStream didn't take + is.clear(); // Just in case EOF is set + for (u32 i = 0; i < input.size - input.pos; i++) { + is.unget(); + if (is.fail() || is.bad()) + throw SerializationError("decompressZstd: unget failed"); + } +} + +void compress(u8 *data, u32 size, std::ostream &os, u8 version, int level) +{ + if(version >= 29) + { + // map the zlib levels [0,9] to [1,10]. -1 becomes 0 which indicates the default (currently 3) + compressZstd(data, size, os, level + 1); + return; + } + if(version >= 11) { - compressZlib(*data ,data.getSize(), os); + compressZlib(data, size, os, level); return; } - if(data.getSize() == 0) + if(size == 0) return; // Write length (u32) u8 tmp[4]; - writeU32(tmp, data.getSize()); + writeU32(tmp, size); os.write((char*)tmp, 4); // We will be writing 8-bit pairs of more_count and byte u8 more_count = 0; u8 current_byte = data[0]; - for(u32 i=1; i<data.getSize(); i++) + for(u32 i=1; i<size; i++) { if( data[i] != current_byte @@ -240,8 +347,24 @@ void compress(const SharedBuffer<u8> &data, std::ostream &os, u8 version) os.write((char*)¤t_byte, 1); } +void compress(const SharedBuffer<u8> &data, std::ostream &os, u8 version, int level) +{ + compress(*data, data.getSize(), os, version, level); +} + +void compress(const std::string &data, std::ostream &os, u8 version, int level) +{ + compress((u8*)data.c_str(), data.size(), os, version, level); +} + void decompress(std::istream &is, std::ostream &os, u8 version) { + if(version >= 29) + { + decompressZstd(is, os); + return; + } + if(version >= 11) { decompressZlib(is, os); |