]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
mpm: run engines as few times as possible
authorVictor Julien <victor@inliniac.net>
Mon, 10 Apr 2017 18:42:25 +0000 (20:42 +0200)
committerVictor Julien <victor@inliniac.net>
Fri, 21 Apr 2017 16:58:01 +0000 (18:58 +0200)
In various scenarios buffers would be checked my MPM more than
once. This was because the buffers would be inspected for a
certain progress value or higher.

For example, for each packet in a file upload, the engine would
not just rerun the 'http client body' MPM on the new data, it
would also rerun the method, uri, headers, cookie, etc MPMs.

This was obviously inefficent, so this patch changes the logic.

The patch only runs the MPM engines when the progress is exactly
the intended progress. If the progress is beyond the desired
value, it is run once. A tracker is added to the app layer API,
where the completed MPMs are tracked.

Implemented for HTTP, TLS and SSH.

src/app-layer-htp.c
src/app-layer-htp.h
src/app-layer-parser.c
src/app-layer-parser.h
src/app-layer-ssh.c
src/app-layer-ssh.h
src/app-layer-ssl.c
src/app-layer-ssl.h
src/detect-engine-prefilter.c

index 3588a9e88b3147936cf64f5d8984552f2aa4995a..18662870a761b61d7d8c0783bca673367faad9b1 100644 (file)
@@ -2718,6 +2718,28 @@ static int HTPSetTxDetectState(void *alstate, void *vtx, DetectEngineState *s)
     return 0;
 }
 
