]> git.ipfire.org Git - thirdparty/suricata.git/commitdiff
pgsql: add initial support to CopyIn mode/subproto
authorJuliana Fajardini <jufajardini@oisf.net>
Tue, 29 Apr 2025 13:33:38 +0000 (10:33 -0300)
committerJuliana Fajardini <jufajardini@oisf.net>
Wed, 4 Jun 2025 18:21:32 +0000 (15:21 -0300)
This sub-protocol inspects messages sent mainly from the frontend to
the backend after a 'COPY FROM STDIN' has been processed by the
backend.

Parses new messages:
- CopyInResponse -- initiates copy-in mode/sub-protocol
- CopyData (In) -- data transfer message, from frontend to backend
- CopyDone -- signals that no more CopyData messages will be seen from
  the frontend, for the current transaction
- CopyFail -- used by the frontend to signal some failure to proceed
  with sending CopyData messages

Task #7645

doc/userguide/output/eve/eve-json-format.rst
etc/schema.json
rust/src/pgsql/logger.rs
rust/src/pgsql/parser.rs
rust/src/pgsql/pgsql.rs

index ff2f3643d9c6e74fd3db23cbef3166b8b3ca2445..0ad3a384680149105298c723ad63afd0ea2eddfc 100644 (file)
@@ -2545,6 +2545,10 @@ flow. Some of the possible request messages are:
   transaction where the query was sent.
 * "message": requests which do not have meaningful payloads are logged like this,
   where the field value is the message type
+* "copy_data_in": object. Part of the CopyIn subprotocol, consolidated data
+  resulting from a ``Copy From Stdin`` query
+* "copy_done": string. Similar to ``command_completed`` but sent after the
+  frontend finishes sending a batch of ``CopyData`` messages
 
 There are several different authentication messages possible, based on selected
 authentication method. (e.g. the SASL authentication will have a set of
@@ -2571,6 +2575,8 @@ pgsql flow. Some of the possible request messages are:
 * "data_size": in bytes. When one or many ``DataRow`` messages are parsed, the
   total size in bytes of the data returned
 * "command_completed": string. Informs the command just completed by the backend
+* "copy_in_response": object. Indicates the beginning of a CopyIn mode, shows
+  how many columns will be copied from STDIN (``copy_column_cnt`` field)
 * "copy_out_response": object. Indicates the beginning of a CopyTo mode, shows
   how many columns will be copied to STDOUT (``copy_column_cnt`` field)
 * "copy_data_out": object. Consolidated data on the CopyData sent by the backend
index 41a7911e9ddd9f1483f8582a963f769e1d10428d..78c2f685ddeb82eb2c6b73b375f93ce5a345a582 100644 (file)
                     "type": "object",
                     "additionalProperties": false,
                     "properties": {
+                        "copy_data_in": {
+                            "type": "object",
+                            "description": "CopyData message from CopyIn mode",
+                            "properties": {
+                                "data_size": {
+                                    "type": "integer",
+                                    "description": "Accumulated data size of all CopyData messages sent"
+                                },
+                                "msg_count": {
+                                    "type": "integer",
+                                    "description": "How many CopyData messages were sent (does not necessarily match number of rows from the query)"
+                                }
+                            }
+                        },
                         "message": {
                             "type": "string"
                         },
                                 }
                             }
                         },
+                        "copy_in_response": {
+                            "type": "object",
+                            "description": "Backend/server response accepting CopyIn mode",
+                            "properties": {
+                                "copy_column_count": {
+                                    "type": "integer",
+                                    "description": "Number of columns that will be copied in the CopyData message"
+                                }
+                            }
+                        },
                         "copy_out_response": {
                             "type": "object",
                             "description": "Backend/server response accepting CopyOut mode",
index a714603a85dcec1faffaaa1a50e600b09bc9ed6e..b98874d5b1e9deff7efb551805b3194690548c6e 100644 (file)
@@ -102,6 +102,11 @@ fn log_request(req: &PgsqlFEMessage, flags: u32, js: &mut JsonBuilder) -> Result
             identifier: _,
             length: _,
             payload,
+        })
+        | PgsqlFEMessage::CopyFail(RegularPacket {
+            identifier: _,
+            length: _,
+            payload,
         }) => {
             js.set_string_from_bytes(req.to_str(), payload)?;
         }
