]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
[huf] Reduce stack usage of HUF_readDTableX2 by ~460 bytes
authorNick Terrell <terrelln@fb.com>
Fri, 5 Mar 2021 20:36:50 +0000 (12:36 -0800)
committerNick Terrell <terrelln@fb.com>
Fri, 5 Mar 2021 20:39:46 +0000 (12:39 -0800)
* Use `HUF_readStats_wksp()`
* Use workspace in `HUF_fillDTableX2*()`
* Clean up workspace usage to use a workspace struct

lib/decompress/huf_decompress.c

index 7699c920374ad2d9e775b019e95ce8e84146d33d..4481484f7a7c7cf1f78f64b76f284deec9395ca4 100644 (file)
@@ -528,13 +528,15 @@ typedef rankValCol_t rankVal_t[HUF_TABLELOG_MAX];
 static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 sizeLog, const U32 consumed,
                            const U32* rankValOrigin, const int minWeight,
                            const sortedSymbol_t* sortedSymbols, const U32 sortedListSize,
-                           U32 nbBitsBaseline, U16 baseSeq)
+                           U32 nbBitsBaseline, U16 baseSeq, U32* wksp, size_t wkspSize)
 {
     HUF_DEltX2 DElt;
-    U32 rankVal[HUF_TABLELOG_MAX + 1];
+    U32* rankVal = wksp;
 
+    assert(wkspSize >= HUF_TABLELOG_MAX + 1);
+    (void)wkspSize;
     /* get pre-calculated rankVal */
-    ZSTD_memcpy(rankVal, rankValOrigin, sizeof(rankVal));
+    ZSTD_memcpy(rankVal, rankValOrigin, sizeof(U32) * (HUF_TABLELOG_MAX + 1));
 
     /* fill skipped values */
     if (minWeight>1) {
@@ -569,14 +571,18 @@ static void HUF_fillDTableX2Level2(HUF_DEltX2* DTable, U32 sizeLog, const U32 co
 static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
                            const sortedSymbol_t* sortedList, const U32 sortedListSize,
                            const U32* rankStart, rankVal_t rankValOrigin, const U32 maxWeight,
-                           const U32 nbBitsBaseline)
+                           const U32 nbBitsBaseline, U32* wksp, size_t wkspSize)
 {
-    U32 rankVal[HUF_TABLELOG_MAX + 1];
+    U32* rankVal = wksp;
     const int scaleLog = nbBitsBaseline - targetLog;   /* note : targetLog >= srcLog, hence scaleLog <= 1 */
     const U32 minBits  = nbBitsBaseline - maxWeight;
     U32 s;
 
-    ZSTD_memcpy(rankVal, rankValOrigin, sizeof(rankVal));
+    assert(wkspSize >= HUF_TABLELOG_MAX + 1);
+    wksp += HUF_TABLELOG_MAX + 1;
+    wkspSize -= HUF_TABLELOG_MAX + 1;
+
+    ZSTD_memcpy(rankVal, rankValOrigin, sizeof(U32) * (HUF_TABLELOG_MAX + 1));
 
     /* fill DTable */
     for (s=0; s<sortedListSize; s++) {
@@ -594,7 +600,7 @@ static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
             HUF_fillDTableX2Level2(DTable+start, targetLog-nbBits, nbBits,
                            rankValOrigin[nbBits], minWeight,
                            sortedList+sortedRank, sortedListSize-sortedRank,
-                           nbBitsBaseline, symbol);
+                           nbBitsBaseline, symbol, wksp, wkspSize);
         } else {
             HUF_DEltX2 DElt;
             MEM_writeLE16(&(DElt.sequence), symbol);
@@ -608,6 +614,15 @@ static void HUF_fillDTableX2(HUF_DEltX2* DTable, const U32 targetLog,
     }
 }
 
+typedef struct {
+    rankValCol_t rankVal[HUF_TABLELOG_MAX];
+    U32 rankStats[HUF_TABLELOG_MAX + 1];
+    U32 rankStart0[HUF_TABLELOG_MAX + 2];
+    sortedSymbol_t sortedSymbol[HUF_SYMBOLVALUE_MAX + 1];
+    BYTE weightList[HUF_SYMBOLVALUE_MAX + 1];
+    U32 wksp[HUF_READ_STATS_WORKSPACE_SIZE_U32];
+} HUF_ReadDTableX2_Workspace;
+
 size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
                        const void* src, size_t srcSize,
                              void* workSpace, size_t wkspSize)
@@ -620,47 +635,32 @@ size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
     HUF_DEltX2* const dt = (HUF_DEltX2*)dtPtr;
     U32 *rankStart;
 
