From: Juliana Fajardini Date: Thu, 30 Nov 2023 21:55:13 +0000 (-0300) Subject: pgsql: extract length validation into function X-Git-Tag: suricata-8.0.0-beta1~1911 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7fa8bbfe43f396215238e7d8a2b7ce94a22560bc;p=thirdparty%2Fsuricata.git pgsql: extract length validation into function This is called so many times that it seems to make sense that we use a function for this. --- diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index bed3682bb4..1cfa19da17 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -37,6 +37,10 @@ pub const PGSQL_DUMMY_PROTO_MAJOR: u16 = 1234; // 0x04d2 pub const PGSQL_DUMMY_PROTO_MINOR_SSL: u16 = 5679; //0x162f pub const _PGSQL_DUMMY_PROTO_MINOR_GSSAPI: u16 = 5680; // 0x1630 +fn parse_length(i: &[u8]) -> IResult<&[u8], u32> { + verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i) +} + #[derive(Debug, PartialEq, Eq)] pub enum PgsqlParameters { // startup parameters @@ -564,7 +568,7 @@ fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenti pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?; Ok((i, PgsqlFEMessage::SASLInitialResponse( SASLInitialResponsePacket { @@ -578,7 +582,7 @@ pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; let resp = PgsqlFEMessage::SASLResponse( RegularPacket { @@ -638,7 +642,7 @@ pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { // Password can be encrypted or in cleartext pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; - let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, password) = map_parser( take(length - PGSQL_LENGTH_FIELD), take_until1("\x00") @@ -653,7 +657,7 @@ pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket { identifier, @@ -664,7 +668,7 @@ fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?; - let (i, length) = verify(be_u32, |&x| x == PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }))) } @@ -772,7 +776,7 @@ fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], Pgsq fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?; - let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?; Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { identifier, @@ -803,7 +807,7 @@ fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?; - let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, length) = parse_length(i)?; let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?; Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket { identifier,