@@ -110,10 +115,18 @@ fn log_request(req: &PgsqlFEMessage, flags: u32, js: &mut JsonBuilder) -> Result
             js.set_uint("process_id", *pid)?;
             js.set_uint("secret_key", *backend_key)?;
         }
-        PgsqlFEMessage::Terminate(NoPayloadMessage {
+        PgsqlFEMessage::ConsolidatedCopyDataIn(ConsolidatedDataRowPacket {
             identifier: _,
-            length: _,
+            row_cnt,
+            data_size,
         }) => {
+            js.open_object(req.to_str())?;
+            js.set_uint("msg_count", *row_cnt)?;
+            js.set_uint("data_size", *data_size)?;
+            js.close()?;
+        }
+        PgsqlFEMessage::CopyDone(_)
+        | PgsqlFEMessage::Terminate(_) => {
             js.set_string("message", req.to_str())?;
         }
         PgsqlFEMessage::UnknownMessageType(RegularPacket {
@@ -220,6 +233,11 @@ fn log_response(res: &PgsqlBEMessage, jb: &mut JsonBuilder) -> Result<(), JsonEr
             identifier: _,
             length: _,
             column_cnt,
+        })
+        | PgsqlBEMessage::CopyInResponse(CopyResponse {
+            identifier: _,
+            length: _,
+            column_cnt,
         }) => {
             jb.open_object(res.to_str())?;
             jb.set_uint("copy_column_count", *column_cnt)?;
index c5745ee5b423ee29cd43c077b779630af562b19b..6201f6ad7f92263ff0a98dd3faada444fd88d933 100644 (file)
@@ -299,6 +299,7 @@ pub enum PgsqlBEMessage {
     BackendKeyData(BackendKeyDataMessage),
     CommandComplete(RegularPacket),
     CopyOutResponse(CopyResponse),
+    CopyInResponse(CopyResponse),
     ConsolidatedCopyDataOut(ConsolidatedDataRowPacket),
     CopyDone(NoPayloadMessage),
     ReadyForQuery(ReadyForQueryMessage),
@@ -328,6 +329,7 @@ impl PgsqlBEMessage {
             PgsqlBEMessage::BackendKeyData(_) => "backend_key_data",
             PgsqlBEMessage::CommandComplete(_) => "command_completed",
             PgsqlBEMessage::CopyOutResponse(_) => "copy_out_response",
+            PgsqlBEMessage::CopyInResponse(_) => "copy_in_response",
             PgsqlBEMessage::ConsolidatedCopyDataOut(_) => "copy_data_out",
             PgsqlBEMessage::CopyDone(_) => "copy_done",
             PgsqlBEMessage::ReadyForQuery(_) => "ready_for_query",
@@ -383,6 +385,9 @@ pub enum PgsqlFEMessage {
     SASLInitialResponse(SASLInitialResponsePacket),
     SASLResponse(RegularPacket),
     SimpleQuery(RegularPacket),
+    ConsolidatedCopyDataIn(ConsolidatedDataRowPacket),
+    CopyDone(NoPayloadMessage),
+    CopyFail(RegularPacket),
     CancelRequest(CancelRequestMessage),
     Terminate(NoPayloadMessage),
     UnknownMessageType(RegularPacket),
@@ -397,6 +402,9 @@ impl PgsqlFEMessage {
             PgsqlFEMessage::SASLInitialResponse(_) => "sasl_initial_response",
             PgsqlFEMessage::SASLResponse(_) => "sasl_response",
             PgsqlFEMessage::SimpleQuery(_) => "simple_query",
+            PgsqlFEMessage::ConsolidatedCopyDataIn(_) => "copy_data_in",
+            PgsqlFEMessage::CopyDone(_) => "copy_done",
+            PgsqlFEMessage::CopyFail(_) => "copy_fail",
             PgsqlFEMessage::CancelRequest(_) => "cancel_request",
             PgsqlFEMessage::Terminate(_) => "termination_message",
             PgsqlFEMessage::UnknownMessageType(_) => "unknown_message_type",
@@ -787,6 +795,9 @@ pub fn parse_request(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError
         b'\0' => pgsql_parse_startup_packet(i)?,
         b'Q' => parse_simple_query(i)?,
         b'X' => parse_terminate_message(i)?,
+        b'd' => parse_consolidated_copy_data_in(i)?,
+        b'c' => parse_copy_in_done(i)?,
+        b'f' => parse_copy_fail(i)?,
         _ => {
             let (i, identifier) = be_u8(i)?;
             let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
@@ -1049,6 +1060,22 @@ pub fn parse_copy_out_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, Pgsql
     ))
 }
 
+pub fn parse_copy_in_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
+    let (i, identifier) = verify(be_u8, |&x| x == b'G')(i)?;
+    let (i, length) = parse_gte_length(i, 8)?;
+    let (i, _format) = be_u8(i)?;
+    let (i, columns) = be_u16(i)?;
+    let (i, _formats) = many_m_n(0, columns.to_usize(), be_u16)(i)?;
+    Ok((
+        i,
+        PgsqlBEMessage::CopyInResponse(CopyResponse {
+            identifier,
+            length,
+            column_cnt: columns,
+        })
+    ))
+}
+
 pub fn parse_consolidated_copy_data_out(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'd')(i)?;
     let (i, length) = parse_gte_length(i, 5)?;
@@ -1062,7 +1089,31 @@ pub fn parse_consolidated_copy_data_out(i: &[u8]) -> IResult<&[u8], PgsqlBEMessa
     ))
 }
 
-fn parse_copy_done(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
+pub fn parse_consolidated_copy_data_in(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
+    let (i, identifier) = verify(be_u8, |&x| x == b'd')(i)?;
+    let (i, length) = parse_gte_length(i, 5)?;
+    let (i, _data) = take(length - PGSQL_LENGTH_FIELD)(i)?;
+    SCLogDebug!("data size is {:?}", _data);
+    Ok((
+        i, PgsqlFEMessage::ConsolidatedCopyDataIn(ConsolidatedDataRowPacket {
+            identifier,
+            row_cnt: 1,
+            data_size: (length - PGSQL_LENGTH_FIELD) as u64 })
+    ))
+}
+
+fn parse_copy_in_done(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
+    let (i, identifier) = verify(be_u8, |&x| x == b'c')(i)?;
+    let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?;
+    Ok((
+        i, PgsqlFEMessage::CopyDone(NoPayloadMessage {
+            identifier,
+            length
+        })
+    ))
+}
+
+fn parse_copy_out_done(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&[u8]>> {
     let (i, identifier) = verify(be_u8, |&x| x == b'c')(i)?;
     let (i, length) = parse_exact_length(i, PGSQL_LENGTH_FIELD)?;
     Ok((
@@ -1073,6 +1124,19 @@ fn parse_copy_done(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlParseError<&
     ))
 }
 
+fn parse_copy_fail(i: &[u8]) -> IResult<&[u8], PgsqlFEMessage, PgsqlParseError<&[u8]>> {
+    let (i, identifier) = verify(be_u8, |&x| x == b'f')(i)?;
+    let (i, length) = parse_gte_length(i, 5)?;
+    let (i, data) = take(length - PGSQL_LENGTH_FIELD)(i)?;
+    Ok((
+        i, PgsqlFEMessage::CopyFail(RegularPacket {
+            identifier,
+            length,
+            payload: data.to_vec(),
+        })
+    ))
+}
+
 // Currently, we don't store the actual DataRow messages, as those could easily become a burden, memory-wise
 // We use ConsolidatedDataRow to store info we still want to log: message size.
 // Later on, we calculate the number of lines the command actually returned by counting ConsolidatedDataRow messages
@@ -1267,13 +1331,14 @@ pub fn pgsql_parse_response(i: &[u8]) -> IResult<&[u8], PgsqlBEMessage, PgsqlPar
         b'R' => pgsql_parse_authentication_message(i)?,
         b'S' => parse_parameter_status_message(i)?,
         b'C' => parse_command_complete(i)?,
-        b'c' => parse_copy_done(i)?,
+        b'c' => parse_copy_out_done(i)?,
         b'Z' => parse_ready_for_query(i)?,
         b'T' => parse_row_description(i)?,
         b'A' => parse_notification_response(i)?,
         b'D' => parse_consolidated_data_row(i)?,
         b'd' => parse_consolidated_copy_data_out(i)?,
         b'H' => parse_copy_out_response(i)?,
+        b'G' => parse_copy_in_response(i)?,
         _ => {
             let (i, identifier) = be_u8(i)?;
             let (i, length) = parse_gte_length(i, PGSQL_LENGTH_FIELD)?;
index d569069350bae94303bc2d9451097e854188b2b3..52300dc80aaba74d314c44e3177b133cd1cddda5 100644 (file)
@@ -123,7 +123,11 @@ pub enum PgsqlStateProgress {
     // Related to Backend-received messages //
     CopyOutResponseReceived,
     CopyDataOutReceived,
+    CopyInResponseReceived,
+    FirstCopyDataInReceived,
+    ConsolidatingCopyDataIn,
     CopyDoneReceived,
+    CopyFailReceived,
     SSLRejectedReceived,
     // SSPIAuthenticationReceived, // TODO implement
     SASLAuthenticationReceived,
@@ -257,6 +261,7 @@ impl PgsqlState {
             || self.state_progress == PgsqlStateProgress::SSLRequestReceived
             || self.state_progress == PgsqlStateProgress::ConnectionTerminated
             || self.state_progress == PgsqlStateProgress::CancelRequestReceived
+            || self.state_progress == PgsqlStateProgress::FirstCopyDataInReceived
         {
             let tx = self.new_tx();
             self.transactions.push_back(tx);
@@ -266,13 +271,17 @@ impl PgsqlState {
         return self.transactions.back_mut();
     }
 
+    fn get_curr_state(&mut self) -> PgsqlStateProgress {
+        self.state_progress
+    }
+
     /// Define PgsqlState progression, based on the request received
     ///
     /// As PostgreSQL transactions can have multiple messages, State progression
     /// is what helps us keep track of the PgsqlTransactions - when one finished
     /// when the other starts.
     /// State isn't directly updated to avoid reference borrowing conflicts.
-    fn request_next_state(request: &PgsqlFEMessage) -> Option<PgsqlStateProgress> {
+    fn request_next_state(&mut self, request: &PgsqlFEMessage) -> Option<PgsqlStateProgress> {
         match request {
             PgsqlFEMessage::SSLRequest(_) => Some(PgsqlStateProgress::SSLRequestReceived),
             PgsqlFEMessage::StartupMessage(_) => Some(PgsqlStateProgress::StartupMessageReceived),
@@ -288,6 +297,25 @@ impl PgsqlState {
 
                 // Important to keep in mind that: "In simple Query mode, the format of retrieved values is always text, except when the given command is a FETCH from a cursor declared with the BINARY option. In that case, the retrieved values are in binary format. The format codes given in the RowDescription message tell which format is being used." (from pgsql official documentation)
             }
+            PgsqlFEMessage::ConsolidatedCopyDataIn(_) => {
+                match self.get_curr_state() {
+                    PgsqlStateProgress::CopyInResponseReceived => {
+                        return Some(PgsqlStateProgress::FirstCopyDataInReceived);
+                    }
+                    PgsqlStateProgress::FirstCopyDataInReceived
+                    | PgsqlStateProgress::ConsolidatingCopyDataIn => {
+                        // We are in CopyInResponseReceived state, and we received a CopyDataIn message
+                        // We can either be in the first CopyDataIn message or in the middle
+                        // of consolidating CopyDataIn messages
+                        return Some(PgsqlStateProgress::ConsolidatingCopyDataIn);
+                    }
+                    _ => {
+                        return None;
+                    }
+                }
+            }
+            PgsqlFEMessage::CopyDone(_) => Some(PgsqlStateProgress::CopyDoneReceived),
+            PgsqlFEMessage::CopyFail(_) => Some(PgsqlStateProgress::CopyFailReceived),
             PgsqlFEMessage::CancelRequest(_) => Some(PgsqlStateProgress::CancelRequestReceived),
             PgsqlFEMessage::Terminate(_) => {
                 SCLogDebug!("Match: Terminate message");
@@ -330,6 +358,8 @@ impl PgsqlState {
             | PgsqlStateProgress::SASLInitialResponseReceived
             | PgsqlStateProgress::SASLResponseReceived
             | PgsqlStateProgress::CancelRequestReceived
+            | PgsqlStateProgress::CopyDoneReceived
+            | PgsqlStateProgress::CopyFailReceived
             | PgsqlStateProgress::ConnectionTerminated => true,
             _ => false,
         }
@@ -364,7 +394,7 @@ impl PgsqlState {
             match PgsqlState::state_based_req_parsing(self.state_progress, start) {
                 Ok((rem, request)) => {
                     start = rem;
-                    let new_state = PgsqlState::request_next_state(&request);
+                    let new_state = self.request_next_state(&request);
 
                     if let Some(state) = new_state {
                         self.state_progress = state;
@@ -380,10 +410,31 @@ impl PgsqlState {
                     // https://samadhiweb.com/blog/2013.04.28.graphviz.postgresv3.html
                     if let Some(tx) = self.find_or_create_tx() {
                         tx.tx_data.updated_ts = true;
-                        tx.requests.push(request);
                         if let Some(state) = new_state {
-                            if Self::request_is_complete(state) {
-                                // The request is always complete at this point
+                            if state == PgsqlStateProgress::FirstCopyDataInReceived
+                            || state == PgsqlStateProgress::ConsolidatingCopyDataIn {
+                                // here we're actually only counting how many messages were received.
+                                // frontends are not forced to send one row per message
+                                if let PgsqlFEMessage::ConsolidatedCopyDataIn(msg) = request {
+                                    tx.sum_data_size(msg.data_size);
+                                    tx.incr_row_cnt();
+                                }
+                            } else if (state == PgsqlStateProgress::CopyDoneReceived || state == PgsqlStateProgress::CopyFailReceived) && tx.get_row_cnt() > 0 {
+                                let consolidated_copy_data = PgsqlFEMessage::ConsolidatedCopyDataIn(
+                                    ConsolidatedDataRowPacket {
+                                        identifier: b'd',
+                                        row_cnt: tx.get_row_cnt(),
+                                        data_size: tx.data_size, // total byte count of all copy_data messages combined
+                                    },
+                                );
+                                tx.requests.push(consolidated_copy_data);
+                                tx.requests.push(request);
+                                // reset values
+                                tx.data_row_cnt = 0;
+                                tx.data_size = 0;
+                            } else if Self::request_is_complete(state) {
+                                tx.requests.push(request);
+                                // The request is complete at this point
                                 tx.tx_req_state = PgsqlTxProgress::TxDone;
                                 if state == PgsqlStateProgress::ConnectionTerminated
                                     || state == PgsqlStateProgress::CancelRequestReceived
@@ -491,6 +542,7 @@ impl PgsqlState {
             }
             PgsqlBEMessage::RowDescription(_) => Some(PgsqlStateProgress::RowDescriptionReceived),
             PgsqlBEMessage::CopyOutResponse(_) => Some(PgsqlStateProgress::CopyOutResponseReceived),
+            PgsqlBEMessage::CopyInResponse(_) => Some(PgsqlStateProgress::CopyInResponseReceived),
             PgsqlBEMessage::ConsolidatedDataRow(msg) => {
                 // Increment tx.data_size here, since we know msg type, so that we can later on log that info
                 self.transactions.back_mut()?.sum_data_size(msg.data_size);
@@ -541,6 +593,7 @@ impl PgsqlState {
             | PgsqlStateProgress::SASLAuthenticationReceived
             | PgsqlStateProgress::SASLAuthenticationContinueReceived
             | PgsqlStateProgress::SASLAuthenticationFinalReceived
+            | PgsqlStateProgress::CopyInResponseReceived
             | PgsqlStateProgress::Finished => true,
             _ => false,
         }