]> git.ipfire.org Git - thirdparty/zstd.git/commitdiff
zlibWrapper: improved memory deallocation in case of error
authorinikep <inikep@gmail.com>
Mon, 13 Jun 2016 10:00:46 +0000 (12:00 +0200)
committerinikep <inikep@gmail.com>
Mon, 13 Jun 2016 10:00:46 +0000 (12:00 +0200)
zlibWrapper/zstd_zlibwrapper.c

index c2efe380e8375a2ae2c8c1aa8cc33a60fd1ae962..06667b61088dc1cb053369930f2895c592b88f3a 100644 (file)
@@ -344,10 +344,10 @@ ZEXTERN int ZEXPORT z_inflateInit_ OF((z_streamp strm,
 {
     ZWRAP_DCtx* zwd = ZWRAP_createDCtx(strm);
     LOG_WRAPPER("- inflateInit\n");
-    if (zwd == NULL) return Z_MEM_ERROR;
+    if (zwd == NULL) { strm->state = NULL; return Z_MEM_ERROR; }
 
     zwd->version = zwd->customMem.customAlloc(zwd->customMem.opaque, strlen(version) + 1);
-    if (zwd->version == NULL) { ZWRAP_freeDCtx(zwd); return Z_MEM_ERROR; }
+    if (zwd->version == NULL) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return Z_MEM_ERROR; }
     strcpy(zwd->version, version);
 
     zwd->stream_size = stream_size;
@@ -372,8 +372,6 @@ ZEXTERN int ZEXPORT z_inflateInit2_ OF((z_streamp strm, int  windowBits,
 }
 
 
-
-
 ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm,
                                              const Bytef *dictionary,
                                              uInt  dictLength))
@@ -382,9 +380,11 @@ ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm,
         return inflateSetDictionary(strm, dictionary, dictLength);
 
     LOG_WRAPPER("- inflateSetDictionary\n");
-    {   ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state;
-        size_t errorCode = ZBUFF_decompressInitDictionary(zwd->zbd, dictionary, dictLength);
-        if (ZSTD_isError(errorCode)) return Z_MEM_ERROR; 
+    {   size_t errorCode;
+        ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state;
+        if (strm->state == NULL) return Z_MEM_ERROR;
+        errorCode = ZBUFF_decompressInitDictionary(zwd->zbd, dictionary, dictLength);
+        if (ZSTD_isError(errorCode)) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return Z_MEM_ERROR; }
         
         if (strm->total_in == ZSTD_frameHeaderSize_min) {
             size_t dstCapacity = 0;
@@ -393,6 +393,7 @@ ZEXTERN int ZEXPORT z_inflateSetDictionary OF((z_streamp strm,
             LOG_WRAPPER("ZBUFF_decompressContinue3 errorCode=%d srcSize=%d dstCapacity=%d\n", (int)errorCode, (int)srcSize, (int)dstCapacity);
             if (dstCapacity > 0 || ZSTD_isError(errorCode)) {
                 LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode));
+                ZWRAP_freeDCtx(zwd); strm->state = NULL;
                 return Z_MEM_ERROR;
             }
         }
@@ -410,6 +411,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
     if (strm->avail_in > 0) {
         size_t errorCode, dstCapacity, srcSize;
         ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state;
+        if (strm->state == NULL) return Z_MEM_ERROR;
         LOG_WRAPPER("inflate avail_in=%d avail_out=%d total_in=%d total_out=%d\n", (int)strm->avail_in, (int)strm->avail_out, (int)strm->total_in, (int)strm->total_out);
         if (strm->total_in < ZWRAP_HEADERSIZE)
         {
@@ -432,7 +434,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
                 else
                     errorCode = inflateInit_(strm, zwd->version, zwd->stream_size);
                 LOG_WRAPPER("ZLIB inflateInit errorCode=%d\n", (int)errorCode);
-                if (errorCode != Z_OK) return errorCode;
+                if (errorCode != Z_OK) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return errorCode; }
 
                 /* inflate header */
                 strm->next_in = (unsigned char*)zwd->headerBuf;
@@ -440,8 +442,8 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
                 strm->avail_out = 0;
                 errorCode = inflate(strm, Z_NO_FLUSH);
                 LOG_WRAPPER("ZLIB inflate errorCode=%d strm->avail_in=%d\n", (int)errorCode, (int)strm->avail_in);
-                if (errorCode != Z_OK) return errorCode;
-                if (strm->avail_in > 0) return Z_MEM_ERROR;
+                if (errorCode != Z_OK) { ZWRAP_freeDCtx(zwd); strm->state = NULL; return errorCode; }
+                if (strm->avail_in > 0) goto error;
                 
                 strm->next_in = strm2.next_in;
                 strm->avail_in = strm2.avail_in;
@@ -450,17 +452,17 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
 
                 strm->reserved = 0; /* mark as zlib stream */
                 errorCode = ZWRAP_freeDCtx(zwd);
-                if (ZSTD_isError(errorCode)) return Z_MEM_ERROR;
+                if (ZSTD_isError(errorCode)) goto error;
 
                 if (flush == Z_INFLATE_SYNC) return inflateSync(strm);
                 return inflate(strm, flush);
             }
 
             zwd->zbd = ZBUFF_createDCtx_advanced(zwd->customMem);
-            if (zwd->zbd == NULL) { ZWRAP_freeDCtx(zwd); return Z_MEM_ERROR; }
+            if (zwd->zbd == NULL) goto error;
 
             errorCode = ZBUFF_decompressInit(zwd->zbd);
-            if (ZSTD_isError(errorCode)) return Z_MEM_ERROR;
+            if (ZSTD_isError(errorCode)) goto error;
 
             srcSize = ZWRAP_HEADERSIZE;
             dstCapacity = 0;
@@ -468,7 +470,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
             LOG_WRAPPER("ZBUFF_decompressContinue1 errorCode=%d srcSize=%d dstCapacity=%d\n", (int)errorCode, (int)srcSize, (int)dstCapacity);
             if (ZSTD_isError(errorCode)) {
                 LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode));
-                return Z_MEM_ERROR;
+                goto error;
             }
             if (strm->avail_in == 0) return Z_OK;
         }
