]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[pzstd] Reuse ZSTD_{C,D}Stream
authorNick Terrell <terrelln@fb.com>
Thu, 13 Oct 2016 00:23:38 +0000 (17:23 -0700)
committerNick Terrell <terrelln@fb.com>
Fri, 14 Oct 2016 22:26:55 +0000 (15:26 -0700)
contrib/pzstd/Makefile
contrib/pzstd/Pzstd.cpp
contrib/pzstd/Pzstd.h
contrib/pzstd/utils/ResourcePool.h [new file with mode: 0644]
contrib/pzstd/utils/test/ResourcePoolTest.cpp [new file with mode: 0644]

index b998c1fff68dc8f872845cb3eef0f4d4fccb7fec..4f63887d324560f3689c59ed885063e06c8c8933 100644 (file)
@@ -87,6 +87,7 @@ default: all
 check:
        $(TESTPROG) ./utils/test/BufferTest$(EXT) $(TESTFLAGS)
        $(TESTPROG) ./utils/test/RangeTest$(EXT) $(TESTFLAGS)
+       $(TESTPROG) ./utils/test/ResourcePoolTest$(EXT) $(TESTFLAGS)
        $(TESTPROG) ./utils/test/ScopeGuardTest$(EXT) $(TESTFLAGS)
        $(TESTPROG) ./utils/test/ThreadPoolTest$(EXT) $(TESTFLAGS)
        $(TESTPROG) ./utils/test/WorkQueueTest$(EXT) $(TESTFLAGS)
