]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
Make offload API compatible with static CCtx (#3854)
authorElliot Gorokhovsky <embg@fb.com>
Thu, 28 Dec 2023 19:48:46 +0000 (14:48 -0500)
committerGitHub <noreply@github.com>
Thu, 28 Dec 2023 19:48:46 +0000 (14:48 -0500)
* Add ZSTD_CCtxParams_registerSequenceProducer() to public API

* add unit test

* add docs to zstd.h

* nits

* Add ZSTDLIB_STATIC_API prefix

* Add asserts

lib/compress/zstd_compress.c
lib/zstd.h
tests/zstreamtest.c

index cdd763ff6cfa53b91d2081b8629232ef32f505f5..55415c7e3fbc1958e3cbb88d6067a8a781824b5f 100644 (file)
@@ -7084,14 +7084,27 @@ ZSTD_parameters ZSTD_getParams(int compressionLevel, unsigned long long srcSizeH
 }
 
 void ZSTD_registerSequenceProducer(
-    ZSTD_CCtx* zc, void* extSeqProdState,
+    ZSTD_CCtx* zc,
+    void* extSeqProdState,
     ZSTD_sequenceProducer_F extSeqProdFunc
 ) {
+    assert(zc != NULL);
+    ZSTD_CCtxParams_registerSequenceProducer(
+        &zc->requestedParams, extSeqProdState, extSeqProdFunc
+    );
+}
+
+void ZSTD_CCtxParams_registerSequenceProducer(
+  ZSTD_CCtx_params* params,
+  void* extSeqProdState,
+  ZSTD_sequenceProducer_F extSeqProdFunc
+) {
+    assert(params != NULL);
     if (extSeqProdFunc != NULL) {
-        zc->requestedParams.extSeqProdFunc = extSeqProdFunc;
-        zc->requestedParams.extSeqProdState = extSeqProdState;
+        params->extSeqProdFunc = extSeqProdFunc;
+        params->extSeqProdState = extSeqProdState;
     } else {
-        zc->requestedParams.extSeqProdFunc = NULL;
-        zc->requestedParams.extSeqProdState = NULL;
+        params->extSeqProdFunc = NULL;
+        params->extSeqProdState = NULL;
     }
 }
index 61f81db0f25e82e54ac174298ba072967b9e2c18..841269305205de57af498937cea65439f219a439 100644 (file)
@@ -1665,9 +1665,6 @@ ZSTDLIB_API unsigned ZSTD_isSkippableFrame(const void* buffer, size_t size);
  *
  *  Note : only single-threaded compression is supported.
  *  ZSTD_estimateCCtxSize_usingCCtxParams() will return an error code if ZSTD_c_nbWorkers is >= 1.
- *
- *  Note 2 : ZSTD_estimateCCtxSize* functions are not compatible with the Block-Level Sequence Producer API at this time.
- *  Size estimates assume that no external sequence producer is registered.
  */
 ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize(int maxCompressionLevel);
 ZSTDLIB_STATIC_API size_t ZSTD_estimateCCtxSize_usingCParams(ZSTD_compressionParameters cParams);
@@ -2824,6 +2821,22 @@ ZSTD_registerSequenceProducer(
   ZSTD_sequenceProducer_F sequenceProducer
 );
 
+/*! ZSTD_CCtxParams_registerSequenceProducer() :
+ * Same as ZSTD_registerSequenceProducer(), but operates on ZSTD_CCtx_params.
+ * This is used for accurate size estimation with ZSTD_estimateCCtxSize_usingCCtxParams(),
+ * which is needed when creating a ZSTD_CCtx with ZSTD_initStaticCCtx().
+ *
+ * If you are using the external sequence producer API in a scenario where ZSTD_initStaticCCtx()
+ * is required, then this function is for you. Otherwise, you probably don't need it.
+ *
+ * See tests/zstreamtest.c for example usage. */
+ZSTDLIB_STATIC_API void
+ZSTD_CCtxParams_registerSequenceProducer(
+  ZSTD_CCtx_params* params,
+  void* sequenceProducerState,
+  ZSTD_sequenceProducer_F sequenceProducer
+);
+
 
 /*********************************************************************
 *  Buffer-less and synchronous inner streaming functions (DEPRECATED)
index 04f1f8b0e9c9801f4b022d4e0feeafea89503431..82aaf3db50c6eaf940cc1eaeba41c0375bc98b82 100644 (file)
@@ -1920,7 +1920,7 @@ static int basicUnitTests(U32 seed, double compressibility, int bigTests)
     DISPLAYLEVEL(3, "test%3i : Block-Level External Sequence Producer API: ", testNb++);
     {
         size_t const dstBufSize = ZSTD_compressBound(CNBufferSize);
-        BYTE* const dstBuf = (BYTE*)malloc(ZSTD_compressBound(dstBufSize));
+        BYTE* const dstBuf = (BYTE*)malloc(dstBufSize);
         size_t const checkBufSize = CNBufferSize;
         BYTE* const checkBuf = (BYTE*)malloc(checkBufSize);
         int enableFallback;
@@ -2356,6 +2356,58 @@ static int basicUnitTests(U32 seed, double compressibility, int bigTests)
     }
     DISPLAYLEVEL(3, "OK \n");
 
+    DISPLAYLEVEL(3, "test%3i : Testing external sequence producer with static CCtx: ", testNb++);
+    {
+        size_t const dstBufSize = ZSTD_compressBound(CNBufferSize);
+        BYTE* const dstBuf = (BYTE*)malloc(dstBufSize);
+        size_t const checkBufSize = CNBufferSize;
+        BYTE* const checkBuf = (BYTE*)malloc(checkBufSize);
+        ZSTD_CCtx_params* params = ZSTD_createCCtxParams();
+        ZSTD_CCtx* staticCCtx;
+        void* cctxBuf;
+        EMF_testCase seqProdState;
+
+        CHECK_Z(ZSTD_CCtxParams_setParameter(params, ZSTD_c_validateSequences, 1));
+        CHECK_Z(ZSTD_CCtxParams_setParameter(params, ZSTD_c_enableSeqProducerFallback, 0));
+        ZSTD_CCtxParams_registerSequenceProducer(params, &seqProdState, zstreamSequenceProducer);
+
+        {
+            size_t const cctxSize = ZSTD_estimateCCtxSize_usingCCtxParams(params);
+            cctxBuf = malloc(cctxSize);
+            staticCCtx = ZSTD_initStaticCCtx(cctxBuf, cctxSize);
+            ZSTD_CCtx_setParametersUsingCCtxParams(staticCCtx, params);
+        }
+
+        // Check that compression with external sequence producer succeeds when expected
+        seqProdState = EMF_LOTS_OF_SEQS;
+        {
+            size_t dResult;
+            size_t const cResult = ZSTD_compress2(staticCCtx, dstBuf, dstBufSize, CNBuffer, CNBufferSize);
+            CHECK(ZSTD_isError(cResult), "EMF: Compression error: %s", ZSTD_getErrorName(cResult));
+            dResult = ZSTD_decompress(checkBuf, checkBufSize, dstBuf, cResult);
+            CHECK(ZSTD_isError(dResult), "EMF: Decompression error: %s", ZSTD_getErrorName(dResult));
+            CHECK(dResult != CNBufferSize, "EMF: Corruption!");
+            CHECK(memcmp(CNBuffer, checkBuf, CNBufferSize) != 0, "EMF: Corruption!");
+        }
+
+        // Check that compression with external sequence producer fails when expected
+        seqProdState = EMF_BIG_ERROR;
+        {
+            size_t const cResult = ZSTD_compress2(staticCCtx, dstBuf, dstBufSize, CNBuffer, CNBufferSize);
+            CHECK(!ZSTD_isError(cResult), "EMF: Should have raised an error!");
+            CHECK(
+                ZSTD_getErrorCode(cResult) != ZSTD_error_sequenceProducer_failed,
+                "EMF: Wrong error code: %s", ZSTD_getErrorName(cResult)
+            );
+        }
+
+        free(dstBuf);
+        free(checkBuf);
+        free(cctxBuf);
+        ZSTD_freeCCtxParams(params);
+    }
+    DISPLAYLEVEL(3, "OK \n");
+
 _end:
     FUZ_freeDictionary(dictionary);
     ZSTD_freeCStream(zc);