}
}
-fn parse_length(i: &[u8]) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
+fn parse_gte_length(i: &[u8], expected_length: u32) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
- x >= PGSQL_LENGTH_FIELD
+ x >= expected_length
+ })(i);
+ match res {
+ Ok(result) => Ok((result.0, result.1)),
+ Err(nom7::Err::Incomplete(needed)) => Err(Err::Incomplete(needed)),
+ Err(_) => Err(Err::Error(PgsqlParseError::InvalidLength)),
+ }
+}
+
+fn parse_exact_length(
+ i: &[u8], expected_length: u32,
+) -> IResult<&[u8], u32, PgsqlParseError<&[u8]>> {
+ let res = verify(be_u32::<&[u8], nom7::error::Error<_>>, |&x| {
+ x == expected_length
})(i);
match res {
Ok(result) => Ok((result.0, result.1)),
pub fn parse_sasl_initial_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
parse_sasl_initial_response_payload,
pub fn parse_sasl_response(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let resp = PgsqlFEMessage::SASLResponse(RegularPacket {
identifier,
// Password can be encrypted or in cleartext
pub fn parse_password_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'p')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, password) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((
i,
fn parse_simple_query(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'Q')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, query) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until1("\x00"))(i)?;
Ok((
i,
fn parse_terminate_message(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'X')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?;
Ok((
i,
PgsqlFEMessage::Terminate(TerminationMessage { identifier, length }),
b'X' => parse_terminate_message(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
- let (i, length) = verify(be_u32, |&x| x >= PGSQL_LENGTH_FIELD)(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlFEMessage::UnknownMessageType(RegularPacket {
identifier,
fn pgsql_parse_authentication_message<'a>(i: &'a [u8]) -> IResult<&'a [u8], PgsqlBEMessage, PgsqlParseError<&'a [u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'R')(i)?;
- let (i, length) = verify(be_u32, |&x| x >= 8)(i)?;
+ let (i, length) = parse_gte_length(i, 8)?;
let (i, auth_type) = be_u32(i)?;
let (i, message) = map_parser(take(length - 8), |b: &'a [u8]| {
match auth_type {
fn parse_parameter_status_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'S')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, param) = map_parser(
take(length - PGSQL_LENGTH_FIELD),
pgsql_parse_generic_parameter,
fn parse_backend_key_data_message(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'K')(i)?;
- let (i, length) = verify(be_u32, |&x| x == 12)(i)?;
+ let (i, length) = parse_exact_length(i, 12)?;
let (i, pid) = be_u32(i)?;
let (i, secret_key) = be_u32(i)?;
Ok((
fn parse_command_complete(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'C')(i)?;
- let (i, length) = parse_length(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = map_parser(take(length - PGSQL_LENGTH_FIELD), take_until("\x00"))(i)?;
Ok((
i,
fn parse_ready_for_query(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'Z')(i)?;
- let (i, length) = verify(be_u32, |&x| x == 5)(i)?;
+ let (i, length) = parse_exact_length(i, 5)?;
let (i, status) = verify(be_u8, |&x| x == b'I' || x == b'T' || x == b'E')(i)?;
Ok((
i,
pub fn parse_row_description(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'T')(i)?;
- let (i, length) = verify(be_u32, |&x| x > 6)(i)?;
+ let (i, length) = parse_gte_length(i, 7)?;
let (i, field_count) = be_u16(i)?;
let (i, fields) = map_parser(
take(length - 6),
// Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages
pub fn parse_consolidated_data_row(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'D')(i)?;
- let (i, length) = verify(be_u32, |&x| x >= 6)(i)?;
+ let (i, length) = parse_gte_length(i, 7)?;
let (i, field_count) = be_u16(i)?;
// 6 here is for skipping length + field_count
let (i, rows) = map_parser(
fn pgsql_parse_error_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'E')(i)?;
- let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
+ let (i, length) = parse_gte_length(i, 11)?;
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, true)
})(i)?;
fn pgsql_parse_notice_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'N')(i)?;
- let (i, length) = verify(be_u32, |&x| x > 10)(i)?;
+ let (i, length) = parse_gte_length(i, 11)?;
let (i, message_body) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
parse_error_notice_fields(b, false)
})(i)?;
fn parse_notification_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
let (i, identifier) = verify(be_u8, |&x| x == b'A')(i)?;
// length (u32) + pid (u32) + at least one byte, for we have two str fields
- let (i, length) = verify(be_u32, |&x| x > 9)(i)?;
+ let (i, length) = parse_gte_length(i, 10)?;
let (i, data) = map_parser(take(length - PGSQL_LENGTH_FIELD), |b| {
let (b, pid) = be_u32(b)?;
let (b, channel_name) = take_until_and_consume(b"\x00")(b)?;
b'D' => parse_consolidated_data_row(i)?,
_ => {
let (i, identifier) = be_u8(i)?;
- let (i, length) = verify(be_u32, |&x| x > PGSQL_LENGTH_FIELD)(i)?;
+ let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
let (i, payload) = take(length - PGSQL_LENGTH_FIELD)(i)?;
let unknown = PgsqlBEMessage::UnknownMessageType(RegularPacket {
identifier,