index 70c0515b3854da0e45c8a017fb431b9160a9c4ef..6b3d27b3eb94b1c202e7cb45bddce9b9c8a3a79f 100644 (file)
@@ -173,9 +173,9 @@ static FILE *openOutputFile(const Options &options,
 
 int pzstdMain(const Options &options) {
   int returnCode = 0;
+  SharedState state(options.decompress, options.determineParameters());
   for (const auto& input : options.inputFiles) {
-    // Setup the error holder
-    SharedState state;
+    // Setup the shared state
     auto printErrorGuard = makeScopeGuard([&] {
       if (state.errorHolder.hasError()) {
         returnCode = 1;
@@ -271,24 +271,21 @@ Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) {
  * @param in           Queue that we `pop()` input buffers from
  * @param out          Queue that we `push()` compressed output buffers to
  * @param maxInputSize An upper bound on the size of the input
- * @param parameters   The zstd parameters to use for compression
  */
 static void compress(
     SharedState& state,
     std::shared_ptr<BufferWorkQueue> in,
     std::shared_ptr<BufferWorkQueue> out,
-    size_t maxInputSize,
-    ZSTD_parameters parameters) {
+    size_t maxInputSize) {
   auto& errorHolder = state.errorHolder;
   auto guard = makeScopeGuard([&] { out->finish(); });
   // Initialize the CCtx
-  std::unique_ptr<ZSTD_CStream, size_t (*)(ZSTD_CStream*)> ctx(
-      ZSTD_createCStream(), ZSTD_freeCStream);
+  auto ctx = state.cStreamPool->get();
   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) {
     return;
   }
   {
-    auto err = ZSTD_initCStream_advanced(ctx.get(), nullptr, 0, parameters, 0);
+    auto err = ZSTD_resetCStream(ctx.get(), 0);
     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
       return;
     }
@@ -416,9 +413,9 @@ std::uint64_t asyncCompressChunks(
     // Make a new output queue that compress will put the compressed data into.
     auto out = std::make_shared<BufferWorkQueue>();
     // Start compression in the thread pool
-    executor.add([&state, in, out, step, params] {
+    executor.add([&state, in, out, step] {
       return compress(
-          state, std::move(in), std::move(out), step, params);
+          state, std::move(in), std::move(out), step);
     });
     // Pass the output queue to the writer thread.
     chunks.push(std::move(out));
@@ -445,13 +442,12 @@ static void decompress(
   auto& errorHolder = state.errorHolder;
   auto guard = makeScopeGuard([&] { out->finish(); });
   // Initialize the DCtx
-  std::unique_ptr<ZSTD_DStream, size_t (*)(ZSTD_DStream*)> ctx(
-      ZSTD_createDStream(), ZSTD_freeDStream);
+  auto ctx = state.dStreamPool->get();
   if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) {
     return;
   }
   {
-    auto err = ZSTD_initDStream(ctx.get());
+    auto err = ZSTD_resetDStream(ctx.get());
     if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) {
       return;
     }
index 469c20cd45e15574aff7bdd6c0bf8a88fffe1bf4..b02fe7b19bb3a9a8ffb822814a5d4987d8f88379 100644 (file)
@@ -35,7 +35,46 @@ int pzstdMain(const Options& options);
 
 class SharedState {
  public:
+  SharedState(bool decompress, ZSTD_parameters parameters) {
+    if (!decompress) {
+      cStreamPool.reset(new ResourcePool<ZSTD_CStream>{
+          [parameters]() -> ZSTD_CStream* {
+            auto zcs = ZSTD_createCStream();
+            if (zcs) {
+              auto err = ZSTD_initCStream_advanced(
+                  zcs, nullptr, 0, parameters, 0);
+              if (ZSTD_isError(err)) {
+                ZSTD_freeCStream(zcs);
+                return nullptr;
+              }
+            }
+            return zcs;
+          },
+          [](ZSTD_CStream *zcs) {
+            ZSTD_freeCStream(zcs);
+          }});
+    } else {
+      dStreamPool.reset(new ResourcePool<ZSTD_DStream>{
+          []() -> ZSTD_DStream* {
+            auto zds = ZSTD_createDStream();
+            if (zds) {
+              auto err = ZSTD_initDStream(zds);
+              if (ZSTD_isError(err)) {
+                ZSTD_freeDStream(zds);
+                return nullptr;
+              }
+            }
+            return zds;
+          },
+          [](ZSTD_DStream *zds) {
+            ZSTD_freeDStream(zds);
+          }});
+    }
+  }
+
   ErrorHolder errorHolder;
+  std::unique_ptr<ResourcePool<ZSTD_CStream>> cStreamPool;
+  std::unique_ptr<ResourcePool<ZSTD_DStream>> dStreamPool;
 };
 
 /**
diff --git a/contrib/pzstd/utils/ResourcePool.h b/contrib/pzstd/utils/ResourcePool.h
new file mode 100644 (file)
index 0000000..ed01130
--- /dev/null
@@ -0,0 +1,96 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#pragma once
+
+#include <cassert>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+namespace pzstd {
+
+/**
+ * An unbounded pool of resources.
+ * A `ResourcePool<T>` requires a factory function that takes allocates `T*` and
+ * a free function that frees a `T*`.
+ * Calling `ResourcePool::get()` will give you a new `ResourcePool::UniquePtr`
+ * to a `T`, and when it goes out of scope the resource will be returned to the
+ * pool.
+ * The `ResourcePool<T>` *must* survive longer than any resources it hands out.
+ * Remember that `ResourcePool<T>` hands out mutable `T`s, so make sure to clean
+ * up the resource before or after every use.
+ */
+template <typename T>
+class ResourcePool {
+ public:
+  class Deleter;
+  using Factory = std::function<T*()>;
+  using Free = std::function<void(T*)>;
+  using UniquePtr = std::unique_ptr<T, Deleter>;
+
+ private:
+  std::mutex mutex_;
+  Factory factory_;
+  Free free_;
+  std::vector<T*> resources_;
+  unsigned inUse_;
+
+ public:
+  /**
+   * Creates a `ResourcePool`.
+   *
+   * @param factory  The function to use to create new resources.
+   * @param free     The function to use to free resources created by `factory`.
+   */
+  ResourcePool(Factory factory, Free free)
+      : factory_(std::move(factory)), free_(std::move(free)), inUse_(0) {}
+
+  /**
+   * @returns  A unique pointer to a resource.  The resource is null iff
+   *           there are no avaiable resources and `factory()` returns null.
+   */
+  UniquePtr get() {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (!resources_.empty()) {
+      UniquePtr resource{resources_.back(), Deleter{*this}};
+      resources_.pop_back();
+      ++inUse_;
+      return resource;
+    }
+    UniquePtr resource{factory_(), Deleter{*this}};
+    ++inUse_;
+    return resource;
+  }
+
+  ~ResourcePool() noexcept {
+    assert(inUse_ == 0);
+    for (const auto resource : resources_) {
+      free_(resource);
+    }
+  }
+
+  class Deleter {
+    ResourcePool *pool_;
+  public:
+    explicit Deleter(ResourcePool &pool) : pool_(&pool) {}
+
+    void operator() (T *resource) {
+      std::lock_guard<std::mutex> lock(pool_->mutex_);
+      // Make sure we don't put null resources into the pool
+      if (resource) {
+        pool_->resources_.push_back(resource);
+      }
+      assert(pool_->inUse_ > 0);
+      --pool_->inUse_;
+    }
+  };
+};
+
+}
diff --git a/contrib/pzstd/utils/test/ResourcePoolTest.cpp b/contrib/pzstd/utils/test/ResourcePoolTest.cpp
new file mode 100644 (file)
index 0000000..a6a86b3
--- /dev/null
@@ -0,0 +1,72 @@
+/**
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree. An additional grant
+ * of patent rights can be found in the PATENTS file in the same directory.
+ */
+#include "utils/ResourcePool.h"
+
+#include <gtest/gtest.h>
+#include <atomic>
+#include <thread>
+
+using namespace pzstd;
+
+TEST(ResourcePool, FullTest) {
+  unsigned numCreated = 0;
+  unsigned numDeleted = 0;
+  {
+    ResourcePool<int> pool(
+      [&numCreated] { ++numCreated; return new int{5}; },
+      [&numDeleted](int *x) { ++numDeleted; delete x; });
+
+    {
+      auto i = pool.get();
+      EXPECT_EQ(5, *i);
+      *i = 6;
+    }
+    {
+      auto i = pool.get();
+      EXPECT_EQ(6, *i);
+      auto j = pool.get();
+      EXPECT_EQ(5, *j);
+      *j = 7;
+    }
+    {
+      auto i = pool.get();
+      EXPECT_EQ(6, *i);
+      auto j = pool.get();
+      EXPECT_EQ(7, *j);
+    }
+  }
+  EXPECT_EQ(2, numCreated);
+  EXPECT_EQ(numCreated, numDeleted);
+}
+
+TEST(ResourcePool, ThreadSafe) {
+  std::atomic<unsigned> numCreated{0};
+  std::atomic<unsigned> numDeleted{0};
+  {
+    ResourcePool<int> pool(
+      [&numCreated] { ++numCreated; return new int{0}; },
+      [&numDeleted](int *x) { ++numDeleted; delete x; });
+    auto push = [&pool] {
+      for (int i = 0; i < 100; ++i) {
+        auto x = pool.get();
+        ++*x;
+      }
+    };
+    std::thread t1{push};
+    std::thread t2{push};
+    t1.join();
+    t2.join();
+
+    auto x = pool.get();
+    auto y = pool.get();
+    EXPECT_EQ(200, *x + *y);
+  }
+  EXPECT_GE(2, numCreated);
+  EXPECT_EQ(numCreated, numDeleted);
+}