From: Pierre Chifflier Date: Fri, 21 Jan 2022 12:37:54 +0000 (+0100) Subject: rust/pgsql: convert parsers to nom7 functions X-Git-Tag: suricata-7.0.0-beta1~994 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=ce9efc34c7958c96b580228e8ae48baf1137266f;p=thirdparty%2Fsuricata.git rust/pgsql: convert parsers to nom7 functions --- diff --git a/rust/src/pgsql/parser.rs b/rust/src/pgsql/parser.rs index a84d9535e1..5e041a7d68 100644 --- a/rust/src/pgsql/parser.rs +++ b/rust/src/pgsql/parser.rs @@ -19,11 +19,17 @@ //! PostgreSQL nom parsers -use nom::character::streaming::alphanumeric1; -use nom::combinator::{peek, rest}; -use nom::number::streaming::{be_i16, be_i32}; -use nom::number::streaming::{be_u16, be_u32, be_u8}; -use nom::IResult; +use crate::common::nom7::take_until_and_consume; +use nom7::branch::alt; +use nom7::bytes::streaming::{tag, tag_no_case, take, take_until, take_until1}; +use nom7::character::streaming::{alphanumeric1, char}; +use nom7::combinator::{all_consuming, cond, eof, map_parser, opt, peek, rest, verify}; +use nom7::error::{make_error, ErrorKind}; +use nom7::multi::{many1, many_m_n, many_till}; +use nom7::number::streaming::{be_i16, be_i32}; +use nom7::number::streaming::{be_u16, be_u32, be_u8}; +use nom7::sequence::{terminated, tuple}; +use nom7::{Err, IResult}; pub const PGSQL_LENGTH_FIELD: u32 = 4; @@ -518,33 +524,27 @@ impl From for PgsqlErrorNoticeFieldType { } } -named!( - parse_user_param, - do_parse!( - param_name: tag_no_case!("user") - >> tag!("\x00") - >> param_value: take_until1!("\x00") - >> tag!("\x00") - >> (PgsqlParameter { - name: PgsqlParameters::from(param_name), - value: param_value.to_vec(), - }) - ) -); - -named!( - parse_database_param, - do_parse!( - param_name: tag_no_case!("database") - >> tag!("\x00") - >> param_value: take_until1!("\x00") - >> tag!("\x00") - >> (PgsqlParameter { - name: PgsqlParameters::from(param_name), - value: param_value.to_vec(), - }) - ) -); +fn parse_user_param(i: &[u8]) -> IResult<&[u8], PgsqlParameter> { + let (i, param_name) = tag_no_case("user")(i)?; + let (i, _) = tag("\x00")(i)?; + let (i, param_value) = take_until1("\x00")(i)?; + let (i, _) = tag("\x00")(i)?; + Ok((i, PgsqlParameter { + name: PgsqlParameters::from(param_name), + value: param_value.to_vec(), + })) +} + +fn parse_database_param(i: &[u8]) -> IResult<&[u8], PgsqlParameter> { + let (i, param_name) = tag_no_case("database")(i)?; + let (i, _) = tag("\x00")(i)?; + let (i, param_value) = take_until1("\x00")(i)?; + let (i, _) = tag("\x00")(i)?; + Ok((i, PgsqlParameter { + name: PgsqlParameters::from(param_name), + value: param_value.to_vec(), + })) +} // Currently the set of parameters that could trigger a ParameterStatus message is fixed: // server_version @@ -562,102 +562,97 @@ named!( // standard_conforming_strings // (source: PostgreSQL documentation) // We may be interested, then, in controling this, somehow, to prevent weird things? -named!( - pgsql_parse_generic_parameter, - do_parse!( - param_name: take_until1!("\x00") - >> tag!("\x00") - >> param_value: take_until1!("\x00") - >> tag!("\x00") - >> (PgsqlParameter { - name: PgsqlParameters::from(param_name), - value: param_value.to_vec(), - }) - ) -); - -named!(pub pgsql_parse_startup_parameters, -do_parse!( - user: call!(parse_user_param) - >> database: opt!(parse_database_param) - >> optional: opt!(terminated!(many1!(pgsql_parse_generic_parameter), tag!("\x00"))) - >> (PgsqlStartupParameters{ - user, - database, - optional_params: optional, - }) -)); - -named!( - parse_sasl_initial_response_payload<(SASLAuthenticationMechanism, u32, Vec)>, - do_parse!( - sasl_mechanism: call!(parse_sasl_mechanism) - >> param_length: be_u32 - // From RFC 5802 - the client-first-message will always start w/ - // 'n', 'y' or 'p', otherwise it's invalid, I think we should check that, at some point - >> param: terminated!(take!(param_length), eof!()) - >> ((sasl_mechanism, param_length, param.to_vec())) - ) -); - -named!(pub parse_sasl_initial_response, -do_parse!( - identifier: verify!(be_u8, |&x| x == b'p') - >> length: verify!(be_u32, |&x| x > PGSQL_LENGTH_FIELD) - >> payload: flat_map!(take!(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload) - >> (PgsqlFEMessage::SASLInitialResponse( - SASLInitialResponsePacket { - identifier, - length, - auth_mechanism: payload.0, - param_length: payload.1, - sasl_param: payload.2, +fn pgsql_parse_generic_parameter(i: &[u8]) -> IResult<&[u8], PgsqlParameter> { + let (i, param_name) = take_until1("\x00")(i)?; + let (i, _) = tag("\x00")(i)?; + let (i, param_value) = take_until1("\x00")(i)?; + let (i, _) = tag("\x00")(i)?; + Ok((i, PgsqlParameter { + name: PgsqlParameters::from(param_name), + value: param_value.to_vec(), })) -)); - -named!(pub parse_sasl_response, -do_parse!( - identifier: verify!(be_u8, |&x| x == b'p') - >> length: verify!(be_u32, |&x| x > PGSQL_LENGTH_FIELD) - >> payload: take!(length - PGSQL_LENGTH_FIELD) - >> (PgsqlFEMessage::SASLResponse( - RegularPacket { - identifier, - length, - payload: payload.to_vec(), +} + +pub fn pgsql_parse_startup_parameters(i: &[u8]) -> IResult<&[u8], PgsqlStartupParameters> { + let (i, user) = parse_user_param(i)?; + let (i, database) = opt(parse_database_param)(i)?; + let (i, optional) = opt(terminated(many1(pgsql_parse_generic_parameter), tag("\x00")))(i)?; + Ok((i, PgsqlStartupParameters{ + user, + database, + optional_params: optional, })) -)); - -named!(pub pgsql_parse_startup_packet, -do_parse!( - len: verify!(be_u32, |&x| x >= 8) - >> proto_major: peek!(be_u16) - >> message: flat_map!(take!(len - PGSQL_LENGTH_FIELD), - switch!(value!(proto_major), - 1 | 2 | 3 => do_parse!( - proto_major: be_u16 - >> proto_minor: be_u16 - >> params: call!(pgsql_parse_startup_parameters) - >> (PgsqlFEMessage::StartupMessage(StartupPacket{ - length: len, - proto_major, - proto_minor, - params}))) | - PGSQL_DUMMY_PROTO_MAJOR => do_parse!( - proto_major: be_u16 - >> proto_minor: exact!(be_u16) - >> _message: switch!(value!(proto_minor), - PGSQL_DUMMY_PROTO_MINOR_SSL => tuple!( - value!(len), - value!(proto_major), - value!(proto_minor))) - >> (PgsqlFEMessage::SSLRequest(DummyStartupPacket{ - length: len, - proto_major, - proto_minor}))) - )) - >> (message) -)); +} + +fn parse_sasl_initial_response_payload(i: &[u8]) -> IResult<&[u8], (SASLAuthenticationMechanism, u32, Vec)> { + let (i, sasl_mechanism) = parse_sasl_mechanism(i)?; + let (i, param_length) = be_u32(i)?; + // From RFC 5802 - the client-first-message will always start w/ + // 'n', 'y' or 'p', otherwise it's invalid, I think we should check that, at some point + let (i, param) = terminated(take(param_length), eof)(i)?; + Ok((i, (sasl_mechanism, param_length, param.to_vec()))) +} + +pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; + let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), parse_sasl_initial_response_payload)(i)?; + Ok((i, PgsqlFEMessage::SASLInitialResponse( + SASLInitialResponsePacket { + identifier, + length, + auth_mechanism: payload.0, + param_length: payload.1, + sasl_param: payload.2, + }))) +} + +pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; + let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?; + let resp = PgsqlFEMessage::SASLResponse( + RegularPacket { + identifier, + length, + payload: payload.to_vec(), + }); + Ok((i, resp)) +} + +pub fn pgsql_parse_startup_packet(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, len) = verify(be_u32, |&x| x >= 8)(i)?; + let (i, proto_major) = peek(be_u16)(i)?; + let (i, b) = take(len - PGSQL_LENGTH_FIELD)(i)?; + let (_, message) = + match proto_major { + 1 | 2 | 3 => { + let (b, proto_major) = be_u16(b)?; + let (b, proto_minor) = be_u16(b)?; + let (b, params) = pgsql_parse_startup_parameters(b)?; + (b, PgsqlFEMessage::StartupMessage(StartupPacket{ + length: len, + proto_major, + proto_minor, + params})) + }, + PGSQL_DUMMY_PROTO_MAJOR => { + let (b, proto_major) = be_u16(b)?; + let (b, proto_minor) = all_consuming(be_u16)(b)?; + let _message = match proto_minor { + PGSQL_DUMMY_PROTO_MINOR_SSL => (len, proto_major, proto_minor), + _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), + }; + + (b, PgsqlFEMessage::SSLRequest(DummyStartupPacket{ + length: len, + proto_major, + proto_minor})) + } + _ => return Err(Err::Error(make_error(b, ErrorKind::Switch))), + }; + Ok((i, message)) +} // TODO Decide if it's a good idea to offer GSS encryption support right now, as the documentation seems to have conflicting information... // If we do: @@ -672,261 +667,233 @@ do_parse!( // Source: https://www.postgresql.org/docs/13/protocol-flow.html#id-1.10.5.7.11, GSSAPI Session Encryption // Password can be encrypted or in cleartext -named!(pub parse_password_message, -do_parse!( - identifier: verify!(be_u8, |&x| x == b'p') - >> length: verify!(be_u32, |&x| x >= PGSQL_LENGTH_FIELD) - >> password: flat_map!(take!(length - PGSQL_LENGTH_FIELD), take_until1!("\x00")) - >> (PgsqlFEMessage::PasswordMessage( +pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?; + let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, password) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + take_until1("\x00") + )(i)?; + Ok((i, PgsqlFEMessage::PasswordMessage( RegularPacket{ identifier, length, payload: password.to_vec(), - })) -)); - -named!( - parse_simple_query, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'Q') - >> length: verify!(be_u32, |&x| x > PGSQL_LENGTH_FIELD) - >> query: flat_map!(take!(length - PGSQL_LENGTH_FIELD), take_until1!("\x00")) - >> (PgsqlFEMessage::SimpleQuery(RegularPacket { - identifier, - length, - payload: query.to_vec(), - })) - ) -); - -named!( - parse_terminate_message, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'X') - >> length: verify!(be_u32, |&x| x == PGSQL_LENGTH_FIELD) - >> (PgsqlFEMessage::Terminate(TerminationMessage { identifier, length })) - ) -); + }))) +} + +fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?; + let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?; + Ok((i, PgsqlFEMessage::SimpleQuery(RegularPacket { + identifier, + length, + payload: query.to_vec(), + }))) +} + +fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?; + let (i, length) = verify(be_u32, |&x| x == PGSQL_LENGTH_FIELD)(i)?; + Ok((i, PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }))) +} // Messages that begin with 'p' but are not password ones are not parsed here -named!(pub parse_request, -do_parse!( - tag: peek!(be_u8) - >> message: switch!(value!(tag), - b'\0' => call!(pgsql_parse_startup_packet) | - b'Q' => dbg_dmp!(call!(parse_simple_query)) | - b'X' => dbg_dmp!(call!(parse_terminate_message))) - >> (message) -)); - -named!( - pgsql_parse_authentication_message, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'R') - >> length: verify!(be_u32, |&x| x >= 8) - >> auth_type: be_u32 - >> payload: peek!(rest) - >> message: - flat_map!( - take!(length - 8), - switch!(value!(auth_type), - 0 => value!(PgsqlBEMessage::AuthenticationOk( +pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage> { + let (i, tag) = peek(be_u8)(i)?; + let (i, message) = match tag { + b'\0' => pgsql_parse_startup_packet(i)?, + b'Q' => parse_simple_query(i)?, + b'X' => parse_terminate_message(i)?, + _ => return Err(Err::Error(make_error(i, ErrorKind::Switch))), + }; + Ok((i, message)) +} + +fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?; + let (i, length) = verify(be_u32, |&x| x >= 8)(i)?; + let (i, auth_type) = be_u32(i)?; + let (i, payload) = peek(rest)(i)?; + let (i, message) = map_parser( + take(length - 8), + |b: &'a [u8]| { + match auth_type { + 0 => Ok((b, PgsqlBEMessage::AuthenticationOk( + AuthenticationMessage { + identifier, + length, + auth_type, + payload: payload.to_vec(), + }))), + 3 => Ok((b, PgsqlBEMessage::AuthenticationCleartextPassword( + AuthenticationMessage { + identifier, + length, + auth_type, + payload: payload.to_vec(), + }))), + 5 => { + let (b, salt) = all_consuming(take(4_usize))(b)?; + Ok((b, PgsqlBEMessage::AuthenticationMD5Password( AuthenticationMessage { identifier, length, auth_type, - payload: payload.to_vec(), - })) | - 3 => value!(PgsqlBEMessage::AuthenticationCleartextPassword( - AuthenticationMessage { + payload: salt.to_vec(), + }))) + } + 9 => Ok((b, PgsqlBEMessage::AuthenticationSSPI( + AuthenticationMessage { + identifier, + length, + auth_type, + payload: payload.to_vec(), + }))), + // TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs ) + 10 => { + let (b, auth_mechanisms) = parse_sasl_mechanisms(b)?; + Ok((b, PgsqlBEMessage::AuthenticationSASL( + AuthenticationSASLMechanismMessage { identifier, length, auth_type, - payload: payload.to_vec(), - })) | - 5 => do_parse!( - salt: exact!(take!(4)) - >> (PgsqlBEMessage::AuthenticationMD5Password( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: salt.to_vec(), - })) - ) | - 9 => value!(PgsqlBEMessage::AuthenticationSSPI( - AuthenticationMessage { - identifier, - length, - auth_type, - payload: payload.to_vec(), - })) | - // TODO - For SASL, should we parse specific details of the challenge itself? (as seen in: https://github.com/launchbadge/sqlx/blob/master/sqlx-core/src/postgres/message/authentication.rs ) - 10 => do_parse!( - auth_mechanisms: call!(parse_sasl_mechanisms) - >> (PgsqlBEMessage::AuthenticationSASL( - AuthenticationSASLMechanismMessage { - identifier, - length, - auth_type, - auth_mechanisms, - })) - ) | - 11 => do_parse!( - sasl_challenge: rest - >> (PgsqlBEMessage::AuthenticationSASLContinue( + auth_mechanisms, + }))) + } + 11 => { + let (b, sasl_challenge) = rest(i)?; + Ok((b, PgsqlBEMessage::AuthenticationSASLContinue( AuthenticationMessage { identifier, length, auth_type, payload: sasl_challenge.to_vec(), - })) - ) | - 12 => do_parse!( - signature: rest - >> (PgsqlBEMessage::AuthenticationSASLFinal( + }))) + }, + 12 => { + let (i, signature) = rest(i)?; + Ok((i, PgsqlBEMessage::AuthenticationSASLFinal( AuthenticationMessage { identifier, length, auth_type, payload: signature.to_vec(), } - )) - ) - // TODO add other authentication messages - ) - ) - >> (message) - ) -); - -named!( - parse_parameter_status_message, - dbg_dmp!(do_parse!( - identifier: verify!(be_u8, |&x| x == b'S') - >> length: verify!(be_u32, |&x| x >= PGSQL_LENGTH_FIELD) - >> param: - flat_map!( - take!(length - PGSQL_LENGTH_FIELD), - pgsql_parse_generic_parameter - ) - >> (PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { - identifier, - length, - param, - })) - )) -); - -named!(pub parse_ssl_response, -do_parse!( - tag: alt!(char!('N') | char!('S')) - >> (PgsqlBEMessage::SSLResponse( - SSLResponseMessage::from(tag)) - ) -)); - -named!( - parse_backend_key_data_message, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'K') - >> length: verify!(be_u32, |&x| x == 12) - >> pid: be_u32 - >> secret_key: be_u32 - >> (PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage { - identifier, - length, - backend_pid: pid, - secret_key, - })) - ) -); - -named!( - parse_command_complete, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'C') - >> length: verify!(be_u32, |&x| x > PGSQL_LENGTH_FIELD) - >> payload: flat_map!(take!(length - PGSQL_LENGTH_FIELD), take_until!("\x00")) - >> (PgsqlBEMessage::CommandComplete(RegularPacket { - identifier, - length, - payload: payload.to_vec(), - })) - ) -); - -named!( - parse_ready_for_query, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'Z') - >> length: verify!(be_u32, |&x| x == 5) - >> status: verify!(be_u8, |&x| x == b'I' || x == b'T' || x == b'E') - >> (PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage { - identifier, - length, - transaction_status: status, - })) - ) -); - -named!( - parse_row_field, - do_parse!( - field_name: take_until1!("\x00") - >> tag!("\x00") - >> table_oid: be_u32 - >> column_index: be_u16 - >> data_type_oid: be_u32 - >> data_type_size: be_i16 - >> type_modifier: be_i32 - >> format_code: be_u16 - >> (RowField { - field_name: field_name.to_vec(), - table_oid, - column_index, - data_type_oid, - data_type_size, - type_modifier, - format_code, - }) - ) -); - -named!(pub parse_row_description, -do_parse!( - identifier: dbg_dmp!(verify!(be_u8, |&x| x == b'T')) - >> length: verify!(be_u32, |&x| x > 6) - >> field_count: dbg_dmp!(be_u16) - >> fields: flat_map!( - take!(length - 6), - many_m_n!(0, field_count.into(), - call!(parse_row_field))) - >> (PgsqlBEMessage::RowDescription( - RowDescriptionMessage { - identifier, - length, - field_count, - fields, + ))) + } + // TODO add other authentication messages + _ => return Err(Err::Error(make_error(i, ErrorKind::Switch))), + } + } + )(i)?; + Ok((i, message)) +} + +fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?; + let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?; + let (i, param) = map_parser(take(length - PGSQL_LENGTH_FIELD), pgsql_parse_generic_parameter)(i)?; + Ok((i, PgsqlBEMessage::ParameterStatus(ParameterStatusMessage { + identifier, + length, + param, + }))) +} + +pub fn parse_ssl_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, tag) = alt((char('N'), char('S')))(i)?; + Ok((i, PgsqlBEMessage::SSLResponse( + SSLResponseMessage::from(tag)) + )) +} + +fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?; + let (i, length) = verify(be_u32, |&x| x == 12)(i)?; + let (i, pid) = be_u32(i)?; + let (i, secret_key) = be_u32(i)?; + Ok((i, PgsqlBEMessage::BackendKeyData(BackendKeyDataMessage { + identifier, + length, + backend_pid: pid, + secret_key, + }))) +} + +fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?; + let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?; + let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?; + Ok((i, PgsqlBEMessage::CommandComplete(RegularPacket { + identifier, + length, + payload: payload.to_vec(), + }))) +} + +fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?; + let (i, length) = verify(be_u32, |&x| x == 5)(i)?; + let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?; + Ok((i, PgsqlBEMessage::ReadyForQuery(ReadyForQueryMessage { + identifier, + length, + transaction_status: status, + }))) +} + +fn parse_row_field(i: &[u8]) -> IResult<&[u8], RowField> { + let (i, field_name) = take_until1("\x00")(i)?; + let (i, _) = tag("\x00")(i)?; + let (i, table_oid) = be_u32(i)?; + let (i, column_index) = be_u16(i)?; + let (i, data_type_oid) = be_u32(i)?; + let (i, data_type_size) = be_i16(i)?; + let (i, type_modifier) = be_i32(i)?; + let (i, format_code) = be_u16(i)?; + Ok((i, RowField { + field_name: field_name.to_vec(), + table_oid, + column_index, + data_type_oid, + data_type_size, + type_modifier, + format_code, })) -)); - -named!( - parse_data_row_value, - do_parse!( - value_length: be_i32 - >> value: cond!(value_length >= 0, take!(value_length)) - >> (ColumnFieldValue { - value_length, - value: { - match value { - Some(data) => data.to_vec(), - None => [].to_vec(), - } - }, - }) - ) -); +} + +pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?; + let (i, length) = verify(be_u32, |&x| x > 6)(i)?; + let (i, field_count) = be_u16(i)?; + let (i, fields) = map_parser( + take(length - 6), + many_m_n(0, field_count.into(), parse_row_field) + )(i)?; + Ok((i, PgsqlBEMessage::RowDescription( + RowDescriptionMessage { + identifier, + length, + field_count, + fields, + }))) +} + +fn parse_data_row_value(i: &[u8]) -> IResult<&[u8], ColumnFieldValue> { + let (i, value_length) = be_i32(i)?; + let (i, value) = cond(value_length >= 0, take(value_length as usize))(i)?; + Ok((i, ColumnFieldValue { + value_length, + value: { + match value { + Some(data) => data.to_vec(), + None => [].to_vec(), + } + }, + })) +} /// For each column, add up the data size. Return the total fn add_up_data_size(columns: Vec) -> u64 { @@ -943,79 +910,75 @@ fn add_up_data_size(columns: Vec) -> u64 { // Currently, we don't store the actual DataRow messages, as those could easily become a burden, memory-wise // We use ConsolidatedDataRow to store info we still want to log: message size. // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages -named!(pub parse_consolidated_data_row, -do_parse!( - identifier: verify!(be_u8, |&x| x == b'D') - >> length: verify!(be_u32, |&x| x >= 6) - >> field_count: be_u16 +pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?; + let (i, length) = verify(be_u32, |&x| x >= 6)(i)?; + let (i, field_count) = be_u16(i)?; // 6 here is for skipping length + field_count - >> rows: flat_map!(take!(length - 6), many_m_n!(0, field_count.into(), call!(parse_data_row_value))) - >> (PgsqlBEMessage::ConsolidatedDataRow( - ConsolidatedDataRowPacket { - identifier, - length, - row_cnt: 1, - data_size: add_up_data_size(rows), - } - )) -)); - -named!( - parse_sasl_mechanism, - do_parse!( - mechanism: - alt!( - terminated!(tag!("SCRAM-SHA-256-PLUS"), tag!("\x00")) => { |_| SASLAuthenticationMechanism::ScramSha256Plus} | - terminated!(tag!("SCRAM-SHA-256"), tag!("\x00")) => { |_| SASLAuthenticationMechanism::ScramSha256} - ) - >> (mechanism) - ) -); - -named!( - parse_sasl_mechanisms>, - terminated!(many1!(parse_sasl_mechanism), tag!("\x00")) -); - -named!(pub parse_error_response_code, -do_parse!( - _field_type: char!('C') - >> field_value: flat_map!(take!(6), call!(alphanumeric1)) - >> (PgsqlErrorNoticeMessageField{ + let (i, rows) = map_parser(take(length - 6), many_m_n(0, field_count.into(), parse_data_row_value))(i)?; + Ok((i, PgsqlBEMessage::ConsolidatedDataRow( + ConsolidatedDataRowPacket { + identifier, + length, + row_cnt: 1, + data_size: add_up_data_size(rows), + } + ))) +} + +fn parse_sasl_mechanism(i: &[u8]) -> IResult<&[u8], SASLAuthenticationMechanism> { + let res: IResult<_, _, ()> = terminated(tag("SCRAM-SHA-256-PLUS"), tag("\x00"))(i); + if let Ok((i, _)) = res { + return Ok((i, SASLAuthenticationMechanism::ScramSha256Plus)); + } + let res: IResult<_, _, ()> = terminated(tag("SCRAM-SHA-256"), tag("\x00"))(i); + if let Ok((i, _)) = res { + return Ok((i, SASLAuthenticationMechanism::ScramSha256)); + } + return Err(Err::Error(make_error(i, ErrorKind::Alt))); +} + +fn parse_sasl_mechanisms(i: &[u8]) -> IResult<&[u8], Vec> { + terminated(many1(parse_sasl_mechanism), tag("\x00"))(i) +} + +pub fn parse_error_response_code(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> { + let (i, _field_type) = char('C')(i)?; + let (i, field_value) = map_parser(take(6_usize), alphanumeric1)(i)?; + Ok((i, PgsqlErrorNoticeMessageField{ field_type: PgsqlErrorNoticeFieldType::CodeSqlStateCode, field_value: field_value.to_vec(), - }) -)); + })) +} // Parse an error response with non-localizeable severity message. // Possible values: ERROR, FATAL, or PANIC -named!(pub parse_error_response_severity, -do_parse!( - field_type: char!('V') - >> field_value: alt!(tag!("ERROR") | tag!("FATAL") | tag!("PANIC")) - >> tag!("\x00") - >> (PgsqlErrorNoticeMessageField{ - field_type: PgsqlErrorNoticeFieldType::from(field_type), - field_value: field_value.to_vec(), - }) -)); +pub fn parse_error_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> { + let (i, field_type) = char('V')(i)?; + let (i, field_value) = alt((tag("ERROR"), tag("FATAL"), tag("PANIC")))(i)?; + let (i, _) = tag("\x00")(i)?; + Ok((i, PgsqlErrorNoticeMessageField{ + field_type: PgsqlErrorNoticeFieldType::from(field_type), + field_value: field_value.to_vec(), + })) +} // The non-localizable version of Severity field has different values, // in case of a notice: 'WARNING', 'NOTICE', 'DEBUG', 'INFO' or 'LOG' -named!(pub parse_notice_response_severity, -do_parse!( - field_type: char!('V') - >> field_value: alt!(tag!("WARNING") - | tag!("NOTICE") - | tag!("DEBUG") - | tag!("INFO") - | tag!("LOG")) - >> tag!("\x00") - >> (PgsqlErrorNoticeMessageField{ - field_type: PgsqlErrorNoticeFieldType::from(field_type), - field_value: field_value.to_vec(), - }) -)); +pub fn parse_notice_response_severity(i: &[u8]) -> IResult<&[u8], PgsqlErrorNoticeMessageField> { + let (i, field_type) = char('V')(i)?; + let (i, field_value) = alt(( + tag("WARNING"), + tag("NOTICE"), + tag("DEBUG"), + tag("INFO"), + tag("LOG")))(i)?; + let (i, _) = tag("\x00")(i)?; + Ok((i, PgsqlErrorNoticeMessageField{ + field_type: PgsqlErrorNoticeFieldType::from(field_type), + field_value: field_value.to_vec(), + })) +} pub fn parse_error_response_field( i: &[u8], is_err_msg: bool, @@ -1032,8 +995,8 @@ pub fn parse_error_response_field( b'C' => parse_error_response_code(i)?, _ => { let (i, field_type) = be_u8(i)?; - let (i, field_value) = nom::take_until!(i, "\x00")?; - let (i, _just_tag) = tag!(i, "\x00")?; + let (i, field_value) = take_until("\x00")(i)?; + let (i, _just_tag) = tag("\x00")(i)?; let message = PgsqlErrorNoticeMessageField { field_type: PgsqlErrorNoticeFieldType::from(field_type), field_value: field_value.to_vec(), @@ -1044,95 +1007,90 @@ pub fn parse_error_response_field( Ok((i, data)) } -named_args!(pub parse_error_notice_fields(is_err_msg: bool)>, -do_parse!( - data: many_till!(call!(parse_error_response_field, is_err_msg), tag!("\x00")) - >> (data.0) -)); - -named!( - pgsql_parse_error_response, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'E') - >> length: verify!(be_u32, |&x| x > 10) - >> message_body: - flat_map!( - take!(length - PGSQL_LENGTH_FIELD), - call!(parse_error_notice_fields, true) - ) - >> (PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage { - identifier, - length, - message_body, - })) - ) -); - -named!( - pgsql_parse_notice_response, - dbg_dmp!(do_parse!( - identifier: verify!(be_u8, |&x| x == b'N') - >> length: verify!(be_u32, |&x| x > 10) - >> message_body: - flat_map!( - take!(length - PGSQL_LENGTH_FIELD), - call!(parse_error_notice_fields, false) - ) - >> (PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage { - identifier, - length, - message_body, - })) - )) -); - -named!( - parse_notification_response, - do_parse!( - identifier: verify!(be_u8, |&x| x == b'A') - // length (u32) + pid (u32) + at least one byte, for we have two str fields - >> length: verify!(be_u32, |&x| x > 9) - >> data: flat_map!(take!(length - PGSQL_LENGTH_FIELD), - do_parse!( - pid: be_u32 - >> channel_name: take_until_and_consume!("\x00") - >> payload: take_until_and_consume!("\x00") - >> ((pid, channel_name, payload)) - )) - >> (PgsqlBEMessage::NotificationResponse(NotificationResponse{ - identifier, - length, - pid: data.0, - channel_name: data.1.to_vec(), - payload: data.2.to_vec(), - })) - ) -); - -named!(pub pgsql_parse_response, -do_parse!( - pseudo_header: peek!(tuple!(be_u8, be_u32)) - >> message: flat_map!(take!(pseudo_header.1 + 1), switch!(value!(pseudo_header.0), - b'E' => call!(pgsql_parse_error_response) | - b'K' => call!(parse_backend_key_data_message) | - b'N' => call!(pgsql_parse_notice_response) | - b'R' => call!(pgsql_parse_authentication_message) | - b'S' => call!(parse_parameter_status_message) | - b'C' => call!(parse_command_complete) | - b'Z' => call!(parse_ready_for_query) | - b'T' => call!(parse_row_description) | - b'A' => call!(parse_notification_response) | - b'D' => call!(parse_consolidated_data_row) - // _ => {} // TODO add an unknown message type here? - )) - >> (message) -)); +pub fn parse_error_notice_fields(i: &[u8], is_err_msg: bool) -> IResult<&[u8], Vec> { + let (i, data) = many_till(|b| parse_error_response_field(b, is_err_msg), tag("\x00"))(i)?; + Ok((i, data.0)) +} + +fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?; + let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, message_body) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + |b| parse_error_notice_fields(b, true) + )(i)?; + + Ok((i, PgsqlBEMessage::ErrorResponse(ErrorNoticeMessage { + identifier, + length, + message_body, + }))) +} + +fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?; + let (i, length) = verify(be_u32, |&x| x > 10)(i)?; + let (i, message_body) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + |b| parse_error_notice_fields(b, false) + )(i)?; + Ok((i, PgsqlBEMessage::NoticeResponse(ErrorNoticeMessage { + identifier, + length, + message_body, + }))) +} + +fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?; + // length (u32) + pid (u32) + at least one byte, for we have two str fields + let (i, length) = verify(be_u32, |&x| x > 9)(i)?; + let (i, data) = map_parser( + take(length - PGSQL_LENGTH_FIELD), + |b| { + let (b, pid) = be_u32(b)?; + let (b, channel_name) = take_until_and_consume(b"\x00")(b)?; + let (b, payload) = take_until_and_consume(b"\x00")(b)?; + Ok((b, (pid, channel_name, payload))) + })(i)?; + let msg = PgsqlBEMessage::NotificationResponse(NotificationResponse{ + identifier, + length, + pid: data.0, + channel_name: data.1.to_vec(), + payload: data.2.to_vec(), + }); + Ok((i, msg)) +} + +pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage> { + let (i, pseudo_header) = peek(tuple((be_u8, be_u32)))(i)?; + let (i, message) = map_parser( + take(pseudo_header.1 + 1), + |b| { + match pseudo_header.0 { + b'E' => pgsql_parse_error_response(b), + b'K' => parse_backend_key_data_message(b), + b'N' => pgsql_parse_notice_response(b), + b'R' => pgsql_parse_authentication_message(b), + b'S' => parse_parameter_status_message(b), + b'C' => parse_command_complete(b), + b'Z' => parse_ready_for_query(b), + b'T' => parse_row_description(b), + b'A' => parse_notification_response(b), + b'D' => parse_consolidated_data_row(b), + // _ => {} // TODO add an unknown message type here? + _ => return Err(Err::Error(make_error(i, ErrorKind::Switch))), + } + })(i)?; + Ok((i, message)) +} #[cfg(test)] mod tests { use super::*; - use nom::Needed::Size; + use nom7::Needed; #[test] fn test_parse_request() { @@ -1174,10 +1132,10 @@ mod tests { // there should be nothing left assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Result should not be an error: {:?}.", err); + Err(Err::Error(err)) => { + panic!("Result should not be an error: {:?}.", err.code); } - Err(nom::Err::Incomplete(_)) => { + Err(Err::Incomplete(_)) => { panic!("Result should not have been incomplete."); } _ => { @@ -1217,10 +1175,10 @@ mod tests { assert_eq!(message, expected_result); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -1255,10 +1213,10 @@ mod tests { assert_eq!(message, expected_result); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -1450,13 +1408,13 @@ mod tests { Ok((_remainder, message)) => { assert_eq!(message, ok_res); } - Err(nom::Err::Error((remainder, err))) => { + Err(Err::Error(err)) => { panic!( "Shouldn't be err {:?}, expected Ok(_). Remainder is: {:?} ", - err, remainder + err.code, err.input ); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be incomplete {:?}, expected Ok(_)", needed); } _ => panic!("Unexpected behavior, expected Ok(_)"), @@ -1560,11 +1518,11 @@ mod tests { match result { Ok((_remainder, _message)) => panic!("Result should not be ok, but incomplete."), - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { - assert_eq!(needed, Size(6)); + Err(Err::Incomplete(needed)) => { + assert_eq!(needed, Needed::new(2)); } _ => panic!("Unexpected behavior."), } @@ -1631,10 +1589,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -1644,9 +1602,9 @@ mod tests { let result_incomplete = pgsql_parse_response(&buf[0..22]); match result_incomplete { - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { // parser first tries to take whole message (length + identifier = 151), but buffer is incomplete - assert_eq!(needed, Size(151)); + assert_eq!(needed, Needed::new(129)); } _ => { panic!("Unexpected behavior. Should be incomplete."); @@ -1727,10 +1685,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?} ", needed); } _ => { @@ -1761,10 +1719,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -1796,10 +1754,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -1810,10 +1768,10 @@ mod tests { let incomplete_result = pgsql_parse_response(&buf[0..27]); match incomplete_result { Ok((_remainder, _message)) => panic!("Should not be Ok(_), expected Incomplete!"), - Err(nom::Err::Error((_remainder, err))) => { - panic!("Should not be error {:?}, expected Incomplete!", err) + Err(Err::Error(err)) => { + panic!("Should not be error {:?}, expected Incomplete!", err.code) } - Err(nom::Err::Incomplete(needed)) => assert_eq!(needed, Size(43)), + Err(Err::Incomplete(needed)) => assert_eq!(needed, Needed::new(16)), _ => panic!("Unexpected behavior, expected Incomplete."), } } @@ -1846,10 +1804,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error {:?} expected Ok(_)", err) + Err(Err::Error(err)) => { + panic!("Shouldn't be error {:?} expected Ok(_)", err.code) } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("shouldn't be incomplete {:?}, expected Ok(_)", needed) } _ => panic!("Unexpected behavior, expected Ok(_)"), @@ -1858,11 +1816,11 @@ mod tests { let result_incomplete = pgsql_parse_response(&buf[0..31]); match result_incomplete { Ok((_remainder, _message)) => panic!("Should not be Ok(_), expected Incomplete!"), - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error {:?} expected Incomplete!", err) + Err(Err::Error(err)) => { + panic!("Shouldn't be error {:?} expected Incomplete!", err.code) } - Err(nom::Err::Incomplete(needed)) => { - assert_eq!(needed, Size(93)); + Err(Err::Incomplete(needed)) => { + assert_eq!(needed, Needed::new(62)); } _ => panic!("Unexpected behavior, expected Ok(_)"), } @@ -1890,10 +1848,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error {:?}, expected Ok(_)", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error {:?}, expected Ok(_)", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Shouldn't be incomplete {:?}, expected OK(_)", needed); } _ => panic!("Unexpected behavior, expected Ok(_)"), @@ -1901,8 +1859,8 @@ mod tests { let result_incomplete = pgsql_parse_response(&buf[0..34]); match result_incomplete { - Err(nom::Err::Incomplete(needed)) => { - assert_eq!(needed, Size(55)); + Err(Err::Incomplete(needed)) => { + assert_eq!(needed, Needed::new(21)); } _ => panic!("Unexpected behavior, expected incomplete."), } @@ -1916,10 +1874,10 @@ mod tests { ]; let result_err = pgsql_parse_response(bad_buf); match result_err { - Err(nom::Err::Error((_remainder, err))) => { - assert_eq!(err, nom::error::ErrorKind::Switch); + Err(Err::Error(err)) => { + assert_eq!(err.code, ErrorKind::Switch); } - Err(nom::Err::Incomplete(_)) => { + Err(Err::Incomplete(_)) => { panic!("Unexpected Incomplete, should be ErrorKind::Switch"); } _ => panic!("Unexpected behavior, expected Error"), @@ -1952,10 +1910,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error {:?}, expected Ok(_)", err) + Err(Err::Error(err)) => { + panic!("Shouldn't be error {:?}, expected Ok(_)", err.code) } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Shouldn't be incomplete: {:?}, expected Ok(_)", needed) } _ => panic!("Unexpected behavior, expected Ok(_)"), @@ -1985,10 +1943,10 @@ mod tests { Ok((_remainder, message)) => { assert_eq!(message, ok_res); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?} expected Ok(_)", err) + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?} expected Ok(_)", err.code) } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Shouldn't be incomplete: {:?}, expected Ok(_)", needed) } _ => panic!("Unexpected behavior, should be Ok(_)"), @@ -2050,10 +2008,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } _ => { @@ -2075,10 +2033,10 @@ mod tests { assert_eq!(remainder.len(), 0); assert_eq!(message, ok_res); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + panic!("Shouldn't be error: {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be incomplete. Needed {:?}", needed); } _ => { @@ -2107,10 +2065,10 @@ mod tests { assert_eq!(message, ok_res); assert_eq!(remainder.len(), 0); } - Err(nom::Err::Error((_remainder, err))) => { - panic!("Should not be error {:?}", err); + Err(Err::Error(err)) => { + panic!("Should not be error {:?}", err.code); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be incomplete. Needed: {:?}", needed); } _ => { @@ -2184,12 +2142,12 @@ mod tests { assert_eq!(response, ok_res); assert!(rem.is_empty()); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!("Should not be Incomplete! Needed: {:?}", needed); } - Err(nom::Err::Error((rem, err))) => { - println!("Remainder is: {:?}", rem); - panic!("Shouldn't be error: {:?}", err); + Err(Err::Error(err)) => { + println!("Remainder is: {:?}", err.input); + panic!("Shouldn't be error: {:?}", err.code); } _ => { panic!("Unexpected behavior"); @@ -2228,15 +2186,15 @@ mod tests { assert_eq!(ok_res, message); assert!(rem.is_empty()); } - Err(nom::Err::Incomplete(needed)) => { + Err(Err::Incomplete(needed)) => { panic!( "Shouldn't be Incomplete! Expected Ok(). Needed: {:?}", needed ); } - Err(nom::Err::Error((rem, err))) => { - println!("Unparsed slice: {:?}", rem); - panic!("Shouldn't be Error: {:?}, expected Ok()", err); + Err(Err::Error(err)) => { + println!("Unparsed slice: {:?}", err.input); + panic!("Shouldn't be Error: {:?}, expected Ok()", err.code); } _ => { panic!("Unexpected behavior, should be Ok()"); diff --git a/rust/src/pgsql/pgsql.rs b/rust/src/pgsql/pgsql.rs index 9ea7d40742..ead1ec0c05 100644 --- a/rust/src/pgsql/pgsql.rs +++ b/rust/src/pgsql/pgsql.rs @@ -23,7 +23,7 @@ use super::parser::{self, ConsolidatedDataRowPacket, PgsqlBEMessage, PgsqlFEMess use crate::applayer::*; use crate::conf::*; use crate::core::{AppProto, Flow, ALPROTO_FAILED, ALPROTO_UNKNOWN, IPPROTO_TCP}; -use nom; +use nom7::{Err, IResult}; use std; use std::ffi::CString; @@ -255,7 +255,7 @@ impl PgsqlState { fn state_based_req_parsing( state: PgsqlStateProgress, input: &[u8], - ) -> Result<(&[u8], parser::PgsqlFEMessage), nom::Err<(&[u8], nom::error::ErrorKind)>> { + ) -> IResult<&[u8], parser::PgsqlFEMessage> { match state { PgsqlStateProgress::SASLAuthenticationReceived => { parser::parse_sasl_initial_response(input) @@ -314,7 +314,7 @@ impl PgsqlState { return AppLayerResult::ok(); }; } - Err(nom::Err::Incomplete(_needed)) => { + Err(Err::Incomplete(_needed)) => { let consumed = input.len() - start.len(); let needed_estimation = start.len() + 1; SCLogDebug!( @@ -399,7 +399,7 @@ impl PgsqlState { fn state_based_resp_parsing( state: PgsqlStateProgress, input: &[u8], - ) -> Result<(&[u8], parser::PgsqlBEMessage), nom::Err<(&[u8], nom::error::ErrorKind)>> { + ) -> IResult<&[u8], parser::PgsqlBEMessage> { if state == PgsqlStateProgress::SSLRequestReceived { parser::parse_ssl_response(input) } else { @@ -462,7 +462,7 @@ impl PgsqlState { return AppLayerResult::ok(); }; } - Err(nom::Err::Incomplete(_needed)) => { + Err(Err::Incomplete(_needed)) => { let consumed = input.len() - start.len(); let needed_estimation = start.len() + 1; SCLogDebug!( @@ -548,7 +548,7 @@ pub unsafe extern "C" fn rs_pgsql_probing_parser_tc( Ok((_, _response)) => { return ALPROTO_PGSQL; } - Err(nom::Err::Incomplete(_)) => { + Err(Err::Incomplete(_)) => { return ALPROTO_UNKNOWN; } Err(_) => {