]> git.ipfire.org Git - thirdparty/openssl.git/commitdiff
Adds DTLS 1.3 ACK message functionality
authorFrederik Wedel-Heinen <frederik.wedel-heinen@dencrypt.dk>
Thu, 22 Feb 2024 07:09:38 +0000 (08:09 +0100)
committerTomas Mraz <tomas@openssl.org>
Thu, 2 Oct 2025 12:48:16 +0000 (14:48 +0200)
Reviewed-by: Matt Caswell <matt@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/25119)

40 files changed:
apps/lib/s_cb.c
doc/designs/dtlsv1_3/dtlsv1_3-main.md
engines/e_ossltest.c
include/internal/packet.h
include/internal/recordmethod.h
include/internal/statem.h
include/openssl/ssl.h.in
include/openssl/ssl3.h
ssl/d1_lib.c
ssl/pqueue.c
ssl/quic/quic_tls.c
ssl/record/methods/dtls_meth.c
ssl/record/methods/ktls_meth.c
ssl/record/methods/recmethod_local.h
ssl/record/methods/ssl3_meth.c
ssl/record/methods/tls13_meth.c
ssl/record/methods/tls1_meth.c
ssl/record/methods/tls_common.c
ssl/record/methods/tls_multib.c
ssl/record/rec_layer_d1.c
ssl/record/record.h
ssl/ssl_local.h
ssl/ssl_stat.c
ssl/statem/statem.c
ssl/statem/statem_clnt.c
ssl/statem/statem_dtls.c
ssl/statem/statem_local.h
ssl/statem/statem_srvr.c
ssl/tls13_enc.c
test/recipes/70-test_dtls13ack.t [new file with mode: 0644]
test/recipes/70-test_sslcbcpadding.t
test/recipes/70-test_sslrecords.t
test/recipes/70-test_tls13hrr.t
test/sslapitest.c
test/tls13encryptiontest.c
test/tls13secretstest.c
util/perl/TLSProxy/Message.pm
util/perl/TLSProxy/Proxy.pm
util/perl/TLSProxy/Record.pm
util/perl/TLSProxy/RecordNumber.pm [new file with mode: 0644]

index 9635c0d11fcaf955b7a8761a3852f1c75f14b301..0071e9e6aae2cabda1f934196b10b9b95873dca9 100644 (file)
@@ -682,7 +682,8 @@ void msg_cb(int write_p, int version, int content_type, const void *buf,
         switch (content_type) {
         case SSL3_RT_CHANGE_CIPHER_SPEC:
             /* type 20 */
-            str_content_type = ", ChangeCipherSpec";
+            if (version != DTLS1_3_VERSION)
+                str_content_type = ", ChangeCipherSpec";
             break;
         case SSL3_RT_ALERT:
             /* type 21 */
@@ -711,6 +712,11 @@ void msg_cb(int write_p, int version, int content_type, const void *buf,
             /* type 23 */
             str_content_type = ", ApplicationData";
             break;
+        case SSL3_RT_ACK:
+            /* type 26 */
+            if (version == DTLS1_3_VERSION)
+                str_content_type = ", ACK";
+            break;
         case SSL3_RT_HEADER:
             /* type 256 */
             str_content_type = ", RecordHeader";
@@ -720,6 +726,10 @@ void msg_cb(int write_p, int version, int content_type, const void *buf,
             str_content_type = ", InnerContent";
             break;
         default:
+            break;
+        }
+
+        if (str_content_type[0] == '\0') {
             BIO_snprintf(tmpbuf, sizeof(tmpbuf)-1, ", Unknown (content_type=%d)", content_type);
             str_content_type = tmpbuf;
         }
index a97b6e273ab175a2df287a2db9f7afa3bd892e63..665f62bdca872c8ac4bfca06c16b83eed6bf0931 100644 (file)
@@ -75,6 +75,56 @@ for DTLSv1.3 connections.
 The DTLSv1.3 implementation does not include the message sequence number,
 fragment offset and fragment length as is the case with previous versions of DTLS.
 
+#### DTLS ACK records (RFC9147 Section 7)
+
+ACKs are sent for KeyUpdates, NewSessionTicket, Certificate (client),
+CompressedCertificate (Client), CertificateVerify (client) and Finish (client).
+
+Notes on RFC9147 Section 7.1:
+
+* The implementation does not offer any logic to determine that there is disruption
+  when receiving messages which means it will not send ACKs for the example given
+  in RFC9147 Figure 12.
+* ACKs are always sent immediately after receiving a full message to be ACKed.
+* If the implementation does not receive an ACK for all fragments of a flight,
+  then the full flight will be retransmitted.
+* Empty ACKs are never sent.
+* The implementation does not explicitly prohibit receiving unencrypted ACKs. The
+  implementation will only ACK records of epoch > 0 so all ACKs sent by the
+  implementation will be encrypted.
+* The implementation only accepts ACKs after DTLSv1.3 has been negotiated. ACKs
+  that are received when DTLSv1.3 has not been negotiated is handled with a fatal
+  alert as any other unexpected message. ACKs that are received before version
+  negotiation are dropped.
+* The implementation ignores ACKs received for messages other than KeyUpdates,
+  NewSessionTicket, Certificate (client), CertificateVerify (client) and Finish (client).
+
+Missing functionality:
+
+There's need for a lot more corner case testing:
+
+* Correct handling of ACKs during KeyUpdate and SessionTicket updates.
+* Currently only retransmission after a missing ACK of client Finish message is
+  tested.
+* TLSProxy does not support testing post handshake message ACK testing. Such
+  testing probably needs to be performed by another framework.
+* This comment also forms a great test case:
+  <https://github.com/openssl/openssl/pull/25119#discussion_r1871643459>
+
+### Known issues
+
+#### Dropped records handling
+
+The implementation is only partially able to handle dropped records. For example
+`test_dtls13ack` has a disabled test case that fails when compressed certificates
+are sent. It seems like the implementation is not able to properly handle the
+case were the last flight of the client is dropped if it contains a client cert.
+In that case it should retransmit the CompressedCertificate and CertificateVerify
+messages in epoch 2, but it chooses to do it in epoch 3.
+
+There's a need to setup a test that checks dropped records in several scenarios
+and configurations in order to properly fix and verify.
+
 Implementation progress
 -----------------------
 
@@ -88,12 +138,10 @@ A summary of larger work items that needs to be addressed.
 Notice that some of the requirements mentioned in [List of DTLSv1.3 requirements](#list-of-dtls-13-requirements)
 is not covered by these workitems and must be implemented separately.
 
-| Summary                                             | #PR    |
-|-----------------------------------------------------|--------|
-| ACK messages                                        | #25119 |
-| Use HelloRetryRequest instead of HelloVerifyRequest | #22985 |
-| EndOfEarlyData message                              | -      |
-| DTLSv1.3 Fuzzer                                     | -      |
+| Summary                | #PR    |
+|------------------------|--------|
+| EndOfEarlyData message | -      |
+| DTLSv1.3 Fuzzer        | -      |
 
 ### Changes from DTLS 1.2 and/or TLS 1.3
 
@@ -139,10 +187,6 @@ random value:
 > the EndOfEarlyData message is omitted both from the wire and the handshake
 > transcript
 
-#### ACK messages
-
-See section 7 and 8 of RFC 9147.
-
 ### List of DTLSv1.3 requirements
 
 Here's a list of requirements from RFC 9147 together with their implementation status
index e31d82d0a5b1288ef1d8e9a0e2d1d4661aa5f6b7..7f861d81eeb77e7f0f57c6747f8b0ecc3d5f1141 100644 (file)
@@ -245,17 +245,22 @@ static int ossltest_ciphers(ENGINE *, const EVP_CIPHER **,
                             const int **, int);
 
 static int ossltest_cipher_nids[] = {
-    NID_aes_128_cbc, NID_aes_128_gcm,
+    NID_aes_128_cbc, NID_aes_128_ecb, NID_aes_128_gcm,
     NID_aes_128_cbc_hmac_sha1, 0
 };
 
 /* AES128 */
 
-static int ossltest_aes128_init_key(EVP_CIPHER_CTX *ctx,
-                                    const unsigned char *key,
-                                    const unsigned char *iv, int enc);
+static int ossltest_aes128_cbc_init_key(EVP_CIPHER_CTX *ctx,
+                                        const unsigned char *key,
+                                        const unsigned char *iv, int enc);
 static int ossltest_aes128_cbc_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
                                       const unsigned char *in, size_t inl);
+static int ossltest_aes128_ecb_init_key(EVP_CIPHER_CTX *ctx,
+                                        const unsigned char *key,
+                                        const unsigned char *iv, int enc);
+static int ossltest_aes128_ecb_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
+                                      const unsigned char *in, size_t inl);
 static int ossltest_aes128_gcm_init_key(EVP_CIPHER_CTX *ctx,
                                         const unsigned char *key,
                                         const unsigned char *iv, int enc);
@@ -291,7 +296,7 @@ static const EVP_CIPHER *ossltest_aes_128_cbc(void)
                                           EVP_CIPH_FLAG_DEFAULT_ASN1
                                           | EVP_CIPH_CBC_MODE)
             || !EVP_CIPHER_meth_set_init(_hidden_aes_128_cbc,
-                                         ossltest_aes128_init_key)
+                                         ossltest_aes128_cbc_init_key)
             || !EVP_CIPHER_meth_set_do_cipher(_hidden_aes_128_cbc,
                                               ossltest_aes128_cbc_cipher)
             || !EVP_CIPHER_meth_set_impl_ctx_size(_hidden_aes_128_cbc,
@@ -302,6 +307,26 @@ static const EVP_CIPHER *ossltest_aes_128_cbc(void)
     return _hidden_aes_128_cbc;
 }
 
+static EVP_CIPHER *_hidden_aes_128_ecb = NULL;
+static const EVP_CIPHER *ossltest_aes_128_ecb(void)
+{
+    if (_hidden_aes_128_ecb == NULL
+        && ((_hidden_aes_128_ecb = EVP_CIPHER_meth_new(NID_aes_128_ecb,
+                                                       16 /* block size */,
+                                                       16 /* key len */)) == NULL
+            || !EVP_CIPHER_meth_set_iv_length(_hidden_aes_128_ecb, 0)
+            || !EVP_CIPHER_meth_set_flags(_hidden_aes_128_ecb,
+                                          EVP_CIPH_FLAG_DEFAULT_ASN1 | EVP_CIPH_ECB_MODE)
+            || !EVP_CIPHER_meth_set_init(_hidden_aes_128_ecb, ossltest_aes128_ecb_init_key)
+            || !EVP_CIPHER_meth_set_do_cipher(_hidden_aes_128_ecb, ossltest_aes128_ecb_cipher)
+            || !EVP_CIPHER_meth_set_impl_ctx_size(_hidden_aes_128_ecb,
+                                                  EVP_CIPHER_impl_ctx_size(EVP_aes_128_ecb())))) {
+        EVP_CIPHER_meth_free(_hidden_aes_128_ecb);
+        _hidden_aes_128_ecb = NULL;
+    }
+    return _hidden_aes_128_ecb;
+}
+
 static EVP_CIPHER *_hidden_aes_128_gcm = NULL;
 
 #define AES_GCM_FLAGS   (EVP_CIPH_FLAG_DEFAULT_ASN1 \
@@ -366,9 +391,11 @@ static const EVP_CIPHER *ossltest_aes_128_cbc_hmac_sha1(void)
 static void destroy_ciphers(void)
 {
     EVP_CIPHER_meth_free(_hidden_aes_128_cbc);
+    EVP_CIPHER_meth_free(_hidden_aes_128_ecb);
     EVP_CIPHER_meth_free(_hidden_aes_128_gcm);
     EVP_CIPHER_meth_free(_hidden_aes_128_cbc_hmac_sha1);
     _hidden_aes_128_cbc = NULL;
+    _hidden_aes_128_ecb = NULL;
     _hidden_aes_128_gcm = NULL;
     _hidden_aes_128_cbc_hmac_sha1 = NULL;
 }
@@ -537,6 +564,9 @@ static int ossltest_ciphers(ENGINE *e, const EVP_CIPHER **cipher,
     case NID_aes_128_cbc:
         *cipher = ossltest_aes_128_cbc();
         break;
+    case NID_aes_128_ecb:
+        *cipher = ossltest_aes_128_ecb();
+        break;
     case NID_aes_128_gcm:
         *cipher = ossltest_aes_128_gcm();
         break;
@@ -686,16 +716,8 @@ static int digest_sha512_final(EVP_MD_CTX *ctx, unsigned char *md)
 /*
  * AES128 Implementation
  */
-
-static int ossltest_aes128_init_key(EVP_CIPHER_CTX *ctx,
-                                    const unsigned char *key,
-                                    const unsigned char *iv, int enc)
-{
-    return EVP_CIPHER_meth_get_init(EVP_aes_128_cbc()) (ctx, key, iv, enc);
-}
-
-static int ossltest_aes128_cbc_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
-                                      const unsigned char *in, size_t inl)
+static int ossltest_cipher(EVP_CIPHER_CTX *ctx, const EVP_CIPHER *cipher,
+                           unsigned char *out, const unsigned char *in, size_t inl)
 {
     unsigned char *tmpbuf;
     int ret;
@@ -711,16 +733,43 @@ static int ossltest_aes128_cbc_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
         memcpy(tmpbuf, in, inl);
 
     /* Go through the motions of encrypting it */
-    ret = EVP_CIPHER_meth_get_do_cipher(EVP_aes_128_cbc())(ctx, out, in, inl);
+    ret = EVP_CIPHER_meth_get_do_cipher(cipher)(ctx, out, in, inl);
 
     /* Throw it all away and just use the plaintext as the output */
     if (tmpbuf != NULL)
         memcpy(out, tmpbuf, inl);
+
     OPENSSL_free(tmpbuf);
 
     return ret;
 }
 
+static int ossltest_aes128_cbc_init_key(EVP_CIPHER_CTX *ctx,
+                                        const unsigned char *key,
+                                        const unsigned char *iv, int enc)
+{
+    return EVP_CIPHER_meth_get_init(EVP_aes_128_cbc()) (ctx, key, iv, enc);
+}
+
+static int ossltest_aes128_cbc_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
+                                      const unsigned char *in, size_t inl)
+{
+    return ossltest_cipher(ctx, EVP_aes_128_cbc(), out, in, inl);
+}
+
+static int ossltest_aes128_ecb_init_key(EVP_CIPHER_CTX *ctx,
+                                        const unsigned char *key,
+                                        const unsigned char *iv, int enc)
+{
+    return EVP_CIPHER_meth_get_init(EVP_aes_128_ecb()) (ctx, key, iv, enc);
+}
+
+static int ossltest_aes128_ecb_cipher(EVP_CIPHER_CTX *ctx, unsigned char *out,
+                                      const unsigned char *in, size_t inl)
+{
+    return ossltest_cipher(ctx, EVP_aes_128_ecb(), out, in, inl);
+}
+
 static int ossltest_aes128_gcm_init_key(EVP_CIPHER_CTX *ctx,
                                         const unsigned char *key,
                                         const unsigned char *iv, int enc)
index 2850fc5c5a1db39b642f634483104010c6e60a9c..9ea21ac9e3b18cae739fd5f07a0ef4bca9056a46 100644 (file)
@@ -228,6 +228,37 @@ __owur static ossl_inline int PACKET_peek_net_4(const PACKET *pkt,
     return 1;
 }
 
+/*
+ * Peek ahead at 6 bytes in network order from |pkt| and store the value in
+ * |*data|
+ */
+__owur static ossl_inline int PACKET_peek_net_6(const PACKET *pkt,
+                                                uint64_t *data)
+{
+    if (PACKET_remaining(pkt) < 6)
+        return 0;
+
+    *data = ((uint64_t)(*(pkt->curr))) << 40;
+    *data |= ((uint64_t)(*(pkt->curr + 1))) << 32;
+    *data |= ((uint64_t)(*(pkt->curr + 2))) << 24;
+    *data |= ((uint64_t)(*(pkt->curr + 3))) << 16;
+    *data |= ((uint64_t)(*(pkt->curr + 4))) << 8;
+    *data |= *(pkt->curr + 5);
+
+    return 1;
+}
+
+/* Get 6 bytes in network order from |pkt| and store the value in |*data| */
+__owur static ossl_inline int PACKET_get_net_6(PACKET *pkt, uint64_t *data)
+{
+    if (!PACKET_peek_net_6(pkt, data))
+        return 0;
+
+    packet_forward(pkt, 6);
+
+    return 1;
+}
+
 /*
  * Peek ahead at 8 bytes in network order from |pkt| and store the value in
  * |*data|
@@ -913,6 +944,8 @@ int WPACKET_put_bytes__(WPACKET *pkt, uint64_t val, size_t bytes);
     WPACKET_put_bytes__((pkt), (val), 3)
 #define WPACKET_put_bytes_u32(pkt, val) \
     WPACKET_put_bytes__((pkt), (val), 4)
+#define WPACKET_put_bytes_u48(pkt, val) \
+    WPACKET_put_bytes__((pkt), (val), 6)
 #define WPACKET_put_bytes_u64(pkt, val) \
     WPACKET_put_bytes__((pkt), (val), 8)
 
index 7078c331776d632ead51dc5c85d4c0571068c576..3a335355d1b5c7463972bf81d05781f6156f0f3c 100644 (file)
@@ -57,6 +57,8 @@ typedef struct ossl_record_layer_st OSSL_RECORD_LAYER;
 struct ossl_record_template_st {
     unsigned char type;
     unsigned int version;
+    uint64_t sequence_number;
+    uint64_t epoch;
     const unsigned char *buf;
     size_t buflen;
 };
@@ -227,8 +229,8 @@ struct ossl_record_method_st {
      * multiple records in one go and buffer them.
      */
     int (*read_record)(OSSL_RECORD_LAYER *rl, void **rechandle, int *rversion,
-                      uint8_t *type, const unsigned char **data, size_t *datalen,
-                      uint16_t *epoch, unsigned char *seq_num);
+                       uint8_t *type, const unsigned char **data, size_t *datalen,
+                       uint16_t *epoch, uint64_t *seq_num);
     /*
      * Release length bytes from a buffer associated with a record previously
      * read with read_record. Once all the bytes from a record are released, the
index 261d7967cc9a6a0de15e934e955c55750b4731d4..a1633de9cd7fff50098cd85d3120fea1cd7658b8 100644 (file)
@@ -106,6 +106,8 @@ struct ossl_statem_st {
     OSSL_HANDSHAKE_STATE hand_state;
     /* The handshake state requested by an API call (e.g. HelloRequest) */
     OSSL_HANDSHAKE_STATE request_state;
+    /* The handshake state waiting for acknowledge */
+    OSSL_HANDSHAKE_STATE deferred_ack_state;
     int in_init;
     int read_state_first_init;
     /* true when we are actually in SSL_accept() or SSL_connect() */
index c3144f66458f4120e4789cec45356177e6ff6183..1162ea32bb36779df66b41990d19dde0379da085 100644 (file)
@@ -1066,7 +1066,11 @@ typedef enum {
     TLS_ST_EARLY_DATA,
     TLS_ST_PENDING_EARLY_DATA_END,
     TLS_ST_CW_END_OF_EARLY_DATA,
-    TLS_ST_SR_END_OF_EARLY_DATA
+    TLS_ST_SR_END_OF_EARLY_DATA,
+    TLS_ST_CR_ACK,
+    TLS_ST_CW_ACK,
+    TLS_ST_SR_ACK,
+    TLS_ST_SW_ACK
 } OSSL_HANDSHAKE_STATE;
 
 /*
index b8dada1b6b157a2f5802a112e60b82787133ae06..2e93dcf212ade1cbc45c61986af68f05eeaa44eb 100644 (file)
@@ -220,6 +220,7 @@ extern "C" {
 # define SSL3_RT_ALERT                   21
 # define SSL3_RT_HANDSHAKE               22
 # define SSL3_RT_APPLICATION_DATA        23
+# define SSL3_RT_ACK                     26 /* RFC 9147 */
 
 /* Pseudo content types to indicate additional parameters */
 # define TLS1_RT_CRYPTO                  0x1000
