]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
pgsql/parser: always use fn for parsing PDU length
authorJuliana Fajardini <jufajardini@gmail.com>
Mon, 17 Feb 2025 22:13:50 +0000 (19:13 -0300)
committerVictor Julien <victor@inliniac.net>
Wed, 19 Feb 2025 08:21:37 +0000 (09:21 +0100)
Some inner parsers were using it, some weren't. Better to standardize
this. Also take the time to avoid magic numbers for representing the
expected lengths for pgsql PDUs.
Also throwing PgsqlParseError and allowing for incomplete results.

Related to
Task #5566
Bug #5524

rust/src/pgsql/parser.rs

index c7c158761c05d8ea74da753db23f6861861b613b..94ddc676b950fe8055782c4aa1bb09d42db5b030 100644 (file)
@@ -55,9 +55,22 @@ impl<I> ParseError<I> for PgsqlParseError<I> {
     }
 }
 
-fn parse_length(i: &[u8]) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
+fn parse_gte_length(i: &[u8], expected_length: u32) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
     let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
-        x >= PGSQL_LENGTH_FIELD
+        x >= expected_length
+    })(i);
+    match res {
+        Ok(result) => Ok((result.0, result.1)),
+        Err(nom7::Err::Incomplete(needed)) => Err(Err::Incomplete(needed)),
+        Err(_) => Err(Err::Error(PgsqlParseError::InvalidLength)),
+    }
+}
+
+fn parse_exact_length(
+    i: &[u8], expected_length: u32,
+) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
+    let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
+        x == expected_length
     })(i);
     match res {
         Ok(result) => Ok((result.0, result.1)),
@@ -612,7 +625,7 @@ fn parse_sasl_initial_response_payload(
 
 pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, payload) = map_parser(
         take(length - PGSQL_LENGTH_FIELD),
         parse_sasl_initial_response_payload,
@@ -631,7 +644,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, P
 
 pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
     let resp = PgsqlFEMessage::SASLResponse(RegularPacket {
         identifier,
@@ -698,7 +711,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, Pg
 // Password can be encrypted or in cleartext
 pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
     Ok((
         i,
@@ -712,7 +725,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlP
 
 fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
     Ok((
         i,
@@ -735,7 +748,7 @@ fn parse_cancel_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseEr
 
 fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?;
     Ok((
         i,
         PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }),
@@ -751,7 +764,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError
         b'X' => parse_terminate_message(i)?,
         _ => {
             let (i, identifier) = be_u8(i)?;
-            let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
+            let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
             let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
             let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket {
                 identifier,
@@ -766,7 +779,7 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError
 
 fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage, PgsqlParseError<&'a [u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?;
-    let (i, length) = verify(be_u32, |&x| x >= 8)(i)?;
+    let (i, length) = parse_gte_length(i, 8)?;
     let (i, auth_type) = be_u32(i)?;
     let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| {
         match auth_type {
@@ -849,7 +862,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq
 
 fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, param) = map_parser(
         take(length - PGSQL_LENGTH_FIELD),
         pgsql_parse_generic_parameter,
@@ -874,7 +887,7 @@ pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse
 
 fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?;
-    let (i, length) = verify(be_u32, |&x| x == 12)(i)?;
+    let (i, length) = parse_exact_length(i, 12)?;
     let (i, pid) = be_u32(i)?;
     let (i, secret_key) = be_u32(i)?;
     Ok((
@@ -890,7 +903,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pg
 
 fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
-    let (i, length) = parse_length(i)?;
+    let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
     let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
     Ok((
         i,
@@ -904,7 +917,7 @@ fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParse
 
 fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?;
-    let (i, length) = verify(be_u32, |&x| x == 5)(i)?;
+    let (i, length) = parse_exact_length(i, 5)?;
     let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?;
     Ok((
         i,
@@ -941,7 +954,7 @@ fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField, PgsqlParseError<&[u8]>>
 
 pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > 6)(i)?;
+    let (i, length) = parse_gte_length(i, 7)?;
     let (i, field_count) = be_u16(i)?;
     let (i, fields) = map_parser(
         take(length - 6),
@@ -992,7 +1005,7 @@ fn add_up_data_size(columns: Vec<ColumnFieldValue>) -> u64 {
 // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages
 pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?;
-    let (i, length) = verify(be_u32, |&x| x >= 6)(i)?;
+    let (i, length) = parse_gte_length(i, 7)?;
     let (i, field_count) = be_u16(i)?;
     // 6 here is for skipping length + field_count
     let (i, rows) = map_parser(
@@ -1109,7 +1122,7 @@ pub fn parse_error_notice_fields(
 
 fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
+    let (i, length) = parse_gte_length(i, 11)?;
     let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
         parse_error_notice_fields(b, true)
     })(i)?;
@@ -1126,7 +1139,7 @@ fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlP
 
 fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
+    let (i, length) = parse_gte_length(i, 11)?;
     let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
         parse_error_notice_fields(b, false)
     })(i)?;
@@ -1143,7 +1156,7 @@ fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pgsql
 fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?;
     // length (u32) + pid (u32) + at least one byte, for we have two str fields
-    let (i, length) = verify(be_u32, |&x| x > 9)(i)?;
+    let (i, length) = parse_gte_length(i, 10)?;
     let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
         let (b, pid) = be_u32(b)?;
         let (b, channel_name) = take_until_and_consume(b"\x00")(b)?;
@@ -1175,7 +1188,7 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar
         b'D' => parse_consolidated_data_row(i)?,
         _ => {
             let (i, identifier) = be_u8(i)?;
-            let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+            let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
             let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
             let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket {
                 identifier,