@@ -480,7 +482,7 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
         if (ZSTD_isError(errorCode)) {
             LOG_WRAPPER("ERROR: ZBUFF_decompressContinue %s\n", ZSTD_getErrorName(errorCode));
             zwd->errorCount++;
-            return (zwd->errorCount<=1) ? Z_NEED_DICT : Z_MEM_ERROR;
+            if (zwd->errorCount<=1) return Z_NEED_DICT; else goto error;
         }
         strm->next_out += dstCapacity;
         strm->total_out += dstCapacity;
@@ -489,6 +491,11 @@ ZEXTERN int ZEXPORT z_inflate OF((z_streamp strm, int flush))
         strm->next_in += srcSize;
         strm->avail_in -= srcSize;
         if (errorCode == 0) return Z_STREAM_END;
+        return Z_OK;
+error:
+        ZWRAP_freeDCtx(zwd); 
+        strm->state = NULL;
+        return Z_MEM_ERROR;
     }
     return Z_OK;
 }
@@ -503,6 +510,7 @@ ZEXTERN int ZEXPORT z_inflateEnd OF((z_streamp strm))
     LOG_WRAPPER("- inflateEnd total_in=%d total_out=%d\n", (int)(strm->total_in), (int)(strm->total_out));
     {   ZWRAP_DCtx* zwd = (ZWRAP_DCtx*) strm->state;
         size_t const errorCode = ZWRAP_freeDCtx(zwd);
+        strm->state = NULL;
         if (ZSTD_isError(errorCode)) return Z_MEM_ERROR;
     }
     return ret;