]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
ssh: handle fragmented banner
authorVictor Julien <victor@inliniac.net>
Sat, 1 Mar 2014 15:50:07 +0000 (16:50 +0100)
committerVictor Julien <victor@inliniac.net>
Mon, 3 Mar 2014 16:34:57 +0000 (17:34 +0100)
Cleanups.

src/app-layer-ssh.c
src/app-layer-ssh.h

index 535f1c970a3cc1540952e7c7cdfb5389844573b3..dd58def5408f4abda1f1a2b2ec6d1c09bacffc53 100644 (file)
@@ -61,7 +61,7 @@
  *  \param  input       Pointer the received input data
  *  \param  input_len   Length in bytes of the received data
  */
-static int SSHParseVersion(SshState *state, SshHeader *header, const uint8_t *input, uint32_t input_len)
+static int SSHParseBanner(SshState *state, SshHeader *header, const uint8_t *input, uint32_t input_len)
 {
     const uint8_t *line_ptr = input;
     uint32_t line_len = input_len;
@@ -71,7 +71,7 @@ static int SSHParseVersion(SshState *state, SshHeader *header, const uint8_t *in
         SCReturnInt(-1);
     }
     if (line_len > 255) {
-        SCLogInfo("Invalid version string, it should be less than 255 characters including <CR><NL>");
+        SCLogDebug("Invalid version string, it should be less than 255 characters including <CR><NL>");
         SCReturnInt(-1);
     }
 
@@ -98,7 +98,7 @@ static int SSHParseVersion(SshState *state, SshHeader *header, const uint8_t *in
     line_ptr += proto_ver_len + 1;
     line_len -= proto_ver_len + 1;
     if (line_len < 1) {
-        SCLogInfo("No software version specified (weird)");
+        SCLogDebug("No software version specified (weird)");
         header->flags |= SSH_FLAG_VERSION_PARSED;
         /* Return the remaining length */
         SCReturnInt(0);
@@ -129,16 +129,17 @@ static int SSHParseVersion(SshState *state, SshHeader *header, const uint8_t *in
     SCReturnInt(len);
 }
 
-static int SSHParseClientRecordHeader(SshState *state, SshHeader *header, uint8_t *input, uint32_t input_len)
+static int SSHParseRecordHeader(SshState *state, SshHeader *header, uint8_t *input, uint32_t input_len)
 {
+#ifdef DEBUG
     BUG_ON(input_len != 6);
-
+#endif
     /* input and input_len now point past initial line */
     uint32_t pkt_len = 0;
     int r = ByteExtractUint32(&pkt_len, BYTE_BIG_ENDIAN,
             4, input);
     if (r != 4) {
-        SCLogInfo("xtract 4 bytes failed %d", r);
+        SCLogDebug("xtract 4 bytes failed %d", r);
         SCReturnInt(-1);
     }
     if (pkt_len < 2) {
@@ -146,7 +147,7 @@ static int SSHParseClientRecordHeader(SshState *state, SshHeader *header, uint8_
     }
 
     header->pkt_len = pkt_len;
-    SCLogInfo("pkt len: %"PRIu32, pkt_len);
+    SCLogDebug("pkt len: %"PRIu32, pkt_len);
 
     input += 4;
     input_len -= 4;
@@ -156,22 +157,22 @@ static int SSHParseClientRecordHeader(SshState *state, SshHeader *header, uint8_
     input += 1;
     input_len -= 1;
 
-    SCLogInfo("padding: %u", header->padding_len);
+    SCLogDebug("padding: %u", header->padding_len);
 
     header->msg_code = *input;
 
     input += 1;
     input_len -= 1;
 
-    SCLogInfo("msg code: %u", header->msg_code);
+    SCLogDebug("msg code: %u", header->msg_code);
 
     if (header->msg_code == SSH_MSG_NEWKEYS) {
         /* done */
-        SCLogInfo("done");
+        SCLogDebug("done");
         state->flags |= SSH_FLAG_PARSER_DONE;
     } else {
         /* not yet done */
-        SCLogInfo("not done");
+        SCLogDebug("not done");
     }
     SCReturnInt(0);
 }
@@ -197,15 +198,15 @@ static int SSHParseRecord(SshState *state, SshHeader *header, uint8_t *input, ui
 
     SCLogDebug("state %p, input %p,input_len %" PRIu32,
                state, input, input_len);
-    PrintRawDataFp(stdout, input, input_len);
+    //PrintRawDataFp(stdout, input, input_len);
 
     if (!(header->flags & SSH_FLAG_VERSION_PARSED)) {
-        ret = SSHParseVersion(state, header, input, input_len);
+        ret = SSHParseBanner(state, header, input, input_len);
         if (ret < 0) {
             SCLogDebug("Invalid version string");
             SCReturnInt(-1);
         } else if (header->flags & SSH_FLAG_VERSION_PARSED) {
-            SCLogInfo("Version string parsed, remaining length %d", ret);
+            SCLogDebug("Version string parsed, remaining length %d", ret);
             input += input_len - ret;
             input_len -= (input_len - ret);
             ret = 0;
@@ -214,7 +215,7 @@ static int SSHParseRecord(SshState *state, SshHeader *header, uint8_t *input, ui
             while (u < input_len && (input[u] == '\r' || input[u] == '\n')) {
                 u++;
             }
-            SCLogInfo("skipping %u EOL bytes", u);
+            SCLogDebug("skipping %u EOL bytes", u);
             input += u;
             input_len -= u;
 
@@ -232,19 +233,19 @@ static int SSHParseRecord(SshState *state, SshHeader *header, uint8_t *input, ui
     }
 
     /* skip bytes from the current record if we have to */
-    if (state->cli_hdr.record_left > 0) {
-        SCLogInfo("skipping bytes part of the current record");
-        if (state->cli_hdr.record_left > input_len) {
-            state->cli_hdr.record_left -= input_len;
-            SCLogInfo("all input skipped, %u left in record", state->cli_hdr.record_left);
+    if (header->record_left > 0) {
+        SCLogDebug("skipping bytes part of the current record");
+        if (header->record_left > input_len) {
+            header->record_left -= input_len;
+            SCLogDebug("all input skipped, %u left in record", header->record_left);
             SCReturnInt(0);
         } else {
-            input_len -= state->cli_hdr.record_left;
-            input += state->cli_hdr.record_left;
-            state->cli_hdr.record_left = 0;
+            input_len -= header->record_left;
+            input += header->record_left;
+            header->record_left = 0;
 
             if (input_len == 0) {
-                SCLogInfo("all input skipped");
+                SCLogDebug("all input skipped");
                 SCReturnInt(0);
             }
         }
@@ -252,33 +253,31 @@ static int SSHParseRecord(SshState *state, SshHeader *header, uint8_t *input, ui
 
 again:
     /* input is too small, even when combined with stored bytes */
-    if (state->cli_hdr.buf_offset + input_len < 6) {
-        memcpy(state->cli_hdr.buf + state->cli_hdr.buf_offset, input, input_len);
-        state->cli_hdr.buf_offset += input_len;
-
-        //PrintRawDataFp(stdout, state->cli_hdr.buf, state->cli_hdr.buf_offset);
+    if (header->buf_offset + input_len < 6) {
+        memcpy(header->buf + header->buf_offset, input, input_len);
+        header->buf_offset += input_len;
         SCReturnInt(0);
 
     /* we have enough bytes to parse 6 bytes, lets see if we have
      * previously stored some */
-    } else if (state->cli_hdr.buf_offset > 0) {
-        uint8_t needed = 6 - state->cli_hdr.buf_offset;
+    } else if (header->buf_offset > 0) {
+        uint8_t needed = 6 - header->buf_offset;
 
-        SCLogInfo("parse stored");
-        memcpy(state->cli_hdr.buf + state->cli_hdr.buf_offset, input, needed);
-        state->cli_hdr.buf_offset = 6;
+        SCLogDebug("parse stored");
+        memcpy(header->buf + header->buf_offset, input, needed);
+        header->buf_offset = 6;
 
         // parse the 6
-        if (SSHParseClientRecordHeader(state, header, state->cli_hdr.buf, 6) < 0)
+        if (SSHParseRecordHeader(state, header, header->buf, 6) < 0)
             SCReturnInt(-1);
-        state->cli_hdr.buf_offset = 0;
+        header->buf_offset = 0;
 
-        uint32_t record_left = state->cli_hdr.pkt_len - 2;
+        uint32_t record_left = header->pkt_len - 2;
         input_len -= needed;
         input += needed;
 
         if (record_left > input_len) {
-            state->cli_hdr.record_left = record_left - input_len;
+            header->record_left = record_left - input_len;
         } else {
             input_len -= record_left;
             if (input_len == 0)
@@ -293,19 +292,18 @@ again:
 
     /* nothing stored, lets parse this directly */
     } else {
-        // parse the 6
-        SCLogInfo("parse direct");
-        PrintRawDataFp(stdout, input, input_len);
-        if (SSHParseClientRecordHeader(state, header, input, 6) < 0)
+        SCLogDebug("parse direct");
+        //PrintRawDataFp(stdout, input, input_len);
+        if (SSHParseRecordHeader(state, header, input, 6) < 0)
             SCReturnInt(-1);
 
-        uint32_t record_left = state->cli_hdr.pkt_len - 2;
-        SCLogInfo("record left %u", record_left);
+        uint32_t record_left = header->pkt_len - 2;
+        SCLogDebug("record left %u", record_left);
         input_len -= 6;
         input += 6;
 
         if (record_left > input_len) {
-            state->cli_hdr.record_left = record_left - input_len;
+            header->record_left = record_left - input_len;
         } else {
             input_len -= record_left;
             if (input_len == 0)
@@ -343,42 +341,83 @@ static int EnoughData(uint8_t *input, uint32_t input_len)
     return FALSE;
 }
 
-static int SSHParseRequest(Flow *f, void *state, AppLayerParserState *pstate,
-                           uint8_t *input, uint32_t input_len,
-                           void *local_data)
+static int SSHParseData(SshState *state, SshHeader *header,
+                        uint8_t *input, uint32_t input_len)
 {
-    SshState *ssh_state = (SshState *)state;
+    /* we're looking for the banner */
+    if (!(header->flags & SSH_FLAG_VERSION_PARSED))
+    {
+        int banner_eol = EnoughData(input, input_len);
+
+        /* fast track normal case: no buffering */
+        if (header->banner_buffer == NULL && banner_eol)
+        {
+            SCLogDebug("enough data, parse now");
+            // parse now
+            int r = SSHParseRecord(state, header, input, input_len);
+            SCReturnInt(r);
+
+        /* banner EOL with existing buffer present. Time for magic. */
+        } else if (banner_eol) {
+            SCLogDebug("banner EOL with existing buffer");
+
+            uint32_t tocopy = 256 - header->banner_len;
+            if (tocopy > input_len)
+                tocopy = input_len;
+
+            memcpy(header->banner_buffer + header->banner_len, input, tocopy);
+            header->banner_len += tocopy;
+
+            int r = SSHParseRecord(state, header,
+                    header->banner_buffer, header->banner_len);
+            SCReturnInt(r);
+
+        /* no banner EOL, so we need to buffer */
+        } else if (!banner_eol) {
+            if (header->banner_buffer == NULL) {
+                header->banner_buffer = SCMalloc(256);
+                if (header->banner_buffer == NULL)
+                    SCReturnInt(-1);
+            }
 
-    if (ssh_state->cli_hdr.flags & SSH_FLAG_VERSION_PARSED || EnoughData(input, input_len) == TRUE) {
-        SCLogInfo("enough data, parse now");
-        // parse now
-        int r = SSHParseRecord(ssh_state, &ssh_state->cli_hdr, input, input_len);
-        SCReturnInt(r);
+            uint32_t tocopy = 256 - header->banner_len;
+            if (tocopy > input_len)
+                tocopy = input_len;
+
+            memcpy(header->banner_buffer + header->banner_len, input, tocopy);
+            header->banner_len += tocopy;
+        }
+
+    /* we have a banner, the rest is just records */
     } else {
-        // buffer
+        int r = SSHParseRecord(state, header, input, input_len);
+        SCReturnInt(r);
     }
 
-    PrintRawDataFp(stdout, input, input_len);
+    //PrintRawDataFp(stdout, input, input_len);
     return 0;
 }
 
+static int SSHParseRequest(Flow *f, void *state, AppLayerParserState *pstate,
+                           uint8_t *input, uint32_t input_len,
+                           void *local_data)
+{
+    SshState *ssh_state = (SshState *)state;
+    SshHeader *ssh_header = &ssh_state->cli_hdr;
+
+    int r = SSHParseData(ssh_state, ssh_header, input, input_len);
+    SCReturnInt(r);
+}
+
 static int SSHParseResponse(Flow *f, void *state, AppLayerParserState *pstate,
                             uint8_t *input, uint32_t input_len,
                             void *local_data)
 {
     SshState *ssh_state = (SshState *)state;
+    SshHeader *ssh_header = &ssh_state->srv_hdr;
 
-    if (ssh_state->srv_hdr.flags & SSH_FLAG_VERSION_PARSED || EnoughData(input, input_len) == TRUE) {
-        SCLogInfo("enough data, parse now");
-        // parse now
-        int r = SSHParseRecord(ssh_state, &ssh_state->srv_hdr, input, input_len);
-        SCReturnInt(r);
-    } else {
-        // buffer
-    }
-
-    PrintRawDataFp(stdout, input, input_len);
-    return 0;
+    int r = SSHParseData(ssh_state, ssh_header, input, input_len);
+    SCReturnInt(r);
 }
 
 /** \brief Function to allocates the SSH state memory
@@ -402,10 +441,15 @@ static void SSHStateFree(void *state)
         SCFree(s->cli_hdr.proto_version);
     if (s->cli_hdr.software_version != NULL)
         SCFree(s->cli_hdr.software_version);
+    if (s->cli_hdr.banner_buffer != NULL)
+        SCFree(s->cli_hdr.banner_buffer);
+
     if (s->srv_hdr.proto_version != NULL)
         SCFree(s->srv_hdr.proto_version);
     if (s->srv_hdr.software_version != NULL)
         SCFree(s->srv_hdr.software_version);
+    if (s->srv_hdr.banner_buffer != NULL)
+        SCFree(s->srv_hdr.banner_buffer);
 
     SCFree(s);
 }
@@ -1719,9 +1763,9 @@ static int SSHParserTest17(void) {
     uint32_t sshlen1 = sizeof(sshbuf1) - 1;
     uint8_t sshbuf2[] = "2.0-MySSHClient-0.5.1\r\n";
     uint32_t sshlen2 = sizeof(sshbuf2) - 1;
-    uint8_t sshbuf3[] = { 0x00, 0x00, 0x00, 0x03,0x01, 17, 0x00};
+    uint8_t sshbuf3[] = { 0x00, 0x00, 0x00, 0x03, 0x01, 17, 0x00};
     uint32_t sshlen3 = sizeof(sshbuf3);
-    uint8_t sshbuf4[] = { 0x00, 0x00, 0x00, 0x03,0x01, 21, 0x00};
+    uint8_t sshbuf4[] = { 0x00, 0x00, 0x00, 0x03, 0x01, 21, 0x00};
     uint32_t sshlen4 = sizeof(sshbuf4);
     TcpSession ssn;
     AppLayerParserThreadCtx *alp_tctx = AppLayerParserThreadCtxAlloc();
index 05a7f49ebcf71660cfc5d1883c5446cdc26f0221..95ae8384a56eaf920bbe6fbbe1be16d1b4b6b74b 100644 (file)
 #ifndef __APP_LAYER_SSH_H__
 #define __APP_LAYER_SSH_H__
 
-#define SSH_FLAG_SERVER_CHANGE_CIPHER_SPEC   0x01    /**< Flag to indicate that
-                                                     server will now on sends
-                                                     encrypted msgs. */
-#define SSH_FLAG_CLIENT_CHANGE_CIPHER_SPEC   0x02    /**< Flag to indicate that
-                                                     client will now on sends
-                                                     encrypted msgs. */
-
-#define SSH_FLAG_VERSION_PARSED       0x08
+/* header flag */
+#define SSH_FLAG_VERSION_PARSED              0x01
 
 /* This flags indicate that the rest of the communication
  * must be ciphered, so the parsing finish here */
 #define SSH_FLAG_PARSER_DONE                 0x04
 
 /* MSG_CODE */
-#define SSH_MSG_NEWKEYS             21
+#define SSH_MSG_NEWKEYS                      21
 
 /** From SSH-TRANSP rfc
 
@@ -64,6 +58,8 @@ typedef struct SshHeader_ {
     uint32_t record_left;
     uint8_t *proto_version;
     uint8_t *software_version;
+    uint8_t *banner_buffer;
+    uint16_t banner_len;
 } SshHeader;
 
 /** structure to store the SSH state values */