+static uint64_t HTPGetTxMpmIDs(void *vtx)
+{
+    htp_tx_t *tx = (htp_tx_t *)vtx;
+    HtpTxUserData *tx_ud = htp_tx_get_user_data(tx);
+    return tx_ud ? tx_ud->mpm_ids : 0;
+}
+
+static int HTPSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
+{
+    htp_tx_t *tx = (htp_tx_t *)vtx;
+    HtpTxUserData *tx_ud = htp_tx_get_user_data(tx);
+    if (tx_ud == NULL) {
+        tx_ud = HTPMalloc(sizeof(*tx_ud));
+        if (unlikely(tx_ud == NULL))
+            return -ENOMEM;
+        memset(tx_ud, 0, sizeof(*tx_ud));
+        htp_tx_set_user_data(tx, tx_ud);
+    }
+    tx_ud->mpm_ids = mpm_ids;
+    return 0;
+}
+
 static int HTPRegisterPatternsForProtocolDetection(void)
 {
     char *methods[] = { "GET", "PUT", "POST", "HEAD", "TRACE", "OPTIONS",
@@ -2806,6 +2828,8 @@ void RegisterHTPParsers(void)
         AppLayerParserRegisterDetectStateFuncs(IPPROTO_TCP, ALPROTO_HTTP,
                                                HTPStateHasTxDetectState,
                                                HTPGetTxDetectState, HTPSetTxDetectState);
+        AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_HTTP,
+                                               HTPGetTxMpmIDs, HTPSetTxMpmIDs);
 
         AppLayerParserRegisterParser(IPPROTO_TCP, ALPROTO_HTTP, STREAM_TOSERVER,
                                      HTPHandleRequestData);
index 0e3503c3e8ee6d2076ba0d66c3a1b714f7681e2b..b8f4f29b7045bb682330f14165259a9aeacc0277 100644 (file)
@@ -188,6 +188,9 @@ typedef struct HtpBody_ {
 /** Now the Body Chunks will be stored per transaction, at
   * the tx user data */
 typedef struct HtpTxUserData_ {
+    /** flags to track which mpm has run */
+    uint64_t mpm_ids;
+
     /* Body of the request (if any) */
     uint8_t request_body_init;
     uint8_t response_body_init;
@@ -228,7 +231,6 @@ typedef struct HtpTxUserData_ {
 } HtpTxUserData;
 
 typedef struct HtpState_ {
-
     /* Connection parser structure for each connection */
     htp_connp_t *connp;
     /* Connection structure for each connection */
index e69997d71fcd06d7109924405823c1148eb6eaae..44a03368cb36f57761abdbba1c8a30c6dec5ae7a 100644 (file)
@@ -116,6 +116,9 @@ typedef struct AppLayerParserProtoCtx_
     DetectEngineState *(*GetTxDetectState)(void *tx);
     int (*SetTxDetectState)(void *alstate, void *tx, DetectEngineState *);
 
+    uint64_t (*GetTxMpmIDs)(void *tx);
+    int (*SetTxMpmIDs)(void *tx, uint64_t);
+
     /* each app-layer has its own value */
     uint32_t stream_depth;
 
@@ -537,6 +540,18 @@ void AppLayerParserRegisterDetectStateFuncs(uint8_t ipproto, AppProto alproto,
     SCReturn;
 }
 
+void AppLayerParserRegisterMpmIDsFuncs(uint8_t ipproto, AppProto alproto,
+        uint64_t(*GetTxMpmIDs)(void *tx),
+        int (*SetTxMpmIDs)(void *tx, uint64_t))
+{
+    SCEnter();
+
+    alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs = GetTxMpmIDs;
+    alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs = SetTxMpmIDs;
+
+    SCReturn;
+}
+
 /***** Get and transaction functions *****/
 
 void *AppLayerParserGetProtocolParserLocalStorage(uint8_t ipproto, AppProto alproto)
@@ -929,6 +944,24 @@ int AppLayerParserSetTxDetectState(uint8_t ipproto, AppProto alproto,
     SCReturnInt(r);
 }
 
+uint64_t AppLayerParserGetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx)
+{
+    if (alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs != NULL) {
+        return alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].GetTxMpmIDs(tx);
+    }
+
+    return 0ULL;
+}
+
+int AppLayerParserSetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx, uint64_t mpm_ids)
+{
+    int r = 0;
+    if (alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs != NULL) {
+        r = alp_ctx.ctxs[FlowGetProtoMapping(ipproto)][alproto].SetTxMpmIDs(tx, mpm_ids);
+    }
+    SCReturnInt(r);
+}
+
 /***** General *****/
 
 int AppLayerParserParse(ThreadVars *tv, AppLayerParserThreadCtx *alp_tctx, Flow *f, AppProto alproto,
index c53cf781291385a9be1bbd3c8b450fcb194f98ac..e887e0008b0c9377a8f7563f7015c57258b61cae 100644 (file)
@@ -153,6 +153,9 @@ void AppLayerParserRegisterDetectStateFuncs(uint8_t ipproto, AppProto alproto,
 void AppLayerParserRegisterGetStreamDepth(uint8_t ipproto,
                                           AppProto alproto,
                                           uint32_t (*GetStreamDepth)(void));
+void AppLayerParserRegisterMpmIDsFuncs(uint8_t ipproto, AppProto alproto,
+        uint64_t (*GetTxMpmIDs)(void *tx),
+        int (*SetTxMpmIDs)(void *tx, uint64_t));
 
 /***** Get and transaction functions *****/
 
@@ -195,6 +198,9 @@ int AppLayerParserHasTxDetectState(uint8_t ipproto, AppProto alproto, void *alst
 DetectEngineState *AppLayerParserGetTxDetectState(uint8_t ipproto, AppProto alproto, void *tx);
 int AppLayerParserSetTxDetectState(uint8_t ipproto, AppProto alproto, void *alstate, void *tx, DetectEngineState *s);
 
+uint64_t AppLayerParserGetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx);
+int AppLayerParserSetTxMpmIDs(uint8_t ipproto, AppProto alproto, void *tx, uint64_t);
+
 /***** General *****/
 
 int AppLayerParserParse(ThreadVars *tv, AppLayerParserThreadCtx *tctx, Flow *f, AppProto alproto,
index c4e9c118dbcaaef91602ab6a05a41a7a579ca1de..b7a5b2be43ae5c11b4f0fbcf25209dee6377aa34 100644 (file)
@@ -557,6 +557,19 @@ static int SSHGetTxLogged(void *state, void *tx, uint32_t logger)
     return 0;
 }
 
+static uint64_t SSHGetTxMpmIDs(void *vtx)
+{
+    SshState *ssh_state = (SshState *)vtx;
+    return ssh_state->mpm_ids;
+}
+
+static int SSHSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
+{
+    SshState *ssh_state = (SshState *)vtx;
+    ssh_state->mpm_ids = mpm_ids;
+    return 0;
+}
+
 static int SSHGetAlstateProgressCompletionStatus(uint8_t direction)
 {
     return SSH_STATE_FINISHED;
@@ -632,6 +645,8 @@ void RegisterSSHParsers(void)
         AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_SSH, SSHGetAlstateProgress);
 
         AppLayerParserRegisterLoggerFuncs(IPPROTO_TCP, ALPROTO_SSH, SSHGetTxLogged, SSHSetTxLogged);
+        AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_SSH,
+                SSHGetTxMpmIDs, SSHSetTxMpmIDs);
 
         AppLayerParserRegisterGetStateProgressCompletionStatus(ALPROTO_SSH,
                                                                SSHGetAlstateProgressCompletionStatus);
index 7fa368812b95c8eb5871808945a337ca45a62cd0..d5b6d3a8a54d8e4cbff246aeff41231a5f3e0589 100644 (file)
@@ -76,6 +76,9 @@ typedef struct SshState_ {
     /* specifies which loggers are done logging */
     uint32_t logged;
 
+    /* bit flags of mpms that have already run */
+    uint64_t mpm_ids;
+
     DetectEngineState *de_state;
 } SshState;
 
index 47c3e9ef6025d3bf8a19a271161d96d3eb2a1db8..0215eed77d63501c8e3f980c34fed615f8fb1f15 100644 (file)
@@ -245,6 +245,19 @@ int SSLGetAlstateProgress(void *tx, uint8_t direction)
     return TLS_STATE_IN_PROGRESS;
 }
 
+static uint64_t SSLGetTxMpmIDs(void *vtx)
+{
+    SSLState *ssl_state = (SSLState *)vtx;
+    return ssl_state->mpm_ids;
+}
+
+static int SSLSetTxMpmIDs(void *vtx, uint64_t mpm_ids)
+{
+    SSLState *ssl_state = (SSLState *)vtx;
+    ssl_state->mpm_ids = mpm_ids;
+    return 0;
+}
+
 static int TLSDecodeHandshakeHello(SSLState *ssl_state, uint8_t *input,
                                    uint32_t input_len)
 {
@@ -1832,6 +1845,8 @@ void RegisterSSLParsers(void)
         AppLayerParserRegisterGetStateProgressFunc(IPPROTO_TCP, ALPROTO_TLS, SSLGetAlstateProgress);
 
         AppLayerParserRegisterLoggerFuncs(IPPROTO_TCP, ALPROTO_TLS, SSLGetTxLogged, SSLSetTxLogged);
+        AppLayerParserRegisterMpmIDsFuncs(IPPROTO_TCP, ALPROTO_TLS,
+                SSLGetTxMpmIDs, SSLSetTxMpmIDs);
 
         AppLayerParserRegisterGetStateProgressCompletionStatus(ALPROTO_TLS,
                                                                SSLGetAlstateProgressCompletionStatus);
index 1748803afb94106a142d3b7a1c7440a3774ed34f..d48835cc131890f4889284c8d2faa953f3a2f539 100644 (file)
@@ -191,6 +191,9 @@ typedef struct SSLState_ {
     /* specifies which loggers are done logging */
     uint32_t logged;
 
+    /* MPM/prefilter Id's */
+    uint64_t mpm_ids;
+
     /* there might be a better place to store this*/
     uint16_t hb_record_len;
 
index 5a1580fc05e22947b48c40fe287a19576c77983d..ed6ce400548875c9e448f49a625b1c9e4b060681 100644 (file)
@@ -118,6 +118,7 @@ static inline void PrefilterTx(DetectEngineThreadCtx *det_ctx,
         if (tx == NULL)
             continue;
 
+        uint64_t mpm_ids = AppLayerParserGetTxMpmIDs(ipproto, alproto, tx);
         const int tx_progress = AppLayerParserGetStateProgress(ipproto, alproto, tx, flags);
         SCLogDebug("tx %p progress %d", tx, tx_progress);
 
@@ -127,16 +128,30 @@ static inline void PrefilterTx(DetectEngineThreadCtx *det_ctx,
                 goto next;
             if (engine->tx_min_progress > tx_progress)
                 goto next;
+            if (tx_progress > engine->tx_min_progress) {
+                if (mpm_ids & (1<<(engine->gid))) {
+                    goto next;
+                }
+            }
 
             PROFILING_PREFILTER_START(p);
             engine->cb.PrefilterTx(det_ctx, engine->pectx,
                     p, p->flow, tx, idx, flags);
             PROFILING_PREFILTER_END(p, engine->gid);
+
+            if (tx_progress > engine->tx_min_progress) {
+                mpm_ids |= (1<<(engine->gid));
+            }
         next:
             if (engine->is_last)
                 break;
             engine++;
         } while (1);
+
+        if (mpm_ids != 0) {
+            //SCLogNotice("tx %p Mpm IDs: %"PRIx64, tx, mpm_ids);
+            AppLayerParserSetTxMpmIDs(ipproto, alproto, tx, mpm_ids);
+        }
     }
 }