@@ -334,6 +335,9 @@ extern "C" {
 # define SSL3_MT_MESSAGE_HASH                    254
 # define DTLS1_MT_HELLO_VERIFY_REQUEST           3
 
+/* Dummy message type for handling ACK like a normal handshake message */
+# define DTLS13_MT_ACK                           0x0126
+
 /* Dummy message type for handling CCS like a normal handshake message */
 # define SSL3_MT_CHANGE_CIPHER_SPEC              0x0101
 
index 162306cc2ae09b3228610dc21b3a1172321a28bb..718ccd777f8c3ebdee775020299d841e7554c600 100644 (file)
@@ -114,7 +114,8 @@ int dtls1_new(SSL *ssl)
 static void dtls1_clear_queues(SSL_CONNECTION *s)
 {
     dtls1_clear_received_buffer(s);
-    dtls1_clear_sent_buffer(s);
+    dtls1_clear_sent_buffer(s, 0);
+    ossl_list_record_number_elem_free(&s->d1->ack_rec_num);
 }
 
 void dtls1_clear_received_buffer(SSL_CONNECTION *s)
@@ -130,15 +131,75 @@ void dtls1_clear_received_buffer(SSL_CONNECTION *s)
     }
 }
 
-void dtls1_clear_sent_buffer(SSL_CONNECTION *s)
+void ossl_list_record_number_elem_free(OSSL_LIST(record_number) *p_list)
+{
+    DTLS1_RECORD_NUMBER *p_elem;
+    DTLS1_RECORD_NUMBER *p_elem_next = NULL;
+
+    if (p_list != NULL)
+        p_elem_next = ossl_list_record_number_head(p_list);
+
+    while ((p_elem = p_elem_next) != NULL) {
+        p_elem_next = ossl_list_record_number_next(p_elem_next);
+        ossl_list_record_number_remove(p_list, p_elem);
+        OPENSSL_free(p_elem);
+    }
+}
+
+DTLS1_RECORD_NUMBER *dtls1_record_number_new(uint64_t epoch, uint64_t seqnum)
+{
+    DTLS1_RECORD_NUMBER *recnum = OPENSSL_zalloc(sizeof(*recnum));
+
+    if (recnum != NULL) {
+        recnum->epoch = epoch;
+        recnum->seqnum = seqnum;
+    }
+
+    return recnum;
+}
+
+void dtls1_acknowledge_sent_buffer(SSL_CONNECTION *s, uint16_t before_epoch)
+{
+    pitem *item = NULL;
+    piterator iter = pqueue_iterator(&s->d1->sent_messages);
+
+    while ((item = pqueue_next(&iter)) != NULL) {
+        dtls_sent_msg *sent_msg = (dtls_sent_msg *)item->data;
+        DTLS1_RECORD_NUMBER *recnum;
+        DTLS1_RECORD_NUMBER *recnum_next = ossl_list_record_number_head(&sent_msg->rec_nums);
+
+        while ((recnum = recnum_next) != NULL) {
+            recnum_next = ossl_list_record_number_next(recnum_next);
+
+            if (recnum->epoch < before_epoch) {
+                ossl_list_record_number_remove(&sent_msg->rec_nums, recnum);
+                OPENSSL_free(recnum);
+            }
+        }
+    }
+}
+
+void dtls1_clear_sent_buffer(SSL_CONNECTION *s, int keep_unacked_msgs)
 {
     pitem *item = NULL;
+    pqueue *remaining_sent_messages = pqueue_new();
     pqueue *sent_messages = &s->d1->sent_messages;
 
     while ((item = pqueue_pop(sent_messages)) != NULL) {
-        dtls_sent_msg *sent_msg = (dtls_sent_msg *)item->data;
+        dtls_sent_msg *sent_msg = (dtls_sent_msg *) item->data;
+        unsigned char msg_type = sent_msg->msg_info.msg_type;
+        unsigned char record_type = sent_msg->msg_info.record_type;
+
+        if (SSL_CONNECTION_IS_DTLS13(s)
+            && !ossl_list_record_number_is_empty(&sent_msg->rec_nums)
+            && keep_unacked_msgs) {
+            pqueue_insert(remaining_sent_messages, item);
+            continue;
+        }
 
-        if (sent_msg->record_type == SSL3_RT_CHANGE_CIPHER_SPEC
+        if (((!SSL_CONNECTION_IS_DTLS13(s) && record_type == SSL3_RT_CHANGE_CIPHER_SPEC)
+             || (SSL_CONNECTION_IS_DTLS13(s)
+                 && (msg_type == SSL3_MT_FINISHED || msg_type == SSL3_MT_KEY_UPDATE)))
             && sent_msg->saved_retransmit_state.wrlmethod != NULL
             && s->rlayer.wrl != sent_msg->saved_retransmit_state.wrl) {
             /*
@@ -151,8 +212,28 @@ void dtls1_clear_sent_buffer(SSL_CONNECTION *s)
         dtls1_sent_msg_free(sent_msg);
         pitem_free(item);
     }
+
+    if (SSL_CONNECTION_IS_DTLS13(s))
+        while ((item = pqueue_pop(remaining_sent_messages)) != NULL)
+            pqueue_insert(&s->d1->sent_messages, item);
+
+    pqueue_free(remaining_sent_messages);
 }
 
+int dtls_any_sent_messages_are_missing_acknowledge(SSL_CONNECTION *s)
+{
+    pitem *item;
+    piterator iter = pqueue_iterator(&s->d1->sent_messages);
+
+    while ((item = pqueue_next(&iter)) != NULL) {
+        dtls_sent_msg *msg = (dtls_sent_msg *)item->data;
+
+        if (!ossl_list_record_number_is_empty(&msg->rec_nums))
+            return 1;
+    }
+
+    return 0;
+}
 
 void dtls1_free(SSL *ssl)
 {
@@ -354,7 +435,7 @@ void dtls1_stop_timer(SSL_CONNECTION *s)
     s->d1->timeout_duration_us = 1000000;
     dtls1_bio_set_next_timeout(s->rbio, s->d1);
     /* Clear retransmission buffer */
-    dtls1_clear_sent_buffer(s);
+    dtls1_clear_sent_buffer(s, 0);
 }
 
 int dtls1_check_timeout_num(SSL_CONNECTION *s)
index a2a57cbf456c6118c2959cf93aaee64846f8aa5e..f02dda731ba275ecbc8574ac84ae0b54afa1393a 100644 (file)
@@ -20,6 +20,23 @@ pitem *pitem_new(unsigned char *prio64be, void *data)
     memcpy(item->priority, prio64be, sizeof(item->priority));
     item->data = data;
     item->next = NULL;
+
+    return item;
+}
+
+pitem *pitem_new_u64(uint64_t prio, void *data)
+{
+    pitem *item = OPENSSL_malloc(sizeof(*item));
+    unsigned char *p_item_prio;
+
+    if (item == NULL)
+        return NULL;
+
+    p_item_prio = item->priority;
+    l2n8(prio, p_item_prio);
+    item->data = data;
+    item->next = NULL;
+
     return item;
 }
 
@@ -116,6 +133,15 @@ pitem *pqueue_find(pqueue *pq, unsigned char *prio64be)
     return found;
 }
 
+pitem *pqueue_find_u64(pqueue *pq, uint64_t prio)
+{
+    unsigned char prio64be[8], *p_prio64be = prio64be;
+
+    l2n8(prio, p_prio64be);
+
+    return pqueue_find(pq, prio64be);
+}
+
 pitem *pqueue_iterator(pqueue *pq)
 {
     return pqueue_peek(pq);
index 3c21a3d870b5e08915974b8d7f92201fd6c45e0f..0e3a84fa748e980cb308581ca1b0a45bb2f7dae4 100644 (file)
@@ -364,7 +364,7 @@ static int quic_retry_write_records(OSSL_RECORD_LAYER *rl)
 static int quic_read_record(OSSL_RECORD_LAYER *rl, void **rechandle,
                             int *rversion, uint8_t *type, const unsigned char **data,
                             size_t *datalen, uint16_t *epoch,
-                            unsigned char *seq_num)
+                            uint64_t *seq_num)
 {
     if (rl->recread != 0 || rl->recunreleased != 0)
         return OSSL_RECORD_RETURN_FATAL;
index 8b0626255eab92ed900376501393c2b17df591bd..752a98434a06a039fb0a1437958fc7fcfc1599a5 100644 (file)
 #include "../record_local.h"
 #include "recmethod_local.h"
 
-/* mod 128 saturating subtract of two 64-bit values in big-endian order */
-static int satsub64be(const unsigned char *v1, const unsigned char *v2)
+/* mod 128 saturating subtract of two 64-bit values */
+static int satsub64(uint64_t l1, uint64_t l2)
 {
-    int64_t ret;
-    uint64_t l1, l2;
+    uint64_t max, min;
+    int sign;
 
-    n2l8(v1, l1);
-    n2l8(v2, l2);
-
-    ret = l1 - l2;
+    if (l1 > l2) {
+        max = l1;
+        min = l2;
+        sign = 1;
+    } else {
+        max = l2;
+        min = l1;
+        sign = -1;
+    }
 
-    /* We do not permit wrap-around */
-    if (l1 > l2 && ret < 0)
-        return 128;
-    else if (l2 > l1 && ret > 0)
-        return -128;
+    if (max - min > 128)
+        return sign * 128;
 
-    if (ret > 128)
-        return 128;
-    else if (ret < -128)
-        return -128;
-    else
-        return (int)ret;
+    return sign * ((int)(max - min));
 }
 
 static int dtls_record_replay_check(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
 {
     int cmp;
     unsigned int shift;
-    const unsigned char *seq = rl->sequence;
 
-    cmp = satsub64be(seq, bitmap->max_seq_num);
+    cmp = satsub64(rl->sequence, bitmap->max_seq_num);
     if (cmp > 0) {
-        ossl_tls_rl_record_set_seq_num(&rl->rrec[0], seq);
+        rl->rrec[0].seq_num = rl->sequence;
         return 1;               /* this record in new */
     }
     shift = -cmp;
@@ -54,25 +50,23 @@ static int dtls_record_replay_check(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
     else if (bitmap->map & ((uint64_t)1 << shift))
         return 0;               /* record previously received */
 
-    ossl_tls_rl_record_set_seq_num(&rl->rrec[0], seq);
+    rl->rrec[0].seq_num = rl->sequence;
     return 1;
 }
 
-static void dtls_record_bitmap_update(OSSL_RECORD_LAYER *rl,
-                                      DTLS_BITMAP *bitmap)
+static void dtls_record_bitmap_update(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
 {
     int cmp;
     unsigned int shift;
-    const unsigned char *seq = rl->sequence;
 
-    cmp = satsub64be(seq, bitmap->max_seq_num);
+    cmp = satsub64(rl->sequence, bitmap->max_seq_num);
     if (cmp > 0) {
         shift = cmp;
         if (shift < sizeof(bitmap->map) * 8)
             bitmap->map <<= shift, bitmap->map |= 1UL;
         else
             bitmap->map = 1UL;
-        memcpy(bitmap->max_seq_num, seq, SEQ_NUM_SIZE);
+        bitmap->max_seq_num = rl->sequence;
     } else {
         shift = -cmp;
         if (shift < sizeof(bitmap->map) * 8)
@@ -298,7 +292,7 @@ static int dtls_process_record(OSSL_RECORD_LAYER *rl, DTLS_BITMAP *bitmap)
 }
 
 static int dtls_rlayer_buffer_record(OSSL_RECORD_LAYER *rl, struct pqueue_st *queue,
-                                     unsigned char *priority)
+                                     uint64_t priority)
 {
     DTLS_RLAYER_RECORD_DATA *rdata;
     pitem *item;
@@ -308,7 +302,7 @@ static int dtls_rlayer_buffer_record(OSSL_RECORD_LAYER *rl, struct pqueue_st *qu
         return 0;
 
     rdata = OPENSSL_malloc(sizeof(*rdata));
-    item = pitem_new(priority, rdata);
+    item = pitem_new_u64(priority, rdata);
     if (rdata == NULL || item == NULL) {
         OPENSSL_free(rdata);
         pitem_free(item);
@@ -361,8 +355,7 @@ static int dtls_copy_rlayer_record(OSSL_RECORD_LAYER *rl, pitem *item)
     memcpy(&rl->rrec[0], &(rdata->rrec), sizeof(TLS_RL_RECORD));
 
     /* Set proper sequence number for mac calculation */
-    assert(sizeof(rl->sequence) == sizeof(rdata->rrec.seq_num));
-    memcpy(rl->sequence, rdata->rrec.seq_num, sizeof(rl->sequence));
+    rl->sequence = rdata->rrec.seq_num;
 
     return 1;
 }
@@ -405,6 +398,7 @@ int dtls_crypt_sequence_number(EVP_CIPHER_CTX *ctx, unsigned char *seq, size_t s
 
     if (!ossl_assert(inlen >= 0)
             || (size_t)inlen > sizeof(mask)
+            || !EVP_CIPHER_CTX_set_padding(ctx, 0)
             || EVP_CipherInit_ex2(ctx, NULL, NULL, iv, 1, NULL) <= 0
             || EVP_CipherUpdate(ctx, mask, &outlen, in, inlen) <= 0
             || outlen != inlen
@@ -444,8 +438,6 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
     size_t rechdrlen = 0;
     size_t recseqnumoffs = 0;
 
-    memset(recseqnum, 0, sizeof(recseqnum));
-
     rl->num_recs = 0;
     rl->curr_rec = 0;
     rl->num_released = 0;
@@ -460,6 +452,8 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
     }
 
  again:
+    memset(recseqnum, 0, sizeof(recseqnum));
+
     /* if we're renegotiating, then there may be buffered records */
     if (dtls_retrieve_rlayer_buffered_record(rl, &rl->processed_rcds)) {
         rl->num_recs = 1;
@@ -515,7 +509,7 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
         if (rl->version == DTLS1_3_VERSION
             && rr->type != SSL3_RT_ALERT
             && rr->type != SSL3_RT_HANDSHAKE
-            /* TODO(DTLSv1.3): && rr->type != SSL3_RT_ACK depends on acknowledge implementation */
+            && rr->type != SSL3_RT_ACK
             && !DTLS13_UNI_HDR_FIX_BITS_IS_SET(rr->type)) {
             /* Silently discard */
             rr->length = 0;
@@ -545,22 +539,15 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
                  * Naive approach? We expect sequence number to be filled already
                  * and then override the last bytes of the sequence number.
                  */
-                || !PACKET_copy_bytes(&dtlsrecord, recseqnum + recseqnumoffs, recseqnumlen)) {
-                rr->length = 0;
-                rl->packet_length = 0;
-                goto again;
-            }
-
-            /*
-             * rfc9147:
-             * The length field MAY be omitted by clearing the L bit, which means
-             * that the record consumes the entire rest of the datagram in the
-             * lower level transport
-             */
-            length = TLS_BUFFER_get_len(&rl->rbuf) - dtls_get_rec_header_size(rr->type);
-
-            if ((lbitisset && !PACKET_get_net_2(&dtlsrecord, &length))
-                || length == 0) {
+                || !PACKET_copy_bytes(&dtlsrecord, recseqnum + recseqnumoffs, recseqnumlen)
+                /*
+                 * rfc9147:
+                 * The length field MAY be omitted by clearing the L bit, which means
+                 * that the record consumes the entire rest of the datagram in the
+                 * lower level transport
+                 */
+                || (lbitisset ? !PACKET_get_net_2(&dtlsrecord, &length)
+                              : (length = (unsigned int)TLS_BUFFER_get_len(&rl->rbuf)) > 0)) {
                 rr->length = 0;
                 rl->packet_length = 0;
                 goto again;
@@ -670,14 +657,13 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
 
     /*
      * rfc9147:
-     * This procedure requires the ciphertext length to be at least
-     * DTLS13_CIPHERTEXT_MINSIZE (16) bytes.
+     * This procedure requires the ciphertext length to be at least 16 bytes.
      * Receivers MUST reject shorter records as if they had failed deprotection
      */
     if (DTLS13_UNI_HDR_FIX_BITS_IS_SET(rr->type)
             && rl->version == DTLS1_3_VERSION
-            && (!ossl_assert(rl->sn_enc_ctx != NULL)
-                || !ossl_assert(rl->packet_length >= rechdrlen + DTLS13_CIPHERTEXT_MINSIZE)
+            && ossl_assert(rl->sn_enc_ctx != NULL)
+            && (!ossl_assert(rl->packet_length >= rechdrlen + DTLS13_CIPHERTEXT_MINSIZE)
                 || !dtls_crypt_sequence_number(rl->sn_enc_ctx,
                                                recseqnum + recseqnumoffs,
                                                recseqnumlen,
@@ -689,8 +675,13 @@ int dtls_get_more_records(OSSL_RECORD_LAYER *rl)
         goto again;
     }
 
-    memset(rl->sequence, 0, sizeof(rl->sequence));
-    memcpy(rl->sequence + 2, recseqnum, sizeof(recseqnum));
+    /* TODO(DTLSv1.3): make recseqnum a uint64_t */
+    rl->sequence =  ((uint64_t)recseqnum[0]) << 40;
+    rl->sequence |= ((uint64_t)recseqnum[1]) << 32;
+    rl->sequence |= ((uint64_t)recseqnum[2]) << 24;
+    rl->sequence |= ((uint64_t)recseqnum[3]) << 16;
+    rl->sequence |= ((uint64_t)recseqnum[4]) << 8;
+    rl->sequence |= ((uint64_t)recseqnum[5]) << 0;
 
     /* match epochs.  NULL means the packet is dropped on the floor */
     bitmap = dtls_get_bitmap(rl, rr, &is_next_epoch);
@@ -872,9 +863,11 @@ int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
     size_t maxcomplen;
     int unifiedheader = rl->version == DTLS1_3_VERSION && rl->epoch > 0;
 
+    templ->sequence_number = rl->sequence;
+    templ->epoch = rl->epoch;
     *recdata = NULL;
-
     maxcomplen = templ->buflen;
+
     if (rl->compctx != NULL)
         maxcomplen += SSL3_RT_MAX_COMPRESSED_OVERHEAD;
 
@@ -887,7 +880,8 @@ int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
         uint8_t unifiedhdrbits = fixedbits | cbit | sbit | lbit | ebits;
 
         if (!WPACKET_put_bytes_u8(thispkt, unifiedhdrbits)
-            || !WPACKET_memcpy(thispkt, rl->sequence + 6, 2)
+            || (sbit ? !WPACKET_put_bytes_u16(thispkt, rl->sequence)
+                     : !WPACKET_put_bytes_u8(thispkt, rl->sequence))
             || !WPACKET_start_sub_packet_u16(thispkt)
             || (rl->eivlen > 0
                 && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL))
@@ -899,8 +893,8 @@ int dtls_prepare_record_header(OSSL_RECORD_LAYER *rl,
     } else {
         if (!WPACKET_put_bytes_u8(thispkt, rectype)
             || !WPACKET_put_bytes_u16(thispkt, templ->version)
-            || !WPACKET_put_bytes_u16(thispkt, rl->epoch)
-            || !WPACKET_memcpy(thispkt, &(rl->sequence[2]), 6)
+            || !WPACKET_put_bytes_u16(thispkt, templ->epoch)
+            || !WPACKET_put_bytes_u48(thispkt, templ->sequence_number)
             || !WPACKET_start_sub_packet_u16(thispkt)
             || (rl->eivlen > 0
                 && !WPACKET_allocate_bytes(thispkt, rl->eivlen, NULL))
index c5ee2894f681c17439334d5bc3667e4192f3d678..89bccf9ff48a12a210f4054cfeb75e0a78b2fbed 100644 (file)
@@ -299,6 +299,9 @@ static int ktls_set_crypto_state(OSSL_RECORD_LAYER *rl, int level,
                                  COMP_METHOD *comp)
 {
     ktls_crypto_info_t crypto_info;
+    unsigned char recseq[SEQ_NUM_SIZE], *p_recseq = recseq;
+
+    l2n8(rl->sequence, p_recseq);
 
     /*
      * Check if we are suitable for KTLS. If not suitable we return
@@ -327,7 +330,7 @@ static int ktls_set_crypto_state(OSSL_RECORD_LAYER *rl, int level,
             return OSSL_RECORD_RETURN_NON_FATAL_ERR;
     }
 
-    if (!ktls_configure_crypto(rl->libctx, rl->version, ciph, md, rl->sequence,
+    if (!ktls_configure_crypto(rl->libctx, rl->version, ciph, md, recseq,
                                &crypto_info,
                                rl->direction == OSSL_RECORD_DIRECTION_WRITE,
                                iv, ivlen, key, keylen, mackey, mackeylen))
index 8da1c8a28243cc71f351a7a3b703a429f5b365b4..6df375a3dff03d5775b3ef9224f80a1941807c8f 100644 (file)
@@ -16,8 +16,8 @@
 typedef struct dtls_bitmap_st {
     /* Track 64 packets */
     uint64_t map;
-    /* Max record number seen so far, 64-bit value in big-endian encoding */
-    unsigned char max_seq_num[SEQ_NUM_SIZE];
+    /* Max record number seen so far */
+    uint64_t max_seq_num;
 } DTLS_BITMAP;
 
 typedef struct ssl_mac_buf_st {
@@ -75,7 +75,7 @@ typedef struct tls_rl_record_st {
     uint16_t epoch;
     /* sequence number, needed by DTLS1 */
     /* r */
-    unsigned char seq_num[SEQ_NUM_SIZE];
+    uint64_t seq_num;
 } TLS_RL_RECORD;
 
 /* Macros/functions provided by the TLS_RL_RECORD component */
@@ -272,7 +272,7 @@ struct ossl_record_layer_st {
     size_t packet_length;
 
     /* Sequence number for the next record */
-    unsigned char sequence[SEQ_NUM_SIZE];
+    uint64_t sequence;
 
     /* Alert code to be used if an error occurs */
     int alert;
@@ -407,9 +407,6 @@ void ossl_rlayer_fatal(OSSL_RECORD_LAYER *rl, int al, int reason,
                                     || (rl)->version == DTLS1_VERSION \
                                     || (rl)->version == DTLS1_2_VERSION)
 
-void ossl_tls_rl_record_set_seq_num(TLS_RL_RECORD *r,
-                                    const unsigned char *seq_num);
-
 int ossl_set_tls_provider_parameters(OSSL_RECORD_LAYER *rl,
                                      EVP_CIPHER_CTX *ctx,
                                      const EVP_CIPHER *ciph,
@@ -485,7 +482,7 @@ int tls_get_alert_code(OSSL_RECORD_LAYER *rl);
 int tls_set1_bio(OSSL_RECORD_LAYER *rl, BIO *bio);
 int tls_read_record(OSSL_RECORD_LAYER *rl, void **rechandle, int *rversion,
                     uint8_t *type, const unsigned char **data, size_t *datalen,
-                    uint16_t *epoch, unsigned char *seq_num);
+                    uint16_t *epoch, uint64_t *seq_num);
 int tls_release_record(OSSL_RECORD_LAYER *rl, void *rechandle, size_t length);
 int tls_set_protocol_version(OSSL_RECORD_LAYER *rl, int version);
 void tls_set_plain_alerts(OSSL_RECORD_LAYER *rl, int allow);
index b424cef278eb7fed253aba0263627bdd7c3b02bd..568ac06ad3d81dbe642ce8e1f60f5b9303a34ea8 100644 (file)
@@ -220,14 +220,22 @@ static const unsigned char ssl3_pad_2[48] = {
 static int ssl3_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md,
                     int sending)
 {
-    unsigned char *mac_sec, *seq = rl->sequence;
+    /*
+     * npad is, at most, 48 bytes and that's with MD5:
+     *   16 + 48 + 8 (sequence bytes) + 1 + 2 = 75.
+     *
+     * With SHA-1 (the largest hash speced for SSLv3) the hash size
+     * goes up 4, but npad goes down by 8, resulting in a smaller
+     * total size.
+     */
+    unsigned char header[75];
+    WPACKET hdr;
+    size_t hdr_written;
     const EVP_MD_CTX *hash;
-    unsigned char *p, rec_char;
     size_t md_size;
     size_t npad;
-    int t;
+    int t, cbc_encrypted;
 
-    mac_sec = &(rl->mac_secret[0]);
     hash = rl->md_ctx;
 
     t = EVP_MD_CTX_get_size(hash);
@@ -238,7 +246,24 @@ static int ssl3_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md
 
     if (!sending
         && EVP_CIPHER_CTX_get_mode(rl->enc_ctx) == EVP_CIPH_CBC_MODE
-        && ssl3_cbc_record_digest_supported(hash)) {
+        && ssl3_cbc_record_digest_supported(hash))
+        cbc_encrypted = 1;
+    else
+        cbc_encrypted = 0;
+
+    if (!WPACKET_init_static_len(&hdr, header, sizeof(header), 0)
+            || !WPACKET_memcpy(&hdr, rl->mac_secret, md_size)
+            || !WPACKET_memcpy(&hdr, ssl3_pad_1, npad)
+            || !WPACKET_put_bytes_u64(&hdr, rl->sequence)
+            || !WPACKET_put_bytes_u8(&hdr, rec->type)
+            || !WPACKET_put_bytes_u16(&hdr, rec->length)
+            || !WPACKET_finish(&hdr)
+            || !WPACKET_get_total_written(&hdr, &hdr_written)) {
+        WPACKET_cleanup(&hdr);
+        return 0;
+    }
+
+    if (cbc_encrypted) {
 #ifdef OPENSSL_NO_DEPRECATED_3_0
         return 0;
 #else
@@ -248,32 +273,12 @@ static int ssl3_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md
          * are hashing because that gives an attacker a timing-oracle.
          */
 
-        /*-
-         * npad is, at most, 48 bytes and that's with MD5:
-         *   16 + 48 + 8 (sequence bytes) + 1 + 2 = 75.
-         *
-         * With SHA-1 (the largest hash speced for SSLv3) the hash size
-         * goes up 4, but npad goes down by 8, resulting in a smaller
-         * total size.
-         */
-        unsigned char header[75];
-        size_t j = 0;
-        memcpy(header + j, mac_sec, md_size);
-        j += md_size;
-        memcpy(header + j, ssl3_pad_1, npad);
-        j += npad;
-        memcpy(header + j, seq, 8);
-        j += 8;
-        header[j++] = rec->type;
-        header[j++] = (unsigned char)(rec->length >> 8);
-        header[j++] = (unsigned char)(rec->length & 0xff);
-
         /* Final param == is SSLv3 */
         if (ssl3_cbc_digest_record(EVP_MD_CTX_get0_md(hash),
                                    md, &md_size,
                                    header, rec->input,
                                    rec->length, rec->orig_len,
-                                   mac_sec, md_size, 1) <= 0)
+                                   rl->mac_secret, md_size, 1) <= 0)
             return 0;
 #endif
     } else {
@@ -284,19 +289,12 @@ static int ssl3_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md
         if (md_ctx == NULL)
             return 0;
 
-        rec_char = rec->type;
-        p = md;
-        s2n(rec->length, p);
         if (EVP_MD_CTX_copy_ex(md_ctx, hash) <= 0
-            || EVP_DigestUpdate(md_ctx, mac_sec, md_size) <= 0
-            || EVP_DigestUpdate(md_ctx, ssl3_pad_1, npad) <= 0
-            || EVP_DigestUpdate(md_ctx, seq, 8) <= 0
-            || EVP_DigestUpdate(md_ctx, &rec_char, 1) <= 0
-            || EVP_DigestUpdate(md_ctx, md, 2) <= 0
+            || EVP_DigestUpdate(md_ctx, header, hdr_written) <= 0
             || EVP_DigestUpdate(md_ctx, rec->input, rec->length) <= 0
             || EVP_DigestFinal_ex(md_ctx, md, NULL) <= 0
             || EVP_MD_CTX_copy_ex(md_ctx, hash) <= 0
-            || EVP_DigestUpdate(md_ctx, mac_sec, md_size) <= 0
+            || EVP_DigestUpdate(md_ctx, rl->mac_secret, md_size) <= 0
             || EVP_DigestUpdate(md_ctx, ssl3_pad_2, npad) <= 0
             || EVP_DigestUpdate(md_ctx, md, md_size) <= 0
             || EVP_DigestFinal_ex(md_ctx, md, &md_size_u) <= 0) {
index b3361125c06fcbbfb911fd02dc6a41d3241fa130..3bffad0bd02375a5b3afc7240737b2a73f624330 100644 (file)
@@ -116,7 +116,7 @@ static int tls13_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
     int isdtls, sbit = 0, addlen;
     unsigned char *staticiv;
     unsigned char *nonce;
-    unsigned char *seq = rl->sequence;
+    unsigned char seq[SEQ_NUM_SIZE], *p_seq = seq;
     int lenu, lenf;
     TLS_RL_RECORD *rec = &recs[0];
     WPACKET wpkt;
@@ -134,6 +134,7 @@ static int tls13_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
     staticiv = rl->iv;
     nonce = rl->nonce;
     isdtls = rl->isdtls;
+    l2n8(rl->sequence, p_seq);
 
     if (enc_ctx == NULL && rl->mac_ctx == NULL) {
         RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
@@ -221,8 +222,8 @@ static int tls13_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
     if ((isdtls && !ossl_assert(!DTLS13_UNI_HDR_CID_BIT_IS_SET(rec->type)))
             || !WPACKET_init_static_len(&wpkt, recheader, sizeof(recheader), 0)
             || !WPACKET_put_bytes_u8(&wpkt, rec->type)
-            || (isdtls && (sbit ? !WPACKET_memcpy(&wpkt, rl->sequence + 6, 2)
-                                : !WPACKET_memcpy(&wpkt, rl->sequence + 7, 1)))
+            || (isdtls && (sbit ? !WPACKET_put_bytes_u16(&wpkt, rl->sequence)
+                                : !WPACKET_put_bytes_u8(&wpkt, rl->sequence)))
             || (!isdtls && !WPACKET_put_bytes_u16(&wpkt, rec->rec_version))
             || (addlen && !WPACKET_put_bytes_u16(&wpkt, rec->length + rl->taglen))
             || !WPACKET_get_total_written(&wpkt, &hdrlen)
index 949492163a392b688c14ec7cef0ee2f2d0974d2c..2d23a7cd0002a130acd53d06c2c26c0d6a9b418d 100644 (file)
@@ -167,6 +167,30 @@ static int tls1_set_crypto_state(OSSL_RECORD_LAYER *rl, int level,
     return OSSL_RECORD_RETURN_SUCCESS;
 }
 
+static int setup_record_header(const OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec,
+                               unsigned char *buf, size_t buflen)
+{
+    WPACKET hdr;
+    size_t hdrsize;
+
+    if (buflen < EVP_AEAD_TLS1_AAD_LEN
+            || !WPACKET_init_static_len(&hdr, buf, EVP_AEAD_TLS1_AAD_LEN, 0)
+            || (rl->isdtls && !WPACKET_put_bytes_u16(&hdr, rl->epoch))
+            || (rl->isdtls ? !WPACKET_put_bytes_u48(&hdr, rl->sequence)
+                           : !WPACKET_put_bytes_u64(&hdr, rl->sequence))
+            || !WPACKET_put_bytes_u8(&hdr, rec->type)
+            || !WPACKET_put_bytes_u16(&hdr, rl->version)
+            || !WPACKET_put_bytes_u16(&hdr, rec->length)
+            || !WPACKET_finish(&hdr)
+            || !WPACKET_get_total_written(&hdr, &hdrsize)
+            || hdrsize != EVP_AEAD_TLS1_AAD_LEN) {
+        WPACKET_cleanup(&hdr);
+        return 0;
+    }
+
+    return 1;
+}
+
 #define MAX_PADDING 256
 /*-
  * tls1_cipher encrypts/decrypts |n_recs| in |recs|. Calls RLAYERfatal on
@@ -264,29 +288,16 @@ static int tls1_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
 
         if ((EVP_CIPHER_get_flags(EVP_CIPHER_CTX_get0_cipher(ds))
                  & EVP_CIPH_FLAG_AEAD_CIPHER) != 0) {
-            unsigned char *seq;
-
-            seq = rl->sequence;
-
-            if (rl->isdtls) {
-                unsigned char dtlsseq[8], *p = dtlsseq;
+            if (!setup_record_header(rl, &recs[ctr], buf[ctr], sizeof(buf[ctr]))) {
+                RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                return 0;
+            }
 
-                s2n(rl->epoch, p);
-                memcpy(p, &seq[2], 6);
-                memcpy(buf[ctr], dtlsseq, 8);
-            } else {
-                memcpy(buf[ctr], seq, 8);
-                if (!tls_increment_sequence_ctr(rl)) {
-                    /* RLAYERfatal already called */
-                    return 0;
-                }
+            if (!rl->isdtls && !tls_increment_sequence_ctr(rl)) {
+                /* RLAYERfatal already called */
+                return 0;
             }
 
-            buf[ctr][8] = recs[ctr].type;
-            buf[ctr][9] = (unsigned char)(rl->version >> 8);
-            buf[ctr][10] = (unsigned char)(rl->version);
-            buf[ctr][11] = (unsigned char)(recs[ctr].length >> 8);
-            buf[ctr][12] = (unsigned char)(recs[ctr].length & 0xff);
             pad = EVP_CIPHER_CTX_ctrl(ds, EVP_CTRL_AEAD_TLS1_AAD,
                                       EVP_AEAD_TLS1_AAD_LEN, buf[ctr]);
             if (pad <= 0) {
@@ -351,6 +362,9 @@ static int tls1_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
 
     if (!rl->isdtls && rl->tlstree) {
         int decrement_seq = 0;
+        unsigned char recseq[SEQ_NUM_SIZE], *p_recseq = recseq;
+
+        l2n8(rl->sequence, p_recseq);
 
         /*
          * When sending, seq is incremented after MAC calculation.
@@ -360,9 +374,7 @@ static int tls1_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
         if (sending && !rl->use_etm)
             decrement_seq = 1;
 
-        if (EVP_CIPHER_CTX_ctrl(ds, EVP_CTRL_TLSTREE, decrement_seq,
-                                rl->sequence) <= 0) {
-
+        if (EVP_CIPHER_CTX_ctrl(ds, EVP_CTRL_TLSTREE, decrement_seq, recseq) <= 0) {
             RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
             return 0;
         }
@@ -378,7 +390,7 @@ static int tls1_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
         }
 
         if (!EVP_CipherUpdate(ds, recs[0].data, &outlen, recs[0].input,
-                              (unsigned int)reclen[0]))
+                              (int)reclen[0]))
             return 0;
         recs[0].length = outlen;
 
@@ -478,7 +490,7 @@ static int tls1_cipher(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *recs,
 static int tls1_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md,
                     int sending)
 {
-    unsigned char *seq = rl->sequence;
+    unsigned char seq[SEQ_NUM_SIZE], *p_seq = seq;
     EVP_MD_CTX *hash;
     size_t md_size;
     EVP_MD_CTX *hmac = NULL, *mac_ctx;
@@ -487,6 +499,7 @@ static int tls1_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md
     int ret = 0;
 
     hash = rl->md_ctx;
+    l2n8(rl->sequence, p_seq);
 
     t = EVP_MD_CTX_get_size(hash);
     if (!ossl_assert(t >= 0))
@@ -508,22 +521,8 @@ static int tls1_mac(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec, unsigned char *md
             && EVP_MD_CTX_ctrl(mac_ctx, EVP_MD_CTRL_TLSTREE, 0, seq) <= 0)
         goto end;
 
-    if (rl->isdtls) {
-        unsigned char dtlsseq[8], *p = dtlsseq;
-
-        s2n(rl->epoch, p);
-        memcpy(p, &seq[2], 6);
-
-        memcpy(header, dtlsseq, 8);
-    } else {
-        memcpy(header, seq, 8);
-    }
-
-    header[8] = rec->type;
-    header[9] = (unsigned char)(rl->version >> 8);
-    header[10] = (unsigned char)(rl->version);
-    header[11] = (unsigned char)(rec->length >> 8);
-    header[12] = (unsigned char)(rec->length & 0xff);
+    if (!setup_record_header(rl, rec, header, sizeof(header)))
+        goto end;
 
     if (!sending && !rl->use_etm
         && EVP_CIPHER_CTX_get_mode(rl->enc_ctx) == EVP_CIPH_CBC_MODE
index 0725eb118aa0cfaa0c3faa9ab8c22eef5875bd8a..6976cbefd0c134bd8817efecd8f2ccc794b3ef1f 100644 (file)
@@ -39,12 +39,6 @@ static void TLS_RL_RECORD_release(TLS_RL_RECORD *r, size_t num_recs)
     }
 }
 
-void ossl_tls_rl_record_set_seq_num(TLS_RL_RECORD *r,
-                                    const unsigned char *seq_num)
-{
-    memcpy(r->seq_num, seq_num, SEQ_NUM_SIZE);
-}
-
 void ossl_rlayer_fatal(OSSL_RECORD_LAYER *rl, int al, int reason,
                        const char *fmt, ...)
 {
@@ -857,7 +851,7 @@ int tls_get_more_records(OSSL_RECORD_LAYER *rl)
             rl->curr_rec = 0;
             rl->num_released = 0;
             /* Reset the read sequence */
-            memset(rl->sequence, 0, sizeof(rl->sequence));
+            rl->sequence = 0;
             ret = 1;
             goto end;
         }
@@ -1084,7 +1078,8 @@ int tls13_common_post_process_record(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec)
 {
     if (rec->type != SSL3_RT_APPLICATION_DATA
             && rec->type != SSL3_RT_ALERT
-            && rec->type != SSL3_RT_HANDSHAKE) {
+            && rec->type != SSL3_RT_HANDSHAKE
+            && (!rl->isdtls || rec->type != SSL3_RT_ACK)) {
         RLAYERfatal(rl, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_BAD_RECORD_TYPE);
         return 0;
     }
@@ -1111,7 +1106,7 @@ int tls13_common_post_process_record(OSSL_RECORD_LAYER *rl, TLS_RL_RECORD *rec)
 
 int tls_read_record(OSSL_RECORD_LAYER *rl, void **rechandle, int *rversion,
                     uint8_t *type, const unsigned char **data, size_t *datalen,
-                    uint16_t *epoch, unsigned char *seq_num)
+                    uint16_t *epoch, uint64_t *seq_num)
 {
     TLS_RL_RECORD *rec;
 
@@ -1148,7 +1143,7 @@ int tls_read_record(OSSL_RECORD_LAYER *rl, void **rechandle, int *rversion,
     *datalen = rec->length;
     if (rl->isdtls) {
         *epoch = rec->epoch;
-        memcpy(seq_num, rec->seq_num, sizeof(rec->seq_num));
+        *seq_num = rec->seq_num;
     }
 
     return OSSL_RECORD_RETURN_SUCCESS;
@@ -2125,15 +2120,8 @@ void tls_set_max_frag_len(OSSL_RECORD_LAYER *rl, size_t max_frag_len)
 
 int tls_increment_sequence_ctr(OSSL_RECORD_LAYER *rl)
 {
-    int i;
-
     /* Increment the sequence counter */
-    for (i = SEQ_NUM_SIZE; i > 0; i--) {
-        ++(rl->sequence[i - 1]);
-        if (rl->sequence[i - 1] != 0)
-            break;
-    }
-    if (i == 0) {
+    if (++rl->sequence == 0) {
         /* Sequence has wrapped */
         RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, SSL_R_SEQUENCE_CTR_WRAPPED);
         return 0;
index ead66d84fccadc8b8e82487045dde7195b4aa47e..640ee26d589620b86d1a1be471276801ff5bf92e 100644 (file)
@@ -68,9 +68,10 @@ static int tls_write_records_multiblock_int(OSSL_RECORD_LAYER *rl,
 {
 #if !defined(OPENSSL_NO_MULTIBLOCK) && EVP_CIPH_FLAG_TLS1_1_MULTIBLOCK
     size_t i;
-    size_t totlen;
+    size_t totlen, aad_written;
     TLS_BUFFER *wb;
     unsigned char aad[13];
+    WPACKET aad_pkt;
     EVP_CTRL_TLS1_1_MULTIBLOCK_PARAM mb_param;
     size_t packlen;
     int packleni;
@@ -118,13 +119,20 @@ static int tls_write_records_multiblock_int(OSSL_RECORD_LAYER *rl,
     }
     wb = &rl->wbuf[0];
 
+    if (!WPACKET_init_static_len(&aad_pkt, aad, sizeof(aad), 0)
+            || !WPACKET_put_bytes_u64(&aad_pkt, rl->sequence)
+            || !WPACKET_put_bytes_u8(&aad_pkt, templates[0].type)
+            || !WPACKET_put_bytes_u16(&aad_pkt, templates[0].version)
+            || !WPACKET_put_bytes_u16(&aad_pkt, 0)
+            || !WPACKET_get_total_written(&aad_pkt, &aad_written)
+            || aad_written != sizeof(aad)
+            || !WPACKET_finish(&aad_pkt)) {
+        WPACKET_cleanup(&aad_pkt);
+        RLAYERfatal(rl, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return -1;
+    }
+
     mb_param.interleave = (unsigned int)numtempl;
-    memcpy(aad, rl->sequence, 8);
-    aad[8] = templates[0].type;
-    aad[9] = (unsigned char)(templates[0].version >> 8);
-    aad[10] = (unsigned char)(templates[0].version);
-    aad[11] = 0;
-    aad[12] = 0;
     mb_param.out = NULL;
     mb_param.inp = aad;
     mb_param.len = totlen;
@@ -149,11 +157,7 @@ static int tls_write_records_multiblock_int(OSSL_RECORD_LAYER *rl,
         return -1;
     }
 
-    rl->sequence[7] += mb_param.interleave;
-    if (rl->sequence[7] < mb_param.interleave) {
-        int j = 6;
-        while (j >= 0 && (++rl->sequence[j--]) == 0) ;
-    }
+    rl->sequence += mb_param.interleave;
 
     wb->offset = 0;
     wb->left = packlen;
index 883f3ae9944a338ee153a6350dcf734655ed9180..031f2c67e4dc579a7216ad646a2cd890020e15ba 100644 (file)
@@ -87,7 +87,7 @@ static int dtls_buffer_record(SSL_CONNECTION *s, TLS_RECORD *rec)
         return -1;
 
     rdata = OPENSSL_malloc(sizeof(*rdata));
-    item = pitem_new(rec->seq_num, rdata);
+    item = pitem_new_u64(rec->seq_num, rdata);
     if (rdata == NULL || item == NULL) {
         OPENSSL_free(rdata);
         pitem_free(item);
@@ -258,7 +258,7 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
                                                       &rr->rechandle,
                                                       &rr->version, &rr->type,
                                                       &rr->data, &rr->length,
-                                                      &rr->epoch, rr->seq_num));
+                                                      &rr->epoch, &rr->seq_num));
             if (ret <= 0) {
                 ret = dtls1_read_failed(sc, ret);
                 /*
@@ -314,14 +314,20 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
         return 0;
     }
 
+    if (rr->type == SSL3_RT_HANDSHAKE && SSL_CONNECTION_IS_DTLS13(sc)) {
+        sc->s3.tmp.record_epoch = rr->epoch;
+        sc->s3.tmp.record_seq_num = rr->seq_num;
+    }
+
     if (type == rr->type
-        || (rr->type == SSL3_RT_CHANGE_CIPHER_SPEC
-            && type == SSL3_RT_HANDSHAKE && recvd_type != NULL
-            && !is_dtls13)) {
+            || (type == SSL3_RT_HANDSHAKE
+                && ((!is_dtls13 && recvd_type != NULL && rr->type == SSL3_RT_CHANGE_CIPHER_SPEC)
+                    || (is_dtls13 && rr->type == SSL3_RT_ACK)))) {
         /*
          * SSL3_RT_APPLICATION_DATA or
          * SSL3_RT_HANDSHAKE or
-         * SSL3_RT_CHANGE_CIPHER_SPEC
+         * SSL3_RT_CHANGE_CIPHER_SPEC or
+         * SSL3_RT_ACK
          */
         /*
          * make sure that we are not getting application data when we are
@@ -375,6 +381,7 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
         }
 #endif
         *readbytes = n;
+
         return 1;
     }
 
@@ -501,7 +508,7 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
     /*
      * Unexpected handshake message (Client Hello, or protocol violation)
      */
-    if (rr->type == SSL3_RT_HANDSHAKE && !ossl_statem_get_in_handshake(sc)) {
+    if (!ossl_statem_get_in_handshake(sc) && rr->type == SSL3_RT_HANDSHAKE) {
         unsigned char msg_type;
 
         /*
@@ -548,7 +555,10 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
             }
             goto start;
         }
+    }
 
+    if (!ossl_statem_get_in_handshake(sc)
+        && (rr->type == SSL3_RT_HANDSHAKE || rr->type == SSL3_RT_ACK)) {
         /*
          * To get here we must be trying to read app data but found handshake
          * data. But if we're trying to read app data, and we're not in init
@@ -604,6 +614,34 @@ int dtls1_read_bytes(SSL *s, uint8_t type, uint8_t *recvd_type,
          */
         SSLfatal(sc, SSL_AD_UNEXPECTED_MESSAGE, ERR_R_INTERNAL_ERROR);
         return -1;
+
+    case SSL3_RT_ACK:
+        switch (sc->negotiated_version) {
+        case DTLS1_3_VERSION:
+            /* ACK should have been handled if DTLSv1.3 has been negotiated. */
+            SSLfatal(sc, SSL_AD_UNEXPECTED_MESSAGE, ERR_R_INTERNAL_ERROR);
+            return -1;
+
+        case DTLS_ANY_VERSION:
+            /*
+             * This must be an ACK from a DTLSv1.3 server for a partial
+             * ClientHello. We always send the full message again if the
+             * ClientHello is not responded to with a ServerHello before the
+             * timer runs out. Drop the record.
+             */
+            if (!ssl_release_record(sc, rr, 0))
+                return -1;
+            goto start;
+
+        default:
+            /*
+             * If we receive an ACK record when we have negotiated a lower version
+             * than DTLSv1.3 then we respond with an unexpected record fatal alert.
+             */
+            SSLfatal(sc, SSL_AD_UNEXPECTED_MESSAGE, SSL_R_UNEXPECTED_RECORD);
+            return -1;
+        }
+
     case SSL3_RT_APPLICATION_DATA:
         /*
          * At this point, we were expecting handshake data, but have
@@ -689,6 +727,33 @@ int do_dtls1_write(SSL_CONNECTION *sc, uint8_t type, const unsigned char *buf,
     if (ret > 0)
         *written = len;
 
+    /*
+     * Add record number to the buffered sent message
+     */
+    if (type == SSL3_RT_HANDSHAKE && ret > 0 && SSL_CONNECTION_IS_DTLS13(sc)) {
+        pitem *item;
+        unsigned char prio[8];
+
+        dtls1_get_queue_priority(prio, sc->d1->w_msg.msg_seq, 0);
+        item = pqueue_find(&sc->d1->sent_messages, prio);
+
+        if (item == NULL)
+            return ret;
+
+        if (dtls_msg_needs_ack(sc->server, sc->d1->w_msg.msg_type)) {
+            dtls_sent_msg *sent_msg;
+            DTLS1_RECORD_NUMBER *rec_num;
+
+            sent_msg = (dtls_sent_msg *) item->data;
+            rec_num = dtls1_record_number_new(tmpl.epoch, tmpl.sequence_number);
+
+            if (rec_num == NULL)
+                return -1;
+
+            ossl_list_record_number_insert_tail(&sent_msg->rec_nums, rec_num);
+        }
+    }
+
     return ret;
 }
 
index 9bdf4f23bdda615c34bc1c4770ac3b68535637f5..db9c7495cee9b35130b873d3e5e89093ac97ad98 100644 (file)
@@ -37,7 +37,7 @@ typedef struct tls_record_st {
     /* epoch number. DTLS only */
     uint16_t epoch;
     /* sequence number. DTLS only */
-    unsigned char seq_num[SEQ_NUM_SIZE];
+    uint64_t seq_num;
 #ifndef OPENSSL_NO_SCTP
     struct bio_dgram_sctp_rcvinfo recordinfo;
 #endif
@@ -160,6 +160,7 @@ int do_dtls1_write(SSL_CONNECTION *s, uint8_t type, const unsigned char *buf,
                    size_t len, size_t *written);
 void dtls1_increment_epoch(SSL_CONNECTION *s, int rw);
 uint16_t dtls1_get_epoch(SSL_CONNECTION *s, int rw);
+uint64_t dtls1_get_record_sequence_number(SSL_CONNECTION *s);
 int ssl_release_record(SSL_CONNECTION *s, TLS_RECORD *rr, size_t length);
 
 # define HANDLE_RLAYER_READ_RETURN(s, ret) \
index 0da48aafaa7f32faf1310f0c608318d7ba918f75..6dec8d8413509111493e6866609d0ce7cbbe97a5 100644 (file)
@@ -34,6 +34,7 @@
 # include "internal/tsan_assist.h"
 # include "internal/bio.h"
 # include "internal/ktls.h"
+# include "internal/list.h"
 # include "internal/time.h"
 # include "internal/ssl.h"
 # include "internal/cryptlib.h"
@@ -1417,6 +1418,8 @@ struct ssl_connection_st {
             size_t peer_finish_md_len;
             size_t message_size;
             int message_type;
+            uint64_t record_epoch;
+            uint64_t record_seq_num;
             /* used to hold the new cipher we are going to use */
             const SSL_CIPHER *new_cipher;
             EVP_PKEY *pkey;         /* holds short lived key exchange key */
@@ -2013,6 +2016,7 @@ struct pitem_st {
 typedef struct pitem_st *piterator;
 
 pitem *pitem_new(unsigned char *prio64be, void *data);
+pitem *pitem_new_u64(uint64_t prio, void *data);
 void pitem_free(pitem *item);
 pqueue *pqueue_new(void);
 void pqueue_free(pqueue *pq);
@@ -2020,23 +2024,61 @@ pitem *pqueue_insert(pqueue *pq, pitem *item);
 pitem *pqueue_peek(pqueue *pq);
 pitem *pqueue_pop(pqueue *pq);
 pitem *pqueue_find(pqueue *pq, unsigned char *prio64be);
+pitem *pqueue_find_u64(pqueue *pq, uint64_t prio);
 pitem *pqueue_iterator(pqueue *pq);
 pitem *pqueue_next(piterator *iter);
 size_t pqueue_size(pqueue *pq);
 
 typedef struct dtls_msg_info_st {
+    unsigned char record_type;
     unsigned char msg_type;
     size_t msg_body_len;
     unsigned short msg_seq;
 } dtls_msg_info;
 
+/* rfc9147, section 4 */
+typedef struct dtls1_record_number_st DTLS1_RECORD_NUMBER;
+
+struct dtls1_record_number_st {
+    uint64_t epoch;
+    uint64_t seqnum;
+    OSSL_LIST_MEMBER(record_number, DTLS1_RECORD_NUMBER);
+};
+
+DEFINE_LIST_OF(record_number, DTLS1_RECORD_NUMBER);
+
+DTLS1_RECORD_NUMBER *dtls1_record_number_new(uint64_t epoch, uint64_t seqnum);
+
+void ossl_list_record_number_elem_free(OSSL_LIST(record_number) *p_list);
+
 typedef struct dtls_sent_msg_st {
     dtls_msg_info msg_info;
-    int record_type;
+    OSSL_LIST(record_number) rec_nums;
     unsigned char *msg_buf;
     struct dtls1_retransmit_state saved_retransmit_state;
 } dtls_sent_msg;
 
+int dtls_any_sent_messages_are_missing_acknowledge(SSL_CONNECTION *s);
+
+static ossl_inline int dtls_msg_needs_ack(int sentbyserver, unsigned char msgtype)
+{
+    switch (msgtype) {
+    case SSL3_MT_NEWSESSION_TICKET:
+    case SSL3_MT_KEY_UPDATE:
+        return 1;
+
+    case SSL3_MT_CERTIFICATE:
+    case SSL3_MT_COMPRESSED_CERTIFICATE:
+    case SSL3_MT_CERTIFICATE_VERIFY:
+    case SSL3_MT_FINISHED:
+        if (!sentbyserver)
+            return 1;
+        /* fall-through */
+    default:
+        return 0;
+    }
+}
+
 typedef struct dtls1_state_st {
     unsigned char cookie[DTLS1_COOKIE_LENGTH];
     size_t cookie_len;
@@ -2070,6 +2112,9 @@ typedef struct dtls1_state_st {
     int shutdown_received;
 # endif
 
+    /* Sequence numbers that are to be acknowledged */
+    OSSL_LIST(record_number) ack_rec_num;
+
     DTLS_timer_cb timer_cb;
 
 } DTLS1_STATE;
@@ -2820,12 +2865,13 @@ int dtls1_write_app_data_bytes(SSL *s, uint8_t type, const void *buf_,
 
 __owur int dtls1_read_failed(SSL_CONNECTION *s, int code);
 __owur int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type);
-__owur int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq,
-                                    int *found);
-__owur int dtls1_get_queue_priority(unsigned short seq, int is_ccs);
+__owur int dtls1_retransmit_message(SSL_CONNECTION *s, dtls_sent_msg *sent_msg);
+void dtls1_get_queue_priority(unsigned char *prio64be, unsigned short seq,
+                              int record_type);
 int dtls1_retransmit_sent_messages(SSL_CONNECTION *s);
 void dtls1_clear_received_buffer(SSL_CONNECTION *s);
-void dtls1_clear_sent_buffer(SSL_CONNECTION *s);
+void dtls1_clear_sent_buffer(SSL_CONNECTION *s, int keep_unacked_msgs);
+void dtls1_acknowledge_sent_buffer(SSL_CONNECTION *s, uint16_t before_epoch);
 __owur OSSL_TIME dtls1_default_timeout(void);
 __owur int dtls1_get_timeout(const SSL_CONNECTION *s, OSSL_TIME *timeleft);
 __owur int dtls1_check_timeout_num(SSL_CONNECTION *s);
index d6ba000c65d45192e9495a1dcd00e4256c7a7f0e..367b93805a0ee0257d9bfd40dbe09e299437f8bd 100644 (file)
@@ -124,6 +124,14 @@ const char *SSL_state_string_long(const SSL *s)
         return "TLSv1.3 write end of early data";
     case TLS_ST_SR_END_OF_EARLY_DATA:
         return "TLSv1.3 read end of early data";
+    case TLS_ST_CR_ACK:
+        return "DTLSv1.3 read client ack";
+    case TLS_ST_CW_ACK:
+        return "DTLSv1.3 write client ack";
+    case TLS_ST_SR_ACK:
+        return "DTLSv1.3 read server ack";
+    case TLS_ST_SW_ACK:
+        return "DTLSv1.3 write server ack";
     default:
         return "unknown state";
     }
@@ -241,6 +249,14 @@ const char *SSL_state_string(const SSL *s)
         return "TWEOED";
     case TLS_ST_SR_END_OF_EARLY_DATA:
         return "TWEOED";
+    case TLS_ST_CR_ACK:
+        return "TRCACK";
+    case TLS_ST_CW_ACK:
+        return "TWCACK";
+    case TLS_ST_SR_ACK:
+        return "TRSACK";
+    case TLS_ST_SW_ACK:
+        return "TWSACK";
     default:
         return "UNKWN";
     }
index d62dc8060e1f93e29ff04c1170e1f40758a921b1..4c9590f5b4949dd34f71af7cfc7b058190e3006e 100644 (file)
@@ -543,8 +543,8 @@ static void init_read_state_machine(SSL_CONNECTION *s)
     st->read_state = READ_STATE_HEADER;
 }
 
-static int grow_init_buf(SSL_CONNECTION *s, size_t size) {
-
+static int grow_init_buf(SSL_CONNECTION *s, size_t size)
+{
     size_t msg_offset = (char *)s->init_msg - s->init_buf->data;
 
     if (!BUF_MEM_grow_clean(s->init_buf, size))
@@ -750,17 +750,28 @@ static SUB_STATE_RETURN read_state_machine(SSL_CONNECTION *s)
  */
 static int statem_do_write(SSL_CONNECTION *s)
 {
+    int record_type;
     OSSL_STATEM *st = &s->statem;
 
-    if (st->hand_state == TLS_ST_CW_CHANGE
-        || st->hand_state == TLS_ST_SW_CHANGE) {
-        if (SSL_CONNECTION_IS_DTLS(s))
-            return dtls1_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC);
-        else
-            return ssl3_do_write(s, SSL3_RT_CHANGE_CIPHER_SPEC);
-    } else {
+    switch (st->hand_state) {
+    case TLS_ST_CW_CHANGE:
+    case TLS_ST_SW_CHANGE:
+        record_type = SSL3_RT_CHANGE_CIPHER_SPEC;
+
+        break;
+    case TLS_ST_CW_ACK:
+    case TLS_ST_SW_ACK:
+        record_type = SSL3_RT_ACK;
+
+        break;
+    default:
         return ssl_do_write(s);
     }
+
+    if (SSL_CONNECTION_IS_DTLS(s))
+        return dtls1_do_write(s, record_type);
+    else
+        return ssl3_do_write(s, record_type);
 }
 
 /*
@@ -815,6 +826,7 @@ static SUB_STATE_RETURN write_state_machine(SSL_CONNECTION *s)
                                     CON_FUNC_RETURN (**confunc) (SSL_CONNECTION *s,
                                                                  WPACKET *pkt),
                                     int *mt);
+    int (*dtls_use_timer) (SSL_CONNECTION *s);
     void (*cb) (const SSL *ssl, int type, int val) = NULL;
     CON_FUNC_RETURN (*confunc) (SSL_CONNECTION *s, WPACKET *pkt);
     int mt;
@@ -828,11 +840,13 @@ static SUB_STATE_RETURN write_state_machine(SSL_CONNECTION *s)
         pre_work = ossl_statem_server_pre_work;
         post_work = ossl_statem_server_post_work;
         get_construct_message_f = ossl_statem_server_construct_message;
+        dtls_use_timer = ossl_statem_dtls_server_use_timer;
     } else {
         transition = ossl_statem_client_write_transition;
         pre_work = ossl_statem_client_pre_work;
         post_work = ossl_statem_client_post_work;
         get_construct_message_f = ossl_statem_client_construct_message;
+        dtls_use_timer = ossl_statem_dtls_client_use_timer;
     }
 
     while (1) {
@@ -925,9 +939,9 @@ static SUB_STATE_RETURN write_state_machine(SSL_CONNECTION *s)
             /* Fall through */
 
         case WRITE_STATE_SEND:
-            if (SSL_CONNECTION_IS_DTLS(s) && st->use_timer) {
+            if (SSL_CONNECTION_IS_DTLS(s) && dtls_use_timer(s))
                 dtls1_start_timer(s);
-            }
+
             ret = statem_do_write(s);
             if (ret <= 0) {
                 return SUB_STATE_ERROR;
index 1704b7f520cee42478480617118d17165448c23d..32d136b1a7bfed557a662b02993c80db8db77286 100644 (file)
@@ -182,7 +182,14 @@ static int ossl_statem_client13_read_transition(SSL_CONNECTION *s, int mt)
         }
         break;
 
+    case TLS_ST_CW_KEY_UPDATE:
+    case TLS_ST_CW_FINISHED:
+    case TLS_ST_CR_ACK:
     case TLS_ST_OK:
+        if (mt == DTLS13_MT_ACK) {
+            st->hand_state = TLS_ST_CR_ACK;
+            return 1;
+        }
         if (mt == SSL3_MT_NEWSESSION_TICKET) {
             st->hand_state = TLS_ST_CR_SESSION_TICKET;
             return 1;
@@ -509,9 +516,31 @@ static WRITE_TRAN ossl_statem_client13_write_transition(SSL_CONNECTION *s)
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_CR_KEY_UPDATE:
-    case TLS_ST_CW_KEY_UPDATE:
     case TLS_ST_CR_SESSION_TICKET:
+        if (SSL_CONNECTION_IS_DTLS13(s)) {
+            st->hand_state = TLS_ST_CW_ACK;
+            return WRITE_TRAN_CONTINUE;
+        }
+        /* Fall-through */
+    case TLS_ST_CW_KEY_UPDATE:
     case TLS_ST_CW_FINISHED:
+        if (SSL_CONNECTION_IS_DTLS13(s))
+            /* We wait for ACK */
+            return WRITE_TRAN_FINISHED;
+        else
+            st->hand_state = TLS_ST_OK;
+        return WRITE_TRAN_CONTINUE;
+
+    case TLS_ST_CR_ACK:
+        if (SSL_CONNECTION_IS_DTLS13(s)
+            && dtls_any_sent_messages_are_missing_acknowledge(s)) {
+            /* We wait for ACK */
+            return WRITE_TRAN_FINISHED;
+        }
+        st->hand_state = TLS_ST_OK;
+        return WRITE_TRAN_CONTINUE;
+
+    case TLS_ST_CW_ACK:
         st->hand_state = TLS_ST_OK;
         return WRITE_TRAN_CONTINUE;
 
@@ -698,6 +727,48 @@ WRITE_TRAN ossl_statem_client_write_transition(SSL_CONNECTION *s)
     }
 }
 
+static int ossl_statem_dtls_client13_use_timer(SSL_CONNECTION *s)
+{
+    OSSL_STATEM *st = &s->statem;
+
+    switch (st->hand_state) {
+    default:
+        break;
+
+    case TLS_ST_CW_ACK:
+        /* Fall through */
+
+    case TLS_ST_OK:
+        return 0;
+    }
+
+    return 1;
+}
+
+int ossl_statem_dtls_client_use_timer(SSL_CONNECTION *s)
+{
+    OSSL_STATEM *st = &s->statem;
+
+    if (SSL_CONNECTION_IS_DTLS13(s))
+        return ossl_statem_dtls_client13_use_timer(s);
+
+    switch (st->hand_state) {
+    default:
+        break;
+
+    case TLS_ST_CW_CHANGE:
+        /*
+         * We're into the last flight so we don't retransmit these
+         * messages unless we need to.
+         */
+        if (s->hit)
+            st->use_timer = 0;
+        break;
+    }
+
+    return st->use_timer;
+}
+
 /*
  * Perform any pre work that needs to be done prior to sending a message from
  * the client to the server.
@@ -741,13 +812,6 @@ WORK_STATE ossl_statem_client_pre_work(SSL_CONNECTION *s, WORK_STATE wst)
 
     case TLS_ST_CW_CHANGE:
         if (SSL_CONNECTION_IS_DTLS(s)) {
-            if (s->hit) {
-                /*
-                 * We're into the last flight so we don't retransmit these
-                 * messages unless we need to.
-                 */
-                st->use_timer = 0;
-            }
 #ifndef OPENSSL_NO_SCTP
             if (BIO_dgram_is_sctp(SSL_get_wbio(SSL_CONNECTION_GET_SSL(s)))) {
                 /* Calls SSLfatal() as required */
@@ -924,6 +988,11 @@ WORK_STATE ossl_statem_client_post_work(SSL_CONNECTION *s, WORK_STATE wst)
             return WORK_ERROR;
         }
         break;
+
+    case TLS_ST_CW_ACK:
+        if (statem_flush(s) != 1)
+            return WORK_MORE_A;
+        break;
     }
 
     return WORK_FINISHED_CONTINUE;
@@ -1008,6 +1077,11 @@ int ossl_statem_client_construct_message(SSL_CONNECTION *s,
         *confunc = tls_construct_key_update;
         *mt = SSL3_MT_KEY_UPDATE;
         break;
+
+    case TLS_ST_CW_ACK:
+        *confunc = dtls_construct_ack;
+        *mt = DTLS13_MT_ACK;
+        break;
     }
 
     return 1;
@@ -1073,6 +1147,9 @@ size_t ossl_statem_client_max_message_size(SSL_CONNECTION *s)
 
     case TLS_ST_CR_KEY_UPDATE:
         return KEY_UPDATE_MAX_LENGTH;
+
+    case TLS_ST_CR_ACK:
+        return ACK_MAX_LENGTH;
     }
 }
 
@@ -1136,6 +1213,9 @@ MSG_PROCESS_RETURN ossl_statem_client_process_message(SSL_CONNECTION *s,
 
     case TLS_ST_CR_KEY_UPDATE:
         return tls_process_key_update(s, pkt);
+
+    case TLS_ST_CR_ACK:
+        return dtls_process_ack(s, pkt);
     }
 }
 
index 5f60b37f21b58ddb9e232e90c1a328645e70dc69..81d06eb5f71d2c2b79a97c2719c91ba4ae95dd8f 100644 (file)
@@ -65,6 +65,9 @@ static dtls_sent_msg *dtls1_sent_msg_new(size_t msg_len)
 
 void dtls1_sent_msg_free(dtls_sent_msg *msg)
 {
+    if (msg != NULL)
+        ossl_list_record_number_elem_free(&msg->rec_nums);
+
     OPENSSL_free(msg);
 }
 
@@ -144,24 +147,14 @@ int dtls1_do_write(SSL_CONNECTION *s, uint8_t recordtype)
     size_t written;
     size_t curr_mtu;
     int retry = 1;
-    size_t len, overhead, used_len, msg_len = 0;
+    size_t len, overhead, used_len;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
     SSL *ussl = SSL_CONNECTION_GET_USER_SSL(s);
     uint8_t saved_payload[DTLS1_HM_HEADER_LENGTH];
-    unsigned char *data = (unsigned char *)s->init_buf->data;
-    unsigned short msg_seq = s->d1->w_msg.msg_seq;
-    unsigned char msg_type = 0;
-
-    if (recordtype == SSL3_RT_HANDSHAKE) {
-        msg_type = *data++;
-        l3n2(data, msg_len);
-    } else if (ossl_assert(recordtype == SSL3_RT_CHANGE_CIPHER_SPEC)) {
-        msg_type = SSL3_MT_CCS;
-        msg_len = 0; /* SSL3_RT_CHANGE_CIPHER_SPEC */
-    } else {
-        /* Other record types are not supported */
-        return -1;
-    }
+    /* msg_len, msg_seq, msg_type are only used for recordtype == SSL3_RT_HANDSHAKE */
+    const size_t msg_len = s->d1->w_msg.msg_body_len;
+    const unsigned short msg_seq = s->d1->w_msg.msg_seq;
+    const unsigned char msg_type = s->d1->w_msg.msg_type;
 
     if (!dtls1_query_mtu(s))
         return -1;
@@ -383,17 +376,11 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt)
 
     rec_data = (unsigned char *)s->init_buf->data;
 
-    if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
-        if (s->msg_callback) {
-            s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC,
-                            rec_data, 1, SSL_CONNECTION_GET_USER_SSL(s),
-                            s->msg_callback_arg);
-        }
-        /*
-         * This isn't a real handshake message so skip the processing below.
-         */
+    /*
+     * If this isn't a real handshake message skip the processing below.
+     */
+    if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC || *mt == DTLS13_MT_ACK)
         return 1;
-    }
 
     /* reconstruct message header */
     dtls1_write_hm_header(rec_data, s->s3.tmp.message_type, s->s3.tmp.message_size,
@@ -415,10 +402,28 @@ int dtls_get_message(SSL_CONNECTION *s, int *mt)
  */
 int dtls_get_message_body(SSL_CONNECTION *s, size_t *len)
 {
-    if (s->s3.tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC) {
-        /* Nothing to be done */
+    unsigned char *msg = (unsigned char *)s->init_buf->data;
+    size_t msg_len;
+    int recordtype;
+
+    switch (s->s3.tmp.message_type) {
+    default:
+        recordtype = SSL3_RT_HANDSHAKE;
+        msg_len = s->init_num + DTLS1_HM_HEADER_LENGTH;
+
+        break;
+    case DTLS13_MT_ACK:
+        recordtype = SSL3_RT_ACK;
+        msg_len = s->init_num;
+
+        goto end;
+    case SSL3_MT_CHANGE_CIPHER_SPEC:
+        recordtype = SSL3_RT_CHANGE_CIPHER_SPEC;
+        msg_len = 1;
+
         goto end;
     }
+
     /*
      * If receiving Finished, record MAC of prior handshake messages for
      * Finished verification.
@@ -431,12 +436,11 @@ int dtls_get_message_body(SSL_CONNECTION *s, size_t *len)
     if (!tls_common_finish_mac(s))
         return 0;
 
+end:
     if (s->msg_callback)
-        s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
-                        s->init_buf->data, s->init_num + DTLS1_HM_HEADER_LENGTH,
+        s->msg_callback(0, s->version, recordtype, msg, msg_len,
                         SSL_CONNECTION_GET_USER_SSL(s), s->msg_callback_arg);
 
- end:
     *len = s->init_num;
     return 1;
 }
@@ -486,6 +490,30 @@ static int dtls1_preprocess_fragment(SSL_CONNECTION *s,
     return 1;
 }
 
+static int add_record_to_ack_list(SSL_CONNECTION *sc)
+{
+    DTLS1_RECORD_NUMBER *recnum;
+    uint64_t epoch = sc->s3.tmp.record_epoch;
+    uint64_t sequence = sc->s3.tmp.record_seq_num;
+
+    for (recnum = ossl_list_record_number_head(&sc->d1->ack_rec_num);
+         recnum != NULL;
+         recnum = ossl_list_record_number_next(recnum)) {
+        /* Is the record number already in the list? */
+        if (recnum->epoch == epoch && recnum->seqnum == sequence)
+            return 1;
+    }
+
+    recnum = dtls1_record_number_new(epoch, sequence);
+
+    if (recnum == NULL)
+        return 0;
+
+    ossl_list_record_number_insert_tail(&sc->d1->ack_rec_num, recnum);
+
+    return 1;
+}
+
 /*
  * Returns 1 if there is a buffered fragment available, 0 if not, or -1 on a
  * fatal error.
@@ -612,13 +640,12 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s,
     hm_fragment *frag = NULL;
     pitem *item = NULL;
     int i = -1, is_complete;
-    unsigned char seq64be[8];
     size_t frag_len = msg_hdr->frag_len;
     size_t readbytes;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
 
-    if ((msg_hdr->frag_off + frag_len) > msg_hdr->msg_len ||
-        msg_hdr->msg_len > dtls1_max_handshake_message_len(s))
+    if ((msg_hdr->frag_off + frag_len) > msg_hdr->msg_len
+            || msg_hdr->msg_len > dtls1_max_handshake_message_len(s))
         goto err;
 
     if (frag_len == 0) {
@@ -626,10 +653,7 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s,
     }
 
     /* Try to find item in queue */
-    memset(seq64be, 0, sizeof(seq64be));
-    seq64be[6] = (unsigned char)(msg_hdr->seq >> 8);
-    seq64be[7] = (unsigned char)msg_hdr->seq;
-    item = pqueue_find(&s->d1->rcvd_messages, seq64be);
+    item = pqueue_find_u64(&s->d1->rcvd_messages, msg_hdr->seq);
 
     if (item == NULL) {
         frag = dtls1_hm_fragment_new(msg_hdr->msg_len, 1);
@@ -687,14 +711,14 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s,
         frag->reassembly = NULL;
 
     if (item == NULL) {
-        item = pitem_new(seq64be, frag);
+        item = pitem_new_u64(msg_hdr->seq, frag);
         if (item == NULL)
             goto err;
 
         item = pqueue_insert(&s->d1->rcvd_messages, item);
         /*
          * pqueue_insert fails iff a duplicate item is inserted. However,
-         * |item| cannot be a duplicate. If it were, |pqueue_find|, above,
+         * |item| cannot be a duplicate. If it were, |pqueue_find_u64|, above,
          * would have returned it and control would never have reached this
          * branch.
          */
@@ -702,6 +726,10 @@ static int dtls1_reassemble_fragment(SSL_CONNECTION *s,
             goto err;
     }
 
+    if (dtls_msg_needs_ack(!s->server, msg_hdr->type)
+            && !add_record_to_ack_list(s))
+        goto err;
+
     return DTLS1_HM_FRAGMENT_RETRY;
 
  err:
@@ -716,7 +744,6 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s,
     int i = -1;
     hm_fragment *frag = NULL;
     pitem *item = NULL;
-    unsigned char seq64be[8];
     size_t frag_len = msg_hdr->frag_len;
     size_t readbytes;
     SSL *ssl = SSL_CONNECTION_GET_SSL(s);
@@ -725,10 +752,7 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s,
         goto err;
 
     /* Try to find item in queue, to prevent duplicate entries */
-    memset(seq64be, 0, sizeof(seq64be));
-    seq64be[6] = (unsigned char)(msg_hdr->seq >> 8);
-    seq64be[7] = (unsigned char)msg_hdr->seq;
-    item = pqueue_find(&s->d1->rcvd_messages, seq64be);
+    item = pqueue_find_u64(&s->d1->rcvd_messages, msg_hdr->seq);
 
     /*
      * If we already have an entry and this one is a fragment, don't discard
@@ -782,10 +806,14 @@ static int dtls1_process_out_of_seq_message(SSL_CONNECTION *s,
                 goto err;
         }
 
-        item = pitem_new(seq64be, frag);
+        item = pitem_new_u64(msg_hdr->seq, frag);
         if (item == NULL)
             goto err;
 
+        if (dtls_msg_needs_ack(!s->server, msg_hdr->type)
+                && !add_record_to_ack_list(s))
+            goto err;
+
         item = pqueue_insert(&s->d1->rcvd_messages, item);
         /*
          * pqueue_insert fails iff a duplicate item is inserted. However,
@@ -885,6 +913,33 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
         *len = readbytes - 1;
         return 1;
     }
+    if (recvd_type == SSL3_RT_ACK) {
+        if (readbytes == DTLS1_HM_HEADER_LENGTH) {
+            const size_t first_readbytes = readbytes;
+
+            p += DTLS1_HM_HEADER_LENGTH;
+
+            i = ssl->method->ssl_read_bytes(ssl, SSL3_RT_HANDSHAKE, NULL, p,
+                                            s->init_num - DTLS1_HM_HEADER_LENGTH,
+                                            0, &readbytes);
+            readbytes += first_readbytes;
+            /*
+             * This shouldn't ever fail due to NBIO because we already checked
+             * that we have enough data in the record
+             */
+            if (i <= 0) {
+                s->rwstate = SSL_READING;
+                *len = 0;
+                return 0;
+            }
+        }
+        s->init_num = readbytes;
+        s->init_msg = s->init_buf->data;
+        s->s3.tmp.message_type = DTLS13_MT_ACK;
+        s->s3.tmp.message_size = readbytes;
+        *len = readbytes;
+        return 1;
+    }
 
     /* Handshake fails if message header is incomplete */
     if (readbytes != DTLS1_HM_HEADER_LENGTH) {
@@ -1002,6 +1057,12 @@ static int dtls_get_reassembled_message(SSL_CONNECTION *s, int *errtype,
         s->d1->next_handshake_write_seq = 0;
     }
 
+    if (dtls_msg_needs_ack(!s->server, msg_hdr.type)
+        && !add_record_to_ack_list(s)) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        goto f_err;
+    }
+
     /*
      * Note that s->init_num is *not* used as current offset in
      * s->init_buf->data, but as a counter summing up fragments' lengths: as
@@ -1038,6 +1099,105 @@ CON_FUNC_RETURN dtls_construct_change_cipher_spec(SSL_CONNECTION *s,
     return CON_FUNC_SUCCESS;
 }
 
+CON_FUNC_RETURN dtls_construct_ack(SSL_CONNECTION *s, WPACKET *pkt)
+{
+    DTLS1_RECORD_NUMBER *recnum;
+    DTLS1_RECORD_NUMBER *recnumnext = ossl_list_record_number_head(&s->d1->ack_rec_num);
+
+    if (!WPACKET_start_sub_packet_u16(pkt)) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return CON_FUNC_ERROR;
+    }
+
+    while ((recnum = recnumnext) != NULL) {
+        /*
+         * rfc9147: section 4.
+         *
+         * Record numbers are encoded as
+         *      struct {
+         *           uint64 epoch;
+         *           uint64 sequence_number;
+         *      } RecordNumber;
+         */
+
+        recnumnext = ossl_list_record_number_next(recnum);
+
+        if (recnum->epoch <= dtls1_get_epoch(s, SSL3_CC_WRITE)) {
+            /*
+             * rfc9147:
+             * During the handshake, ACK records MUST be sent with an epoch which
+             * is equal to or higher than the record which is being acknowledged
+             */
+            if (!WPACKET_put_bytes_u64(pkt, recnum->epoch)
+                || !WPACKET_put_bytes_u64(pkt, recnum->seqnum)) {
+                SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+                return CON_FUNC_ERROR;
+            }
+
+            ossl_list_record_number_remove(&s->d1->ack_rec_num, recnum);
+            OPENSSL_free(recnum);
+        }
+    }
+
+    if (!WPACKET_close(pkt)) {
+        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
+        return CON_FUNC_ERROR;
+    }
+
+    return CON_FUNC_SUCCESS;
+}
+
+MSG_PROCESS_RETURN dtls_process_ack(SSL_CONNECTION *s, PACKET *pkt)
+{
+    PACKET record_numbers;
+
+    if (!PACKET_get_length_prefixed_2(pkt, &record_numbers)) {
+        SSLfatal(s, SSL_AD_ILLEGAL_PARAMETER, SSL_R_LENGTH_TOO_LONG);
+        return MSG_PROCESS_ERROR;
+    }
+
+    while (PACKET_remaining(&record_numbers) > 0) {
+        /*
+         * rfc9147: section 4.
+         *
+         * Record numbers are encoded as
+         *      struct {
+         *           uint64 epoch;
+         *           uint64 sequence_number;
+         *      } RecordNumber;
+         */
+        pitem *item;
+        piterator iter;
+        uint64_t epoch;
+        uint64_t sequence_number;
+
+        if (!PACKET_get_net_8(&record_numbers, &epoch)
+                || !PACKET_get_net_8(&record_numbers, &sequence_number)) {
+            SSLfatal(s, SSL_AD_DECODE_ERROR, SSL_R_LENGTH_TOO_SHORT);
+            return MSG_PROCESS_ERROR;
+        }
+
+        iter = pqueue_iterator(&s->d1->sent_messages);
+
+        while ((item = pqueue_next(&iter)) != NULL) {
+            dtls_sent_msg *msg = (dtls_sent_msg *)item->data;
+            DTLS1_RECORD_NUMBER *recnum;
+            DTLS1_RECORD_NUMBER *recnum_next = ossl_list_record_number_head(&msg->rec_nums);
+
+            while ((recnum = recnum_next) != NULL) {
+                recnum_next = ossl_list_record_number_next(recnum_next);
+
+                if (recnum->epoch == epoch && recnum->seqnum == sequence_number) {
+                    ossl_list_record_number_remove(&msg->rec_nums, recnum);
+                    OPENSSL_free(recnum);
+                }
+            }
+        }
+    }
+
+    return MSG_PROCESS_FINISHED_READING;
+}
+
 #ifndef OPENSSL_NO_SCTP
 /*
  * Wait for a dry event. Should only be called at a point in the handshake
@@ -1104,7 +1264,8 @@ int dtls1_read_failed(SSL_CONNECTION *s, int code)
     return dtls1_handle_timeout(s);
 }
 
-int dtls1_get_queue_priority(unsigned short seq, int record_type)
+void dtls1_get_queue_priority(unsigned char *prio64be, unsigned short seq,
+                              int record_type)
 {
     /*
      * The index of the retransmission queue actually is the message sequence
@@ -1117,23 +1278,27 @@ int dtls1_get_queue_priority(unsigned short seq, int record_type)
      * priority queues) and fits in the unsigned short variable.
      */
     int lsb = (record_type == SSL3_RT_CHANGE_CIPHER_SPEC);
+    const uint16_t prio = seq * 2 - lsb;
 
-    return seq * 2 - lsb;
+    memset(prio64be, 0, 8);
+    prio64be[6] = (unsigned char)(prio >> 8);
+    prio64be[7] = (unsigned char)(prio);
 }
 
 int dtls1_retransmit_sent_messages(SSL_CONNECTION *s)
 {
     piterator iter = pqueue_iterator(&s->d1->sent_messages);
     pitem *item;
-    int found = 0;
 
     for (item = pqueue_next(&iter); item != NULL; item = pqueue_next(&iter)) {
-        int prio;
         dtls_sent_msg *sent_msg = (dtls_sent_msg *)item->data;
 
-        prio = dtls1_get_queue_priority(sent_msg->msg_info.msg_seq, sent_msg->record_type);
+        if (SSL_CONNECTION_IS_DTLS13(s)
+                && ossl_list_record_number_is_empty(&sent_msg->rec_nums))
+            /* rfc9147: Implementations must not retransmit acknowledged msgs */
+            continue;
 
-        if (dtls1_retransmit_message(s, (unsigned short)prio, &found) <= 0)
+        if (dtls1_retransmit_message(s, sent_msg) <= 0)
             return -1;
     }
 
@@ -1146,7 +1311,6 @@ int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type)
     dtls_sent_msg *sent_msg;
     unsigned char seq64be[8];
     size_t headerlen;
-    int prio;
 
     /*
      * this function is called immediately after a message has been
@@ -1172,19 +1336,13 @@ int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type)
         return 0;
     }
 
-    sent_msg->msg_info.msg_body_len = s->d1->w_msg.msg_body_len;
-    sent_msg->msg_info.msg_seq = s->d1->w_msg.msg_seq;
-    sent_msg->msg_info.msg_type = s->d1->w_msg.msg_type;
-    sent_msg->record_type = record_type;
+    memcpy(&sent_msg->msg_info, &s->d1->w_msg, sizeof(s->d1->w_msg));
 
     /* save current state */
     sent_msg->saved_retransmit_state.wrlmethod = s->rlayer.wrlmethod;
     sent_msg->saved_retransmit_state.wrl = s->rlayer.wrl;
 
-    prio = dtls1_get_queue_priority(sent_msg->msg_info.msg_seq, sent_msg->record_type);
-    memset(seq64be, 0, sizeof(seq64be));
-    seq64be[6] = (unsigned char)(prio >> 8);
-    seq64be[7] = (unsigned char)prio;
+    dtls1_get_queue_priority(seq64be, sent_msg->msg_info.msg_seq, sent_msg->msg_info.record_type);
 
     item = pitem_new(seq64be, sent_msg);
     if (item == NULL) {
@@ -1196,43 +1354,25 @@ int dtls1_buffer_sent_message(SSL_CONNECTION *s, int record_type)
     return 1;
 }
 
-int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found)
+int dtls1_retransmit_message(SSL_CONNECTION *s, dtls_sent_msg *sent_msg)
 {
     int ret;
-    /* XDTLS: for now assuming that read/writes are blocking */
-    pitem *item;
-    dtls_sent_msg *sent_msg;
     unsigned long header_length;
-    unsigned char seq64be[8];
     struct dtls1_retransmit_state saved_state;
 
-    /* XDTLS:  the requested message ought to be found, otherwise error */
-    memset(seq64be, 0, sizeof(seq64be));
-    seq64be[6] = (unsigned char)(seq >> 8);
-    seq64be[7] = (unsigned char)seq;
-
-    item = pqueue_find(&s->d1->sent_messages, seq64be);
-    if (item == NULL) {
-        SSLfatal(s, SSL_AD_INTERNAL_ERROR, ERR_R_INTERNAL_ERROR);
-        *found = 0;
-        return 0;
-    }
-
-    *found = 1;
-    sent_msg = (dtls_sent_msg *)item->data;
-
-    if (sent_msg->record_type == SSL3_RT_CHANGE_CIPHER_SPEC)
+    if (sent_msg->msg_info.record_type == SSL3_RT_CHANGE_CIPHER_SPEC)
         header_length = DTLS1_CCS_HEADER_LENGTH;
     else
         header_length = DTLS1_HM_HEADER_LENGTH;
 
+    /* Clear the record number list to be acked for retransmitted messages */
+    ossl_list_record_number_elem_free(&sent_msg->rec_nums);
+
     memcpy(s->init_buf->data, sent_msg->msg_buf,
            sent_msg->msg_info.msg_body_len + header_length);
     s->init_num = sent_msg->msg_info.msg_body_len + header_length;
 
-    s->d1->w_msg.msg_type = sent_msg->msg_info.msg_type;
-    s->d1->w_msg.msg_body_len = sent_msg->msg_info.msg_body_len;
-    s->d1->w_msg.msg_seq = sent_msg->msg_info.msg_seq;
+    memcpy(&s->d1->w_msg, &sent_msg->msg_info, sizeof(sent_msg->msg_info));
 
     /* save current state */
     saved_state.wrlmethod = s->rlayer.wrlmethod;
@@ -1250,7 +1390,7 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found)
      */
     s->rlayer.wrlmethod->set1_bio(s->rlayer.wrl, s->wbio);
 
-    ret = dtls1_do_write(s, sent_msg->record_type);
+    ret = dtls1_do_write(s, sent_msg->msg_info.record_type);
 
     /* restore current state */
     s->rlayer.wrlmethod = saved_state.wrlmethod;
@@ -1264,24 +1404,25 @@ int dtls1_retransmit_message(SSL_CONNECTION *s, unsigned short seq, int *found)
 
 int dtls1_set_handshake_header(SSL_CONNECTION *s, WPACKET *pkt, int htype)
 {
-    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) {
-        s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
+    s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
+    s->d1->w_msg.msg_seq = s->d1->handshake_write_seq;
+    s->d1->w_msg.msg_body_len = 0;
 
+    if (htype == SSL3_MT_CHANGE_CIPHER_SPEC) {
+        s->d1->w_msg.record_type = SSL3_RT_CHANGE_CIPHER_SPEC;
         s->d1->w_msg.msg_type = SSL3_MT_CCS;
-        s->d1->w_msg.msg_body_len = 0;
-        s->d1->w_msg.msg_seq = s->d1->handshake_write_seq;
 
         if (!WPACKET_put_bytes_u8(pkt, SSL3_MT_CCS))
             return 0;
+    } else if (htype == DTLS13_MT_ACK) {
+        s->d1->w_msg.record_type = SSL3_RT_ACK;
+        s->d1->w_msg.msg_type = 0;
     } else {
         size_t subpacket_offset = DTLS1_HM_HEADER_LENGTH - SSL3_HM_HEADER_LENGTH;
 
-        s->d1->handshake_write_seq = s->d1->next_handshake_write_seq;
         s->d1->next_handshake_write_seq++;
-
+        s->d1->w_msg.record_type = SSL3_RT_HANDSHAKE;
         s->d1->w_msg.msg_type = htype;
-        s->d1->w_msg.msg_body_len = 0;
-        s->d1->w_msg.msg_seq = s->d1->handshake_write_seq;
 
         /* Set the content type and 3 bytes for the message len */
         if (!WPACKET_put_bytes_u8(pkt, htype)
@@ -1299,26 +1440,22 @@ int dtls1_set_handshake_header(SSL_CONNECTION *s, WPACKET *pkt, int htype)
 int dtls1_close_construct_packet(SSL_CONNECTION *s, WPACKET *pkt, int htype)
 {
     size_t msglen;
-    int record_type;
-
-    /* Convert from possible dummy message type */
-    record_type = (htype == SSL3_MT_CHANGE_CIPHER_SPEC) ? SSL3_RT_CHANGE_CIPHER_SPEC
-                                                        : SSL3_RT_HANDSHAKE;
 
-    if ((htype != SSL3_MT_CHANGE_CIPHER_SPEC && !WPACKET_close(pkt))
+    if ((s->d1->w_msg.record_type == SSL3_RT_HANDSHAKE && !WPACKET_close(pkt))
             || !WPACKET_get_length(pkt, &msglen)
             || msglen > INT_MAX)
         return 0;
 
-    if (htype != SSL3_MT_CHANGE_CIPHER_SPEC)
+    if (s->d1->w_msg.record_type == SSL3_RT_HANDSHAKE)
         s->d1->w_msg.msg_body_len = msglen - DTLS1_HM_HEADER_LENGTH;
 
     s->init_num = msglen;
     s->init_off = 0;
 
-    if (htype != DTLS1_MT_HELLO_VERIFY_REQUEST) {
+    if (htype != DTLS1_MT_HELLO_VERIFY_REQUEST
+            && s->d1->w_msg.record_type != SSL3_RT_ACK) {
         /* Buffer the message to handle re-xmits */
-        if (!dtls1_buffer_sent_message(s, record_type))
+        if (!dtls1_buffer_sent_message(s, s->d1->w_msg.record_type))
             return 0;
     }
 
index 843de3d9354033dfaaaccf17315d62c2fab49db6..71187f55deebf1a0012e88e0792171e4e39e90c0 100644 (file)
@@ -27,6 +27,7 @@
 #define SERVER_HELLO_DONE_MAX_LENGTH    0
 #define KEY_UPDATE_MAX_LENGTH           1
 #define CCS_MAX_LENGTH                  1
+#define ACK_MAX_LENGTH                  65538
 
 /* Max ServerHello size permitted by RFC 8446 */
 #define SERVER_HELLO_MAX_LENGTH         65607
@@ -92,6 +93,7 @@ MSG_PROCESS_RETURN ossl_statem_client_process_message(SSL_CONNECTION *s,
                                                       PACKET *pkt);
 WORK_STATE ossl_statem_client_post_process_message(SSL_CONNECTION *s,
                                                    WORK_STATE wst);
+int ossl_statem_dtls_client_use_timer(SSL_CONNECTION *s);
 
 /*
  * TLS/DTLS server state machine functions
@@ -107,6 +109,7 @@ MSG_PROCESS_RETURN ossl_statem_server_process_message(SSL_CONNECTION *s,
                                                       PACKET *pkt);
 WORK_STATE ossl_statem_server_post_process_message(SSL_CONNECTION *s,
                                                    WORK_STATE wst);
+int ossl_statem_dtls_server_use_timer(SSL_CONNECTION *s);
 
 /* Functions for getting new message data */
 __owur int tls_get_message_header(SSL_CONNECTION *s, int *mt);
@@ -124,6 +127,7 @@ __owur CON_FUNC_RETURN  tls_construct_change_cipher_spec(SSL_CONNECTION *s,
                                                          WPACKET *pkt);
 __owur CON_FUNC_RETURN dtls_construct_change_cipher_spec(SSL_CONNECTION *s,
                                                          WPACKET *pkt);
+__owur CON_FUNC_RETURN dtls_construct_ack(SSL_CONNECTION *s, WPACKET *pkt);
 
 __owur CON_FUNC_RETURN tls_construct_finished(SSL_CONNECTION *s, WPACKET *pkt);
 __owur CON_FUNC_RETURN tls_construct_key_update(SSL_CONNECTION *s, WPACKET *pkt);
@@ -195,6 +199,7 @@ __owur CON_FUNC_RETURN tls_construct_next_proto(SSL_CONNECTION *s, WPACKET *pkt)
 #endif
 __owur MSG_PROCESS_RETURN tls_process_hello_req(SSL_CONNECTION *s, PACKET *pkt);
 __owur MSG_PROCESS_RETURN dtls_process_hello_verify(SSL_CONNECTION *s, PACKET *pkt);
+__owur MSG_PROCESS_RETURN dtls_process_ack(SSL_CONNECTION *s, PACKET *pkt);
 __owur CON_FUNC_RETURN tls_construct_end_of_early_data(SSL_CONNECTION *s,
                                                        WPACKET *pkt);
 
index 234ff89893fd0e99361e937c1ebfd0b17239ff02..56bcab9682a22eac967874e805d85bd2e444b573 100644 (file)
@@ -140,6 +140,15 @@ static int ossl_statem_server13_read_transition(SSL_CONNECTION *s, int mt)
         }
         break;
 
+    case TLS_ST_SR_ACK:
+    case TLS_ST_SW_KEY_UPDATE:
+    case TLS_ST_SW_SESSION_TICKET:
+        if (mt == DTLS13_MT_ACK) {
+            st->hand_state = TLS_ST_SR_ACK;
+            return 1;
+        }
+        break;
+
     case TLS_ST_OK:
         /*
          * Its never ok to start processing handshake messages in the middle of
@@ -166,6 +175,11 @@ static int ossl_statem_server13_read_transition(SSL_CONNECTION *s, int mt)
             st->hand_state = TLS_ST_SR_KEY_UPDATE;
             return 1;
         }
+
+        if (mt == DTLS13_MT_ACK) {
+            st->hand_state = TLS_ST_SR_ACK;
+            return 1;
+        }
         break;
     }
 
@@ -469,6 +483,7 @@ static int do_compressed_cert(SSL_CONNECTION *sc)
  */
 static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s)
 {
+    OSSL_HANDSHAKE_STATE next_state;
     OSSL_STATEM *st = &s->statem;
 
     /*
@@ -573,16 +588,35 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s)
              * If we're not going to renew the ticket then we just finish the
              * handshake at this point.
              */
-            st->hand_state = TLS_ST_OK;
+            if (SSL_CONNECTION_IS_DTLS13(s)) {
+                st->deferred_ack_state = TLS_ST_OK;
+                st->hand_state = TLS_ST_SW_ACK;
+            } else {
+                st->hand_state = TLS_ST_OK;
+            }
+
             return WRITE_TRAN_CONTINUE;
         }
         if (s->num_tickets > s->sent_tickets)
-            st->hand_state = TLS_ST_SW_SESSION_TICKET;
+            next_state = TLS_ST_SW_SESSION_TICKET;
         else
-            st->hand_state = TLS_ST_OK;
+            next_state = TLS_ST_OK;
+
+        if (SSL_CONNECTION_IS_DTLS13(s)) {
+            st->deferred_ack_state = next_state;
+            st->hand_state = TLS_ST_SW_ACK;
+        } else {
+            st->hand_state = next_state;
+        }
         return WRITE_TRAN_CONTINUE;
 
     case TLS_ST_SR_KEY_UPDATE:
+        if (SSL_CONNECTION_IS_DTLS13(s)) {
+            st->deferred_ack_state = TLS_ST_OK;
+            st->hand_state = TLS_ST_SW_ACK;
+            return WRITE_TRAN_CONTINUE;
+        }
+        /* Fall through */
     case TLS_ST_SW_KEY_UPDATE:
         st->hand_state = TLS_ST_OK;
         return WRITE_TRAN_CONTINUE;
@@ -592,13 +626,26 @@ static WRITE_TRAN ossl_statem_server13_write_transition(SSL_CONNECTION *s)
          * Following an initial handshake we send the number of tickets we have
          * been configured for.
          */
-        if (!SSL_IS_FIRST_HANDSHAKE(s) && s->ext.extra_tickets_expected > 0) {
-            return WRITE_TRAN_CONTINUE;
-        } else if (s->hit || s->num_tickets <= s->sent_tickets) {
+        if ((SSL_IS_FIRST_HANDSHAKE(s) || s->ext.extra_tickets_expected <= 0)
+                && (s->hit || s->num_tickets <= s->sent_tickets)) {
             /* We've written enough tickets out. */
             st->hand_state = TLS_ST_OK;
         }
         return WRITE_TRAN_CONTINUE;
+
+    case TLS_ST_SR_ACK:
+        if (SSL_CONNECTION_IS_DTLS13(s)
+            && dtls_any_sent_messages_are_missing_acknowledge(s)) {
+            /* We wait for ACK */
+            return WRITE_TRAN_FINISHED;
+        }
+        st->hand_state = TLS_ST_OK;
+        return WRITE_TRAN_CONTINUE;
+
+    case TLS_ST_SW_ACK:
+        st->hand_state = st->deferred_ack_state;
+
+        return WRITE_TRAN_CONTINUE;
     }
 }
 
@@ -742,6 +789,70 @@ WRITE_TRAN ossl_statem_server_write_transition(SSL_CONNECTION *s)
     }
 }
 
+static int ossl_statem_dtls_server13_use_timer(SSL_CONNECTION *s)
+{
+    OSSL_STATEM *st = &s->statem;
+
+    switch (st->hand_state) {
+    default:
+        break;
+
+    case TLS_ST_SW_ACK:
+        /* Fall through */
+
+    case TLS_ST_OK:
+        return 0;
+    }
+
+    return 1;
+}
+
+int ossl_statem_dtls_server_use_timer(SSL_CONNECTION *s)
+{
+    OSSL_STATEM *st = &s->statem;
+
+    if (SSL_CONNECTION_IS_DTLS13(s))
+        return ossl_statem_dtls_server13_use_timer(s);
+
+    switch (st->hand_state) {
+    default:
+        break;
+
+    case TLS_ST_SW_SESSION_TICKET:
+        /*
+         * We're into the last flight. We don't retransmit the last flight
+         * unless we need to, so we don't use the timer
+         */
+        st->use_timer = 0;
+        break;
+
+    case TLS_ST_SW_CHANGE:
+        /*
+         * We're into the last flight. We don't retransmit the last flight
+         * unless we need to, so we don't use the timer. This might have
+         * already been set to 0 if we sent a NewSessionTicket message,
+         * but we'll set it again here in case we didn't.
+         */
+        st->use_timer = 0;
+        break;
+
+    case DTLS_ST_SW_HELLO_VERIFY_REQUEST:
+        /* We don't buffer this message so don't use the timer */
+        st->use_timer = 0;
+        break;
+
+    case TLS_ST_SW_SRVR_HELLO:
+        /*
+        * Messages we write from now on should be buffered and
+        * retransmitted if necessary, so we need to use the timer now
+        */
+        st->use_timer = 1;
+        break;
+    }
+
+    return st->use_timer;
+}
+
 /*
  * Perform any pre work that needs to be done prior to sending a message from
  * the server to the client.
@@ -757,28 +868,12 @@ WORK_STATE ossl_statem_server_pre_work(SSL_CONNECTION *s, WORK_STATE wst)
         break;
 
     case TLS_ST_SW_HELLO_REQ:
-        s->shutdown = 0;
-        if (SSL_CONNECTION_IS_DTLS(s))
-            dtls1_clear_sent_buffer(s);
-        break;
-
+        /* fall-through */
     case DTLS_ST_SW_HELLO_VERIFY_REQUEST:
         s->shutdown = 0;
-        if (SSL_CONNECTION_IS_DTLS(s)) {
-            dtls1_clear_sent_buffer(s);
-            /* We don't buffer this message so don't use the timer */
-            st->use_timer = 0;
-        }
-        break;
+        if (SSL_CONNECTION_IS_DTLS(s))
+            dtls1_clear_sent_buffer(s, 0);
 
-    case TLS_ST_SW_SRVR_HELLO:
-        if (SSL_CONNECTION_IS_DTLS(s)) {
-            /*
-             * Messages we write from now on should be buffered and
-             * retransmitted if necessary, so we need to use the timer now
-             */
-            st->use_timer = 1;
-        }
         break;
 
     case TLS_ST_SW_SRVR_DONE:
@@ -802,13 +897,6 @@ WORK_STATE ossl_statem_server_pre_work(SSL_CONNECTION *s, WORK_STATE wst)
              */
             return tls_finish_handshake(s, wst, 0, 0);
         }
-        if (SSL_CONNECTION_IS_DTLS(s)) {
-            /*
-             * We're into the last flight. We don't retransmit the last flight
-             * unless we need to, so we don't use the timer
-             */
-            st->use_timer = 0;
-        }
         break;
 
     case TLS_ST_SW_CHANGE:
@@ -825,15 +913,6 @@ WORK_STATE ossl_statem_server_pre_work(SSL_CONNECTION *s, WORK_STATE wst)
             /* SSLfatal() already called */
             return WORK_ERROR;
         }
-        if (SSL_CONNECTION_IS_DTLS(s)) {
-            /*
-             * We're into the last flight. We don't retransmit the last flight
-             * unless we need to, so we don't use the timer. This might have
-             * already been set to 0 if we sent a NewSessionTicket message,
-             * but we'll set it again here in case we didn't.
-             */
-            st->use_timer = 0;
-        }
         return WORK_FINISHED_CONTINUE;
 
     case TLS_ST_EARLY_DATA:
@@ -1102,6 +1181,11 @@ WORK_STATE ossl_statem_server_post_work(SSL_CONNECTION *s, WORK_STATE wst)
             return WORK_MORE_A;
         }
         break;
+
+    case TLS_ST_SW_ACK:
+        if (statem_flush(s) != 1)
+            return WORK_MORE_A;
+        break;
     }
 
     return WORK_FINISHED_CONTINUE;
@@ -1212,6 +1296,11 @@ int ossl_statem_server_construct_message(SSL_CONNECTION *s,
         *confunc = tls_construct_key_update;
         *mt = SSL3_MT_KEY_UPDATE;
         break;
+
+    case TLS_ST_SW_ACK:
+        *confunc = dtls_construct_ack;
+        *mt = DTLS13_MT_ACK;
+        break;
     }
 
     return 1;
@@ -1279,6 +1368,9 @@ size_t ossl_statem_server_max_message_size(SSL_CONNECTION *s)
 
     case TLS_ST_SR_KEY_UPDATE:
         return KEY_UPDATE_MAX_LENGTH;
+
+    case TLS_ST_SR_ACK:
+        return ACK_MAX_LENGTH;
     }
 }
 
@@ -1330,6 +1422,8 @@ MSG_PROCESS_RETURN ossl_statem_server_process_message(SSL_CONNECTION *s,
     case TLS_ST_SR_KEY_UPDATE:
         return tls_process_key_update(s, pkt);
 
+    case TLS_ST_SR_ACK:
+        return dtls_process_ack(s, pkt);
     }
 }
 
index 778d7c0fce2b51b1b86487d4ac6a66a3b0da5ec5..d08f76d993c8fb59139e19d49e244e01311b56ed 100644 (file)
@@ -813,12 +813,6 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
                : OSSL_RECORD_PROTECTION_LEVEL_APPLICATION);
 
     if (SSL_CONNECTION_IS_DTLS(s)) {
-        /* We have moved to the next flight lets clear out old messages */
-        if (direction == OSSL_RECORD_DIRECTION_READ)
-            dtls1_clear_received_buffer(s);
-        else
-            dtls1_clear_sent_buffer(s);
-
         dtls1_increment_epoch(s, which);
 
         if (level == OSSL_RECORD_PROTECTION_LEVEL_HANDSHAKE
@@ -829,6 +823,14 @@ int tls13_change_cipher_state(SSL_CONNECTION *s, int which)
              */
             dtls1_increment_epoch(s, which);
         }
+
+        /* We have moved to the next flight lets clear out old messages */
+        if (direction == OSSL_RECORD_DIRECTION_READ) {
+            dtls1_clear_received_buffer(s);
+            dtls1_acknowledge_sent_buffer(s, dtls1_get_epoch(s, which));
+        }
+
+        dtls1_clear_sent_buffer(s, 1);
     }
 
     if (!ssl_set_new_record_layer(s, s->version, direction, level, secret,
diff --git a/test/recipes/70-test_dtls13ack.t b/test/recipes/70-test_dtls13ack.t
new file mode 100644 (file)
index 0000000..57a8770
--- /dev/null
@@ -0,0 +1,228 @@
+#! /usr/bin/env perl
+# Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+#
+# Licensed under the Apache License 2.0 (the "License").  You may not use
+# this file except in compliance with the License.  You can obtain a copy
+# in the file LICENSE in the source distribution or at
+# https://www.openssl.org/source/license.html
+
+use strict;
+use feature 'state';
+
+use OpenSSL::Test qw/:DEFAULT cmdstr srctop_file bldtop_dir/;
+use OpenSSL::Test::Utils;
+use TLSProxy::Proxy;
+use TLSProxy::Message;
+use Cwd qw(abs_path);
+
+my $test_name = "test_dtls13ack";
+setup($test_name);
+
+# TODO(DTLSv1.3): The test currently does not work as changes to Â´engines/e_ossltest.c´
+# in #25119 should be ported to the ossltest provider.
+plan skip_all => "This doesn't work properly currently";
+plan skip_all => "TLSProxy isn't usable on $^O"
+    if $^O =~ /^(VMS)$/ || $^O =~ /^(MSWin32)$/;
+
+plan skip_all => "$test_name needs the module feature enabled"
+    if disabled("module");
+
+plan skip_all => "$test_name needs the sock feature enabled"
+    if disabled("sock");
+
+plan skip_all => "DTLSProxy does not support partial messages"
+    if disabled("ec");
+
+plan skip_all => "$test_name needs DTLSv1.3 enabled"
+    if disabled("dtls1_3");
+
+$ENV{OPENSSL_MODULES} = abs_path(bldtop_dir("test"));
+
+my $proxy = TLSProxy::Proxy->new_dtls(
+    undef,
+    cmdstr(app(["openssl"]), display => 1),
+    srctop_file("apps", "server.pem"),
+    (!$ENV{HARNESS_ACTIVE} || $ENV{HARNESS_VERBOSE})
+);
+
+my $testcount = 3;
+
+plan tests => $testcount;
+
+#Test 1: Check that records are acked during an uninterrupted handshake
+$proxy->serverflags("-min_protocol DTLSv1.3 -max_protocol DTLSv1.3");
+$proxy->clientflags("-min_protocol DTLSv1.3 -max_protocol DTLSv1.3");
+TLSProxy::Message->successondata(1);
+skip "TLSProxy could not start", $testcount if !$proxy->start();
+
+my @expected = get_expected_ack_record_numbers();
+my @actual = get_actual_acked_record_numbers();
+my @missing = record_numbers_missing(\@expected, \@actual);
+my $expected_count = @expected;
+my $missing_count = @missing;
+
+ok($missing_count == 0 && $expected_count == 1,
+    "Check that all record numbers are acked");
+
+# Test 2: Check that records that are missing are not acked during a handshake
+$proxy->clear();
+my $found_first_client_finish_msg = 0;
+$proxy->serverflags("-min_protocol DTLSv1.3 -max_protocol DTLSv1.3");
+$proxy->clientflags("-min_protocol DTLSv1.3 -max_protocol DTLSv1.3");
+$proxy->filter(\&drop_first_client_finish_filter);
+TLSProxy::Message->successondata(1);
+$proxy->start();
+
+@expected = get_expected_ack_record_numbers();
+@actual = get_actual_acked_record_numbers();
+@missing = record_numbers_missing(\@expected, \@actual);
+$expected_count = @expected;
+$missing_count = @missing;
+
+ok($missing_count == 1 && $expected_count == 2,
+   "Check that all record numbers except one are acked");
+
+SKIP: {
+    skip "TODO(DTLSv1.3): This test fails because the client does not properly
+          handle when the last flight is dropped when it includes a
+          CompressedCertificate.", 1
+        if !disabled("zlib") || !disabled("zstd") || !disabled("brotli");
+    # Test 3: Check that client cert and verify messages are also acked
+    $proxy->clear();
+    $proxy->filter(undef);
+    $found_first_client_finish_msg = 0;
+    $proxy->serverflags("-min_protocol DTLSv1.3 -max_protocol DTLSv1.3 -Verify 1");
+    $proxy->clientflags("-mtu 2000 -min_protocol DTLSv1.3 -max_protocol DTLSv1.3"
+                        ." -cert ".srctop_file("apps", "server.pem"));
+    TLSProxy::Message->successondata(1);
+    $proxy->start();
+
+    @expected = get_expected_ack_record_numbers();
+    @actual = get_actual_acked_record_numbers();
+    @missing = record_numbers_missing(\@expected, \@actual);
+    $expected_count = @expected;
+    $missing_count = @missing;
+
+    ok($missing_count == 0 && $expected_count == 3,
+        "Check that all record numbers are acked");
+}
+
+sub get_expected_ack_record_numbers
+{
+    my $records = $proxy->record_list;
+    my @record_numbers = ();
+
+    foreach (@{$records}) {
+        my $record = $_;
+
+        if ($record->content_type == TLSProxy::Record::RT_HANDSHAKE
+                && $record->{sent}) {
+            my $epoch = $record->epoch;
+            my $seqnum = $record->seq;
+            my $serverissender = $record->serverissender;
+            my $recnum = TLSProxy::RecordNumber->new($epoch, $seqnum);
+
+            my @messages = TLSProxy::Message->get_messages($record);
+
+            my $record_should_be_acked = 0;
+
+            foreach (@messages) {
+                my $message = $_;
+                if (!$serverissender
+                    && ($message->mt == TLSProxy::Message::MT_FINISHED
+                        || $message->mt == TLSProxy::Message::MT_CERTIFICATE
+                        || $message->mt == TLSProxy::Message::MT_COMPRESSED_CERTIFICATE
+                        || $message->mt == TLSProxy::Message::MT_CERTIFICATE_VERIFY)
+                # TODO(DTLSv1.3): The ACK of the following messages are never processed
+                # by the proxy because s_client is closed too early send it:
+                #        || $message->mt == TLSProxy::Message::MT_KEY_UPDATE
+                #        || $message->mt == TLSProxy::Message::MT_NEW_SESSION_TICKET
+                ) {
+                    $record_should_be_acked = 1;
+                }
+            }
+
+            push(@record_numbers, $recnum) if ($record_should_be_acked == 1);
+        }
+    }
+
+    return @record_numbers;
+}
+
+sub get_actual_acked_record_numbers
+{
+    my @records = @{$proxy->record_list};
+    my @record_numbers = ();
+
+    foreach (@records) {
+        my $record = $_;
+
+        if ($record->content_type == TLSProxy::Record::RT_ACK) {
+            my $recnum_count = unpack('n', $record->decrypt_data) / 16;
+            my $ptr = 2;
+
+            for (my $idx = 0; $idx < $recnum_count; $idx++) {
+                my $epoch_lo;
+                my $epoch_hi;
+                my $msgseq_lo;
+                my $msgseq_hi;
+
+                ($epoch_hi, $epoch_lo, $msgseq_hi, $msgseq_lo)
+                    = unpack('NNNN', substr($record->decrypt_data, $ptr));
+                $ptr = $ptr + 16;
+
+                my $epoch = ($epoch_hi << 32) | $epoch_lo;
+                my $msgseq = ($msgseq_hi << 32) | $msgseq_lo;
+                my $recnum = TLSProxy::RecordNumber->new($epoch, $msgseq);
+
+                push(@record_numbers, $recnum);
+            }
+        }
+    }
+    return @record_numbers;
+}
+
+sub record_numbers_missing
+{
+    my @expected_record_numbers = @{$_[0]};
+    my @actual_record_numbers = @{$_[1]};
+    my @missing_record_numbers = ();
+
+    foreach (@expected_record_numbers)
+    {
+        my $found = 0;
+        my $expected = $_;
+
+        foreach (@actual_record_numbers) {
+            my $actual = $_;
+            if ($actual->epoch() == $expected->epoch()
+                    && $actual->seqnum() == $expected->seqnum()) {
+                $found = 1
+            }
+        }
+
+        if ($found == 0) {
+            push(@missing_record_numbers, $expected);
+        }
+    }
+
+    return @missing_record_numbers;
+}
+
+sub drop_first_client_finish_filter
+{
+    my $inproxy = shift;
+
+    foreach my $record (@{$inproxy->record_list}) {
+        next if ($record->{sent} == 1 || $record->serverissender || $found_first_client_finish_msg == 1);
+
+        my @messages = TLSProxy::Message->get_messages($record);
+        foreach my $message (@messages) {
+            if ($message->mt == TLSProxy::Message::MT_FINISHED) {
+                $record->{sent} = 1;
+                $found_first_client_finish_msg = 1;
+                last;
+            }
+        }
+    }
+}
index 83fc8ab91cd819bc9aa22da3887430f1bdb3f00b..de90e0996c521d52f107f6f7a3cc98bf67dc2bb8 100644 (file)
@@ -114,6 +114,7 @@ sub add_maximal_padding_filter
         }
 
         my $record = TLSProxy::Record->new(
+            $last_message->server,
             $proxy->flight,
             TLSProxy::Record::RT_APPLICATION_DATA,
             TLSProxy::Record::VERS_TLS_1_2,
index fe26564d506b96f7b84e5b6c34cb4e9941959c82..f8b9f6123a5df855d4bb28359bab6911e271d0ae 100644 (file)
@@ -367,6 +367,7 @@ sub add_empty_recs_filter
         my $record;
         if ($isdtls == 1) {
             $record = TLSProxy::Record->new_dtls(
+                0,
                 0,
                 $content_type,
                 TLSProxy::Record::VERS_DTLS_1_2,
@@ -381,6 +382,7 @@ sub add_empty_recs_filter
             );
         } else {
             $record = TLSProxy::Record->new(
+                0,
                 0,
                 $content_type,
                 TLSProxy::Record::VERS_TLS_1_2,
@@ -424,6 +426,7 @@ sub add_frag_alert_filter
     # Now add the alert level (Fatal) as a separate record
     $byte = pack('C', TLSProxy::Message::AL_LEVEL_FATAL);
     my $record = TLSProxy::Record->new(
+        0,
         0,
         TLSProxy::Record::RT_ALERT,
         TLSProxy::Record::VERS_TLS_1_2,
@@ -439,6 +442,7 @@ sub add_frag_alert_filter
     # And finally the description (Unexpected message) in a third record
     $byte = pack('C', TLSProxy::Message::AL_DESC_UNEXPECTED_MESSAGE);
     $record = TLSProxy::Record->new(
+        0,
         0,
         TLSProxy::Record::RT_ALERT,
         TLSProxy::Record::VERS_TLS_1_2,
@@ -471,6 +475,7 @@ sub add_sslv2_filter
                                TLSProxy::Message::AL_DESC_NO_RENEGOTIATION);
         my $alertlen = length $alert;
         $record = TLSProxy::Record->new(
+            0,
             0,
             TLSProxy::Record::RT_ALERT,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -510,6 +515,7 @@ sub add_sslv2_filter
         my $chlen = length $clienthello;
 
         $record = TLSProxy::Record->new(
+            0,
             0,
             TLSProxy::Record::RT_HANDSHAKE,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -550,6 +556,7 @@ sub add_sslv2_filter
 
         my $fraglen = length $frag1;
         $record = TLSProxy::Record->new(
+            0,
             0,
             TLSProxy::Record::RT_HANDSHAKE,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -570,6 +577,7 @@ sub add_sslv2_filter
             $recvers = 0;
         }
         $record = TLSProxy::Record->new(
+            0,
             0,
             TLSProxy::Record::RT_HANDSHAKE,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -584,6 +592,7 @@ sub add_sslv2_filter
 
         $fraglen = length $frag3;
         $record = TLSProxy::Record->new(
+            0,
             0,
             TLSProxy::Record::RT_HANDSHAKE,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -603,6 +612,8 @@ sub add_unknown_record_type
 {
     my $proxy = shift;
     my $records = $proxy->record_list;
+    my $lastmessage =  @{$proxy->message_list}[-1];
+    my $isserver = $lastmessage->server;
     my $isdtls = $proxy->isdtls;
     state $added_record;
 
@@ -619,6 +630,7 @@ sub add_unknown_record_type
 
     if ($isdtls) {
         $record = TLSProxy::Record->new_dtls(
+            $isserver,
             1,
             TLSProxy::Record::RT_UNKNOWN,
             @{$records}[-1]->version(),
@@ -633,6 +645,7 @@ sub add_unknown_record_type
         );
     } else {
         $record = TLSProxy::Record->new(
+            $isserver,
             1,
             TLSProxy::Record::RT_UNKNOWN,
             @{$records}[-1]->version(),
@@ -775,6 +788,7 @@ sub not_on_record_boundary
         #KeyUpdates must end on a record boundary
 
         my $record = TLSProxy::Record->new(
+            @{$proxy->{message_list}}[-1]->server,
             1,
             TLSProxy::Record::RT_APPLICATION_DATA,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -803,8 +817,10 @@ sub not_on_record_boundary
     } else {
         return if @{$proxy->{message_list}}[-1]->{mt}
                   != TLSProxy::Message::MT_FINISHED;
+        my $isserver = @{$proxy->{message_list}}[-1]->server;
 
         my $record = TLSProxy::Record->new(
+            $isserver,
             1,
             TLSProxy::Record::RT_APPLICATION_DATA,
             TLSProxy::Record::VERS_TLS_1_2,
@@ -830,6 +846,7 @@ sub not_on_record_boundary
         if ($boundary_test_type == DATA_BETWEEN_KEY_UPDATE) {
             #Now add an app data record
             $record = TLSProxy::Record->new(
+                $isserver,
                 1,
                 TLSProxy::Record::RT_APPLICATION_DATA,
                 TLSProxy::Record::VERS_TLS_1_2,
@@ -851,6 +868,7 @@ sub not_on_record_boundary
 
         #Now add the rest of the KeyUpdate message
         $record = TLSProxy::Record->new(
+            $isserver,
             1,
             TLSProxy::Record::RT_APPLICATION_DATA,
             TLSProxy::Record::VERS_TLS_1_2,
index a5ff57f02d8dcdea12f53e20d2e58f2507e01ca2..ffe4b3fc7a9c20cc63b4a99c358bab64f1a778b1 100644 (file)
@@ -227,7 +227,9 @@ sub hrr_filter
         my $dup_hrr;
 
         if ($proxy->isdtls()) {
-            $dup_hrr = TLSProxy::Record->new_dtls(3,
+            $dup_hrr = TLSProxy::Record->new_dtls(
+                1,
+                3,
                 $hrr_record->content_type(),
                 $hrr_record->version(),
                 $hrr_record->epoch(),
@@ -239,7 +241,9 @@ sub hrr_filter
                 $hrr_record->data(),
                 $hrr_record->decrypt_data());
         } else {
-            $dup_hrr = TLSProxy::Record->new(3,
+            $dup_hrr = TLSProxy::Record->new(
+                1,
+                3,
                 $hrr_record->content_type(),
                 $hrr_record->version(),
                 $hrr_record->len(),
index ecdf35c8b9e1ffa242ed2209bfee6ba4bb20d0ab..933c5d5c2975cff471df7bb9ee873a35e5c3eceb 100644 (file)
@@ -1098,14 +1098,10 @@ static int ping_pong_query(SSL *clientssl, SSL *serverssl)
     unsigned char cbuf[16000] = {0};
     unsigned char sbuf[16000];
     size_t err = 0;
-    char crec_wseq_before[SEQ_NUM_SIZE];
-    char crec_wseq_after[SEQ_NUM_SIZE];
-    char crec_rseq_before[SEQ_NUM_SIZE];
-    char crec_rseq_after[SEQ_NUM_SIZE];
-    char srec_wseq_before[SEQ_NUM_SIZE];
-    char srec_wseq_after[SEQ_NUM_SIZE];
-    char srec_rseq_before[SEQ_NUM_SIZE];
-    char srec_rseq_after[SEQ_NUM_SIZE];
+    uint64_t crec_wseq_before, crec_wseq_after;
+    uint64_t crec_rseq_before, crec_rseq_after;
+    uint64_t srec_wseq_before, srec_wseq_after;
+    uint64_t srec_rseq_before, srec_rseq_after;
     SSL_CONNECTION *clientsc, *serversc;
 
     if (!TEST_ptr(clientsc = SSL_CONNECTION_FROM_SSL_ONLY(clientssl))
@@ -1113,10 +1109,10 @@ static int ping_pong_query(SSL *clientssl, SSL *serverssl)
         goto end;
 
     cbuf[0] = count++;
-    memcpy(crec_wseq_before, &clientsc->rlayer.wrl->sequence, SEQ_NUM_SIZE);
-    memcpy(srec_wseq_before, &serversc->rlayer.wrl->sequence, SEQ_NUM_SIZE);
-    memcpy(crec_rseq_before, &clientsc->rlayer.rrl->sequence, SEQ_NUM_SIZE);
-    memcpy(srec_rseq_before, &serversc->rlayer.rrl->sequence, SEQ_NUM_SIZE);
+    crec_wseq_before = clientsc->rlayer.wrl->sequence;
+    srec_wseq_before = serversc->rlayer.wrl->sequence;
+    crec_rseq_before = clientsc->rlayer.rrl->sequence;
+    srec_rseq_before = serversc->rlayer.rrl->sequence;
 
     if (!TEST_true(SSL_write(clientssl, cbuf, sizeof(cbuf)) == sizeof(cbuf)))
         goto end;
@@ -1136,10 +1132,10 @@ static int ping_pong_query(SSL *clientssl, SSL *serverssl)
         }
     }
 
-    memcpy(crec_wseq_after, &clientsc->rlayer.wrl->sequence, SEQ_NUM_SIZE);
-    memcpy(srec_wseq_after, &serversc->rlayer.wrl->sequence, SEQ_NUM_SIZE);
-    memcpy(crec_rseq_after, &clientsc->rlayer.rrl->sequence, SEQ_NUM_SIZE);
-    memcpy(srec_rseq_after, &serversc->rlayer.rrl->sequence, SEQ_NUM_SIZE);
+    crec_wseq_after = clientsc->rlayer.wrl->sequence;
+    srec_wseq_after = serversc->rlayer.wrl->sequence;
+    crec_rseq_after = clientsc->rlayer.rrl->sequence;
+    srec_rseq_after = serversc->rlayer.rrl->sequence;
 
     /* verify the payload */
     if (!TEST_mem_eq(cbuf, sizeof(cbuf), sbuf, sizeof(sbuf)))
@@ -1150,42 +1146,34 @@ static int ping_pong_query(SSL *clientssl, SSL *serverssl)
      * OpenSSL sequences
      */
     if (!BIO_get_ktls_send(clientsc->wbio)) {
-        if (!TEST_mem_ne(crec_wseq_before, SEQ_NUM_SIZE,
-                         crec_wseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_ne(crec_wseq_before, crec_wseq_after))
             goto end;
     } else {
-        if (!TEST_mem_eq(crec_wseq_before, SEQ_NUM_SIZE,
-                         crec_wseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_eq(crec_wseq_before, crec_wseq_after))
             goto end;
     }
 
     if (!BIO_get_ktls_send(serversc->wbio)) {
-        if (!TEST_mem_ne(srec_wseq_before, SEQ_NUM_SIZE,
-                         srec_wseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_ne(srec_wseq_before, srec_wseq_after))
             goto end;
     } else {
-        if (!TEST_mem_eq(srec_wseq_before, SEQ_NUM_SIZE,
-                         srec_wseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_eq(srec_wseq_before, srec_wseq_after))
             goto end;
     }
 
     if (!BIO_get_ktls_recv(clientsc->wbio)) {
-        if (!TEST_mem_ne(crec_rseq_before, SEQ_NUM_SIZE,
-                         crec_rseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_ne(crec_rseq_before, crec_rseq_after))
             goto end;
     } else {
-        if (!TEST_mem_eq(crec_rseq_before, SEQ_NUM_SIZE,
-                         crec_rseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_eq(crec_rseq_before, crec_rseq_after))
             goto end;
     }
 
     if (!BIO_get_ktls_recv(serversc->wbio)) {
-        if (!TEST_mem_ne(srec_rseq_before, SEQ_NUM_SIZE,
-                         srec_rseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_ne(srec_rseq_before, srec_rseq_after))
             goto end;
     } else {
-        if (!TEST_mem_eq(srec_rseq_before, SEQ_NUM_SIZE,
-                         srec_rseq_after, SEQ_NUM_SIZE))
+        if (!TEST_uint64_t_eq(srec_rseq_before, srec_rseq_after))
             goto end;
     }
 
index 9cc5b167b0090866b14225dfbebc7b8b8983d6ce..012cdab54b9d132546d42644ae1d05ed3d916274 100644 (file)
@@ -240,9 +240,9 @@ static unsigned char *multihexstr2buf(const char *str[3], size_t *len)
 
 static int load_record(TLS_RL_RECORD *rec, RECORD_DATA *recd,
                        unsigned char **key, unsigned char *iv, size_t ivlen,
-                       unsigned char *seq)
+                       uint64_t *seq)
 {
-    unsigned char *pt = NULL, *sq = NULL, *ivtmp = NULL;
+    unsigned char *pt = NULL, *sq = NULL, *p_sq, *ivtmp = NULL;
     size_t ptlen;
 
     *key = OPENSSL_hexstr2buf(recd->key, NULL);
@@ -261,7 +261,8 @@ static int load_record(TLS_RL_RECORD *rec, RECORD_DATA *recd,
     rec->length = ptlen;
     memcpy(rec->data, pt, ptlen);
     OPENSSL_free(pt);
-    memcpy(seq, sq, SEQ_NUM_SIZE);
+    p_sq = sq;
+    n2l8(p_sq, *seq);
     OPENSSL_free(sq);
     memcpy(iv, ivtmp, ivlen);
     OPENSSL_free(ivtmp);
@@ -311,7 +312,7 @@ static int test_tls13_encryption(void)
     const EVP_CIPHER *ciph = EVP_aes_128_gcm();
     int ret = 0;
     size_t ivlen, ctr;
-    unsigned char seqbuf[SEQ_NUM_SIZE];
+    uint64_t recseq;
     unsigned char iv[EVP_MAX_IV_LENGTH];
     OSSL_RECORD_LAYER *rrl = NULL, *wrl = NULL;
 
@@ -326,7 +327,7 @@ static int test_tls13_encryption(void)
     for (ctr = 0; ctr < OSSL_NELEM(refdata); ctr++) {
         /* Load the record */
         ivlen = EVP_CIPHER_get_iv_length(ciph);
-        if (!load_record(&rec, &refdata[ctr], &key, iv, ivlen, seqbuf)) {
+        if (!load_record(&rec, &refdata[ctr], &key, iv, ivlen, &recseq)) {
             TEST_error("Failed loading key into EVP_CIPHER_CTX");
             goto err;
         }
@@ -342,7 +343,8 @@ static int test_tls13_encryption(void)
                           NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
                           &wrl)))
             goto err;
-        memcpy(wrl->sequence, seqbuf, sizeof(seqbuf));
+
+        wrl->sequence = recseq;
 
         /* Encrypt it */
         if (!TEST_size_t_eq(wrl->funcs->cipher(wrl, &rec, 1, 1, NULL, 0), 1)) {
@@ -366,7 +368,8 @@ static int test_tls13_encryption(void)
                           NULL, NULL, NULL, NULL, NULL, NULL, NULL, NULL,
                           &rrl)))
             goto err;
-        memcpy(rrl->sequence, seqbuf, sizeof(seqbuf));
+
+        rrl->sequence = recseq;
 
         /* Decrypt it */
         if (!TEST_int_eq(rrl->funcs->cipher(rrl, &rec, 1, 0, NULL, 0), 1)) {
index 8f539d41f1e88fae9df1c3c278bd1f73865deda7..f7fe6ffb47120899e596d86cf31d2c3368228a32 100644 (file)
@@ -250,7 +250,11 @@ void dtls1_clear_received_buffer(SSL_CONNECTION *s)
 {
 }
 
-void dtls1_clear_sent_buffer(SSL_CONNECTION *s)
+void dtls1_clear_sent_buffer(SSL_CONNECTION *s, int keep_unacked_msgs)
+{
+}
+
+void dtls1_acknowledge_sent_buffer(SSL_CONNECTION *s, uint16_t before_epoch)
 {
 }
 
index 526afc9b4aaf718323e6904b8fbba66d1c6fd226..ae8c73497002f18ff271fd08bfb79ebbed70194c 100644 (file)
@@ -9,6 +9,7 @@ use strict;
 
 package TLSProxy::Message;
 
+use TLSProxy::RecordNumber;
 use TLSProxy::Alert;
 
 use constant DTLS_MESSAGE_HEADER_LENGTH => 12;
@@ -21,6 +22,7 @@ use constant {
     MT_SERVER_HELLO => 2,
     MT_HELLO_VERIFY_REQUEST => 3,
     MT_NEW_SESSION_TICKET => 4,
+    MT_END_OF_EARLY_DATA => 5,
     MT_ENCRYPTED_EXTENSIONS => 8,
     MT_CERTIFICATE => 11,
     MT_SERVER_KEY_EXCHANGE => 12,
@@ -29,7 +31,10 @@ use constant {
     MT_CERTIFICATE_VERIFY => 15,
     MT_CLIENT_KEY_EXCHANGE => 16,
     MT_FINISHED => 20,
+    MT_CERTIFICATE_URL => 21,
     MT_CERTIFICATE_STATUS => 22,
+    MT_SUPPLEMENTAL_DATA => 23,
+    MT_KEY_UPDATE => 24,
     MT_COMPRESSED_CERTIFICATE => 25,
     MT_NEXT_PROTO => 67
 };
@@ -184,9 +189,9 @@ sub clear
 sub get_messages
 {
     my $class = shift;
-    my $serverin = shift;
     my $record = shift;
-    my $isdtls = shift;
+    my $serverin = $record->serverissender;
+    my $isdtls = $record->isdtls;
     my @messages = ();
     my $message;
 
index caf566601db604ca9e79f5d9790afcfa5dfeb9fc..99a92cc267dd23fa0e4707968d5797a0e0de3134 100644 (file)
@@ -434,7 +434,7 @@ sub clientstart
         my $pid;
         my $execcmd = $self->execute
              ." s_client -provider=p_ossltest -provider=default -propquery ?provider=p_ossltest"
-             ." -connect $self->{proxy_addr}:$self->{proxy_port}";
+             ." -state -connect $self->{proxy_addr}:$self->{proxy_port}";
         if ($self->{isdtls}) {
             $execcmd .= " -dtls -max_protocol DTLSv1.3"
                         # TLSProxy does not support message fragmentation. So
@@ -608,21 +608,17 @@ sub clientstart
 
 sub process_packet
 {
-    my ($self, $server, $packet) = @_;
-    my $len_real;
-    my $decrypt_len;
-    my $data;
-    my $recnum;
+    my ($self, $serverissender, $packet) = @_;
 
-    if ($server) {
+    if ($serverissender) {
         print "Received server packet\n";
     } else {
         print "Received client packet\n";
     }
 
-    if ($self->{direction} != $server) {
+    if ($self->{direction} != $serverissender) {
         $self->{flight} = $self->{flight} + 1;
-        $self->{direction} = $server;
+        $self->{direction} = $serverissender;
     }
 
     print "Packet length = ".length($packet)."\n";
@@ -630,11 +626,11 @@ sub process_packet
 
     #Return contains the list of record found in the packet followed by the
     #list of messages in those records and any partial message
-    my @ret = TLSProxy::Record->get_records($server, $self->flight,
-                                            $self->{partial}[$server].$packet,
+    my @ret = TLSProxy::Record->get_records($serverissender, $self->flight,
+                                            $self->{partial}[$serverissender].$packet,
                                             $self->{isdtls});
 
-    $self->{partial}[$server] = $ret[2];
+    $self->{partial}[$serverissender] = $ret[2];
     push @{$self->{record_list}}, @{$ret[0]};
     push @{$self->{message_list}}, @{$ret[1]};
 
@@ -660,7 +656,7 @@ sub process_packet
     #Reconstruct the packet
     $packet = "";
     foreach my $record (@{$self->record_list}) {
-        $packet .= $record->reconstruct_record($server);
+        $packet .= $record->reconstruct_record($serverissender);
     }
 
     print "Forwarded packet length = ".length($packet)."\n\n";
index 8f90dea4be5e3d3f015cf219b209cb6e7323d9a1..213e3874bc78062ef67c7b3e3aa2bbd9a0dc5c8e 100644 (file)
@@ -24,6 +24,7 @@ use constant {
     RT_HANDSHAKE          => 22,
     RT_ALERT              => 21,
     RT_CCS                => 20,
+    RT_ACK                => 26,
     RT_UNKNOWN            => 100,
     RT_DTLS_UNIHDR_EPOCH4 => 0x2c,
     RT_DTLS_UNIHDR_EPOCH1 => 0x2d,
@@ -36,6 +37,7 @@ my %record_type = (
     RT_HANDSHAKE, "HANDSHAKE",
     RT_ALERT, "ALERT",
     RT_CCS, "CCS",
+    RT_ACK, "ACK",
     RT_UNKNOWN, "UNKNOWN",
     RT_DTLS_UNIHDR_EPOCH4, "DTLS UNIFIED HEADER (EPOCH 4)",
     RT_DTLS_UNIHDR_EPOCH1, "DTLS UNIFIED HEADER (EPOCH 1)",
@@ -72,7 +74,7 @@ our %tls_version = (
 sub get_records
 {
     my $class = shift;
-    my $server = shift;
+    my $serverissender = shift;
     my $flight = shift;
     my $packet = shift;
     my $isdtls = shift;
@@ -82,8 +84,8 @@ sub get_records
 
     my $recnum = 1;
     while (length ($packet) > 0) {
-        print " Record $recnum ", $server ? "(server -> client)\n"
-                                          : "(client -> server)\n";
+        print " Record $recnum ", $serverissender ? "(server -> client)\n"
+                                                  : "(client -> server)\n";
         my $record_hdr_len;
         my $content_type;
         my $version;
@@ -91,7 +93,7 @@ sub get_records
         my $epoch;
         my $seq;
 
-        if ($isdtls) {
+        if ($isdtls == 1) {
             my $isunifiedhdr;
 
             $content_type = unpack('B[8]', $packet);
@@ -115,6 +117,12 @@ sub get_records
                     ($content_type, $seq, $len) = unpack('CCn', $packet);
                     $record_hdr_len = 4;
                 }
+                # Encrypted DTLS 1.3 records have encrypted sequence numbers.
+                # ossltest engine overrides ecb encryption to be a no-op.
+                # This effectively means that the sequence number encryption mask
+                # is just the 16 first bytes of the record body.
+                my $recordbody = substr($packet, $record_hdr_len, $len);
+                (my $maskhi, my $maskmi, my $masklo) = unpack('nnn', $recordbody);
                 $version = VERS_DTLS_1_2; # DTLSv1.3 headers has DTLSv1.2 in its legacy_version field
 
                 if ($eebits == "00") {
@@ -128,6 +136,7 @@ sub get_records
                 } else {
                     die("Epoch bits is not 0's or 1's: should not happen")
                 }
+                $seq ^= $maskhi;
             } else {
                 my $seqhi;
                 my $seqmi;
@@ -154,7 +163,7 @@ sub get_records
 
         print "  Content type: ".$record_type{$content_type}."\n";
         print "  Version: $tls_version{$version}\n";
-        if($isdtls) {
+        if($isdtls == 1) {
             print "  Epoch: $epoch\n";
             print "  Sequence: $seq\n";
         }
@@ -163,6 +172,7 @@ sub get_records
         my $record;
         if ($isdtls) {
             $record = TLSProxy::Record->new_dtls(
+                $serverissender,
                 $flight,
                 $content_type,
                 $version,
@@ -177,6 +187,7 @@ sub get_records
             );
         } else {
             $record = TLSProxy::Record->new(
+                $serverissender,
                 $flight,
                 $content_type,
                 $version,
@@ -192,8 +203,8 @@ sub get_records
         if ($content_type != RT_CCS
                 && (!TLSProxy::Proxy->is_tls13()
                     || $content_type != RT_ALERT)) {
-            if (($server && $server_encrypting)
-                     || (!$server && $client_encrypting)) {
+            if (($serverissender && $server_encrypting)
+                     || (!$serverissender && $client_encrypting)) {
                 if (!TLSProxy::Proxy->is_tls13() && $etm) {
                     $record->decryptETM();
                 } else {
@@ -204,6 +215,7 @@ sub get_records
                 if (TLSProxy::Proxy->is_tls13()) {
                     print "  Inner content type: "
                           .$record_type{$record->content_type()}."\n";
+                    print " Data: ".unpack("n",$record->decrypt_data)."\n";
                 }
             }
         }
@@ -211,7 +223,7 @@ sub get_records
         push @record_list, $record;
 
         #Now figure out what messages are contained within this record
-        my @messages = TLSProxy::Message->get_messages($server, $record, $isdtls);
+        my @messages = TLSProxy::Message->get_messages($record);
         push @message_list, @messages;
 
         $packet = substr($packet, $record_hdr_len + $len);
@@ -257,7 +269,8 @@ sub etm
 sub new_dtls
 {
     my $class = shift;
-    my ($flight,
+    my ($serverissender,
+        $flight,
         $content_type,
         $version,
         $epoch,
@@ -268,7 +281,8 @@ sub new_dtls
         $decrypt_len,
         $data,
         $decrypt_data) = @_;
-    return $class->init(1,
+    return $class->init($serverissender,
+        1,
         $flight,
         $content_type,
         $version,
@@ -285,7 +299,8 @@ sub new_dtls
 sub new
 {
     my $class = shift;
-    my ($flight,
+    my ($serverissender,
+        $flight,
         $content_type,
         $version,
         $len,
@@ -295,6 +310,7 @@ sub new
         $data,
         $decrypt_data) = @_;
     return $class->init(
+        $serverissender,
         0,
         $flight,
         $content_type,
@@ -312,7 +328,8 @@ sub new
 sub init
 {
     my $class = shift;
-    my ($isdtls,
+    my ($serverissender,
+        $isdtls,
         $flight,
         $content_type,
         $version,
@@ -326,6 +343,7 @@ sub init
         $decrypt_data) = @_;
 
     my $self = {
+        serverissender => $serverissender,
         isdtls => $isdtls,
         flight => $flight,
         content_type => $content_type,
@@ -443,13 +461,18 @@ sub reconstruct_record
         my $content_type = (TLSProxy::Proxy->is_tls13() && $self->encrypted)
                            ? $self->outer_content_type : $self->content_type;
         if($self->{isdtls}) {
+            my $seqhi = ($self->seq >> 32) & 0xffff;
+            my $seqmi = ($self->seq >> 16) & 0xffff;
+            my $seqlo = ($self->seq >> 0) & 0xffff;
+
             if (TLSProxy::Proxy->is_tls13() && $self->encrypted) {
+                # Mask sequence number with record body bytes. Explanation
+                # given in get_records.
+                (my $maskhi, my $maskmi, my $masklo) = unpack("nnn", $self->data);
+                $seqlo ^= $maskhi;
                 # Prepare a unified header
-                $data = pack('Cnn', $content_type, $self->seq, $self->len);
+                $data = pack('Cnn', $content_type, $seqlo, $self->len);
             } else {
-                my $seqhi = ($self->seq >> 32) & 0xffff;
-                my $seqmi = ($self->seq >> 16) & 0xffff;
-                my $seqlo = ($self->seq >> 0) & 0xffff;
                 $data = pack('Cnnnnnn', $content_type, $self->version,
                     $self->epoch, $seqhi, $seqmi, $seqlo, $self->len);
             }
@@ -465,6 +488,16 @@ sub reconstruct_record
 }
 
 #Read only accessors
+sub serverissender
+{
+    my $self = shift;
+    return $self->{serverissender};
+}
+sub isdtls
+{
+    my $self = shift;
+    return $self->{isdtls};
+}
 sub flight
 {
     my $self = shift;
diff --git a/util/perl/TLSProxy/RecordNumber.pm b/util/perl/TLSProxy/RecordNumber.pm
new file mode 100644 (file)
index 0000000..6cff7a8
--- /dev/null
@@ -0,0 +1,37 @@
+# Copyright 2024 The OpenSSL Project Authors. All Rights Reserved.
+#
+# Licensed under the Apache License 2.0 (the "License").  You may not use
+# this file except in compliance with the License.  You can obtain a copy
+# in the file LICENSE in the source distribution or at
+# https://www.openssl.org/source/license.html
+
+use strict;
+
+package TLSProxy::RecordNumber;
+
+sub new
+{
+    my $class = shift;
+    my ($epoch,
+        $seqnum) = @_;
+
+    my $self = {
+        epoch => $epoch,
+        seqnum => $seqnum
+    };
+
+    return bless $self, $class;
+}
+
+# Read only accessors
+sub epoch
+{
+    my $self = shift;
+    return $self->{epoch};
+}
+sub seqnum
+{
+    my $self = shift;
+    return $self->{seqnum};
+}
+1;