]> git.ipfire.org Git - people/ms/suricata.git/blobdiff - src/tests/fuzz/fuzz_applayerprotodetectgetproto.c
fuzz: restrict flags passed to AppLayerProtoDetectGetProto
[people/ms/suricata.git] / src / tests / fuzz / fuzz_applayerprotodetectgetproto.c
index 27b271924e5cd88e621279b7255ea08a29e0ae30..e9df4db8ebdcffa76c0e31ee1657b25c7cc9f14d 100644 (file)
@@ -45,7 +45,7 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
         alpd_tctx = AppLayerProtoDetectGetCtxThread();
     }
 
-    f = UTHBuildFlow(AF_INET, "1.2.3.4", "5.6.7.8", (data[2] << 8) | data[3], (data[4] << 8) | data[5]);
+    f = TestHelperBuildFlow(AF_INET, "1.2.3.4", "5.6.7.8", (data[2] << 8) | data[3], (data[4] << 8) | data[5]);
     if (f == NULL) {
         return 0;
     }
@@ -54,22 +54,35 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
     f->protoctx = &ssn;
     f->protomap = FlowGetProtoMapping(f->proto);
 
-    alproto = AppLayerProtoDetectGetProto(alpd_tctx, f, data+HEADER_LEN, size-HEADER_LEN, f->proto, data[0], &reverse);
-    if (alproto != ALPROTO_UNKNOWN && alproto != ALPROTO_FAILED && f->proto == IPPROTO_TCP) {
-        /* If we find a valid protocol :
+    uint8_t flags = STREAM_TOCLIENT;
+    if (data[0] & STREAM_TOSERVER) {
+        flags = STREAM_TOSERVER;
+    }
+    alproto = AppLayerProtoDetectGetProto(
+            alpd_tctx, f, data + HEADER_LEN, size - HEADER_LEN, f->proto, flags, &reverse);
+    if (alproto != ALPROTO_UNKNOWN && alproto != ALPROTO_FAILED && f->proto == IPPROTO_TCP &&
+            (data[0] & STREAM_MIDSTREAM) == 0) {
+        /* If we find a valid protocol at the start of a stream :
          * check that with smaller input
          * we find the same protocol or ALPROTO_UNKNOWN.
          * Otherwise, we have evasion with TCP splitting
          */
         for (size_t i = 0; i < size-HEADER_LEN && i < PROTO_DETECT_MAX_LEN; i++) {
-            alproto2 = AppLayerProtoDetectGetProto(alpd_tctx, f, data+HEADER_LEN, i, f->proto, data[0], &reverse);
+            alproto2 = AppLayerProtoDetectGetProto(
+                    alpd_tctx, f, data + HEADER_LEN, i, f->proto, flags, &reverse);
             if (alproto2 != ALPROTO_UNKNOWN && alproto2 != alproto) {
-                printf("Assertion failure : With input length %"PRIuMAX", found %s instead of %s\n", (uintmax_t) i, AppProtoToString(alproto2), AppProtoToString(alproto));
+                printf("Failed with input length %" PRIuMAX " versus %" PRIuMAX
+                       ", found %s instead of %s\n",
+                        (uintmax_t)i, (uintmax_t)size - HEADER_LEN, AppProtoToString(alproto2),
+                        AppProtoToString(alproto));
+                printf("Assertion failure: %s-%s\n", AppProtoToString(alproto2),
+                        AppProtoToString(alproto));
+                fflush(stdout);
                 abort();
             }
         }
     }
-    UTHFreeFlow(f);
+    FlowFree(f);
 
     return 0;
 }