-    rankValCol_t* rankVal;
-    U32* rankStats;
-    U32* rankStart0;
-    sortedSymbol_t* sortedSymbol;
-    BYTE* weightList;
-    size_t spaceUsed32 = 0;
-
-    rankVal = (rankValCol_t *)((U32 *)workSpace + spaceUsed32);
-    spaceUsed32 += (sizeof(rankValCol_t) * HUF_TABLELOG_MAX) >> 2;
-    rankStats = (U32 *)workSpace + spaceUsed32;
-    spaceUsed32 += HUF_TABLELOG_MAX + 1;
-    rankStart0 = (U32 *)workSpace + spaceUsed32;
-    spaceUsed32 += HUF_TABLELOG_MAX + 2;
-    sortedSymbol = (sortedSymbol_t *)workSpace + (spaceUsed32 * sizeof(U32)) / sizeof(sortedSymbol_t);
-    spaceUsed32 += HUF_ALIGN(sizeof(sortedSymbol_t) * (HUF_SYMBOLVALUE_MAX + 1), sizeof(U32)) >> 2;
-    weightList = (BYTE *)((U32 *)workSpace + spaceUsed32);
-    spaceUsed32 += HUF_ALIGN(HUF_SYMBOLVALUE_MAX + 1, sizeof(U32)) >> 2;
-
-    if ((spaceUsed32 << 2) > wkspSize) return ERROR(tableLog_tooLarge);
-
-    rankStart = rankStart0 + 1;
-    ZSTD_memset(rankStats, 0, sizeof(U32) * (2 * HUF_TABLELOG_MAX + 2 + 1));
+    HUF_ReadDTableX2_Workspace* wksp = (HUF_ReadDTableX2_Workspace*)workSpace;
+
+    if (sizeof(*wksp) > wkspSize) return ERROR(tableLog_tooLarge);
+
+    rankStart = wksp->rankStart0 + 1;
+    ZSTD_memset(wksp->rankStats, 0, sizeof(wksp->rankStats));
+    ZSTD_memset(wksp->rankStart0, 0, sizeof(wksp->rankStart0));
 
     DEBUG_STATIC_ASSERT(sizeof(HUF_DEltX2) == sizeof(HUF_DTable));   /* if compiler fails here, assertion is wrong */
     if (maxTableLog > HUF_TABLELOG_MAX) return ERROR(tableLog_tooLarge);
     /* ZSTD_memset(weightList, 0, sizeof(weightList)); */  /* is not necessary, even though some analyzer complain ... */
 
-    iSize = HUF_readStats(weightList, HUF_SYMBOLVALUE_MAX + 1, rankStats, &nbSymbols, &tableLog, src, srcSize);
+    iSize = HUF_readStats_wksp(wksp->weightList, HUF_SYMBOLVALUE_MAX + 1, wksp->rankStats, &nbSymbols, &tableLog, src, srcSize, wksp->wksp, sizeof(wksp->wksp), /* bmi2 */ 0);
     if (HUF_isError(iSize)) return iSize;
 
     /* check result */
     if (tableLog > maxTableLog) return ERROR(tableLog_tooLarge);   /* DTable can't fit code depth */
 
     /* find maxWeight */
-    for (maxW = tableLog; rankStats[maxW]==0; maxW--) {}  /* necessarily finds a solution before 0 */
+    for (maxW = tableLog; wksp->rankStats[maxW]==0; maxW--) {}  /* necessarily finds a solution before 0 */
 
     /* Get start index of each weight */
     {   U32 w, nextRankStart = 0;
         for (w=1; w<maxW+1; w++) {
             U32 curr = nextRankStart;
-            nextRankStart += rankStats[w];
+            nextRankStart += wksp->rankStats[w];
             rankStart[w] = curr;
         }
         rankStart[0] = nextRankStart;   /* put all 0w symbols at the end of sorted list*/
@@ -670,37 +670,38 @@ size_t HUF_readDTableX2_wksp(HUF_DTable* DTable,
     /* sort symbols by weight */
     {   U32 s;
         for (s=0; s<nbSymbols; s++) {
-            U32 const w = weightList[s];
+            U32 const w = wksp->weightList[s];
             U32 const r = rankStart[w]++;
-            sortedSymbol[r].symbol = (BYTE)s;
-            sortedSymbol[r].weight = (BYTE)w;
+            wksp->sortedSymbol[r].symbol = (BYTE)s;
+            wksp->sortedSymbol[r].weight = (BYTE)w;
         }
         rankStart[0] = 0;   /* forget 0w symbols; this is beginning of weight(1) */
     }
 
     /* Build rankVal */
-    {   U32* const rankVal0 = rankVal[0];
+    {   U32* const rankVal0 = wksp->rankVal[0];
         {   int const rescale = (maxTableLog-tableLog) - 1;   /* tableLog <= maxTableLog */
             U32 nextRankVal = 0;
             U32 w;
             for (w=1; w<maxW+1; w++) {
                 U32 curr = nextRankVal;
-                nextRankVal += rankStats[w] << (w+rescale);
+                nextRankVal += wksp->rankStats[w] << (w+rescale);
                 rankVal0[w] = curr;
         }   }
         {   U32 const minBits = tableLog+1 - maxW;
             U32 consumed;
             for (consumed = minBits; consumed < maxTableLog - minBits + 1; consumed++) {
-                U32* const rankValPtr = rankVal[consumed];
+                U32* const rankValPtr = wksp->rankVal[consumed];
                 U32 w;
                 for (w = 1; w < maxW+1; w++) {
                     rankValPtr[w] = rankVal0[w] >> consumed;
     }   }   }   }
 
     HUF_fillDTableX2(dt, maxTableLog,
-                   sortedSymbol, sizeOfSort,
-                   rankStart0, rankVal, maxW,
-                   tableLog+1);
+                   wksp->sortedSymbol, sizeOfSort,
+                   wksp->rankStart0, wksp->rankVal, maxW,
+                   tableLog+1,
+                   wksp->wksp, sizeof(wksp->wksp) / sizeof(U32));
 
     dtd.tableLog = (BYTE)maxTableLog;
     dtd.tableType = 1;
@@ -1225,7 +1226,7 @@ size_t HUF_decompress1X1 (void* dst, size_t dstSize, const void* cSrc, size_t cS
     HUF_CREATE_STATIC_DTABLEX1(DTable, HUF_TABLELOG_MAX);
     return HUF_decompress1X1_DCtx (DTable, dst, dstSize, cSrc, cSrcSize);
 }
-#endif 
+#endif
 
 #ifndef HUF_FORCE_DECOMPRESS_X1
 size_t HUF_readDTableX2(HUF_DTable* DTable, const void* src, size_t srcSize)