--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include <atomic>
+#include <stdexcept>
+#include <string>
+
+namespace pzstd {
+
+// Coordinates graceful shutdown of the pzstd pipeline
+class ErrorHolder {
+ std::atomic<bool> error_;
+ std::string message_;
+
+ public:
+ ErrorHolder() : error_(false) {}
+
+ bool hasError() noexcept {
+ return error_.load();
+ }
+
+ void setError(std::string message) noexcept {
+ // Given multiple possibly concurrent calls, exactly one will ever succeed.
+ bool expected = false;
+ if (error_.compare_exchange_strong(expected, true)) {
+ message_ = std::move(message);
+ }
+ }
+
+ bool check(bool predicate, std::string message) noexcept {
+ if (!predicate) {
+ setError(std::move(message));
+ }
+ return !hasError();
+ }
+
+ std::string getError() noexcept {
+ error_.store(false);
+ return std::move(message_);
+ }
+
+ ~ErrorHolder() {
+ if (hasError()) {
+ throw std::logic_error(message_);
+ }
+ }
+};
+}
--- /dev/null
+# ##########################################################################
+# Copyright (c) 2016-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant
+# of patent rights can be found in the PATENTS file in the same directory.
+# ##########################################################################
+
+ZSTDDIR = ../../lib
+PROGDIR = ../../programs
+
+CPPFLAGS = -I$(ZSTDDIR) -I$(ZSTDDIR)/common -I$(ZSTDDIR)/dictBuilder -I$(PROGDIR) -I.
+CFLAGS ?= -O3
+CFLAGS += -Wall -Wextra -Wcast-qual -Wcast-align -Wstrict-aliasing=1 \
+ -Wswitch-enum -Wdeclaration-after-statement -Wstrict-prototypes -Wundef \
+ -std=c++11
+CFLAGS += $(MOREFLAGS)
+FLAGS = $(CPPFLAGS) $(CFLAGS) $(LDFLAGS)
+
+
+ZSTDCOMMON_FILES := $(ZSTDDIR)/common/*.c
+ZSTDCOMP_FILES := $(ZSTDDIR)/compress/zstd_compress.c $(ZSTDDIR)/compress/fse_compress.c $(ZSTDDIR)/compress/huf_compress.c
+ZSTDDECOMP_FILES := $(ZSTDDIR)/decompress/huf_decompress.c
+ZSTD_FILES := $(ZSTDDECOMP_FILES) $(ZSTDCOMMON_FILES) $(ZSTDCOMP_FILES)
+
+
+# Define *.exe as extension for Windows systems
+ifneq (,$(filter Windows%,$(OS)))
+EXT =.exe
+else
+EXT =
+endif
+
+.PHONY: default all test clean
+
+default: pzstd
+
+all: pzstd
+
+
+libzstd.a: $(ZSTD_FILES)
+ $(MAKE) -C $(ZSTDDIR) libzstd
+ @cp $(ZSTDDIR)/libzstd.a .
+
+
+Pzstd.o: Pzstd.h Pzstd.cpp ErrorHolder.h utils/*.h
+ $(CXX) $(FLAGS) -c Pzstd.cpp -o $@
+
+SkippableFrame.o: SkippableFrame.h SkippableFrame.cpp utils/*.h
+ $(CXX) $(FLAGS) -c SkippableFrame.cpp -o $@
+
+Options.o: Options.h Options.cpp
+ $(CXX) $(FLAGS) -c Options.cpp -o $@
+
+main.o: main.cpp *.h utils/*.h
+ $(CXX) $(FLAGS) -c main.cpp -o $@
+
+pzstd: libzstd.a Pzstd.o SkippableFrame.o Options.o main.o
+ $(CXX) $(FLAGS) $^ -o $@$(EXT)
+
+test: libzstd.a Pzstd.o Options.o SkippableFrame.o
+ $(MAKE) -C utils/test test
+ $(MAKE) -C test test
+
+clean:
+ $(MAKE) -C $(ZSTDDIR) clean
+ $(MAKE) -C utils/test clean
+ $(MAKE) -C test clean
+ @$(RM) libzstd.a *.o pzstd$(EXT)
+ @echo Cleaning completed
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "Options.h"
+
+#include <cstdio>
+
+namespace pzstd {
+
+namespace {
+unsigned parseUnsigned(const char* arg) {
+ unsigned result = 0;
+ while (*arg >= '0' && *arg <= '9') {
+ result *= 10;
+ result += *arg - '0';
+ ++arg;
+ }
+ return result;
+}
+
+const std::string zstdExtension = ".zst";
+constexpr unsigned defaultCompressionLevel = 3;
+constexpr unsigned maxNonUltraCompressionLevel = 19;
+
+void usage() {
+ std::fprintf(stderr, "Usage:\n");
+ std::fprintf(stderr, "\tpzstd [args] FILE\n");
+ std::fprintf(stderr, "Parallel ZSTD options:\n");
+ std::fprintf(stderr, "\t-n/--num-threads #: Number of threads to spawn\n");
+ std::fprintf(stderr, "\t-p/--pzstd-headers: Write pzstd headers to enable parallel decompression\n");
+
+ std::fprintf(stderr, "ZSTD options:\n");
+ std::fprintf(stderr, "\t-u/--ultra : enable levels beyond %i, up to %i (requires more memory)\n", maxNonUltraCompressionLevel, ZSTD_maxCLevel());
+ std::fprintf(stderr, "\t-h/--help : display help and exit\n");
+ std::fprintf(stderr, "\t-V/--version : display version number and exit\n");
+ std::fprintf(stderr, "\t-d/--decompress : decompression\n");
+ std::fprintf(stderr, "\t-f/--force : overwrite output\n");
+ std::fprintf(stderr, "\t-o/--output file : result stored into `file`\n");
+ std::fprintf(stderr, "\t-c/--stdout : write output to standard output\n");
+ std::fprintf(stderr, "\t-# : # compression level (1-%d, default:%d)\n", maxNonUltraCompressionLevel, defaultCompressionLevel);
+}
+} // anonymous namespace
+
+Options::Options()
+ : numThreads(0),
+ maxWindowLog(23),
+ compressionLevel(defaultCompressionLevel),
+ decompress(false),
+ overwrite(false),
+ pzstdHeaders(false) {}
+
+bool Options::parse(int argc, const char** argv) {
+ bool ultra = false;
+ for (int i = 1; i < argc; ++i) {
+ const char* arg = argv[i];
+ // Arguments with a short option
+ char option = 0;
+ if (!std::strcmp(arg, "--num-threads")) {
+ option = 'n';
+ } else if (!std::strcmp(arg, "--pzstd-headers")) {
+ option = 'p';
+ } else if (!std::strcmp(arg, "--ultra")) {
+ option = 'u';
+ } else if (!std::strcmp(arg, "--version")) {
+ option = 'V';
+ } else if (!std::strcmp(arg, "--help")) {
+ option = 'h';
+ } else if (!std::strcmp(arg, "--decompress")) {
+ option = 'd';
+ } else if (!std::strcmp(arg, "--force")) {
+ option = 'f';
+ } else if (!std::strcmp(arg, "--output")) {
+ option = 'o';
+ } else if (!std::strcmp(arg, "--stdout")) {
+ option = 'c';
+ }else if (arg[0] == '-' && arg[1] != 0) {
+ // Parse the compression level or short option
+ if (arg[1] >= '0' && arg[1] <= '9') {
+ compressionLevel = parseUnsigned(arg + 1);
+ continue;
+ }
+ option = arg[1];
+ } else if (inputFile.empty()) {
+ inputFile = arg;
+ continue;
+ } else {
+ std::fprintf(stderr, "Invalid argument: %s.\n", arg);
+ return false;
+ }
+
+ switch (option) {
+ case 'n':
+ if (++i == argc) {
+ std::fprintf(stderr, "Invalid argument: -n requires an argument.\n");
+ return false;
+ }
+ numThreads = parseUnsigned(argv[i]);
+ if (numThreads == 0) {
+ std::fprintf(stderr, "Invalid argument: # of threads must be > 0.\n");
+ }
+ break;
+ case 'p':
+ pzstdHeaders = true;
+ break;
+ case 'u':
+ ultra = true;
+ maxWindowLog = 0;
+ break;
+ case 'V':
+ std::fprintf(stderr, "ZSTD version: %s.\n", ZSTD_VERSION_STRING);
+ return false;
+ case 'h':
+ usage();
+ return false;
+ case 'd':
+ decompress = true;
+ break;
+ case 'f':
+ overwrite = true;
+ break;
+ case 'o':
+ if (++i == argc) {
+ std::fprintf(stderr, "Invalid argument: -o requires an argument.\n");
+ return false;
+ }
+ outputFile = argv[i];
+ break;
+ case 'c':
+ outputFile = '-';
+ break;
+ default:
+ std::fprintf(stderr, "Invalid argument: %s.\n", arg);
+ return false;
+ }
+ }
+ // Determine input file if not specified
+ if (inputFile.empty()) {
+ inputFile = "-";
+ }
+ // Determine output file if not specified
+ if (outputFile.empty()) {
+ if (inputFile == "-") {
+ std::fprintf(
+ stderr,
+ "Invalid arguments: Reading from stdin, but -o not provided.\n");
+ return false;
+ }
+ // Attempt to add/remove zstd extension from the input file
+ if (decompress) {
+ int stemSize = inputFile.size() - zstdExtension.size();
+ if (stemSize > 0 && inputFile.substr(stemSize) == zstdExtension) {
+ outputFile = inputFile.substr(0, stemSize);
+ } else {
+ std::fprintf(
+ stderr, "Invalid argument: Unable to determine output file.\n");
+ return false;
+ }
+ } else {
+ outputFile = inputFile + zstdExtension;
+ }
+ }
+ // Check compression level
+ {
+ unsigned maxCLevel = ultra ? ZSTD_maxCLevel() : maxNonUltraCompressionLevel;
+ if (compressionLevel > maxCLevel) {
+ std::fprintf(
+ stderr, "Invalid compression level %u.\n", compressionLevel);
+ }
+ }
+ // Check that numThreads is set
+ if (numThreads == 0) {
+ std::fprintf(stderr, "Invalid arguments: # of threads not specified.\n");
+ return false;
+ }
+ return true;
+}
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#define ZSTD_STATIC_LINKING_ONLY
+#include "zstd.h"
+#undef ZSTD_STATIC_LINKING_ONLY
+
+#include <cstdint>
+#include <string>
+
+namespace pzstd {
+
+struct Options {
+ unsigned numThreads;
+ unsigned maxWindowLog;
+ unsigned compressionLevel;
+ bool decompress;
+ std::string inputFile;
+ std::string outputFile;
+ bool overwrite;
+ bool pzstdHeaders;
+
+ Options();
+ Options(
+ unsigned numThreads,
+ unsigned maxWindowLog,
+ unsigned compressionLevel,
+ bool decompress,
+ const std::string& inputFile,
+ const std::string& outputFile,
+ bool overwrite,
+ bool pzstdHeaders)
+ : numThreads(numThreads),
+ maxWindowLog(maxWindowLog),
+ compressionLevel(compressionLevel),
+ decompress(decompress),
+ inputFile(inputFile),
+ outputFile(outputFile),
+ overwrite(overwrite),
+ pzstdHeaders(pzstdHeaders) {}
+
+ bool parse(int argc, const char** argv);
+
+ ZSTD_parameters determineParameters() const {
+ ZSTD_parameters params = ZSTD_getParams(compressionLevel, 0, 0);
+ if (maxWindowLog != 0 && params.cParams.windowLog > maxWindowLog) {
+ params.cParams.windowLog = maxWindowLog;
+ params.cParams = ZSTD_adjustCParams(params.cParams, 0, 0);
+ }
+ return params;
+ }
+};
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "Pzstd.h"
+#include "SkippableFrame.h"
+#include "utils/FileSystem.h"
+#include "utils/Range.h"
+#include "utils/ScopeGuard.h"
+#include "utils/ThreadPool.h"
+#include "utils/WorkQueue.h"
+
+#include <cstddef>
+#include <cstdio>
+#include <memory>
+#include <string>
+
+namespace pzstd {
+
+namespace {
+#ifdef _WIN32
+const std::string nullOutput = "nul";
+#else
+const std::string nullOutput = "/dev/null";
+#endif
+}
+
+using std::size_t;
+
+size_t pzstdMain(const Options& options, ErrorHolder& errorHolder) {
+ // Open the input file and attempt to determine its size
+ FILE* inputFd = stdin;
+ size_t inputSize = 0;
+ if (options.inputFile != "-") {
+ inputFd = std::fopen(options.inputFile.c_str(), "rb");
+ if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) {
+ return 0;
+ }
+ std::error_code ec;
+ inputSize = file_size(options.inputFile, ec);
+ if (ec) {
+ inputSize = 0;
+ }
+ }
+ auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); });
+
+ // Check if the output file exists and then open it
+ FILE* outputFd = stdout;
+ if (options.outputFile != "-") {
+ if (!options.overwrite && options.outputFile != nullOutput) {
+ outputFd = std::fopen(options.outputFile.c_str(), "rb");
+ if (!errorHolder.check(outputFd == nullptr, "Output file exists")) {
+ return 0;
+ }
+ }
+ outputFd = std::fopen(options.outputFile.c_str(), "wb");
+ if (!errorHolder.check(
+ outputFd != nullptr, "Failed to open output file")) {
+ return 0;
+ }
+ }
+ auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); });
+
+ // WorkQueue outlives ThreadPool so in the case of error we are certain
+ // we don't accidently try to call push() on it after it is destroyed.
+ WorkQueue<std::shared_ptr<BufferWorkQueue>> outs;
+ size_t bytesWritten;
+ {
+ // Initialize the thread pool with numThreads
+ ThreadPool executor(options.numThreads);
+ if (!options.decompress) {
+ // Add a job that reads the input and starts all the compression jobs
+ executor.add(
+ [&errorHolder, &outs, &executor, inputFd, inputSize, &options] {
+ asyncCompressChunks(
+ errorHolder,
+ outs,
+ executor,
+ inputFd,
+ inputSize,
+ options.numThreads,
+ options.determineParameters());
+ });
+ // Start writing
+ bytesWritten =
+ writeFile(errorHolder, outs, outputFd, options.pzstdHeaders);
+ } else {
+ // Add a job that reads the input and starts all the decompression jobs
+ executor.add([&errorHolder, &outs, &executor, inputFd] {
+ asyncDecompressFrames(errorHolder, outs, executor, inputFd);
+ });
+ // Start writing
+ bytesWritten = writeFile(
+ errorHolder, outs, outputFd, /* writeSkippableFrames */ false);
+ }
+ }
+ return bytesWritten;
+}
+
+/// Construct a `ZSTD_inBuffer` that points to the data in `buffer`.
+static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) {
+ return ZSTD_inBuffer{buffer.data(), buffer.size(), 0};
+}
+
+/**
+ * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by
+ * `inBuffer.pos`.
+ */
+void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) {
+ auto pos = inBuffer.pos;
+ inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos;
+ inBuffer.size -= pos;
+ inBuffer.pos = 0;
+ return buffer.advance(pos);
+}
+
+/// Construct a `ZSTD_outBuffer` that points to the data in `buffer`.
+static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) {
+ return ZSTD_outBuffer{buffer.data(), buffer.size(), 0};
+}
+
+/**
+ * Split `buffer` and advance `outBuffer` by the amount of data written, as
+ * indicated by `outBuffer.pos`.
+ */
+Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
+ auto pos = outBuffer.pos;
+ outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos;
+ outBuffer.size -= pos;
+ outBuffer.pos = 0;
+ return buffer.splitAt(pos);
+}
+
+/**
+ * Stream chunks of input from `in`, compress it, and stream it out to `out`.
+ *
+ * @param errorHolder Used to report errors and check if an error occured
+ * @param in Queue that we `pop()` input buffers from
+ * @param out Queue that we `push()` compressed output buffers to
+ * @param maxInputSize An upper bound on the size of the input
+ * @param parameters The zstd parameters to use for compression
+ */
+static void compress(
+ ErrorHolder& errorHolder,
+ std::shared_ptr<BufferWorkQueue> in,
+ std::shared_ptr<BufferWorkQueue> out,
+ size_t maxInputSize,
+ ZSTD_parameters parameters) {
+ auto guard = makeScopeGuard([&] { out->finish(); });
+ // Initialize the CCtx
+ std::unique_ptr<ZSTD_CStream, size_t (&)(ZSTD_CStream*)> ctx(
+ ZSTD_createCStream(), ZSTD_freeCStream);
+ if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
+ return;
+ }
+ {
+ auto err = ZSTD_initCStream_advanced(ctx.get(), nullptr, 0, parameters, 0);
+ if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
+ return;
+ }
+ }
+
+ // Allocate space for the result
+ auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize));
+ auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
+ {
+ Buffer inBuffer;
+ // Read a buffer in from the input queue
+ while (in->pop(inBuffer) && !errorHolder.hasError()) {
+ auto zstdInBuffer = makeZstdInBuffer(inBuffer);
+ // Compress the whole buffer and send it to the output queue
+ while (!inBuffer.empty() && !errorHolder.hasError()) {
+ if (!errorHolder.check(
+ !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
+ return;
+ }
+ // Compress
+ auto err =
+ ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
+ if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
+ return;
+ }
+ // Split the compressed data off outBuffer and pass to the output queue
+ out->push(split(outBuffer, zstdOutBuffer));
+ // Forget about the data we already compressed
+ advance(inBuffer, zstdInBuffer);
+ }
+ }
+ }
+ // Write the epilog
+ size_t bytesLeft;
+ do {
+ if (!errorHolder.check(
+ !outBuffer.empty(), "ZSTD_compressBound() was too small")) {
+ return;
+ }
+ bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer);
+ if (!errorHolder.check(
+ !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) {
+ return;
+ }
+ out->push(split(outBuffer, zstdOutBuffer));
+ } while (bytesLeft != 0 && !errorHolder.hasError());
+}
+
+/**
+ * Calculates how large each independently compressed frame should be.
+ *
+ * @param size The size of the source if known, 0 otherwise
+ * @param numThreads The number of threads available to run compression jobs on
+ * @param params The zstd parameters to be used for compression
+ */
+static size_t
+calculateStep(size_t size, size_t numThreads, const ZSTD_parameters& params) {
+ size_t step = 1ul << (params.cParams.windowLog + 2);
+ // If file size is known, see if a smaller step will spread work more evenly
+ if (size != 0) {
+ size_t newStep = size / numThreads;
+ if (newStep != 0) {
+ step = std::min(step, newStep);
+ }
+ }
+ return step;
+}
+
+namespace {
+enum class FileStatus { Continue, Done, Error };
+} // anonymous namespace
+
+/**
+ * Reads `size` data in chunks of `chunkSize` and puts it into `queue`.
+ * Will read less if an error or EOF occurs.
+ * Returns the status of the file after all of the reads have occurred.
+ */
+static FileStatus
+readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd) {
+ Buffer buffer(size);
+ while (!buffer.empty()) {
+ auto bytesRead =
+ std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd);
+ queue.push(buffer.splitAt(bytesRead));
+ if (std::feof(fd)) {
+ return FileStatus::Done;
+ } else if (std::ferror(fd) || bytesRead == 0) {
+ return FileStatus::Error;
+ }
+ }
+ return FileStatus::Continue;
+}
+
+void asyncCompressChunks(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
+ ThreadPool& executor,
+ FILE* fd,
+ size_t size,
+ size_t numThreads,
+ ZSTD_parameters params) {
+ auto chunksGuard = makeScopeGuard([&] { chunks.finish(); });
+
+ // Break the input up into chunks of size `step` and compress each chunk
+ // independently.
+ size_t step = calculateStep(size, numThreads, params);
+ auto status = FileStatus::Continue;
+ while (status == FileStatus::Continue && !errorHolder.hasError()) {
+ // Make a new input queue that we will put the chunk's input data into.
+ auto in = std::make_shared<BufferWorkQueue>();
+ auto inGuard = makeScopeGuard([&] { in->finish(); });
+ // Make a new output queue that compress will put the compressed data into.
+ auto out = std::make_shared<BufferWorkQueue>();
+ // Start compression in the thread pool
+ executor.add([&errorHolder, in, out, step, params] {
+ return compress(
+ errorHolder, std::move(in), std::move(out), step, params);
+ });
+ // Pass the output queue to the writer thread.
+ chunks.push(std::move(out));
+ // Fill the input queue for the compression job we just started
+ status = readData(*in, ZSTD_CStreamInSize(), step, fd);
+ }
+ errorHolder.check(status != FileStatus::Error, "Error reading input");
+}
+
+/**
+ * Decompress a frame, whose data is streamed into `in`, and stream the output
+ * to `out`.
+ *
+ * @param errorHolder Used to report errors and check if an error occured
+ * @param in Queue that we `pop()` input buffers from. It contains
+ * exactly one compressed frame.
+ * @param out Queue that we `push()` decompressed output buffers to
+ */
+static void decompress(
+ ErrorHolder& errorHolder,
+ std::shared_ptr<BufferWorkQueue> in,
+ std::shared_ptr<BufferWorkQueue> out) {
+ auto guard = makeScopeGuard([&] { out->finish(); });
+ // Initialize the DCtx
+ std::unique_ptr<ZSTD_DStream, size_t (&)(ZSTD_DStream*)> ctx(
+ ZSTD_createDStream(), ZSTD_freeDStream);
+ if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
+ return;
+ }
+ {
+ auto err = ZSTD_initDStream(ctx.get());
+ if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
+ return;
+ }
+ }
+
+ const size_t outSize = ZSTD_DStreamOutSize();
+ Buffer inBuffer;
+ size_t returnCode = 0;
+ // Read a buffer in from the input queue
+ while (in->pop(inBuffer) && !errorHolder.hasError()) {
+ auto zstdInBuffer = makeZstdInBuffer(inBuffer);
+ // Decompress the whole buffer and send it to the output queue
+ while (!inBuffer.empty() && !errorHolder.hasError()) {
+ // Allocate a buffer with at least outSize bytes.
+ Buffer outBuffer(outSize);
+ auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
+ // Decompress
+ returnCode =
+ ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
+ if (!errorHolder.check(
+ !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
+ return;
+ }
+ // Pass the buffer with the decompressed data to the output queue
+ out->push(split(outBuffer, zstdOutBuffer));
+ // Advance past the input we already read
+ advance(inBuffer, zstdInBuffer);
+ if (returnCode == 0) {
+ // The frame is over, prepare to (maybe) start a new frame
+ ZSTD_initDStream(ctx.get());
+ }
+ }
+ }
+ if (!errorHolder.check(returnCode <= 1, "Incomplete block")) {
+ return;
+ }
+ // We've given ZSTD_decompressStream all of our data, but there may still
+ // be data to read.
+ while (returnCode == 1) {
+ // Allocate a buffer with at least outSize bytes.
+ Buffer outBuffer(outSize);
+ auto zstdOutBuffer = makeZstdOutBuffer(outBuffer);
+ // Pass in no input.
+ ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0};
+ // Decompress
+ returnCode =
+ ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer);
+ if (!errorHolder.check(
+ !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) {
+ return;
+ }
+ // Pass the buffer with the decompressed data to the output queue
+ out->push(split(outBuffer, zstdOutBuffer));
+ }
+}
+
+void asyncDecompressFrames(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
+ ThreadPool& executor,
+ FILE* fd) {
+ auto framesGuard = makeScopeGuard([&] { frames.finish(); });
+ // Split the source up into its component frames.
+ // If we find our recognized skippable frame we know the next frames size
+ // which means that we can decompress each standard frame in independently.
+ // Otherwise, we will decompress using only one decompression task.
+ const size_t chunkSize = ZSTD_DStreamInSize();
+ auto status = FileStatus::Continue;
+ while (status == FileStatus::Continue && !errorHolder.hasError()) {
+ // Make a new input queue that we will put the frames's bytes into.
+ auto in = std::make_shared<BufferWorkQueue>();
+ auto inGuard = makeScopeGuard([&] { in->finish(); });
+ // Make a output queue that decompress will put the decompressed data into
+ auto out = std::make_shared<BufferWorkQueue>();
+
+ size_t frameSize;
+ {
+ // Calculate the size of the next frame.
+ // frameSize is 0 if the frame info can't be decoded.
+ Buffer buffer(SkippableFrame::kSize);
+ auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd);
+ if (bytesRead == 0 && status != FileStatus::Continue) {
+ break;
+ }
+ buffer.subtract(buffer.size() - bytesRead);
+ frameSize = SkippableFrame::tryRead(buffer.range());
+ in->push(std::move(buffer));
+ }
+ // Start decompression in the thread pool
+ executor.add([&errorHolder, in, out] {
+ return decompress(errorHolder, std::move(in), std::move(out));
+ });
+ // Pass the output queue to the writer thread
+ frames.push(std::move(out));
+ if (frameSize == 0) {
+ // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted
+ // Pass the rest of the source to this decompression task
+ while (status == FileStatus::Continue && !errorHolder.hasError()) {
+ status = readData(*in, chunkSize, chunkSize, fd);
+ }
+ break;
+ }
+ // Fill the input queue for the decompression job we just started
+ status = readData(*in, chunkSize, frameSize, fd);
+ }
+ errorHolder.check(status != FileStatus::Error, "Error reading input");
+}
+
+/// Write `data` to `fd`, returns true iff success.
+static bool writeData(ByteRange data, FILE* fd) {
+ while (!data.empty()) {
+ data.advance(std::fwrite(data.begin(), 1, data.size(), fd));
+ if (std::ferror(fd)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+size_t writeFile(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
+ FILE* outputFd,
+ bool writeSkippableFrames) {
+ size_t bytesWritten = 0;
+ std::shared_ptr<BufferWorkQueue> out;
+ // Grab the output queue for each decompression job (in order).
+ while (outs.pop(out) && !errorHolder.hasError()) {
+ if (writeSkippableFrames) {
+ // If we are compressing and want to write skippable frames we can't
+ // start writing before compression is done because we need to know the
+ // compressed size.
+ // Wait for the compressed size to be available and write skippable frame
+ SkippableFrame frame(out->size());
+ if (!writeData(frame.data(), outputFd)) {
+ errorHolder.setError("Failed to write output");
+ return bytesWritten;
+ }
+ bytesWritten += frame.kSize;
+ }
+ // For each chunk of the frame: Pop it from the queue and write it
+ Buffer buffer;
+ while (out->pop(buffer) && !errorHolder.hasError()) {
+ if (!writeData(buffer.range(), outputFd)) {
+ errorHolder.setError("Failed to write output");
+ return bytesWritten;
+ }
+ bytesWritten += buffer.size();
+ }
+ }
+ return bytesWritten;
+}
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "ErrorHolder.h"
+#include "Options.h"
+#include "utils/Buffer.h"
+#include "utils/Range.h"
+#include "utils/ThreadPool.h"
+#include "utils/WorkQueue.h"
+#define ZSTD_STATIC_LINKING_ONLY
+#include "zstd.h"
+#undef ZSTD_STATIC_LINKING_ONLY
+
+#include <cstddef>
+#include <memory>
+
+namespace pzstd {
+/**
+ * Runs pzstd with `options` and returns the number of bytes written.
+ * An error occurred if `errorHandler.hasError()`.
+ *
+ * @param options The pzstd options to use for (de)compression
+ * @param errorHolder Used to report errors and coordinate early shutdown
+ * if an error occured
+ * @returns The number of bytes written.
+ */
+std::size_t pzstdMain(const Options& options, ErrorHolder& errorHolder);
+
+/**
+ * Streams input from `fd`, breaks input up into chunks, and compresses each
+ * chunk independently. Output of each chunk gets streamed to a queue, and
+ * the output queues get put into `chunks` in order.
+ *
+ * @param errorHolder Used to report errors and coordinate early shutdown
+ * @param chunks Each compression jobs output queue gets `pushed()` here
+ * as soon as it is available
+ * @param executor The thread pool to run compression jobs in
+ * @param fd The input file descriptor
+ * @param size The size of the input file if known, 0 otherwise
+ * @param numThreads The number of threads in the thread pool
+ * @param parameters The zstd parameters to use for compression
+ */
+void asyncCompressChunks(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks,
+ ThreadPool& executor,
+ FILE* fd,
+ std::size_t size,
+ std::size_t numThreads,
+ ZSTD_parameters parameters);
+
+/**
+ * Streams input from `fd`. If pzstd headers are available it breaks the input
+ * up into independent frames. It sends each frame to an independent
+ * decompression job. Output of each frame gets streamed to a queue, and
+ * the output queues get put into `frames` in order.
+ *
+ * @param errorHolder Used to report errors and coordinate early shutdown
+ * @param frames Each decompression jobs output queue gets `pushed()` here
+ * as soon as it is available
+ * @param executor The thread pool to run compression jobs in
+ * @param fd The input file descriptor
+ */
+void asyncDecompressFrames(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames,
+ ThreadPool& executor,
+ FILE* fd);
+
+/**
+ * Streams input in from each queue in `outs` in order, and writes the data to
+ * `outputFd`.
+ *
+ * @param errorHolder Used to report errors and coordinate early exit
+ * @param outs A queue of output queues, one for each
+ * (de)compression job.
+ * @param outputFd The file descriptor to write to
+ * @param writeSkippableFrames Should we write pzstd headers?
+ * @returns The number of bytes written
+ */
+std::size_t writeFile(
+ ErrorHolder& errorHolder,
+ WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs,
+ FILE* outputFd,
+ bool writeSkippableFrames);
+}
--- /dev/null
+# Parallel Zstandard (PZstandard)
+
+Parallel Zstandard provides Zstandard format compatible compression and decompression that is able to utilize multiple cores.
+It breaks the input up into equal sized chunks and compresses each chunk independently into a Zstandard frame.
+It then concatenates the frames together to produce the final compressed output.
+Optionally, with the `-p` option, PZstandard will write a 12 byte header for each frame that is a skippable frame in the Zstandard format, which tells PZstandard the size of the next compressed frame.
+When `-p` is specified for compression, PZstandard can decompress the output in parallel.
+
+## Usage
+
+Basic usage
+
+ pzstd input-file -o output-file -n num-threads [ -p ] -# # Compression
+ pzstd -d input-file -o output-file -n num-threads # Decompression
+
+PZstandard also supports piping and fifo pipes
+
+ cat input-file | pzstd -n num-threads [ -p ] -# -c > /dev/null
+
+For more options
+
+ pzstd --help
+
+## Benchmarks
+
+As a reference, PZstandard and Pigz were compared on an Intel Core i7 @ 3.1 GHz, each using 4 threads, with the [Silesia compression corpus](http://sun.aei.polsl.pl/~sdeor/index.php?page=silesia).
+
+Compression Speed vs Ratio with 4 Threads | Decompression Speed with 4 Threads
+------------------------------------------|-----------------------------------
+ | 
+
+The test procedure was to run each of the following commands 2 times for each compression level, and take the minimum time.
+
+ time ./pzstd -# -n 4 -p -c silesia.tar > silesia.tar.zst
+ time ./pzstd -d -n 4 -c silesia.tar.zst > /dev/null
+
+ time pigz -# -p 4 -k -c silesia.tar > silesia.tar.gz
+ time pigz -d -p 4 -k -c silesia.tar.gz > /dev/null
+
+PZstandard was tested using compression levels 1-19, and Pigz was tested using compression levels 1-9.
+Pigz cannot do parallel decompression, it simply does each of reading, decompression, and writing on separate threads.
+
+## Tests
+
+Tests require that you have [gtest](https://github.com/google/googletest) installed.
+Modify `GTEST_INC` and `GTEST_LIB` in `test/Makefile` and `utils/test/Makefile` to work for your install of gtest.
+Then run `make test` in the `contrib/pzstd` directory.
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "SkippableFrame.h"
+#include "common/mem.h"
+#include "utils/Range.h"
+
+#include <cstdio>
+
+using namespace pzstd;
+
+SkippableFrame::SkippableFrame(std::uint32_t size) : frameSize_(size) {
+ MEM_writeLE32(data_.data(), kSkippableFrameMagicNumber);
+ MEM_writeLE32(data_.data() + 4, kFrameContentsSize);
+ MEM_writeLE32(data_.data() + 8, frameSize_);
+}
+
+/* static */ std::size_t SkippableFrame::tryRead(ByteRange bytes) {
+ if (bytes.size() < SkippableFrame::kSize ||
+ MEM_readLE32(bytes.begin()) != kSkippableFrameMagicNumber ||
+ MEM_readLE32(bytes.begin() + 4) != kFrameContentsSize) {
+ return 0;
+ }
+ return MEM_readLE32(bytes.begin() + 8);
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "utils/Range.h"
+
+#include <array>
+#include <cstddef>
+#include <cstdint>
+#include <cstdio>
+
+namespace pzstd {
+/**
+ * We put a skippable frame before each frame.
+ * It contains a skippable frame magic number, the size of the skippable frame,
+ * and the size of the next frame.
+ * Each skippable frame is exactly 12 bytes in little endian format.
+ * The first 8 bytes are for compatibility with the ZSTD format.
+ * If we have N threads, the output will look like
+ *
+ * [0x184D2A50|4|size1] [frame1 of size size1]
+ * [0x184D2A50|4|size2] [frame2 of size size2]
+ * ...
+ * [0x184D2A50|4|sizeN] [frameN of size sizeN]
+ *
+ * Each sizeX is 4 bytes.
+ *
+ * These skippable frames should allow us to skip through the compressed file
+ * and only load at most N pages.
+ */
+class SkippableFrame {
+ public:
+ static constexpr std::size_t kSize = 12;
+
+ private:
+ std::uint32_t frameSize_;
+ std::array<std::uint8_t, kSize> data_;
+ static constexpr std::uint32_t kSkippableFrameMagicNumber = 0x184D2A50;
+ // Could be improved if the size fits in less bytes
+ static constexpr std::uint32_t kFrameContentsSize = kSize - 8;
+
+ public:
+ // Write the skippable frame to data_ in LE format.
+ explicit SkippableFrame(std::uint32_t size);
+
+ // Read the skippable frame from bytes in LE format.
+ static std::size_t tryRead(ByteRange bytes);
+
+ ByteRange data() const {
+ return {data_.data(), data_.size()};
+ }
+
+ // Size of the next frame.
+ std::size_t frameSize() const {
+ return frameSize_;
+ }
+};
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "ErrorHolder.h"
+#include "Options.h"
+#include "Pzstd.h"
+#include "utils/FileSystem.h"
+#include "utils/Range.h"
+#include "utils/ScopeGuard.h"
+#include "utils/ThreadPool.h"
+#include "utils/WorkQueue.h"
+
+#include <chrono>
+#include <cstdio>
+#include <cstdlib>
+
+using namespace pzstd;
+
+namespace {
+// Prints how many ns it was in scope for upon destruction
+// Used for rough estimates of how long things took
+struct BenchmarkTimer {
+ using Clock = std::chrono::system_clock;
+ Clock::time_point start;
+ FILE* fd;
+
+ explicit BenchmarkTimer(FILE* fd = stdout) : fd(fd) {
+ start = Clock::now();
+ }
+
+ ~BenchmarkTimer() {
+ auto end = Clock::now();
+ size_t ticks =
+ std::chrono::duration_cast<std::chrono::milliseconds>(end - start)
+ .count();
+ ticks = std::max(ticks, size_t{1});
+ for (auto tmp = ticks; tmp < 100000; tmp *= 10) {
+ std::fprintf(fd, " ");
+ }
+ std::fprintf(fd, "%zu | ", ticks);
+ }
+};
+}
+
+// Code I used for benchmarking
+
+void testMain(const Options& options) {
+ if (!options.decompress) {
+ if (options.compressionLevel < 10) {
+ std::printf("0");
+ }
+ std::printf("%u | ", options.compressionLevel);
+ } else {
+ std::printf(" d | ");
+ }
+ if (options.numThreads < 10) {
+ std::printf("0");
+ }
+ std::printf("%u | ", options.numThreads);
+
+ FILE* inputFd = std::fopen(options.inputFile.c_str(), "rb");
+ if (inputFd == nullptr) {
+ std::abort();
+ }
+ size_t inputSize = 0;
+ if (inputFd != stdin) {
+ std::error_code ec;
+ inputSize = file_size(options.inputFile, ec);
+ if (ec) {
+ inputSize = 0;
+ }
+ }
+ FILE* outputFd = std::fopen(options.outputFile.c_str(), "wb");
+ if (outputFd == nullptr) {
+ std::abort();
+ }
+ auto guard = makeScopeGuard([&] {
+ std::fclose(inputFd);
+ std::fclose(outputFd);
+ });
+
+ WorkQueue<std::shared_ptr<BufferWorkQueue>> outs;
+ ErrorHolder errorHolder;
+ size_t bytesWritten;
+ {
+ ThreadPool executor(options.numThreads);
+ BenchmarkTimer timeIncludingClose;
+ if (!options.decompress) {
+ executor.add(
+ [&errorHolder, &outs, &executor, inputFd, inputSize, &options] {
+ asyncCompressChunks(
+ errorHolder,
+ outs,
+ executor,
+ inputFd,
+ inputSize,
+ options.numThreads,
+ options.determineParameters());
+ });
+ bytesWritten = writeFile(errorHolder, outs, outputFd, true);
+ } else {
+ executor.add([&errorHolder, &outs, &executor, inputFd] {
+ asyncDecompressFrames(errorHolder, outs, executor, inputFd);
+ });
+ bytesWritten = writeFile(
+ errorHolder, outs, outputFd, /* writeSkippableFrames */ false);
+ }
+ }
+ if (errorHolder.hasError()) {
+ std::fprintf(stderr, "Error: %s.\n", errorHolder.getError().c_str());
+ std::abort();
+ }
+ std::printf("%zu\n", bytesWritten);
+}
+
+int main(int argc, const char** argv) {
+ if (argc < 3) {
+ return 1;
+ }
+ Options options(0, 23, 0, false, "", "", true, true);
+ // Benchmarking code
+ for (size_t i = 0; i < 2; ++i) {
+ for (size_t compressionLevel = 1; compressionLevel <= 16;
+ compressionLevel <<= 1) {
+ for (size_t numThreads = 1; numThreads <= 16; numThreads <<= 1) {
+ options.numThreads = numThreads;
+ options.compressionLevel = compressionLevel;
+ options.decompress = false;
+ options.inputFile = argv[1];
+ options.outputFile = argv[2];
+ testMain(options);
+ options.decompress = true;
+ options.inputFile = argv[2];
+ options.outputFile = std::string(argv[1]) + ".d";
+ testMain(options);
+ std::fflush(stdout);
+ }
+ }
+ }
+ return 0;
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "ErrorHolder.h"
+#include "Options.h"
+#include "Pzstd.h"
+#include "utils/FileSystem.h"
+#include "utils/Range.h"
+#include "utils/ScopeGuard.h"
+#include "utils/ThreadPool.h"
+#include "utils/WorkQueue.h"
+
+using namespace pzstd;
+
+int main(int argc, const char** argv) {
+ Options options;
+ if (!options.parse(argc, argv)) {
+ return 1;
+ }
+
+ ErrorHolder errorHolder;
+ pzstdMain(options, errorHolder);
+
+ if (errorHolder.hasError()) {
+ std::fprintf(stderr, "Error: %s.\n", errorHolder.getError().c_str());
+ return 1;
+ }
+ return 0;
+}
--- /dev/null
+# ##########################################################################
+# Copyright (c) 2016-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant
+# of patent rights can be found in the PATENTS file in the same directory.
+# ##########################################################################
+
+# Set GTEST_INC and GTEST_LIB to work with your install of gtest
+GTEST_INC ?= -isystem googletest/googletest/include
+GTEST_LIB ?= -L googletest/build/googlemock/gtest
+
+# Define *.exe as extension for Windows systems
+ifneq (,$(filter Windows%,$(OS)))
+EXT =.exe
+else
+EXT =
+endif
+
+PZSTDDIR = ..
+PROGDIR = ../../../programs
+ZSTDDIR = ../../../lib
+
+CPPFLAGS = -I$(PZSTDDIR) $(GTEST_INC) $(GTEST_LIB) -I$(ZSTDDIR)/common -I$(PROGDIR)
+
+CFLAGS ?= -O3
+CFLAGS += -std=c++11
+CFLAGS += $(MOREFLAGS)
+FLAGS = $(CPPFLAGS) $(CFLAGS) $(LDFLAGS)
+
+datagen.o: $(PROGDIR)/datagen.*
+ $(CXX) $(FLAGS) $(PROGDIR)/datagen.c -c -o $@
+
+%: %.cpp *.h datagen.o
+ $(CXX) $(FLAGS) -lgtest -lgtest_main $@.cpp datagen.o $(PZSTDDIR)/libzstd.a $(PZSTDDIR)/Pzstd.o $(PZSTDDIR)/SkippableFrame.o $(PZSTDDIR)/Options.o -o $@$(EXT)
+
+.PHONY: test clean
+
+test: OptionsTest PzstdTest RoundTripTest
+ @./OptionsTest$(EXT)
+ @./PzstdTest$(EXT)
+ @./RoundTripTest$(EXT)
+
+clean:
+ @rm -f datagen.o OptionsTest PzstdTest RoundTripTest
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "Options.h"
+
+#include <gtest/gtest.h>
+#include <array>
+
+using namespace pzstd;
+
+namespace pzstd {
+bool operator==(const Options& lhs, const Options& rhs) {
+ return lhs.numThreads == rhs.numThreads &&
+ lhs.maxWindowLog == rhs.maxWindowLog &&
+ lhs.compressionLevel == rhs.compressionLevel &&
+ lhs.decompress == rhs.decompress && lhs.inputFile == rhs.inputFile &&
+ lhs.outputFile == rhs.outputFile && lhs.overwrite == rhs.overwrite &&
+ lhs.pzstdHeaders == rhs.pzstdHeaders;
+}
+}
+
+TEST(Options, ValidInputs) {
+ {
+ Options options;
+ std::array<const char*, 6> args = {
+ {nullptr, "--num-threads", "5", "-o", "-", "-f"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {5, 23, 3, false, "-", "-", true, false};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 6> args = {
+ {nullptr, "-n", "1", "input", "-19", "-p"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {1, 23, 19, false, "input", "input.zst", false, true};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 10> args = {{nullptr,
+ "--ultra",
+ "-22",
+ "-n",
+ "1",
+ "--output",
+ "x",
+ "-d",
+ "x.zst",
+ "-f"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {1, 0, 22, true, "x.zst", "x", true, false};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 6> args = {{nullptr,
+ "--num-threads",
+ "100",
+ "hello.zst",
+ "--decompress",
+ "--force"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {100, 23, 3, true, "hello.zst", "hello", true, false};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 5> args = {{nullptr, "-", "-n", "1", "-c"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {1, 23, 3, false, "-", "-", false, false};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 5> args = {{nullptr, "-", "-n", "1", "--stdout"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {1, 23, 3, false, "-", "-", false, false};
+ EXPECT_EQ(expected, options);
+ }
+ {
+ Options options;
+ std::array<const char*, 10> args = {{nullptr,
+ "-n",
+ "1",
+ "-",
+ "-5",
+ "-o",
+ "-",
+ "-u",
+ "-d",
+ "--pzstd-headers"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {1, 0, 5, true, "-", "-", false, true};
+ }
+ {
+ Options options;
+ std::array<const char*, 6> args = {
+ {nullptr, "silesia.tar", "-o", "silesia.tar.pzstd", "-n", "2"}};
+ EXPECT_TRUE(options.parse(args.size(), args.data()));
+ Options expected = {
+ 2, 23, 3, false, "silesia.tar", "silesia.tar.pzstd", false, false};
+ }
+}
+
+TEST(Options, BadNumThreads) {
+ {
+ Options options;
+ std::array<const char*, 3> args = {{nullptr, "-o", "-"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 5> args = {{nullptr, "-n", "0", "-o", "-"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 4> args = {{nullptr, "-n", "-o", "-"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+}
+
+TEST(Options, BadCompressionLevel) {
+ {
+ Options options;
+ std::array<const char*, 3> args = {{nullptr, "x", "-20"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 4> args = {{nullptr, "x", "-u", "-23"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+}
+
+TEST(Options, InvalidOption) {
+ {
+ Options options;
+ std::array<const char*, 3> args = {{nullptr, "x", "-x"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+}
+
+TEST(Options, BadOutputFile) {
+ {
+ Options options;
+ std::array<const char*, 5> args = {{nullptr, "notzst", "-d", "-n", "1"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 3> args = {{nullptr, "-n", "1"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 4> args = {{nullptr, "-", "-n", "1"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+}
+
+TEST(Options, Extras) {
+ {
+ Options options;
+ std::array<const char*, 2> args = {{nullptr, "-h"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+ {
+ Options options;
+ std::array<const char*, 2> args = {{nullptr, "-V"}};
+ EXPECT_FALSE(options.parse(args.size(), args.data()));
+ }
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "datagen.h"
+#include "Pzstd.h"
+#include "test/RoundTrip.h"
+#include "utils/ScopeGuard.h"
+
+#include <gtest/gtest.h>
+#include <cstddef>
+#include <cstdio>
+#include <memory>
+
+using namespace std;
+using namespace pzstd;
+
+TEST(Pzstd, SmallSizes) {
+ for (unsigned len = 1; len < 1028; ++len) {
+ std::string inputFile = std::tmpnam(nullptr);
+ auto guard = makeScopeGuard([&] { std::remove(inputFile.c_str()); });
+ {
+ static uint8_t buf[1028];
+ RDG_genBuffer(buf, len, 0.5, 0.0, 42);
+ auto fd = std::fopen(inputFile.c_str(), "wb");
+ auto written = std::fwrite(buf, 1, len, fd);
+ std::fclose(fd);
+ ASSERT_EQ(written, len);
+ }
+ for (unsigned headers = 0; headers <= 1; ++headers) {
+ for (unsigned numThreads = 1; numThreads <= 4; numThreads *= 2) {
+ for (unsigned level = 1; level <= 8; level *= 8) {
+ auto errorGuard = makeScopeGuard([&] {
+ guard.dismiss();
+ std::fprintf(stderr, "file: %s\n", inputFile.c_str());
+ std::fprintf(stderr, "pzstd headers: %u\n", headers);
+ std::fprintf(stderr, "# threads: %u\n", numThreads);
+ std::fprintf(stderr, "compression level: %u\n", level);
+ });
+ Options options;
+ options.pzstdHeaders = headers;
+ options.overwrite = true;
+ options.inputFile = inputFile;
+ options.numThreads = numThreads;
+ options.compressionLevel = level;
+ ASSERT_TRUE(roundTrip(options));
+ errorGuard.dismiss();
+ }
+ }
+ }
+ }
+}
+
+TEST(Pzstd, LargeSizes) {
+ for (unsigned len = 1 << 20; len <= (1 << 24); len *= 2) {
+ std::string inputFile = std::tmpnam(nullptr);
+ auto guard = makeScopeGuard([&] { std::remove(inputFile.c_str()); });
+ {
+ std::unique_ptr<uint8_t[]> buf(new uint8_t[len]);
+ RDG_genBuffer(buf.get(), len, 0.5, 0.0, 42);
+ auto fd = std::fopen(inputFile.c_str(), "wb");
+ auto written = std::fwrite(buf.get(), 1, len, fd);
+ std::fclose(fd);
+ ASSERT_EQ(written, len);
+ }
+ for (unsigned headers = 0; headers <= 1; ++headers) {
+ for (unsigned numThreads = 1; numThreads <= 16; numThreads *= 4) {
+ for (unsigned level = 1; level <= 4; level *= 2) {
+ auto errorGuard = makeScopeGuard([&] {
+ guard.dismiss();
+ std::fprintf(stderr, "file: %s\n", inputFile.c_str());
+ std::fprintf(stderr, "pzstd headers: %u\n", headers);
+ std::fprintf(stderr, "# threads: %u\n", numThreads);
+ std::fprintf(stderr, "compression level: %u\n", level);
+ });
+ Options options;
+ options.pzstdHeaders = headers;
+ options.overwrite = true;
+ options.inputFile = inputFile;
+ options.numThreads = numThreads;
+ options.compressionLevel = level;
+ ASSERT_TRUE(roundTrip(options));
+ errorGuard.dismiss();
+ }
+ }
+ }
+ }
+}
+
+TEST(Pzstd, ExtremelyCompressible) {
+ std::string inputFile = std::tmpnam(nullptr);
+ auto guard = makeScopeGuard([&] { std::remove(inputFile.c_str()); });
+ {
+ std::unique_ptr<uint8_t[]> buf(new uint8_t[10000]);
+ std::memset(buf.get(), 'a', 10000);
+ auto fd = std::fopen(inputFile.c_str(), "wb");
+ auto written = std::fwrite(buf.get(), 1, 10000, fd);
+ std::fclose(fd);
+ ASSERT_EQ(written, 10000);
+ }
+ Options options;
+ options.pzstdHeaders = false;
+ options.overwrite = true;
+ options.inputFile = inputFile;
+ options.numThreads = 1;
+ options.compressionLevel = 1;
+ ASSERT_TRUE(roundTrip(options));
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "Options.h"
+#include "Pzstd.h"
+#include "utils/ScopeGuard.h"
+
+#include <cstdio>
+#include <string>
+#include <cstdint>
+#include <memory>
+
+namespace pzstd {
+
+inline bool check(std::string source, std::string decompressed) {
+ std::unique_ptr<std::uint8_t[]> sBuf(new std::uint8_t[1024]);
+ std::unique_ptr<std::uint8_t[]> dBuf(new std::uint8_t[1024]);
+
+ auto sFd = std::fopen(source.c_str(), "rb");
+ auto dFd = std::fopen(decompressed.c_str(), "rb");
+ auto guard = makeScopeGuard([&] {
+ std::fclose(sFd);
+ std::fclose(dFd);
+ });
+
+ size_t sRead, dRead;
+
+ do {
+ sRead = std::fread(sBuf.get(), 1, 1024, sFd);
+ dRead = std::fread(dBuf.get(), 1, 1024, dFd);
+ if (std::ferror(sFd) || std::ferror(dFd)) {
+ return false;
+ }
+ if (sRead != dRead) {
+ return false;
+ }
+
+ for (size_t i = 0; i < sRead; ++i) {
+ if (sBuf.get()[i] != dBuf.get()[i]) {
+ return false;
+ }
+ }
+ } while (sRead == 1024);
+ if (!std::feof(sFd) || !std::feof(dFd)) {
+ return false;
+ }
+ return true;
+}
+
+inline bool roundTrip(Options& options) {
+ std::string source = options.inputFile;
+ std::string compressedFile = std::tmpnam(nullptr);
+ std::string decompressedFile = std::tmpnam(nullptr);
+ auto guard = makeScopeGuard([&] {
+ std::remove(compressedFile.c_str());
+ std::remove(decompressedFile.c_str());
+ });
+
+ {
+ options.outputFile = compressedFile;
+ options.decompress = false;
+ ErrorHolder errorHolder;
+ pzstdMain(options, errorHolder);
+ if (errorHolder.hasError()) {
+ errorHolder.getError();
+ return false;
+ }
+ }
+ {
+ options.decompress = true;
+ options.inputFile = compressedFile;
+ options.outputFile = decompressedFile;
+ ErrorHolder errorHolder;
+ pzstdMain(options, errorHolder);
+ if (errorHolder.hasError()) {
+ errorHolder.getError();
+ return false;
+ }
+ }
+ return check(source, decompressedFile);
+}
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "datagen.h"
+#include "Options.h"
+#include "test/RoundTrip.h"
+#include "utils/ScopeGuard.h"
+
+#include <cstddef>
+#include <cstdio>
+#include <cstdlib>
+#include <memory>
+#include <random>
+
+using namespace std;
+using namespace pzstd;
+
+namespace {
+string
+writeData(size_t size, double matchProba, double litProba, unsigned seed) {
+ std::unique_ptr<uint8_t[]> buf(new uint8_t[size]);
+ RDG_genBuffer(buf.get(), size, matchProba, litProba, seed);
+ string file = tmpnam(nullptr);
+ auto fd = std::fopen(file.c_str(), "wb");
+ auto guard = makeScopeGuard([&] { std::fclose(fd); });
+ auto bytesWritten = std::fwrite(buf.get(), 1, size, fd);
+ if (bytesWritten != size) {
+ std::abort();
+ }
+ return file;
+}
+
+template <typename Generator>
+string generateInputFile(Generator& gen) {
+ // Use inputs ranging from 1 Byte to 2^16 Bytes
+ std::uniform_int_distribution<size_t> size{1, 1 << 16};
+ std::uniform_real_distribution<> prob{0, 1};
+ return writeData(size(gen), prob(gen), prob(gen), gen());
+}
+
+template <typename Generator>
+Options generateOptions(Generator& gen, const string& inputFile) {
+ Options options;
+ options.inputFile = inputFile;
+ options.overwrite = true;
+
+ std::bernoulli_distribution pzstdHeaders{0.75};
+ std::uniform_int_distribution<unsigned> numThreads{1, 32};
+ std::uniform_int_distribution<unsigned> compressionLevel{1, 10};
+
+ options.pzstdHeaders = pzstdHeaders(gen);
+ options.numThreads = numThreads(gen);
+ options.compressionLevel = compressionLevel(gen);
+
+ return options;
+}
+}
+
+int main(int argc, char** argv) {
+ std::mt19937 gen(std::random_device{}());
+
+ auto newlineGuard = makeScopeGuard([] { std::fprintf(stderr, "\n"); });
+ for (unsigned i = 0; i < 10000; ++i) {
+ if (i % 100 == 0) {
+ std::fprintf(stderr, "Progress: %u%%\r", i / 100);
+ }
+ auto inputFile = generateInputFile(gen);
+ auto inputGuard = makeScopeGuard([&] { std::remove(inputFile.c_str()); });
+ for (unsigned i = 0; i < 10; ++i) {
+ auto options = generateOptions(gen, inputFile);
+ if (!roundTrip(options)) {
+ std::fprintf(stderr, "numThreads: %u\n", options.numThreads);
+ std::fprintf(stderr, "level: %u\n", options.compressionLevel);
+ std::fprintf(stderr, "decompress? %u\n", (unsigned)options.decompress);
+ std::fprintf(
+ stderr, "pzstd headers? %u\n", (unsigned)options.pzstdHeaders);
+ std::fprintf(stderr, "file: %s\n", inputFile.c_str());
+ return 1;
+ }
+ }
+ }
+ return 0;
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "utils/Range.h"
+
+#include <array>
+#include <cstddef>
+#include <memory>
+
+namespace pzstd {
+
+/**
+ * A `Buffer` has a pointer to a shared buffer, and a range of the buffer that
+ * it owns.
+ * The idea is that you can allocate one buffer, and write chunks into it
+ * and break off those chunks.
+ * The underlying buffer is reference counted, and will be destroyed when all
+ * `Buffer`s that reference it are destroyed.
+ */
+class Buffer {
+ std::shared_ptr<unsigned char> buffer_;
+ MutableByteRange range_;
+
+ static void delete_buffer(unsigned char* buffer) {
+ delete[] buffer;
+ }
+
+ public:
+ /// Construct an empty buffer that owns no data.
+ explicit Buffer() {}
+
+ /// Construct a `Buffer` that owns a new underlying buffer of size `size`.
+ explicit Buffer(std::size_t size)
+ : buffer_(new unsigned char[size], delete_buffer),
+ range_(buffer_.get(), buffer_.get() + size) {}
+
+ explicit Buffer(std::shared_ptr<unsigned char> buffer, MutableByteRange data)
+ : buffer_(buffer), range_(data) {}
+
+ Buffer(Buffer&&) = default;
+ Buffer& operator=(Buffer&&) & = default;
+
+ /**
+ * Splits the data into two pieces: [begin, begin + n), [begin + n, end).
+ * Their data both points into the same underlying buffer.
+ * Modifies the original `Buffer` to point to only [begin + n, end).
+ *
+ * @param n The offset to split at.
+ * @returns A buffer that owns the data [begin, begin + n).
+ */
+ Buffer splitAt(std::size_t n) {
+ auto firstPiece = range_.subpiece(0, n);
+ range_.advance(n);
+ return Buffer(buffer_, firstPiece);
+ }
+
+ /// Modifies the buffer to point to the range [begin + n, end).
+ void advance(std::size_t n) {
+ range_.advance(n);
+ }
+
+ /// Modifies the buffer to point to the range [begin, end - n).
+ void subtract(std::size_t n) {
+ range_.subtract(n);
+ }
+
+ /// Returns a read only `Range` pointing to the `Buffer`s data.
+ ByteRange range() const {
+ return range_;
+ }
+ /// Returns a mutable `Range` pointing to the `Buffer`s data.
+ MutableByteRange range() {
+ return range_;
+ }
+
+ const unsigned char* data() const {
+ return range_.data();
+ }
+
+ unsigned char* data() {
+ return range_.data();
+ }
+
+ std::size_t size() const {
+ return range_.size();
+ }
+
+ bool empty() const {
+ return range_.empty();
+ }
+};
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "utils/Range.h"
+
+#include <sys/stat.h>
+#include <cstdint>
+#include <system_error>
+
+// A small subset of `std::filesystem`.
+// `std::filesystem` should be a drop in replacement.
+// See http://en.cppreference.com/w/cpp/filesystem for documentation.
+
+namespace pzstd {
+
+using file_status = struct stat;
+
+/// http://en.cppreference.com/w/cpp/filesystem/status
+inline file_status status(StringPiece path, std::error_code& ec) noexcept {
+ file_status status;
+ if (stat(path.data(), &status)) {
+ ec.assign(errno, std::generic_category());
+ } else {
+ ec.clear();
+ }
+ return status;
+}
+
+/// http://en.cppreference.com/w/cpp/filesystem/is_regular_file
+inline bool is_regular_file(file_status status) noexcept {
+ return S_ISREG(status.st_mode);
+}
+
+/// http://en.cppreference.com/w/cpp/filesystem/is_regular_file
+inline bool is_regular_file(StringPiece path, std::error_code& ec) noexcept {
+ return is_regular_file(status(path, ec));
+}
+
+/// http://en.cppreference.com/w/cpp/filesystem/file_size
+inline std::uintmax_t file_size(
+ StringPiece path,
+ std::error_code& ec) noexcept {
+ auto stat = status(path, ec);
+ if (ec) {
+ return -1;
+ }
+ if (!is_regular_file(stat)) {
+ ec.assign(ENOTSUP, std::generic_category());
+ return -1;
+ }
+ ec.clear();
+ return stat.st_size;
+}
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+/**
+ * Compiler hints to indicate the fast path of an "if" branch: whether
+ * the if condition is likely to be true or false.
+ *
+ * @author Tudor Bosman (tudorb@fb.com)
+ */
+
+#pragma once
+
+#undef LIKELY
+#undef UNLIKELY
+
+#if defined(__GNUC__) && __GNUC__ >= 4
+#define LIKELY(x) (__builtin_expect((x), 1))
+#define UNLIKELY(x) (__builtin_expect((x), 0))
+#else
+#define LIKELY(x) (x)
+#define UNLIKELY(x) (x)
+#endif
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+
+/**
+ * A subset of `folly/Range.h`.
+ * All code copied verbatiam modulo formatting
+ */
+#pragma once
+
+#include "utils/Likely.h"
+
+#include <cstddef>
+#include <stdexcept>
+#include <string>
+#include <type_traits>
+
+namespace pzstd {
+
+namespace detail {
+/*
+ *Use IsCharPointer<T>::type to enable const char* or char*.
+ *Use IsCharPointer<T>::const_type to enable only const char*.
+*/
+template <class T>
+struct IsCharPointer {};
+
+template <>
+struct IsCharPointer<char*> {
+ typedef int type;
+};
+
+template <>
+struct IsCharPointer<const char*> {
+ typedef int const_type;
+ typedef int type;
+};
+
+} // namespace detail
+
+template <typename Iter>
+class Range {
+ Iter b_;
+ Iter e_;
+
+ public:
+ using size_type = std::size_t;
+ using iterator = Iter;
+ using const_iterator = Iter;
+ using value_type = typename std::remove_reference<
+ typename std::iterator_traits<Iter>::reference>::type;
+ using reference = typename std::iterator_traits<Iter>::reference;
+
+ constexpr Range() : b_(), e_() {}
+ constexpr Range(Iter begin, Iter end) : b_(begin), e_(end) {}
+
+ constexpr Range(Iter begin, size_type size) : b_(begin), e_(begin + size) {}
+
+ template <class T = Iter, typename detail::IsCharPointer<T>::type = 0>
+ /* implicit */ Range(Iter str) : b_(str), e_(str + std::strlen(str)) {}
+
+ template <class T = Iter, typename detail::IsCharPointer<T>::const_type = 0>
+ /* implicit */ Range(const std::string& str)
+ : b_(str.data()), e_(b_ + str.size()) {}
+
+ // Allow implicit conversion from Range<From> to Range<To> if From is
+ // implicitly convertible to To.
+ template <
+ class OtherIter,
+ typename std::enable_if<
+ (!std::is_same<Iter, OtherIter>::value &&
+ std::is_convertible<OtherIter, Iter>::value),
+ int>::type = 0>
+ constexpr /* implicit */ Range(const Range<OtherIter>& other)
+ : b_(other.begin()), e_(other.end()) {}
+
+ Range(const Range&) = default;
+ Range(Range&&) = default;
+
+ Range& operator=(const Range&) & = default;
+ Range& operator=(Range&&) & = default;
+
+ constexpr size_type size() const {
+ return e_ - b_;
+ }
+ bool empty() const {
+ return b_ == e_;
+ }
+ Iter data() const {
+ return b_;
+ }
+ Iter begin() const {
+ return b_;
+ }
+ Iter end() const {
+ return e_;
+ }
+
+ void advance(size_type n) {
+ if (UNLIKELY(n > size())) {
+ throw std::out_of_range("index out of range");
+ }
+ b_ += n;
+ }
+
+ void subtract(size_type n) {
+ if (UNLIKELY(n > size())) {
+ throw std::out_of_range("index out of range");
+ }
+ e_ -= n;
+ }
+
+ Range subpiece(size_type first, size_type length = std::string::npos) const {
+ if (UNLIKELY(first > size())) {
+ throw std::out_of_range("index out of range");
+ }
+
+ return Range(b_ + first, std::min(length, size() - first));
+ }
+};
+
+using ByteRange = Range<const unsigned char*>;
+using MutableByteRange = Range<unsigned char*>;
+using StringPiece = Range<const char*>;
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include <utility>
+
+namespace pzstd {
+
+/**
+ * Dismissable scope guard.
+ * `Function` must be callable and take no parameters.
+ * Unless `dissmiss()` is called, the callable is executed upon destruction of
+ * `ScopeGuard`.
+ *
+ * Example:
+ *
+ * auto guard = makeScopeGuard([&] { cleanup(); });
+ */
+template <typename Function>
+class ScopeGuard {
+ Function function;
+ bool dismissed;
+
+ public:
+ explicit ScopeGuard(Function&& function)
+ : function(std::move(function)), dismissed(false) {}
+
+ void dismiss() {
+ dismissed = true;
+ }
+
+ ~ScopeGuard() noexcept {
+ if (!dismissed) {
+ function();
+ }
+ }
+};
+
+/// Creates a scope guard from `function`.
+template <typename Function>
+ScopeGuard<Function> makeScopeGuard(Function&& function) {
+ return ScopeGuard<Function>(std::forward<Function>(function));
+}
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "utils/WorkQueue.h"
+
+#include <cstddef>
+#include <functional>
+#include <thread>
+#include <vector>
+
+namespace pzstd {
+/// A simple thread pool that pulls tasks off its queue in FIFO order.
+class ThreadPool {
+ std::vector<std::thread> threads_;
+
+ WorkQueue<std::function<void()>> tasks_;
+
+ public:
+ /// Constructs a thread pool with `numThreads` threads.
+ explicit ThreadPool(std::size_t numThreads) {
+ threads_.reserve(numThreads);
+ for (std::size_t i = 0; i < numThreads; ++i) {
+ threads_.emplace_back([&] {
+ std::function<void()> task;
+ while (tasks_.pop(task)) {
+ task();
+ }
+ });
+ }
+ }
+
+ /// Finishes all tasks currently in the queue.
+ ~ThreadPool() {
+ tasks_.finish();
+ for (auto& thread : threads_) {
+ thread.join();
+ }
+ }
+
+ /**
+ * Adds `task` to the queue of tasks to execute. Since `task` is a
+ * `std::function<>`, it cannot be a move only type. So any lambda passed must
+ * not capture move only types (like `std::unique_ptr`).
+ *
+ * @param task The task to execute.
+ */
+ void add(std::function<void()> task) {
+ tasks_.push(std::move(task));
+ }
+};
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include "utils/Buffer.h"
+
+#include <atomic>
+#include <cassert>
+#include <condition_variable>
+#include <cstddef>
+#include <functional>
+#include <mutex>
+#include <queue>
+
+namespace pzstd {
+
+/// Unbounded thread-safe work queue.
+template <typename T>
+class WorkQueue {
+ // Protects all member variable access
+ std::mutex mutex_;
+ std::condition_variable cv_;
+
+ std::queue<T> queue_;
+ bool done_;
+
+ public:
+ /// Constructs an empty work queue.
+ WorkQueue() : done_(false) {}
+
+ /**
+ * Push an item onto the work queue. Notify a single thread that work is
+ * available. If `finish()` has been called, do nothing and return false.
+ *
+ * @param item Item to push onto the queue.
+ * @returns True upon success, false if `finish()` has been called. An
+ * item was pushed iff `push()` returns true.
+ */
+ bool push(T item) {
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ if (done_) {
+ return false;
+ }
+ queue_.push(std::move(item));
+ }
+ cv_.notify_one();
+ return true;
+ }
+
+ /**
+ * Attempts to pop an item off the work queue. It will block until data is
+ * available or `finish()` has been called.
+ *
+ * @param[out] item If `pop` returns `true`, it contains the popped item.
+ * If `pop` returns `false`, it is unmodified.
+ * @returns True upon success. False if the queue is empty and
+ * `finish()` has been called.
+ */
+ bool pop(T& item) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (queue_.empty() && !done_) {
+ cv_.wait(lock);
+ }
+ if (queue_.empty()) {
+ assert(done_);
+ return false;
+ }
+ item = std::move(queue_.front());
+ queue_.pop();
+ return true;
+ }
+
+ /**
+ * Promise that `push()` won't be called again, so once the queue is empty
+ * there will never any more work.
+ */
+ void finish() {
+ {
+ std::lock_guard<std::mutex> lock(mutex_);
+ assert(!done_);
+ done_ = true;
+ }
+ cv_.notify_all();
+ }
+
+ /// Blocks until `finish()` has been called (but the queue may not be empty).
+ void waitUntilFinished() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ while (!done_) {
+ cv_.wait(lock);
+ // If we were woken by a push, we need to wake a thread waiting on pop().
+ if (!done_) {
+ lock.unlock();
+ cv_.notify_one();
+ lock.lock();
+ }
+ }
+ }
+};
+
+/// Work queue for `Buffer`s that knows the total number of bytes in the queue.
+class BufferWorkQueue {
+ WorkQueue<Buffer> queue_;
+ std::atomic<std::size_t> size_;
+
+ public:
+ BufferWorkQueue() : size_(0) {}
+
+ void push(Buffer buffer) {
+ size_.fetch_add(buffer.size());
+ queue_.push(std::move(buffer));
+ }
+
+ bool pop(Buffer& buffer) {
+ bool result = queue_.pop(buffer);
+ if (result) {
+ size_.fetch_sub(buffer.size());
+ }
+ return result;
+ }
+
+ void finish() {
+ queue_.finish();
+ }
+
+ /**
+ * Blocks until `finish()` has been called.
+ *
+ * @returns The total number of bytes of all the `Buffer`s currently in the
+ * queue.
+ */
+ std::size_t size() {
+ queue_.waitUntilFinished();
+ return size_.load();
+ }
+};
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/Buffer.h"
+#include "utils/Range.h"
+
+#include <gtest/gtest.h>
+#include <memory>
+
+using namespace pzstd;
+
+namespace {
+void deleter(const unsigned char* buf) {
+ delete[] buf;
+}
+}
+
+TEST(Buffer, Constructors) {
+ Buffer empty;
+ EXPECT_TRUE(empty.empty());
+ EXPECT_EQ(0, empty.size());
+
+ Buffer sized(5);
+ EXPECT_FALSE(sized.empty());
+ EXPECT_EQ(5, sized.size());
+
+ Buffer moved(std::move(sized));
+ EXPECT_FALSE(sized.empty());
+ EXPECT_EQ(5, sized.size());
+
+ Buffer assigned;
+ assigned = std::move(moved);
+ EXPECT_FALSE(sized.empty());
+ EXPECT_EQ(5, sized.size());
+}
+
+TEST(Buffer, BufferManagement) {
+ std::shared_ptr<unsigned char> buf(new unsigned char[10], deleter);
+ {
+ Buffer acquired(buf, MutableByteRange(buf.get(), buf.get() + 10));
+ EXPECT_EQ(2, buf.use_count());
+ Buffer moved(std::move(acquired));
+ EXPECT_EQ(2, buf.use_count());
+ Buffer assigned;
+ assigned = std::move(moved);
+ EXPECT_EQ(2, buf.use_count());
+
+ Buffer split = assigned.splitAt(5);
+ EXPECT_EQ(3, buf.use_count());
+
+ split.advance(1);
+ assigned.subtract(1);
+ EXPECT_EQ(3, buf.use_count());
+ }
+ EXPECT_EQ(1, buf.use_count());
+}
+
+TEST(Buffer, Modifiers) {
+ Buffer buf(10);
+ {
+ unsigned char i = 0;
+ for (auto& byte : buf.range()) {
+ byte = i++;
+ }
+ }
+
+ auto prefix = buf.splitAt(2);
+
+ ASSERT_EQ(2, prefix.size());
+ EXPECT_EQ(0, *prefix.data());
+
+ ASSERT_EQ(8, buf.size());
+ EXPECT_EQ(2, *buf.data());
+
+ buf.advance(2);
+ EXPECT_EQ(4, *buf.data());
+
+ EXPECT_EQ(9, *(buf.range().end() - 1));
+
+ buf.subtract(2);
+ EXPECT_EQ(7, *(buf.range().end() - 1));
+
+ EXPECT_EQ(4, buf.size());
+}
--- /dev/null
+# ##########################################################################
+# Copyright (c) 2016-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant
+# of patent rights can be found in the PATENTS file in the same directory.
+# ##########################################################################
+
+GTEST_INC ?= -isystem googletest/googletest/include
+GTEST_LIB ?= -L googletest/build/googlemock/gtest
+
+# Define *.exe as extension for Windows systems
+ifneq (,$(filter Windows%,$(OS)))
+EXT =.exe
+else
+EXT =
+endif
+
+PZSTDDIR = ../..
+
+CPPFLAGS = -I$(PZSTDDIR) $(GTEST_INC) $(GTEST_LIB)
+CFLAGS ?= -O3
+CFLAGS += -std=c++11
+CFLAGS += $(MOREFLAGS)
+FLAGS = $(CPPFLAGS) $(CFLAGS) $(LDFLAGS)
+
+%: %.cpp
+ $(CXX) $(FLAGS) -lgtest -lgtest_main $^ -o $@$(EXT)
+
+.PHONY: test clean
+
+test: BufferTest RangeTest ScopeGuardTest ThreadPoolTest WorkQueueTest
+ @./BufferTest$(EXT)
+ @./RangeTest$(EXT)
+ @./ScopeGuardTest$(EXT)
+ @./ThreadPoolTest$(EXT)
+ @./WorkQueueTest$(EXT)
+
+clean:
+ @rm -f BufferTest RangeTest ScopeGuardTest ThreadPoolTest WorkQueueTest
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/Range.h"
+
+#include <gtest/gtest.h>
+#include <string>
+
+using namespace pzstd;
+
+// Range is directly copied from folly.
+// Just some sanity tests to make sure everything seems to work.
+
+TEST(Range, Constructors) {
+ StringPiece empty;
+ EXPECT_TRUE(empty.empty());
+ EXPECT_EQ(0, empty.size());
+
+ std::string str = "hello";
+ {
+ Range<std::string::const_iterator> piece(str.begin(), str.end());
+ EXPECT_EQ(5, piece.size());
+ EXPECT_EQ('h', *piece.data());
+ EXPECT_EQ('o', *(piece.end() - 1));
+ }
+
+ {
+ StringPiece piece(str.data(), str.size());
+ EXPECT_EQ(5, piece.size());
+ EXPECT_EQ('h', *piece.data());
+ EXPECT_EQ('o', *(piece.end() - 1));
+ }
+
+ {
+ StringPiece piece(str);
+ EXPECT_EQ(5, piece.size());
+ EXPECT_EQ('h', *piece.data());
+ EXPECT_EQ('o', *(piece.end() - 1));
+ }
+
+ {
+ StringPiece piece(str.c_str());
+ EXPECT_EQ(5, piece.size());
+ EXPECT_EQ('h', *piece.data());
+ EXPECT_EQ('o', *(piece.end() - 1));
+ }
+}
+
+TEST(Range, Modifiers) {
+ StringPiece range("hello world");
+ ASSERT_EQ(11, range.size());
+
+ {
+ auto hello = range.subpiece(0, 5);
+ EXPECT_EQ(5, hello.size());
+ EXPECT_EQ('h', *hello.data());
+ EXPECT_EQ('o', *(hello.end() - 1));
+ }
+ {
+ auto hello = range;
+ hello.subtract(6);
+ EXPECT_EQ(5, hello.size());
+ EXPECT_EQ('h', *hello.data());
+ EXPECT_EQ('o', *(hello.end() - 1));
+ }
+ {
+ auto world = range;
+ world.advance(6);
+ EXPECT_EQ(5, world.size());
+ EXPECT_EQ('w', *world.data());
+ EXPECT_EQ('d', *(world.end() - 1));
+ }
+
+ std::string expected = "hello world";
+ EXPECT_EQ(expected, std::string(range.begin(), range.end()));
+ EXPECT_EQ(expected, std::string(range.data(), range.size()));
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/ScopeGuard.h"
+
+#include <gtest/gtest.h>
+
+using namespace pzstd;
+
+TEST(ScopeGuard, Dismiss) {
+ {
+ auto guard = makeScopeGuard([&] { EXPECT_TRUE(false); });
+ guard.dismiss();
+ }
+}
+
+TEST(ScopeGuard, Executes) {
+ bool executed = false;
+ {
+ auto guard = makeScopeGuard([&] { executed = true; });
+ }
+ EXPECT_TRUE(executed);
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/ThreadPool.h"
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <thread>
+#include <vector>
+
+using namespace pzstd;
+
+TEST(ThreadPool, Ordering) {
+ std::vector<int> results;
+
+ {
+ ThreadPool executor(1);
+ for (int i = 0; i < 100; ++i) {
+ executor.add([ &results, i ] { results.push_back(i); });
+ }
+ }
+
+ for (int i = 0; i < 100; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(ThreadPool, AllJobsFinished) {
+ std::atomic<unsigned> numFinished{0};
+ std::atomic<bool> start{false};
+ {
+ ThreadPool executor(5);
+ for (int i = 0; i < 1000; ++i) {
+ executor.add([ &numFinished, &start ] {
+ while (!start.load()) {
+ // spin
+ }
+ ++numFinished;
+ });
+ }
+ start.store(true);
+ }
+ EXPECT_EQ(1000, numFinished.load());
+}
+
+TEST(ThreadPool, AddJobWhileJoining) {
+ std::atomic<bool> done{false};
+ {
+ ThreadPool executor(1);
+ executor.add([&executor, &done] {
+ while (!done.load()) {
+ std::this_thread::yield();
+ }
+ // Sleep for a second to be sure that we are joining
+ std::this_thread::sleep_for(std::chrono::seconds(1));
+ executor.add([] {
+ EXPECT_TRUE(false);
+ });
+ });
+ done.store(true);
+ }
+}
--- /dev/null
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/Buffer.h"
+#include "utils/WorkQueue.h"
+
+#include <gtest/gtest.h>
+#include <mutex>
+#include <thread>
+#include <vector>
+
+using namespace pzstd;
+
+namespace {
+struct Popper {
+ WorkQueue<int>* queue;
+ int* results;
+ std::mutex* mutex;
+
+ void operator()() {
+ int result;
+ while (queue->pop(result)) {
+ std::lock_guard<std::mutex> lock(*mutex);
+ results[result] = result;
+ }
+ }
+};
+}
+
+TEST(WorkQueue, SingleThreaded) {
+ WorkQueue<int> queue;
+ int result;
+
+ queue.push(5);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(5, result);
+
+ queue.push(1);
+ queue.push(2);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(1, result);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(2, result);
+
+ queue.push(1);
+ queue.push(2);
+ queue.finish();
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(1, result);
+ EXPECT_TRUE(queue.pop(result));
+ EXPECT_EQ(2, result);
+ EXPECT_FALSE(queue.pop(result));
+
+ queue.waitUntilFinished();
+}
+
+TEST(WorkQueue, SPSC) {
+ WorkQueue<int> queue;
+ const int max = 100;
+
+ for (int i = 0; i < 10; ++i) {
+ queue.push(i);
+ }
+
+ std::thread thread([ &queue, max ] {
+ int result;
+ for (int i = 0;; ++i) {
+ if (!queue.pop(result)) {
+ EXPECT_EQ(i, max);
+ break;
+ }
+ EXPECT_EQ(i, result);
+ }
+ });
+
+ std::this_thread::yield();
+ for (int i = 10; i < max; ++i) {
+ queue.push(i);
+ }
+ queue.finish();
+
+ thread.join();
+}
+
+TEST(WorkQueue, SPMC) {
+ WorkQueue<int> queue;
+ std::vector<int> results(10000, -1);
+ std::mutex mutex;
+ std::vector<std::thread> threads;
+ for (int i = 0; i < 100; ++i) {
+ threads.emplace_back(Popper{&queue, results.data(), &mutex});
+ }
+
+ for (int i = 0; i < 10000; ++i) {
+ queue.push(i);
+ }
+ queue.finish();
+
+ for (auto& thread : threads) {
+ thread.join();
+ }
+
+ for (int i = 0; i < 10000; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(WorkQueue, MPMC) {
+ WorkQueue<int> queue;
+ std::vector<int> results(10000, -1);
+ std::mutex mutex;
+ std::vector<std::thread> popperThreads;
+ for (int i = 0; i < 100; ++i) {
+ popperThreads.emplace_back(Popper{&queue, results.data(), &mutex});
+ }
+
+ std::vector<std::thread> pusherThreads;
+ for (int i = 0; i < 10; ++i) {
+ auto min = i * 1000;
+ auto max = (i + 1) * 1000;
+ pusherThreads.emplace_back(
+ [ &queue, min, max ] {
+ for (int i = min; i < max; ++i) {
+ queue.push(i);
+ }
+ });
+ }
+
+ for (auto& thread : pusherThreads) {
+ thread.join();
+ }
+ queue.finish();
+
+ for (auto& thread : popperThreads) {
+ thread.join();
+ }
+
+ for (int i = 0; i < 10000; ++i) {
+ EXPECT_EQ(i, results[i]);
+ }
+}
+
+TEST(BufferWorkQueue, SizeCalculatedCorrectly) {
+ {
+ BufferWorkQueue queue;
+ queue.finish();
+ EXPECT_EQ(0, queue.size());
+ }
+ {
+ BufferWorkQueue queue;
+ queue.push(Buffer(10));
+ queue.finish();
+ EXPECT_EQ(10, queue.size());
+ }
+ {
+ BufferWorkQueue queue;
+ queue.push(Buffer(10));
+ queue.push(Buffer(5));
+ queue.finish();
+ EXPECT_EQ(15, queue.size());
+ }
+ {
+ BufferWorkQueue queue;
+ queue.push(Buffer(10));
+ queue.push(Buffer(5));
+ queue.finish();
+ Buffer buffer;
+ queue.pop(buffer);
+ EXPECT_EQ(5, queue.size());
+ }
+}