]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
pgsql: extract length validation into function
authorJuliana Fajardini <jufajardini@oisf.net>
Thu, 30 Nov 2023 21:55:13 +0000 (18:55 -0300)
committerVictor Julien <victor@inliniac.net>
Fri, 15 Dec 2023 05:08:22 +0000 (06:08 +0100)
This is called so many times that it seems to make sense that we use a
function for this.

rust/src/pgsql/parser.rs

index bed3682bb43d8f54b2e654d8efbedb768698f3ad..1cfa19da17b1ebc15651bf3553f97c30b993733a 100644 (file)
@@ -37,6 +37,10 @@ pub const PGSQL_DUMMY_PROTO_MAJOR: u16 = 1234; // 0x04d2
 pub const PGSQL_DUMMY_PROTO_MINOR_SSL: u16 = 5679; //0x162f
 pub const _PGSQL_DUMMY_PROTO_MINOR_GSSAPI: u16 = 5680; // 0x1630
 
+fn parse_length(i: &[u8]) -> IResult<&[u8], u32> {
+    verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)
+}
+
 #[derive(Debug, PartialEq, Eq)]
 pub enum PgsqlParameters {
     // startup parameters
@@ -564,7 +568,7 @@ fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenti
 
 pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?;
     Ok((i, PgsqlFEMessage::SASLInitialResponse(
                 SASLInitialResponsePacket {
@@ -578,7 +582,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
 
 pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
     let resp = PgsqlFEMessage::SASLResponse(
         RegularPacket {
@@ -638,7 +642,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
 // Password can be encrypted or in cleartext
 pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
-    let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, password) = map_parser(
         take(length - PGSQL_LENGTH_FIELD),
         take_until1("\x00")
@@ -653,7 +657,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
 
 fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
     Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket {
         identifier,
@@ -664,7 +668,7 @@ fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
 
 fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
-    let (i, length) = verify(be_u32, |&x| x == PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length })))
 }
 
@@ -772,7 +776,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq
 
 fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
-    let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?;
     Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage {
         identifier,
@@ -803,7 +807,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
 
 fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> {
     let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
-    let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+    let (i, length) = parse_length(i)?;
     let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
     Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket {
         identifier,