From: Alberto Leiva Popper Date: Fri, 20 Oct 2023 22:18:39 +0000 (-0600) Subject: RTR server maintenance X-Git-Tag: 1.6.0~40 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=eea8e4b4ed18019efeed72315a43c4b41f4fd0db;p=thirdparty%2FFORT-validator.git RTR server maintenance Someone reported a security vulnerability in the server, but the details are muddy, and clarifications have not arrived yet. I haven't been able to reproduce it, but the review did yield room for improvement: 1. Buffer request bytes better The old code seemed to assume each socket read consumed exactly one (nonempty) TCP packet, and each such packet contained exactly one PDU. I'm scratching my head at this, but I guess for most intents and purposes, this assumption is not as lunatic as it seems. Benign RTR PDUs are very small, and it doesn't make sense for a request packet to contain multiple of them. Error Reports aside, it doesn't even make sense for the client to send multiple PDUs in quick succession at all. Regardless, I'm flushing that assumption down the toilet: - If read() yields multiple PDUs, queue and handle them in sequence. Although as I'm writing this I'm realizing that queuing PDUs is a dumb idea, because Serial Queries and Reset Queries are alternate means to achieve the same goal. If the client sent a new request, it's most likely given up on the old one. Plus, queuing PDUs brings additional complexity and risks. I'm going to have to change this in the next commit. - If a read() yields a fragmented PDU, buffer and prepend it to the next successful read. This will probably never happen, but it's nice to handle it properly anyway. 2. Drop unused PDU parsers An RTR server only needs to handle PDU types Serial Query, Reset Query and Error Report. Fort also had dead code meant for the other PDU types. I'm guessing they were intended for the Error Report internal PDU field, but it turns out that's also unused. 3. Improve PDU validation Since Serial Queries and Reset Queries are supposed to have constant length, Fort was often ignoring the PDU header length field. Fort now punishes incorrect lengths more aggressively. --- diff --git a/src/Makefile.am b/src/Makefile.am index 19db79a7..5aae5ae4 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -93,12 +93,11 @@ fort_SOURCES += rrdp/rrdp_parser.h rrdp/rrdp_parser.c fort_SOURCES += rsync/rsync.h rsync/rsync.c +fort_SOURCES += rtr/pdu_stream.c rtr/pdu_stream.h fort_SOURCES += rtr/err_pdu.c rtr/err_pdu.h fort_SOURCES += rtr/pdu_handler.c rtr/pdu_handler.h fort_SOURCES += rtr/pdu_sender.c rtr/pdu_sender.h -fort_SOURCES += rtr/pdu_serializer.c rtr/pdu_serializer.h fort_SOURCES += rtr/pdu.c rtr/pdu.h -fort_SOURCES += rtr/primitive_reader.c rtr/primitive_reader.h fort_SOURCES += rtr/primitive_writer.c rtr/primitive_writer.h fort_SOURCES += rtr/rtr.c rtr/rtr.h diff --git a/src/notify.c b/src/notify.c index 0c7bd398..b4f71876 100644 --- a/src/notify.c +++ b/src/notify.c @@ -6,11 +6,11 @@ #include "rtr/db/vrps.h" static int -send_notify(struct rtr_client const *client, void *arg) +send_notify(int fd, int rtr_version, void *arg) { serial_t *serial = arg; - send_serial_notify_pdu(client->fd, client->rtr_version, *serial); + send_serial_notify_pdu(fd, rtr_version, *serial); /* Errors already logged, do not interrupt notify to other clients */ return 0; diff --git a/src/rtr/err_pdu.c b/src/rtr/err_pdu.c index 2a27f7b4..bb2d5ae1 100644 --- a/src/rtr/err_pdu.c +++ b/src/rtr/err_pdu.c @@ -14,36 +14,30 @@ typedef enum rtr_error_code { ERR_PDU_UNSUP_PDU_TYPE = 5, ERR_PDU_WITHDRAWAL_UNKNOWN = 6, ERR_PDU_DUPLICATE_ANNOUNCE = 7, - /* RTRv1 only, so not used yet. */ ERR_PDU_UNEXPECTED_PROTO_VERSION = 8, } rtr_error_code_t; static int err_pdu_send(int fd, uint8_t version, rtr_error_code_t code, - struct rtr_request const *request, char const *message_const) + struct rtr_buffer const *request, char const *msg_const) { - char *message; + char *msg; - /* - * This function must always return error so callers can interrupt - * themselves easily. - * But note that not all callers should use this feature. - */ - - /* Need a clone to remove the const. */ - message = (message_const != NULL) ? pstrdup(message_const) : NULL; - send_error_report_pdu(fd, version, code, request, message); - free(message); + if ((request == NULL) || (request->bytes[1] != PDU_TYPE_ERROR_REPORT)) { + /* Need a clone to remove the const. */ + msg = (msg_const != NULL) ? pstrdup(msg_const) : NULL; + send_error_report_pdu(fd, version, code, request, msg); + free(msg); + } - return -EINVAL; + return -EINVAL; /* For propagation */ } int err_pdu_send_corrupt_data(int fd, uint8_t version, - struct rtr_request const *request, char const *message) + struct rtr_buffer const *request, char const *msg) { - return err_pdu_send(fd, version, ERR_PDU_CORRUPT_DATA, request, - message); + return err_pdu_send(fd, version, ERR_PDU_CORRUPT_DATA, request, msg); } /* @@ -70,57 +64,39 @@ err_pdu_send_no_data_available(int fd, uint8_t version) int err_pdu_send_invalid_request(int fd, uint8_t version, - struct rtr_request const *request, char const *message) + struct rtr_buffer const *request, char const *msg) { - return err_pdu_send(fd, version, ERR_PDU_INVALID_REQUEST, request, - message); + return err_pdu_send(fd, version, ERR_PDU_INVALID_REQUEST, request, msg); } -/* Caution: @header is supposed to be in serialized form. */ int err_pdu_send_invalid_request_truncated(int fd, uint8_t version, - unsigned char *header, char const *message) + struct rtr_buffer const *request, char const *msg) { - struct rtr_request request = { - .bytes = header, - .bytes_len = RTRPDU_HDR_LEN, - .pdu = NULL, - }; - return err_pdu_send_invalid_request(fd, version, &request, message); + return err_pdu_send_invalid_request(fd, version, request, msg); } int err_pdu_send_unsupported_proto_version(int fd, uint8_t version, - unsigned char *header, - char const *message) + struct rtr_buffer const *request, char const *msg) { - struct rtr_request request = { - .bytes = header, - .bytes_len = RTRPDU_HDR_LEN, - .pdu = NULL, - }; - return err_pdu_send(fd, version, ERR_PDU_UNSUP_PROTO_VERSION, &request, - message); + return err_pdu_send(fd, version, ERR_PDU_UNSUP_PROTO_VERSION, request, + msg); } int err_pdu_send_unsupported_pdu_type(int fd, uint8_t version, - struct rtr_request const *request) + struct rtr_buffer const *request) { return err_pdu_send(fd, version, ERR_PDU_UNSUP_PDU_TYPE, request, NULL); } int err_pdu_send_unexpected_proto_version(int fd, uint8_t version, - unsigned char *header, char const *message) + struct rtr_buffer const *request, char const *msg) { - struct rtr_request request = { - .bytes = header, - .bytes_len = RTRPDU_HDR_LEN, - .pdu = NULL, - }; - return err_pdu_send(fd, version, ERR_PDU_UNEXPECTED_PROTO_VERSION, &request, - message); + return err_pdu_send(fd, version, ERR_PDU_UNEXPECTED_PROTO_VERSION, + request, msg); } char const * diff --git a/src/rtr/err_pdu.h b/src/rtr/err_pdu.h index 89323c1d..29d1f33e 100644 --- a/src/rtr/err_pdu.h +++ b/src/rtr/err_pdu.h @@ -8,19 +8,22 @@ * Mainly, this is for the sake of making it easier to see whether the error is * supposed to contain a message and/or the original PDU or not. */ -int err_pdu_send_corrupt_data(int, uint8_t, struct rtr_request const *, - char const *); -int err_pdu_send_internal_error(int, uint8_t); -int err_pdu_send_no_data_available(int, uint8_t); -int err_pdu_send_invalid_request(int, uint8_t, struct rtr_request const *, - char const *); -int err_pdu_send_invalid_request_truncated(int, uint8_t, unsigned char *, - char const *); -int err_pdu_send_unsupported_proto_version(int, uint8_t, unsigned char *, - char const *); -int err_pdu_send_unsupported_pdu_type(int, uint8_t, struct rtr_request const *); -int err_pdu_send_unexpected_proto_version(int, uint8_t, unsigned char *, - char const *); +int err_pdu_send_corrupt_data( + int, uint8_t, struct rtr_buffer const *, char const *); +int err_pdu_send_internal_error( + int, uint8_t); +int err_pdu_send_no_data_available( + int, uint8_t); +int err_pdu_send_invalid_request( + int, uint8_t, struct rtr_buffer const *, char const *); +int err_pdu_send_invalid_request_truncated( + int, uint8_t, struct rtr_buffer const *, char const *); +int err_pdu_send_unsupported_proto_version( + int, uint8_t, struct rtr_buffer const *, char const *); +int err_pdu_send_unsupported_pdu_type( + int, uint8_t, struct rtr_buffer const *); +int err_pdu_send_unexpected_proto_version( + int, uint8_t, struct rtr_buffer const *, char const *); char const *err_pdu_to_string(uint16_t); diff --git a/src/rtr/pdu.c b/src/rtr/pdu.c index 5d26c628..69d0113a 100644 --- a/src/rtr/pdu.c +++ b/src/rtr/pdu.c @@ -8,6 +8,7 @@ #include "types/address.h" #include "rtr/err_pdu.h" #include "rtr/pdu_handler.h" +#include "rtr/pdu_sender.h" char const * pdutype2str(enum pdu_type type) @@ -38,307 +39,3 @@ pdutype2str(enum pdu_type type) return "unknown PDU"; } -static int -pdu_header_from_reader(struct pdu_reader *reader, struct pdu_header *header) -{ - int error; - - error = read_int8(reader, &header->protocol_version); - if (error) - return error; - error = read_int8(reader, &header->pdu_type); - if (error) - return error; - error = read_int16(reader, &header->m.session_id); - if (error) - return error; - return read_int32(reader, &header->length); -} - -static int -validate_rtr_version(struct rtr_client *client, struct pdu_header *header, - unsigned char *hdr_bytes) -{ - if (client->rtr_version != -1) { - if (header->protocol_version == client->rtr_version) - return 0; - - /* Don't send error on a rcvd error! */ - if (header->pdu_type == PDU_TYPE_ERROR_REPORT) - return -EINVAL; - - switch (client->rtr_version) { - case RTR_V1: - /* Rcvd version is valid, but unexpected */ - if (header->protocol_version == RTR_V0) - return err_pdu_send_unexpected_proto_version( - client->fd, client->rtr_version, hdr_bytes, - "RTR version 0 was expected"); - /* Send common error */ - case RTR_V0: - return err_pdu_send_unsupported_proto_version( - client->fd, client->rtr_version, hdr_bytes, - "RTR version received is unknown."); - default: - pr_crit("Unknown RTR version %u", client->rtr_version); - } - } - - /* Unsigned and incremental values, so compare against major version */ - if (header->protocol_version > RTR_V1) - /* ...and send error with min version */ - return (header->pdu_type != PDU_TYPE_ERROR_REPORT) - ? err_pdu_send_unsupported_proto_version(client->fd, RTR_V0, - hdr_bytes, "RTR version received is unknown.") - : -EINVAL; - - client->rtr_version = header->protocol_version; - return 0; -} - -/* Do not use this macro before @header has been initialized, obviously. */ -#define RESPOND_ERROR(report_cb) \ - ((header.pdu_type != PDU_TYPE_ERROR_REPORT) ? (report_cb) : -EINVAL); - -/* - * Reads the next PDU from @reader. Returns the PDU in @request, and its - * metadata in @metadata. - */ -int -pdu_load(struct pdu_reader *reader, struct rtr_client *client, - struct rtr_request *request, struct pdu_metadata const **metadata) -{ - struct pdu_header header; - struct pdu_metadata const *meta; - int error; - - if (reader->size == 0) { - pr_op_debug("Client packet contains no more PDUs."); - return ENOENT; - } - - request->bytes = reader->buffer; - request->bytes_len = RTRPDU_HDR_LEN; - - error = pdu_header_from_reader(reader, &header); - if (error) - /* No error response because the PDU might have been an error */ - return error; - - pr_op_debug("PDU '%s' received from client '%s'", - pdutype2str(header.pdu_type), client->addr); - - error = validate_rtr_version(client, &header, request->bytes); - if (error) - return error; /* Error response PDU already sent */ - - /* - * DO NOT USE THE err_pdu_* functions directly. Wrap them with - * RESPOND_ERROR() INSTEAD. - */ - - if (header.length < RTRPDU_HDR_LEN) - return RESPOND_ERROR(err_pdu_send_invalid_request_truncated( - client->fd, client->rtr_version, request->bytes, - "Invalid header length. (< 8 bytes)")); - - /* - * Error messages can be quite large. - * But they're probably not legitimate, so drop 'em. - * 512 is like a 5-paragraph error message, so it's probably enough. - * Most error messages are bound to be two phrases tops. - * (Warning: I'm assuming english tho.) - */ - if (header.length > 512) - return RESPOND_ERROR(err_pdu_send_invalid_request_truncated( - client->fd, client->rtr_version, request->bytes, - "PDU is too large. (> 512 bytes)")); - - request->bytes_len = header.length; - - /* Deserialize the PDU. */ - meta = pdu_get_metadata(header.pdu_type); - if (!meta) - return RESPOND_ERROR(err_pdu_send_unsupported_pdu_type( - client->fd, client->rtr_version, request)); - - request->pdu = pmalloc(meta->length); - - error = meta->from_stream(&header, reader, request->pdu); - if (error) { - /* Communication interrupted; no error PDU. */ - free(request->pdu); - return error; - } - - /* Happy path. */ - *metadata = meta; - return 0; -} - -static int -serial_notify_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct serial_notify_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return read_int32(reader, &pdu->serial_number); -} - -static int -serial_query_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct serial_query_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return read_int32(reader, &pdu->serial_number); -} - -static int -reset_query_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct reset_query_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return 0; -} - -static int -cache_response_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct cache_response_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return 0; -} - -static int -ipv4_prefix_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct ipv4_prefix_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return read_int8(reader, &pdu->flags) - || read_int8(reader, &pdu->prefix_length) - || read_int8(reader, &pdu->max_length) - || read_int8(reader, &pdu->zero) - || read_in_addr(reader, &pdu->ipv4_prefix) - || read_int32(reader, &pdu->asn); -} - -static int -ipv6_prefix_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct ipv6_prefix_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return read_int8(reader, &pdu->flags) - || read_int8(reader, &pdu->prefix_length) - || read_int8(reader, &pdu->max_length) - || read_int8(reader, &pdu->zero) - || read_in6_addr(reader, &pdu->ipv6_prefix) - || read_int32(reader, &pdu->asn); -} - -static int -end_of_data_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct end_of_data_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return read_int32(reader, &pdu->serial_number); -} - -static int -cache_reset_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct cache_reset_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return 0; -} - -static int -router_key_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct router_key_pdu *pdu = pdu_void; - memcpy(&pdu->header, header, sizeof(*header)); - return 0; -} - -static int -error_report_from_stream(struct pdu_header *header, struct pdu_reader *reader, - void *pdu_void) -{ - struct error_report_pdu *pdu = pdu_void; - int error; - - memcpy(&pdu->header, header, sizeof(*header)); - - error = read_int32(reader, &pdu->error_pdu_length); - if (error) - return error; - error = read_bytes(reader, pdu->erroneous_pdu, pdu->error_pdu_length); - if (error) - return error; - error = read_int32(reader, &pdu->error_message_length); - if (error) - return error; - return read_string(reader, pdu->error_message_length, - &pdu->error_message); -} - -static void -error_report_destroy(void *pdu_void) -{ - struct error_report_pdu *pdu = pdu_void; - free(pdu->error_message); - free(pdu); -} - -#define DEFINE_METADATA(name, dtor) \ - static struct pdu_metadata const name ## _meta = { \ - .length = sizeof(struct name ## _pdu), \ - .from_stream = name ## _from_stream, \ - .handle = handle_ ## name ## _pdu, \ - .destructor = dtor, \ - } - -DEFINE_METADATA(serial_notify, free); -DEFINE_METADATA(serial_query, free); /* handle_serial_query_pdu */ -DEFINE_METADATA(reset_query, free); -DEFINE_METADATA(cache_response, free); -DEFINE_METADATA(ipv4_prefix, free); -DEFINE_METADATA(ipv6_prefix, free); -DEFINE_METADATA(end_of_data, free); -DEFINE_METADATA(cache_reset, free); -DEFINE_METADATA(router_key, free); -DEFINE_METADATA(error_report, error_report_destroy); - -static struct pdu_metadata const *const pdu_metadatas[] = { - /* 0 */ &serial_notify_meta, - /* 1 */ &serial_query_meta, - /* 2 */ &reset_query_meta, - /* 3 */ &cache_response_meta, - /* 4 */ &ipv4_prefix_meta, - /* 5 */ NULL, - /* 6 */ &ipv6_prefix_meta, - /* 7 */ &end_of_data_meta, - /* 8 */ &cache_reset_meta, - /* 9 */ &router_key_meta, - /* 10 */ &error_report_meta, -}; - -struct pdu_metadata const * -pdu_get_metadata(uint8_t type) -{ - return (ARRAY_LEN(pdu_metadatas) <= type) ? NULL : pdu_metadatas[type]; -} - -struct pdu_header * -pdu_get_header(void *pdu) -{ - /* The header is by definition the first field of every PDU. */ - return pdu; -} diff --git a/src/rtr/pdu.h b/src/rtr/pdu.h index eb93e897..b93947e5 100644 --- a/src/rtr/pdu.h +++ b/src/rtr/pdu.h @@ -3,33 +3,29 @@ #include "common.h" #include "types/router_key.h" -#include "rtr/primitive_reader.h" #include "rtr/rtr.h" -#define RTR_V0 0 -#define RTR_V1 1 - -/** A request from an RTR client. */ -struct rtr_request { - /** Raw bytes. */ - unsigned char *bytes; - /** Length of @bytes. */ - size_t bytes_len; - /** Deserialized PDU. One of the *_pdu struct below. */ - void *pdu; +enum rtr_version { + RTR_V0 = 0, + RTR_V1 = 1, +}; + +struct rtr_buffer { + unsigned char *bytes; /* Raw bytes */ + size_t bytes_len; /* Length of @bytes */ }; enum pdu_type { - PDU_TYPE_SERIAL_NOTIFY = 0, - PDU_TYPE_SERIAL_QUERY = 1, - PDU_TYPE_RESET_QUERY = 2, - PDU_TYPE_CACHE_RESPONSE = 3, - PDU_TYPE_IPV4_PREFIX = 4, - PDU_TYPE_IPV6_PREFIX = 6, - PDU_TYPE_END_OF_DATA = 7, - PDU_TYPE_CACHE_RESET = 8, - PDU_TYPE_ROUTER_KEY = 9, - PDU_TYPE_ERROR_REPORT = 10, + PDU_TYPE_SERIAL_NOTIFY = 0, + PDU_TYPE_SERIAL_QUERY = 1, + PDU_TYPE_RESET_QUERY = 2, + PDU_TYPE_CACHE_RESPONSE = 3, + PDU_TYPE_IPV4_PREFIX = 4, + PDU_TYPE_IPV6_PREFIX = 6, + PDU_TYPE_END_OF_DATA = 7, + PDU_TYPE_CACHE_RESET = 8, + PDU_TYPE_ROUTER_KEY = 9, + PDU_TYPE_ERROR_REPORT = 10, }; char const *pdutype2str(enum pdu_type); @@ -41,35 +37,40 @@ char const *pdutype2str(enum pdu_type); */ /* Header length field is always 64 bits long */ -#define RTRPDU_HDR_LEN 8 - -#define RTRPDU_SERIAL_NOTIFY_LEN 12 -#define RTRPDU_CACHE_RESPONSE_LEN 8 -#define RTRPDU_IPV4_PREFIX_LEN 20 -#define RTRPDU_IPV6_PREFIX_LEN 32 -#define RTRPDU_END_OF_DATA_V0_LEN 12 -#define RTRPDU_END_OF_DATA_V1_LEN 24 -#define RTRPDU_CACHE_RESET_LEN 8 -#define RTRPDU_ROUTER_KEY_LEN 123 +#define RTR_HDR_LEN 8u + +/* Please remember to update the MAX_LENs if you modify this list. */ +#define RTRPDU_SERIAL_NOTIFY_LEN 12u +#define RTRPDU_SERIAL_QUERY_LEN 12u +#define RTRPDU_RESET_QUERY_LEN 8u +#define RTRPDU_CACHE_RESPONSE_LEN 8u +#define RTRPDU_IPV4_PREFIX_LEN 20u +#define RTRPDU_IPV6_PREFIX_LEN 32u +#define RTRPDU_END_OF_DATA_V0_LEN 12u +#define RTRPDU_END_OF_DATA_V1_LEN 24u +#define RTRPDU_CACHE_RESET_LEN 8u +#define RTRPDU_ROUTER_KEY_LEN 123u +/* See rtrpdu_error_report_len() for the missing one. */ + +/* Except for Error Report PDUs. */ +#define RTRPDU_MAX_LEN RTRPDU_ROUTER_KEY_LEN +/* + * The length field is 32 bits. Error PDUs don't need to be that large. + * 1024 is arbitrary. + */ +#define RTRPDU_ERROR_REPORT_MAX_LEN 1024u -/* Ignores Error Report PDUs, which is fine. */ -#define RTRPDU_MAX_LEN RTRPDU_IPV6_PREFIX_LEN -#define RTRPDU_ERR_MAX_LEN 256 +#define RTRPDU_MAX_LEN2 RTRPDU_ERROR_REPORT_MAX_LEN struct pdu_header { - uint8_t protocol_version; - uint8_t pdu_type; + enum rtr_version version; + enum pdu_type type; union { - uint16_t session_id; - uint16_t reserved; - uint16_t error_code; + uint16_t session_id; + uint16_t reserved; + uint16_t error_code; } m; /* Note: "m" stands for "meh." I have no idea what to call this. */ - uint32_t length; -}; - -struct serial_notify_pdu { - struct pdu_header header; - uint32_t serial_number; + uint32_t length; }; struct serial_query_pdu { @@ -81,81 +82,22 @@ struct reset_query_pdu { struct pdu_header header; }; -struct cache_response_pdu { - struct pdu_header header; -}; - -struct ipv4_prefix_pdu { - struct pdu_header header; - uint8_t flags; - uint8_t prefix_length; - uint8_t max_length; - uint8_t zero; - struct in_addr ipv4_prefix; - uint32_t asn; -}; - -struct ipv6_prefix_pdu { - struct pdu_header header; - uint8_t flags; - uint8_t prefix_length; - uint8_t max_length; - uint8_t zero; - struct in6_addr ipv6_prefix; - uint32_t asn; -}; - -struct end_of_data_pdu { - struct pdu_header header; - uint32_t serial_number; - uint32_t refresh_interval; - uint32_t retry_interval; - uint32_t expire_interval; -}; - -struct cache_reset_pdu { - struct pdu_header header; -}; - -struct router_key_pdu { - struct pdu_header header; - unsigned char ski[RK_SKI_LEN]; - size_t ski_len; - uint32_t asn; - unsigned char spki[RK_SPKI_LEN]; - size_t spki_len; -}; - struct error_report_pdu { struct pdu_header header; - uint32_t error_pdu_length; - unsigned char erroneous_pdu[RTRPDU_ERR_MAX_LEN]; - uint32_t error_message_length; - rtr_char *error_message; -}; - -struct pdu_metadata { - size_t length; - /** - * Builds the PDU from @header, and the bytes remaining in the reader. - * - * Caller assumes that from_stream functions are only allowed to fail - * on programming errors. (Because failure results in an internal error - * response.) - */ - int (*from_stream)(struct pdu_header *, struct pdu_reader *, void *); - /** - * Handlers must return 0 to maintain the connection, nonzero to close - * the socket. - * Also, they are supposed to send error PDUs on discretion. - */ - int (*handle)(int, struct rtr_request const *); - void (*destructor)(void *); + uint32_t errpdu_len; + unsigned char errpdu[RTRPDU_MAX_LEN]; + uint32_t errmsg_len; + char *errmsg; }; -int pdu_load(struct pdu_reader *, struct rtr_client *, struct rtr_request *, - struct pdu_metadata const **); -struct pdu_metadata const *pdu_get_metadata(uint8_t); -struct pdu_header *pdu_get_header(void *); +static inline size_t +rtrpdu_error_report_len(uint32_t errpdu_len, uint32_t errmsg_len) +{ + return RTR_HDR_LEN + + 4 /* Length of Encapsulated PDU field */ + + errpdu_len + + 4 /* Length of Error Text field */ + + errmsg_len; +} #endif /* RTR_PDU_H_ */ diff --git a/src/rtr/pdu_handler.c b/src/rtr/pdu_handler.c index f7576d9b..e9277b18 100644 --- a/src/rtr/pdu_handler.c +++ b/src/rtr/pdu_handler.c @@ -1,24 +1,10 @@ #include "rtr/pdu_handler.h" #include - -#include "rtr/err_pdu.h" #include "log.h" -#include "rtr/pdu.h" +#include "rtr/err_pdu.h" #include "rtr/pdu_sender.h" -#include "rtr/db/vrps.h" - -#define WARN_UNEXPECTED_PDU(name, fd, request, pdu_name) \ - struct name##_pdu *pdu = request->pdu; \ - return err_pdu_send_invalid_request(fd, \ - pdu->header.protocol_version, \ - request, "Clients are not supposed to send " pdu_name " PDUs."); - -int -handle_serial_notify_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(serial_notify, fd, request, "Serial Notify"); -} +#include "rtr/pdu_stream.h" struct send_delta_args { int fd; @@ -70,13 +56,18 @@ send_delta_rk(struct delta_router_key const *delta, void *arg) } int -handle_serial_query_pdu(int fd, struct rtr_request const *request) +handle_serial_query_pdu(struct rtr_request *request, struct rtr_pdu *pdu) { - struct serial_query_pdu *query = request->pdu; + struct serial_query_pdu *sq; struct send_delta_args args; serial_t final_serial; int error; + sq = &pdu->obj.sq; + args.fd = request->fd; + args.rtr_version = sq->header.version; + args.cache_response_sent = false; + /* * RFC 6810 and 8210: * "If [...] either the router or the cache finds that the value of the @@ -84,14 +75,9 @@ handle_serial_query_pdu(int fd, struct rtr_request const *request) * the mismatch MUST immediately terminate the session with an Error * Report PDU with code 0 ("Corrupt Data")" */ - args.rtr_version = query->header.protocol_version; - if (query->header.m.session_id != - get_current_session_id(args.rtr_version)) - return err_pdu_send_corrupt_data(fd, args.rtr_version, request, - "Session ID doesn't match."); - - args.fd = fd; - args.cache_response_sent = false; + if (sq->header.m.session_id != get_current_session_id(args.rtr_version)) + return err_pdu_send_corrupt_data(args.fd, args.rtr_version, + &pdu->raw, "Session ID doesn't match."); /* * For the record, there are two reasons why we want to work on a @@ -102,7 +88,7 @@ handle_serial_query_pdu(int fd, struct rtr_request const *request) * PDUs, to minimize writer stagnation. */ - error = vrps_foreach_delta_since(query->serial_number, &final_serial, + error = vrps_foreach_delta_since(sq->serial_number, &final_serial, send_delta_vrp, send_delta_rk, &args); switch (error) { case 0: @@ -113,16 +99,18 @@ handle_serial_query_pdu(int fd, struct rtr_request const *request) * and programming errors. Best avoid error PDUs. */ if (!args.cache_response_sent) { - error = send_cache_response_pdu(fd, args.rtr_version); + error = send_cache_response_pdu(args.fd, + args.rtr_version); if (error) return error; } - return send_end_of_data_pdu(fd, args.rtr_version, final_serial); + return send_end_of_data_pdu(args.fd, args.rtr_version, + final_serial); case -EAGAIN: /* Database still under construction */ - return err_pdu_send_no_data_available(fd, args.rtr_version); + return err_pdu_send_no_data_available(args.fd, args.rtr_version); case -ESRCH: /* Invalid serial */ /* https://tools.ietf.org/html/rfc6810#section-6.3 */ - return send_cache_reset_pdu(fd, args.rtr_version); + return send_cache_reset_pdu(args.fd, args.rtr_version); case -ENOMEM: /* Memory allocation failure */ enomem_panic(); case EAGAIN: /* Too many threads */ @@ -133,7 +121,7 @@ handle_serial_query_pdu(int fd, struct rtr_request const *request) break; } - return err_pdu_send_internal_error(fd, args.rtr_version); + return err_pdu_send_internal_error(args.fd, args.rtr_version); } struct base_roa_args { @@ -176,25 +164,24 @@ send_base_router_key(struct router_key const *key, void *arg) } int -handle_reset_query_pdu(int fd, struct rtr_request const *request) +handle_reset_query_pdu(struct rtr_request *request, struct rtr_pdu *pdu) { - struct reset_query_pdu *pdu = request->pdu; struct base_roa_args args; serial_t current_serial; int error; args.started = false; - args.fd = fd; - args.version = pdu->header.protocol_version; + args.fd = request->fd; + args.version = pdu->obj.hdr.version; error = get_last_serial_number(¤t_serial); switch (error) { case 0: break; case -EAGAIN: - return err_pdu_send_no_data_available(fd, args.version); + return err_pdu_send_no_data_available(args.fd, args.version); default: - err_pdu_send_internal_error(fd, args.version); + err_pdu_send_internal_error(args.fd, args.version); return error; } @@ -213,73 +200,39 @@ handle_reset_query_pdu(int fd, struct rtr_request const *request) /* Assure that cache response is (or was) sent */ if (args.started) break; - error = send_cache_response_pdu(fd, args.version); + error = send_cache_response_pdu(args.fd, args.version); if (error) return error; break; case -EAGAIN: - return err_pdu_send_no_data_available(fd, args.version); + return err_pdu_send_no_data_available(args.fd, args.version); case EAGAIN: - err_pdu_send_internal_error(fd, args.version); + err_pdu_send_internal_error(args.fd, args.version); return error; default: /* Any other error must stop sending more PDUs */ return error; } - return send_end_of_data_pdu(fd, args.version, current_serial); + return send_end_of_data_pdu(args.fd, args.version, current_serial); } int -handle_cache_response_pdu(int fd, struct rtr_request const *request) +handle_error_report_pdu(struct rtr_request *request, struct rtr_pdu *pdu) { - WARN_UNEXPECTED_PDU(cache_response, fd, request, "Cache Response"); -} - -int -handle_ipv4_prefix_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(ipv4_prefix, fd, request, "IPv4 Prefix"); -} - -int -handle_ipv6_prefix_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(ipv6_prefix, fd, request, "IPv6 Prefix"); -} - -int -handle_end_of_data_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(end_of_data, fd, request, "End of Data"); -} - -int -handle_cache_reset_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(cache_reset, fd, request, "Cache Reset"); -} - -int -handle_router_key_pdu(int fd, struct rtr_request const *request) -{ - WARN_UNEXPECTED_PDU(router_key, fd, request, "Router Key"); -} - -int -handle_error_report_pdu(int fd, struct rtr_request const *request) -{ - struct error_report_pdu *received = request->pdu; + struct error_report_pdu *er; char const *error_name; - error_name = err_pdu_to_string(received->header.m.error_code); + er = &pdu->obj.er; + error_name = err_pdu_to_string(er->header.m.error_code); - if (received->error_message != NULL) - pr_op_info("Client responded with error PDU '%s' ('%s'). Closing socket.", - error_name, received->error_message); - else - pr_op_info("Client responded with error PDU '%s'. Closing socket.", - error_name); + if (er->errmsg != NULL) { + pr_op_info("RTR client %s responded with error PDU '%s' ('%s'). Closing socket.", + request->client_addr, error_name, er->errmsg); + } else { + pr_op_info("RTR client %s responded with error PDU '%s'. Closing socket.", + request->client_addr, error_name); + } return -EINVAL; } diff --git a/src/rtr/pdu_handler.h b/src/rtr/pdu_handler.h index c45b8f55..83b39b01 100644 --- a/src/rtr/pdu_handler.h +++ b/src/rtr/pdu_handler.h @@ -1,17 +1,10 @@ -#ifndef RTR_PDU_HANDLER_H_ -#define RTR_PDU_HANDLER_H_ +#ifndef SRC_RTR_PDU_HANDLER_H_ +#define SRC_RTR_PDU_HANDLER_H_ -#include "rtr/pdu.h" +#include "rtr/pdu_stream.h" -int handle_serial_notify_pdu(int, struct rtr_request const *); -int handle_serial_query_pdu(int, struct rtr_request const *); -int handle_reset_query_pdu(int, struct rtr_request const *); -int handle_cache_response_pdu(int, struct rtr_request const *); -int handle_ipv4_prefix_pdu(int, struct rtr_request const *); -int handle_ipv6_prefix_pdu(int, struct rtr_request const *); -int handle_end_of_data_pdu(int, struct rtr_request const *); -int handle_cache_reset_pdu(int, struct rtr_request const *); -int handle_router_key_pdu(int, struct rtr_request const *); -int handle_error_report_pdu(int, struct rtr_request const *); +int handle_serial_query_pdu(struct rtr_request *, struct rtr_pdu *); +int handle_reset_query_pdu(struct rtr_request *, struct rtr_pdu *); +int handle_error_report_pdu(struct rtr_request *, struct rtr_pdu *); -#endif /* RTR_PDU_HANDLER_H_ */ +#endif /* SRC_RTR_PDU_HANDLER_H_ */ diff --git a/src/rtr/pdu_sender.c b/src/rtr/pdu_sender.c index 563a10e3..9c9424dc 100644 --- a/src/rtr/pdu_sender.c +++ b/src/rtr/pdu_sender.c @@ -8,19 +8,18 @@ #include "common.h" #include "config.h" #include "log.h" -#include "rtr/pdu_serializer.h" #include "rtr/db/vrps.h" +#include "rtr/primitive_writer.h" -/* - * Set all the header values, EXCEPT length field. - */ -static void -set_header_values(struct pdu_header *header, uint8_t version, uint8_t type, - uint16_t reserved) +static unsigned char * +serialize_hdr(unsigned char *buf, uint8_t version, uint8_t type, + uint16_t m, uint32_t length) { - header->protocol_version = version; - header->pdu_type = type; - header->m.reserved = reserved; + buf = write_uint8(buf, version); + buf = write_uint8(buf, type); + buf = write_uint16(buf, m); + buf = write_uint32(buf, length); + return buf; } static int @@ -58,130 +57,80 @@ send_response(int fd, uint8_t pdu_type, unsigned char *data, size_t data_len) int send_serial_notify_pdu(int fd, uint8_t version, serial_t start_serial) { - struct serial_notify_pdu pdu; + static const uint8_t type = PDU_TYPE_SERIAL_NOTIFY; + static const uint32_t len = RTRPDU_SERIAL_NOTIFY_LEN; unsigned char data[RTRPDU_SERIAL_NOTIFY_LEN]; - size_t len; - - set_header_values(&pdu.header, version, PDU_TYPE_SERIAL_NOTIFY, - get_current_session_id(version)); - - pdu.serial_number = start_serial; - pdu.header.length = RTRPDU_SERIAL_NOTIFY_LEN; + unsigned char *buf; - len = serialize_serial_notify_pdu(&pdu, data); - if (len != RTRPDU_SERIAL_NOTIFY_LEN) - pr_crit("Serialized Serial Notify is %zu bytes.", len); + buf = serialize_hdr(data, version, type, + get_current_session_id(version), len); + buf = write_uint32(buf, start_serial); - return send_response(fd, pdu.header.pdu_type, data, len); + return send_response(fd, type, data, len); } int send_cache_reset_pdu(int fd, uint8_t version) { - struct cache_reset_pdu pdu; + static const uint8_t type = PDU_TYPE_CACHE_RESET; + static const uint32_t len = RTRPDU_CACHE_RESET_LEN; unsigned char data[RTRPDU_CACHE_RESET_LEN]; - size_t len; - - /* This PDU has only the header */ - set_header_values(&pdu.header, version, PDU_TYPE_CACHE_RESET, 0); - pdu.header.length = RTRPDU_CACHE_RESET_LEN; - len = serialize_cache_reset_pdu(&pdu, data); - if (len != RTRPDU_CACHE_RESET_LEN) - pr_crit("Serialized Cache Reset is %zu bytes.", len); + serialize_hdr(data, version, type, 0, len); - return send_response(fd, pdu.header.pdu_type, data, len); + return send_response(fd, type, data, len); } int send_cache_response_pdu(int fd, uint8_t version) { - struct cache_response_pdu pdu; + static const uint8_t type = PDU_TYPE_CACHE_RESPONSE; + static const uint32_t len = RTRPDU_CACHE_RESPONSE_LEN; unsigned char data[RTRPDU_CACHE_RESPONSE_LEN]; - size_t len; - - /* This PDU has only the header */ - set_header_values(&pdu.header, version, PDU_TYPE_CACHE_RESPONSE, - get_current_session_id(version)); - pdu.header.length = RTRPDU_CACHE_RESPONSE_LEN; - - len = serialize_cache_response_pdu(&pdu, data); - if (len != RTRPDU_CACHE_RESPONSE_LEN) - pr_crit("Serialized Cache Response is %zu bytes.", len); - - return send_response(fd, pdu.header.pdu_type, data, len); -} -static void -pr_debug_prefix4(struct ipv4_prefix_pdu *pdu) -{ - char buffer[INET_ADDRSTRLEN]; + serialize_hdr(data, version, type, get_current_session_id(version), len); - pr_op_debug("Encoded prefix %s/%u into a PDU.", - addr2str4(&pdu->ipv4_prefix, buffer), pdu->prefix_length); + return send_response(fd, type, data, len); } static int send_ipv4_prefix_pdu(int fd, uint8_t version, struct vrp const *vrp, uint8_t flags) { - struct ipv4_prefix_pdu pdu; + static const uint8_t type = PDU_TYPE_IPV4_PREFIX; + static const uint32_t len = RTRPDU_IPV4_PREFIX_LEN; unsigned char data[RTRPDU_IPV4_PREFIX_LEN]; - size_t len; - - set_header_values(&pdu.header, version, PDU_TYPE_IPV4_PREFIX, 0); - pdu.header.length = RTRPDU_IPV4_PREFIX_LEN; - - pdu.flags = flags; - pdu.prefix_length = vrp->prefix_length; - pdu.max_length = vrp->max_prefix_length; - pdu.zero = 0; - pdu.ipv4_prefix = vrp->prefix.v4; - pdu.asn = vrp->asn; - - len = serialize_ipv4_prefix_pdu(&pdu, data); - if (len != RTRPDU_IPV4_PREFIX_LEN) - pr_crit("Serialized IPv4 Prefix is %zu bytes.", len); - if (log_op_enabled(LOG_DEBUG)) - pr_debug_prefix4(&pdu); + unsigned char *buf; - return send_response(fd, pdu.header.pdu_type, data, len); -} - -static void -pr_debug_prefix6(struct ipv6_prefix_pdu *pdu) -{ - char buffer[INET6_ADDRSTRLEN]; + buf = serialize_hdr(data, version, type, 0, len); + buf = write_uint8(buf, flags); + buf = write_uint8(buf, vrp->prefix_length); + buf = write_uint8(buf, vrp->max_prefix_length); + buf = write_uint8(buf, 0); + buf = write_in_addr(buf, vrp->prefix.v4); + buf = write_uint32(buf, vrp->asn); - pr_op_debug("Encoded prefix %s/%u into a PDU.", - addr2str6(&pdu->ipv6_prefix, buffer), pdu->prefix_length); + return send_response(fd, type, data, len); } static int send_ipv6_prefix_pdu(int fd, uint8_t version, struct vrp const *vrp, uint8_t flags) { - struct ipv6_prefix_pdu pdu; + static const uint8_t type = PDU_TYPE_IPV6_PREFIX; + static const uint32_t len = RTRPDU_IPV6_PREFIX_LEN; unsigned char data[RTRPDU_IPV6_PREFIX_LEN]; - size_t len; - - set_header_values(&pdu.header, version, PDU_TYPE_IPV6_PREFIX, 0); - pdu.header.length = RTRPDU_IPV6_PREFIX_LEN; + unsigned char *buf; - pdu.flags = flags; - pdu.prefix_length = vrp->prefix_length; - pdu.max_length = vrp->max_prefix_length; - pdu.zero = 0; - pdu.ipv6_prefix = vrp->prefix.v6; - pdu.asn = vrp->asn; + buf = serialize_hdr(data, version, PDU_TYPE_IPV6_PREFIX, 0, len); + buf = write_uint8(buf, flags); + buf = write_uint8(buf, vrp->prefix_length); + buf = write_uint8(buf, vrp->max_prefix_length); + buf = write_uint8(buf, 0); + buf = write_in6_addr(buf, &vrp->prefix.v6); + buf = write_uint32(buf, vrp->asn); - len = serialize_ipv6_prefix_pdu(&pdu, data); - if (len != RTRPDU_IPV6_PREFIX_LEN) - pr_crit("Serialized IPv6 Prefix is %zu bytes.", len); - if (log_op_enabled(LOG_DEBUG)) - pr_debug_prefix6(&pdu); - - return send_response(fd, pdu.header.pdu_type, data, len); + return send_response(fd, type, data, len); } int @@ -201,33 +150,22 @@ int send_router_key_pdu(int fd, uint8_t version, struct router_key const *router_key, uint8_t flags) { - struct router_key_pdu pdu; + static const uint8_t type = PDU_TYPE_ROUTER_KEY; + static const uint32_t len = RTRPDU_ROUTER_KEY_LEN; unsigned char data[RTRPDU_ROUTER_KEY_LEN]; - size_t len; - uint16_t reserved; + unsigned char *buf; - /* Sanity check: this can't be sent on RTRv0 */ if (version == RTR_V0) return 0; - reserved = 0; - /* Set the flags at the first 8 bits of reserved field */ - reserved += (flags << 8); - set_header_values(&pdu.header, version, PDU_TYPE_ROUTER_KEY, reserved); - pdu.header.length = RTRPDU_ROUTER_KEY_LEN; - - memcpy(pdu.ski, router_key->ski, RK_SKI_LEN); - pdu.ski_len = RK_SKI_LEN; - pdu.asn = router_key->as; - memcpy(pdu.spki, router_key->spk, RK_SPKI_LEN); - pdu.spki_len = RK_SPKI_LEN; + buf = serialize_hdr(data, version, type, flags << 8, len); + memcpy(buf, router_key->ski, sizeof(router_key->ski)); + buf += sizeof(router_key->ski); + buf = write_uint32(buf, router_key->as); + memcpy(buf, router_key->spk, sizeof(router_key->spk)); + buf += sizeof(router_key->spk); - len = serialize_router_key_pdu(&pdu, data); - if (len != RTRPDU_ROUTER_KEY_LEN) - pr_crit("Serialized Router Key PDU is %zu bytes, not the expected %u.", - len, pdu.header.length); - - return send_response(fd, pdu.header.pdu_type, data, len); + return send_response(fd, type, data, len); } #define MAX(a, b) ((a > b) ? a : b) @@ -235,70 +173,74 @@ send_router_key_pdu(int fd, uint8_t version, int send_end_of_data_pdu(int fd, uint8_t version, serial_t end_serial) { - struct end_of_data_pdu pdu; - unsigned char data[MAX( - RTRPDU_END_OF_DATA_V1_LEN, RTRPDU_END_OF_DATA_V0_LEN - )]; - size_t len; - - set_header_values(&pdu.header, version, PDU_TYPE_END_OF_DATA, - get_current_session_id(version)); - - pdu.serial_number = end_serial; + static const uint8_t type = PDU_TYPE_ROUTER_KEY; + unsigned char data[ + MAX(RTRPDU_END_OF_DATA_V1_LEN, RTRPDU_END_OF_DATA_V0_LEN) + ]; + unsigned char *buf; + uint32_t len; + + len = (version == RTR_V1) + ? RTRPDU_END_OF_DATA_V1_LEN + : RTRPDU_END_OF_DATA_V0_LEN; + buf = serialize_hdr(data, version, type, + get_current_session_id(version), len); + + buf = write_uint32(buf, end_serial); if (version == RTR_V1) { - pdu.header.length = RTRPDU_END_OF_DATA_V1_LEN; - pdu.refresh_interval = config_get_interval_refresh(); - pdu.retry_interval = config_get_interval_retry(); - pdu.expire_interval = config_get_interval_expire(); - } else { - pdu.header.length = RTRPDU_END_OF_DATA_V0_LEN; + buf = write_uint32(buf, config_get_interval_refresh()); + buf = write_uint32(buf, config_get_interval_retry()); + buf = write_uint32(buf, config_get_interval_expire()); } - len = serialize_end_of_data_pdu(&pdu, data); - if (len != pdu.header.length) - pr_crit("Serialized End of Data is %zu bytes.", len); + return send_response(fd, type, data, len); +} + +static size_t +compute_error_pdu_len(struct rtr_buffer const *request) +{ + unsigned int result; - return send_response(fd, pdu.header.pdu_type, data, len); + if (request == NULL || request->bytes_len < RTR_HDR_LEN) + return 0; + + result = (((unsigned int)(request->bytes[4])) << 24) + | (((unsigned int)(request->bytes[5])) << 16) + | (((unsigned int)(request->bytes[6])) << 8) + | (((unsigned int)(request->bytes[7])) ); + + return (result <= RTRPDU_MAX_LEN) ? result : RTRPDU_MAX_LEN; } int send_error_report_pdu(int fd, uint8_t version, uint16_t code, - struct rtr_request const *request, char *message) + struct rtr_buffer const *request, char *message) { - struct error_report_pdu pdu; - unsigned char *data; + static const uint8_t type = PDU_TYPE_ERROR_REPORT; + unsigned char *data, *buf; + size_t error_pdu_len; + size_t error_msg_len; size_t len; int error; - set_header_values(&pdu.header, version, PDU_TYPE_ERROR_REPORT, code); + error_pdu_len = compute_error_pdu_len(request); + error_msg_len = (message != NULL) ? strlen(message) : 0; + len = rtrpdu_error_report_len(error_pdu_len, error_msg_len); + data = pmalloc(len); - if (request != NULL) { - pdu.error_pdu_length = (request->bytes_len > RTRPDU_MAX_LEN) - ? RTRPDU_MAX_LEN - : request->bytes_len; - memcpy(pdu.erroneous_pdu, request->bytes, pdu.error_pdu_length); - } else { - pdu.error_pdu_length = 0; + buf = serialize_hdr(data, version, type, 0, len); + buf = write_uint32(buf, error_pdu_len); + if (error_pdu_len > 0) { + memcpy(buf, request->bytes, error_pdu_len); + buf += error_pdu_len; + } + buf = write_uint32(buf, error_msg_len); + if (error_msg_len > 0) { + memcpy(buf, message, error_msg_len); + buf += error_msg_len; } - pdu.error_message_length = (message != NULL) ? strlen(message) : 0; - pdu.error_message = message; - - pdu.header.length = RTRPDU_HDR_LEN - + 4 /* Length of Encapsulated PDU field */ - + pdu.error_pdu_length - + 4 /* Length of Error Text field */ - + pdu.error_message_length; - - data = pmalloc(pdu.header.length); - - len = serialize_error_report_pdu(&pdu, data); - if (len != pdu.header.length) - pr_crit("Serialized Error Report PDU is %zu bytes, not the expected %u.", - len, pdu.header.length); - - error = send_response(fd, pdu.header.pdu_type, data, len); - + error = send_response(fd, type, data, len); free(data); return error; } diff --git a/src/rtr/pdu_sender.h b/src/rtr/pdu_sender.h index ad91531d..cecb95d7 100644 --- a/src/rtr/pdu_sender.h +++ b/src/rtr/pdu_sender.h @@ -11,7 +11,7 @@ int send_cache_response_pdu(int, uint8_t); int send_prefix_pdu(int, uint8_t, struct vrp const *, uint8_t); int send_router_key_pdu(int, uint8_t, struct router_key const *, uint8_t); int send_end_of_data_pdu(int, uint8_t, serial_t); -int send_error_report_pdu(int, uint8_t, uint16_t, struct rtr_request const *, +int send_error_report_pdu(int, uint8_t, uint16_t, struct rtr_buffer const *, char *); #endif /* SRC_RTR_PDU_SENDER_H_ */ diff --git a/src/rtr/pdu_serializer.c b/src/rtr/pdu_serializer.c deleted file mode 100644 index 9eb40efc..00000000 --- a/src/rtr/pdu_serializer.c +++ /dev/null @@ -1,161 +0,0 @@ -#include "rtr/pdu_serializer.h" - -#include "rtr/primitive_writer.h" - -static size_t -serialize_pdu_header(struct pdu_header const *header, uint16_t union_value, - unsigned char *buf) -{ - unsigned char *ptr; - - ptr = buf; - ptr = write_int8(ptr, header->protocol_version); - ptr = write_int8(ptr, header->pdu_type); - ptr = write_int16(ptr, union_value); - ptr = write_int32(ptr, header->length); - - return ptr - buf; -} - -size_t -serialize_serial_notify_pdu(struct serial_notify_pdu *pdu, unsigned char *buf) -{ - size_t head_size; - unsigned char *ptr; - - head_size = serialize_pdu_header(&pdu->header, pdu->header.m.session_id, - buf); - - ptr = buf + head_size; - ptr = write_int32(ptr, pdu->serial_number); - - return ptr - buf; -} - -size_t -serialize_cache_response_pdu(struct cache_response_pdu *pdu, - unsigned char *buf) -{ - /* No payload to serialize */ - return serialize_pdu_header(&pdu->header, pdu->header.m.session_id, - buf); -} - -size_t -serialize_ipv4_prefix_pdu(struct ipv4_prefix_pdu *pdu, unsigned char *buf) -{ - size_t head_size; - unsigned char *ptr; - - head_size = serialize_pdu_header(&pdu->header, pdu->header.m.reserved, - buf); - - ptr = buf + head_size; - ptr = write_int8(ptr, pdu->flags); - ptr = write_int8(ptr, pdu->prefix_length); - ptr = write_int8(ptr, pdu->max_length); - ptr = write_int8(ptr, pdu->zero); - ptr = write_in_addr(ptr, pdu->ipv4_prefix); - ptr = write_int32(ptr, pdu->asn); - - return ptr - buf; -} - -size_t -serialize_ipv6_prefix_pdu(struct ipv6_prefix_pdu *pdu, unsigned char *buf) -{ - size_t head_size; - unsigned char *ptr; - - head_size = serialize_pdu_header(&pdu->header, pdu->header.m.reserved, - buf); - - ptr = buf + head_size; - ptr = write_int8(ptr, pdu->flags); - ptr = write_int8(ptr, pdu->prefix_length); - ptr = write_int8(ptr, pdu->max_length); - ptr = write_int8(ptr, pdu->zero); - ptr = write_in6_addr(ptr, pdu->ipv6_prefix); - ptr = write_int32(ptr, pdu->asn); - - return ptr - buf; -} - -size_t -serialize_end_of_data_pdu(struct end_of_data_pdu const *pdu, unsigned char *buf) -{ - size_t head_size; - unsigned char *ptr; - - head_size = serialize_pdu_header(&pdu->header, pdu->header.m.session_id, - buf); - - ptr = buf + head_size; - ptr = write_int32(ptr, pdu->serial_number); - if (pdu->header.protocol_version == RTR_V1) { - ptr = write_int32(ptr, pdu->refresh_interval); - ptr = write_int32(ptr, pdu->retry_interval); - ptr = write_int32(ptr, pdu->expire_interval); - } - - return ptr - buf; -} - -size_t -serialize_cache_reset_pdu(struct cache_reset_pdu *pdu, unsigned char *buf) -{ - /* No payload to serialize */ - return serialize_pdu_header(&pdu->header, pdu->header.m.reserved, buf); -} - -/* - * Don't forget to use 'header->reserved' to set flags - */ -size_t -serialize_router_key_pdu(struct router_key_pdu *pdu, unsigned char *buf) -{ - size_t head_size; - unsigned char *ptr; - int i; - - if (pdu->header.protocol_version == RTR_V0) - return 0; - - head_size = serialize_pdu_header(&pdu->header, pdu->header.m.reserved, - buf); - - ptr = buf + head_size; - - for (i = 0; i < pdu->ski_len; i++) - ptr = write_int8(ptr, pdu->ski[i]); - - ptr = write_int32(ptr, pdu->asn); - - for (i = 0; i < pdu->spki_len; i++) - ptr = write_int8(ptr, pdu->spki[i]); - - return ptr - buf; -} - -size_t -serialize_error_report_pdu(struct error_report_pdu *pdu, unsigned char *buf) -{ - unsigned char *ptr; - - ptr = buf; - ptr += serialize_pdu_header(&pdu->header, pdu->header.m.error_code, buf); - - ptr = write_int32(ptr, pdu->error_pdu_length); - if (pdu->error_pdu_length > 0) { - memcpy(ptr, pdu->erroneous_pdu, pdu->error_pdu_length); - ptr += pdu->error_pdu_length; - } - - ptr = write_int32(ptr, pdu->error_message_length); - if (pdu->error_message_length > 0) { - memcpy(ptr, pdu->error_message, pdu->error_message_length); - ptr += pdu->error_message_length; - } - - return ptr - buf; -} diff --git a/src/rtr/pdu_serializer.h b/src/rtr/pdu_serializer.h deleted file mode 100644 index 6b157a5d..00000000 --- a/src/rtr/pdu_serializer.h +++ /dev/null @@ -1,18 +0,0 @@ -#ifndef SRC_RTR_PDU_SERIALIZER_H_ -#define SRC_RTR_PDU_SERIALIZER_H_ - -#include "rtr/pdu.h" - -size_t serialize_serial_notify_pdu(struct serial_notify_pdu *, - unsigned char *); -size_t serialize_cache_response_pdu(struct cache_response_pdu *, - unsigned char *); -size_t serialize_ipv4_prefix_pdu(struct ipv4_prefix_pdu *, unsigned char *); -size_t serialize_ipv6_prefix_pdu(struct ipv6_prefix_pdu *, unsigned char *); -size_t serialize_end_of_data_pdu(struct end_of_data_pdu const *, - unsigned char *); -size_t serialize_cache_reset_pdu(struct cache_reset_pdu *, unsigned char *); -size_t serialize_router_key_pdu(struct router_key_pdu *, unsigned char *); -size_t serialize_error_report_pdu(struct error_report_pdu *, unsigned char *); - -#endif /* SRC_RTR_PDU_SERIALIZER_H_ */ diff --git a/src/rtr/pdu_stream.c b/src/rtr/pdu_stream.c new file mode 100644 index 00000000..b2fb6090 --- /dev/null +++ b/src/rtr/pdu_stream.c @@ -0,0 +1,549 @@ +#include "rtr/pdu_stream.h" + +#include +#include + +#include "log.h" +#include "alloc.h" +#include "rtr/pdu.h" +#include "rtr/err_pdu.h" + +enum buffer_state { + /* We've read all available bytes for now. */ + BS_WOULD_BLOCK, + /* "End of Stream." We've read all available bytes, ever. */ + BS_EOS, + /* read() still has more data to yield (but buffer is full for now). */ + BS_KEEP_READING, + /* Communication broken. */ + BS_ERROR, +}; + +struct pdu_stream { + int fd; + char addr[INET6_ADDRSTRLEN]; /* Printable address of the client. */ + int rtr_version; /* -1: unset; > 0: version number */ + + unsigned char buffer[RTRPDU_MAX_LEN2]; + + /* buffer's active bytes */ + unsigned char *start; + unsigned char *end; +}; + +struct pdu_stream *pdustream_create(int fd, char const *addr) +{ + struct pdu_stream *result; + + result = pmalloc(sizeof(struct pdu_stream)); + result->fd = fd; + strcpy(result->addr, addr); + result->rtr_version = -1; + result->start = result->buffer; + result->end = result->buffer; + + return result; +} + +void +pdustream_destroy(struct pdu_stream **_stream) +{ + struct pdu_stream *stream = *_stream; + close(stream->fd); + free(stream); +} + +static size_t +get_length(struct pdu_stream *stream) +{ + return stream->end - stream->start; +} + +/* + * Will read whatever's in the stream without blocking, but not more than + * RTRPDU_MAX_LEN2 bytes. + * + * It might read more than one PDU into the buffer, and extremely unlikely, + * the last PDU might be incomplete (even if it's the only one). + * + * Returns + * - true: success. + * - false: oh noes; close socket. + */ +static enum buffer_state +update_buffer(struct pdu_stream *in /* "in"put stream */) +{ + ssize_t consumed; + int error; + + /* Move leftover bytes to the beginning */ + if (in->buffer != in->start) { + if (in->start != in->end) + memmove(in->buffer, in->start, get_length(in)); + in->end -= in->start - in->buffer; + in->start = in->buffer; + } + + for (; in->end < in->start + RTRPDU_MAX_LEN2; in->end += consumed) { + consumed = read(in->fd, in->end, RTRPDU_MAX_LEN2 - get_length(in)); + if (consumed == -1) { + error = errno; + if (error == EAGAIN || error == EWOULDBLOCK) + return BS_WOULD_BLOCK; + + pr_op_err("Client socket read interrupted: %s", + strerror(error)); + return BS_ERROR; + } + + if (consumed == 0) { + pr_op_debug("Client closed the socket."); + return BS_EOS; + } + } + + /* + * We might or might not have read everything, but we have at least one + * big PDU that either lengths exactly RTRPDU_MAX_LEN2, or is too big + * for us to to allow it. + */ + return BS_KEEP_READING; +} + +static uint16_t +read_uint16(unsigned char *buffer) +{ + return (((uint16_t)buffer[0]) << 8) + | (((uint16_t)buffer[1]) ); +} + +static uint32_t +read_uint32(unsigned char *buffer) +{ + return (((uint32_t)buffer[0]) << 24) + | (((uint32_t)buffer[1]) << 16) + | (((uint32_t)buffer[2]) << 8) + | (((uint32_t)buffer[3]) ); +} + +#define EINVALID_UTF8 -0xFFFF + +/* + * Returns the length (in octets) of the UTF-8 code point that starts with + * octet @first_octet. + */ +static int +get_octets(unsigned char first_octet) +{ + if ((first_octet & 0x80) == 0) + return 1; + if ((first_octet >> 5) == 6) /* 0b110 */ + return 2; + if ((first_octet >> 4) == 14) /* 0b1110 */ + return 3; + if ((first_octet >> 3) == 30) /* 0b11110 */ + return 4; + return EINVALID_UTF8; +} + +/* This is just a cast. The barebones version is too cluttered. */ +#define UCHAR(c) ((unsigned char *)c) + +/* + * This also sanitizes the string, BTW. + * (Because it overrides the first invalid character with the null chara. + * The rest is silently ignored.) + */ +static void +place_null_character(char *str, size_t len) +{ + char *null_chara_pos; + char *cursor; + int octet; + int octets; + + /* + * This could be optimized by noticing that all byte continuations in + * UTF-8 start with 0b10. This means that we could start from the end + * of the string and move left until we find a valid character. + * But if we do that, we'd lose the sanitization. So this is better + * methinks. + */ + + null_chara_pos = str; + cursor = str; + + while (cursor < str + len) { + octets = get_octets(*UCHAR(cursor)); + if (octets == EINVALID_UTF8) + break; + cursor++; + + for (octet = 1; octet < octets; octet++) { + /* Memory ends in the middle of this code point? */ + if (cursor >= str + len) + goto end; + /* All continuation octets must begin with 0b10. */ + if ((*(UCHAR(cursor)) >> 6) != 2 /* 0b10 */) + goto end; + cursor++; + } + + null_chara_pos = cursor; + } + +end: + *null_chara_pos = '\0'; +} + +static char * +read_string(struct pdu_stream *stream, uint32_t len) +{ + char *string; + + if (len == 0) + return NULL; + + string = pmalloc(len + 1); + memcpy(string, stream->start, len); + place_null_character(string, len); + + return string; +} + +static void +read_hdr(struct pdu_stream *stream, struct pdu_header *header) +{ + header->version = stream->start[0]; + header->type = stream->start[1]; + header->m.reserved = read_uint16(stream->start + 2); + header->length = read_uint32(stream->start + 4); +} + +static int +validate_rtr_version(struct pdu_stream *stream, struct pdu_header *header, + struct rtr_buffer *request) +{ + switch (stream->rtr_version) { + case RTR_V1: + switch (header->version) { + case RTR_V0: + goto unexpected; + case RTR_V1: + return 0; + default: + goto unsupported; + } + + case RTR_V0: + switch (header->version) { + case RTR_V0: + return 0; + case RTR_V1: + goto unexpected; + default: + goto unsupported; + } + + case -1: + switch (header->version) { + case RTR_V0: + case RTR_V1: + stream->rtr_version = header->version; + return 0; + default: + goto unsupported; + } + } + + pr_crit("Unknown RTR version %u", stream->rtr_version); + +unsupported: + return err_pdu_send_unsupported_proto_version( + stream->fd, stream->rtr_version, request, + "The maximum supported RTR version is 1." + ); + +unexpected: + return err_pdu_send_unexpected_proto_version( + stream->fd, stream->rtr_version, request, + "The RTR version does not match the one we negotiated during the handshake." + ); +} + +static int +load_serial_query(struct pdu_stream *stream, struct pdu_header *hdr, + struct rtr_pdu *result) +{ + if (hdr->length != RTRPDU_SERIAL_QUERY_LEN) { + return err_pdu_send_invalid_request( + stream->fd, stream->rtr_version, &result->raw, + "Expected length 12 for Serial Query PDUs." + ); + } + if (get_length(stream) < RTRPDU_SERIAL_QUERY_LEN) + return EAGAIN; + + pr_op_debug("Received a Serial Query from %s.", stream->addr); + + memcpy(&result->obj.sq.header, hdr, sizeof(*hdr)); + stream->start += RTR_HDR_LEN; + result->obj.sq.serial_number = read_uint32(stream->start); + stream->start += 4; + + return 0; +} + +static int +load_reset_query(struct pdu_stream *stream, struct pdu_header *hdr, + struct rtr_pdu *result) +{ + if (hdr->length != RTRPDU_RESET_QUERY_LEN) { + return err_pdu_send_invalid_request( + stream->fd, stream->rtr_version, &result->raw, + "Expected length 8 for Reset Query PDUs." + ); + } + if (get_length(stream) < RTRPDU_RESET_QUERY_LEN) + return EAGAIN; + + pr_op_debug("Received a Reset Query from %s.", stream->addr); + + memcpy(&result->obj.rq.header, hdr, sizeof(*hdr)); + stream->start += RTR_HDR_LEN; + + return 0; +} + +static int +load_error_report(struct pdu_stream *stream, struct pdu_header *hdr, + struct rtr_pdu *result) +{ + struct error_report_pdu *pdu; + int error; + + if (hdr->length > RTRPDU_ERROR_REPORT_MAX_LEN) { + return pr_op_err( + "RTR client %s sent a large Error Report PDU (%u bytes). This looks broken, so I'm dropping the connection.", + stream->addr, hdr->length + ); + } + + pr_op_debug("Received an Error Report from %s.", stream->addr); + + pdu = &result->obj.er; + + /* Header */ + memcpy(&pdu->header, hdr, sizeof(*hdr)); + stream->start += RTR_HDR_LEN; + + /* Error PDU length */ + if (get_length(stream) < 4) { + error = EAGAIN; + goto revert_hdr; + } + pdu->errpdu_len = read_uint32(stream->start); + stream->start += 4; + if (pdu->errpdu_len > RTRPDU_MAX_LEN) { + /* + * We truncate PDUs larger than RTRPDU_MAX_LEN, so we couldn't + * have sent this PDU. Looks like someone is messing with us. + */ + error = pr_op_err( + "RTR client %s sent an Error Report PDU containing a large error PDU (%u bytes). This looks broken/insecure; I'm dropping the connection.", + stream->addr, pdu->errpdu_len + ); + goto revert_errpdu_len; + } + + /* Error PDU */ + if (get_length(stream) < pdu->errpdu_len) { + error = EAGAIN; + goto revert_errpdu_len; + } + + memcpy(pdu->errpdu, stream->start, pdu->errpdu_len); + stream->start += pdu->errpdu_len; + + /* Error msg length */ + if (get_length(stream) < 4) { + error = EAGAIN; + goto revert_errpdu; + } + pdu->errmsg_len = read_uint32(stream->start); + stream->start += 4; + if (hdr->length != rtrpdu_error_report_len(pdu->errpdu_len, pdu->errmsg_len)) { + error = pr_op_err( + "RTR client %s sent a malformed Error Report PDU; header length is %u, but effective length is %u + %u + %u + %u + %u.", + stream->addr, hdr->length, + RTR_HDR_LEN, 4, pdu->errpdu_len, 4, pdu->errmsg_len + ); + goto revert_errmsg_len; + } + + /* Error msg */ + pdu->errmsg = read_string(stream, pdu->errmsg_len); + stream->start += pdu->errmsg_len; + + return 0; + +revert_errmsg_len: + stream->start -= 4; +revert_errpdu: + stream->start -= pdu->errpdu_len; +revert_errpdu_len: + stream->start -= 4; +revert_hdr: + stream->start -= RTR_HDR_LEN; + return error; +} + +/* + * Returns: + * == 0: Success; at least zero PDUs read. + * != 0: Communication broken; close the connection. + */ +int +pdustream_next(struct pdu_stream *stream, struct rtr_request **_result) +{ + enum buffer_state state; + struct pdu_header hdr; + struct rtr_request *result; + struct rtr_pdu *pdu; + size_t remainder; + int error; + + result = pmalloc(sizeof(struct rtr_request)); + result->fd = stream->fd; + strcpy(result->client_addr, stream->addr); + STAILQ_INIT(&result->pdus); + result->eos = false; + + pdu = NULL; + +again: + state = update_buffer(stream); + if (state == BS_ERROR) { + error = EINVAL; + goto fail; + } + + while (stream->start < stream->end) { + remainder = get_length(stream); + + /* Read header. */ + if (remainder < RTR_HDR_LEN) + break; + read_hdr(stream, &hdr); + + /* Init raw PDU; Needed early because of error responses. */ + pdu = pzalloc(sizeof(struct rtr_pdu)); + pdu->raw.bytes_len = (hdr.length <= remainder) + ? hdr.length : remainder; + pdu->raw.bytes = pmalloc(pdu->raw.bytes_len); + memcpy(pdu->raw.bytes, stream->start, pdu->raw.bytes_len); + + /* Validate length; Needs raw. */ + if (hdr.length > RTRPDU_MAX_LEN2) { + error = err_pdu_send_invalid_request( + stream->fd, + (stream->rtr_version != -1) + ? stream->rtr_version + : hdr.version, + &pdu->raw, + "PDU is too large." + ); + goto fail; + } + + if (remainder < hdr.length) { + free(pdu->raw.bytes); + free(pdu); + break; + } + + /* Validate version; Needs raw. */ + error = validate_rtr_version(stream, &hdr, &pdu->raw); + if (error) + goto fail; + + switch (hdr.type) { + case PDU_TYPE_SERIAL_QUERY: + error = load_serial_query(stream, &hdr, pdu); + break; + case PDU_TYPE_RESET_QUERY: + error = load_reset_query(stream, &hdr, pdu); + break; + case PDU_TYPE_ERROR_REPORT: + error = load_error_report(stream, &hdr, pdu); + break; + default: + err_pdu_send_unsupported_pdu_type(stream->fd, + stream->rtr_version, &pdu->raw); + error = ENOTSUP; + } + + if (error) + goto fail; + + STAILQ_INSERT_TAIL(&result->pdus, pdu, hook); + } + + *_result = result; + + switch (state) { + case BS_WOULD_BLOCK: + result->eos = false; + return 0; + case BS_EOS: + result->eos = true; + return 0; + case BS_KEEP_READING: + goto again; + default: + error = EINVAL; + } + +fail: + if (pdu != NULL) { + free(pdu->raw.bytes); + free(pdu); + } + rtreq_destroy(result); + return error; +} + +int +pdustream_fd(struct pdu_stream *stream) +{ + return stream->fd; +} + +char const * +pdustream_addr(struct pdu_stream *stream) +{ + return stream->addr; +} + +int +pdustream_version(struct pdu_stream *stream) +{ + return stream->rtr_version; +} + +void +rtreq_destroy(struct rtr_request *request) +{ + struct rtr_pdu *pdu; + + while (!STAILQ_EMPTY(&request->pdus)) { + pdu = STAILQ_FIRST(&request->pdus); + STAILQ_REMOVE_HEAD(&request->pdus, hook); + + if (pdu->obj.hdr.type == PDU_TYPE_ERROR_REPORT) + free(pdu->obj.er.errmsg); + free(pdu->raw.bytes); + free(pdu); + } +} + diff --git a/src/rtr/pdu_stream.h b/src/rtr/pdu_stream.h new file mode 100644 index 00000000..db66ea23 --- /dev/null +++ b/src/rtr/pdu_stream.h @@ -0,0 +1,53 @@ +#ifndef SRC_RTR_PDU_STREAM_H_ +#define SRC_RTR_PDU_STREAM_H_ + +#include + +#include "rtr/pdu.h" +#include "rtr/rtr.h" +#include "data_structure/array_list.h" + +struct pdu_stream; /* It's an *input* stream. */ + +struct rtr_pdu { + /* Deserialized version */ + union { + struct pdu_header hdr; + struct serial_query_pdu sq; + struct reset_query_pdu rq; + struct error_report_pdu er; + } obj; + + /* + * Serialized version. + * Can be truncated; use for responding errors only. + */ + struct rtr_buffer raw; + + STAILQ_ENTRY(rtr_pdu) hook; +}; + +struct rtr_request { + int fd; + char client_addr[INET6_ADDRSTRLEN]; + + /* + * It's not sensible for a request to contain multiple PDUs, + * but I don't know how much buffering the underlying socket has. + */ + STAILQ_HEAD(, rtr_pdu) pdus; + + bool eos; /* end of stream */ +}; + +struct pdu_stream *pdustream_create(int, char const *); +void pdustream_destroy(struct pdu_stream **); + +int pdustream_next(struct pdu_stream *, struct rtr_request **); +int pdustream_fd(struct pdu_stream *); +char const *pdustream_addr(struct pdu_stream *); +int pdustream_version(struct pdu_stream *); + +void rtreq_destroy(struct rtr_request *); + +#endif /* SRC_RTR_PDU_STREAM_H_ */ diff --git a/src/rtr/primitive_reader.c b/src/rtr/primitive_reader.c deleted file mode 100644 index 2ab92d1f..00000000 --- a/src/rtr/primitive_reader.c +++ /dev/null @@ -1,209 +0,0 @@ -#include "rtr/primitive_reader.h" - -#include - -#include "alloc.h" -#include "log.h" - -static int get_octets(unsigned char); -static void place_null_character(rtr_char *, size_t); - -/** - * BTW: I think it's best not to use sizeof for @size, because it risks - * including padding. - */ -void -pdu_reader_init(struct pdu_reader *reader, unsigned char *buffer, size_t size) -{ - reader->buffer = buffer; - reader->size = size; -} - -static int -insufficient_bytes(void) -{ - pr_op_debug("Attempted to read past the end of a PDU Reader."); - return -EPIPE; -} - -int -read_int8(struct pdu_reader *reader, uint8_t *result) -{ - if (reader->size < 1) - return insufficient_bytes(); - - *result = reader->buffer[0]; - reader->buffer++; - reader->size--; - return 0; -} - -/** Big Endian. */ -int -read_int16(struct pdu_reader *reader, uint16_t *result) -{ - if (reader->size < 2) - return insufficient_bytes(); - - *result = (((uint16_t)reader->buffer[0]) << 8) - | (((uint16_t)reader->buffer[1]) ); - reader->buffer += 2; - reader->size -= 2; - return 0; -} - -/** Big Endian. */ -int -read_int32(struct pdu_reader *reader, uint32_t *result) -{ - if (reader->size < 4) - return insufficient_bytes(); - - *result = (((uint32_t)reader->buffer[0]) << 24) - | (((uint32_t)reader->buffer[1]) << 16) - | (((uint32_t)reader->buffer[2]) << 8) - | (((uint32_t)reader->buffer[3]) ); - reader->buffer += 4; - reader->size -= 4; - return 0; -} - -int -read_in_addr(struct pdu_reader *reader, struct in_addr *result) -{ - return read_int32(reader, &result->s_addr); -} - -int -read_in6_addr(struct pdu_reader *reader, struct in6_addr *result) -{ - unsigned int i; - int error; - - for (i = 0; i < 16; i++) { - error = read_int8(reader, &result->s6_addr[i]); - if (error) - return error; - } - - return 0; -} - -#define EINVALID_UTF8 -0xFFFF - -/* - * Returns the length (in octets) of the UTF-8 code point that starts with - * octet @first_octet. - */ -static int -get_octets(unsigned char first_octet) -{ - if ((first_octet & 0x80) == 0) - return 1; - if ((first_octet >> 5) == 6) /* 0b110 */ - return 2; - if ((first_octet >> 4) == 14) /* 0b1110 */ - return 3; - if ((first_octet >> 3) == 30) /* 0b11110 */ - return 4; - return EINVALID_UTF8; -} - -/* This is just a cast. The barebones version is too cluttered. */ -#define UCHAR(c) ((unsigned char *)c) - -/* - * This also sanitizes the string, BTW. - * (Because it overrides the first invalid character with the null chara. - * The rest is silently ignored.) - */ -static void -place_null_character(rtr_char *str, size_t len) -{ - rtr_char *null_chara_pos; - rtr_char *cursor; - int octet; - int octets; - - /* - * This could be optimized by noticing that all byte continuations in - * UTF-8 start with 0b10. This means that we could start from the end - * of the string and move left until we find a valid character. - * But if we do that, we'd lose the sanitization. So this is better - * methinks. - */ - - null_chara_pos = str; - cursor = str; - - while (cursor < str + len) { - octets = get_octets(*UCHAR(cursor)); - if (octets == EINVALID_UTF8) - break; - cursor++; - - for (octet = 1; octet < octets; octet++) { - /* Memory ends in the middle of this code point? */ - if (cursor >= str + len) - goto end; - /* All continuation octets must begin with 0b10. */ - if ((*(UCHAR(cursor)) >> 6) != 2 /* 0b10 */) - goto end; - cursor++; - } - - null_chara_pos = cursor; - } - -end: - *null_chara_pos = '\0'; -} - -/* - * Reads an RTR string from the file descriptor @fd. Returns the string as a - * normal UTF-8 C string (NULL-terminated). - * - * Will consume the entire string from the stream, but @result can be - * truncated. This is because RTR strings are technically allowed to be 4 GBs - * long. - * - * The result is allocated in the heap. It will length 4096 characters at most. - * (Including the NULL chara.) - */ -int -read_string(struct pdu_reader *reader, uint32_t string_len, rtr_char **result) -{ - /* Actual string length claimed by the PDU, in octets. */ - rtr_char *string; - - if (reader->size < string_len) - return pr_op_err("Erroneous PDU's error message is larger than its slot in the PDU."); - - /* - * Ok. Since the PDU size is already sanitized, string_len is guaranteed - * to be relatively small now. - */ - - string = pmalloc(string_len + 1); /* Include NULL chara. */ - - memcpy(string, reader->buffer, string_len); - reader->buffer += string_len; - reader->size -= string_len; - - place_null_character(string, string_len); - - *result = string; - return 0; -} - -int -read_bytes(struct pdu_reader *reader, unsigned char *result, size_t num) -{ - if (reader->size < num) - return insufficient_bytes(); - - memcpy(result, reader->buffer, num); - reader->buffer += num; - reader->size -= num; - return 0; -} diff --git a/src/rtr/primitive_reader.h b/src/rtr/primitive_reader.h deleted file mode 100644 index 6f5c5948..00000000 --- a/src/rtr/primitive_reader.h +++ /dev/null @@ -1,34 +0,0 @@ -#ifndef RTR_PRIMITIVE_READER_H_ -#define RTR_PRIMITIVE_READER_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "common.h" - -typedef char rtr_char; - -struct pdu_reader { - unsigned char *buffer; - size_t size; -}; - -void pdu_reader_init(struct pdu_reader *, unsigned char *, size_t size); - -int read_int8(struct pdu_reader *, uint8_t *); -int read_int16(struct pdu_reader *, uint16_t *); -int read_int32(struct pdu_reader *, uint32_t *); -int read_in_addr(struct pdu_reader *, struct in_addr *); -int read_in6_addr(struct pdu_reader *, struct in6_addr *); -int read_string(struct pdu_reader *, uint32_t, rtr_char **); -int read_bytes(struct pdu_reader *, unsigned char *, size_t); - -#endif /* RTR_PRIMITIVE_READER_H_ */ diff --git a/src/rtr/primitive_writer.c b/src/rtr/primitive_writer.c index 4f5a8601..a82cc71c 100644 --- a/src/rtr/primitive_writer.c +++ b/src/rtr/primitive_writer.c @@ -1,7 +1,7 @@ #include "rtr/primitive_writer.h" unsigned char * -write_int8(unsigned char *buf, uint8_t value) +write_uint8(unsigned char *buf, uint8_t value) { buf[0] = value; return buf + 1; @@ -9,7 +9,7 @@ write_int8(unsigned char *buf, uint8_t value) /** Big Endian. */ unsigned char * -write_int16(unsigned char *buf, uint16_t value) +write_uint16(unsigned char *buf, uint16_t value) { buf[0] = value >> 8; buf[1] = value; @@ -18,7 +18,7 @@ write_int16(unsigned char *buf, uint16_t value) /** Big Endian. */ unsigned char * -write_int32(unsigned char *buf, uint32_t value) +write_uint32(unsigned char *buf, uint32_t value) { buf[0] = value >> 24; buf[1] = value >> 16; @@ -30,15 +30,14 @@ write_int32(unsigned char *buf, uint32_t value) unsigned char * write_in_addr(unsigned char *buf, struct in_addr value) { - return write_int32(buf, ntohl(value.s_addr)); + return write_uint32(buf, ntohl(value.s_addr)); } unsigned char * -write_in6_addr(unsigned char *buf, struct in6_addr value) +write_in6_addr(unsigned char *buf, struct in6_addr const *value) { int i; for (i = 0; i < 16; i++) - buf = write_int8(buf, value.s6_addr[i]); - + buf = write_uint8(buf, value->s6_addr[i]); return buf; } diff --git a/src/rtr/primitive_writer.h b/src/rtr/primitive_writer.h index 6b8d713b..f1b541ad 100644 --- a/src/rtr/primitive_writer.h +++ b/src/rtr/primitive_writer.h @@ -6,10 +6,10 @@ #include #include -unsigned char *write_int8(unsigned char *, uint8_t); -unsigned char *write_int16(unsigned char *, uint16_t); -unsigned char *write_int32(unsigned char *, uint32_t); +unsigned char *write_uint8(unsigned char *, uint8_t); +unsigned char *write_uint16(unsigned char *, uint16_t); +unsigned char *write_uint32(unsigned char *, uint32_t); unsigned char *write_in_addr(unsigned char *, struct in_addr); -unsigned char *write_in6_addr(unsigned char *, struct in6_addr); +unsigned char *write_in6_addr(unsigned char *, struct in6_addr const *); #endif /* RTR_PRIMITIVE_WRITER_H_ */ diff --git a/src/rtr/rtr.c b/src/rtr/rtr.c index 41130dda..908b9cf0 100644 --- a/src/rtr/rtr.c +++ b/src/rtr/rtr.c @@ -8,14 +8,23 @@ #include "config.h" #include "types/address.h" #include "data_structure/array_list.h" +#include "rtr/err_pdu.h" #include "rtr/pdu.h" +#include "rtr/pdu_handler.h" +#include "rtr/pdu_stream.h" #include "thread/thread_pool.h" +struct rtr_server { + int fd; + /* Printable address to which the server was bound. */ + char *addr; +}; + static pthread_t server_thread; static volatile bool stop_server_thread; STATIC_ARRAY_LIST(server_arraylist, struct rtr_server) -STATIC_ARRAY_LIST(client_arraylist, struct rtr_client) +STATIC_ARRAY_LIST(client_arraylist, struct pdu_stream *) static struct server_arraylist servers; static struct client_arraylist clients; @@ -23,14 +32,6 @@ static pthread_mutex_t lock = PTHREAD_MUTEX_INITIALIZER; static struct thread_pool *request_handlers; -#define REQUEST_BUFFER_LEN 1024 - -struct client_request { - struct rtr_client *client; - unsigned char buffer[REQUEST_BUFFER_LEN]; - size_t nread; -}; - enum poll_verdict { PV_CONTINUE, PV_RETRY, /* Pause for a while, then continue */ @@ -45,20 +46,11 @@ cleanup_server(struct rtr_server *server) free(server->addr); } -static void -cleanup_client(struct rtr_client *client) -{ - if (client->fd != -1) { - shutdown(client->fd, SHUT_RDWR); - close(client->fd); - } -} - static void destroy_db(void) { server_arraylist_cleanup(&servers, cleanup_server); - client_arraylist_cleanup(&clients, cleanup_client); + client_arraylist_cleanup(&clients, pdustream_destroy); } /* @@ -325,19 +317,31 @@ init_server_fds(void) static void handle_client_request(void *arg) { - struct client_request *crequest = arg; - struct pdu_reader reader; - struct rtr_request rrequest; - struct pdu_metadata const *meta; + struct rtr_request *request = arg; + struct rtr_pdu *pdu; - pdu_reader_init(&reader, crequest->buffer, crequest->nread); - - while (pdu_load(&reader, crequest->client, &rrequest, &meta) == 0) { - meta->handle(crequest->client->fd, &rrequest); - meta->destructor(rrequest.pdu); + STAILQ_FOREACH(pdu, &request->pdus, hook) { + switch (pdu->obj.hdr.type) { + case PDU_TYPE_SERIAL_QUERY: + handle_serial_query_pdu(request, pdu); + break; + case PDU_TYPE_RESET_QUERY: + handle_reset_query_pdu(request, pdu); + break; + case PDU_TYPE_ERROR_REPORT: + handle_error_report_pdu(request, pdu); + break; + default: + /* Should have been catched during constructor */ + pr_crit("Unexpected PDU type: %u", pdu->obj.hdr.type); + } } - free(crequest); + if (request->eos) + /* Wake poller to close the socket. Read side already shut. */ + shutdown(request->fd, SHUT_WR); + + rtreq_destroy(request); } static void @@ -406,89 +410,49 @@ accept_new_client(struct pollfd const *server_fd) { struct sockaddr_storage client_addr; socklen_t sizeof_client_addr; - struct rtr_client client; + int fd; + char addr[INET6_ADDRSTRLEN]; + struct pdu_stream *client; enum accept_verdict result; sizeof_client_addr = sizeof(client_addr); - /* Accept the connection */ - client.fd = accept(server_fd->fd, (struct sockaddr *) &client_addr, + fd = accept(server_fd->fd, (struct sockaddr *) &client_addr, &sizeof_client_addr); - result = handle_accept_result(client.fd, errno); + result = handle_accept_result(fd, errno); if (result != AV_SUCCESS) return result; - if (set_nonblock(client.fd) != 0) { - close(client.fd); + if (set_nonblock(fd) != 0) { + close(fd); return AV_CLIENT_ERROR; } - client.rtr_version = -1; - sockaddr2str(&client_addr, client.addr); + sockaddr2str(&client_addr, addr); + client = pdustream_create(fd, addr); + client_arraylist_add(&clients, &client); - pr_op_info("Client accepted [FD: %d]: %s", client.fd, client.addr); + pr_op_info("Client accepted [FD: %d]: %s", fd, addr); return AV_SUCCESS; } -/* - * true: success. - * false: oh noes; close socket. - */ static bool -read_until_block(int fd, struct client_request *request) +__handle_client_request(struct pdu_stream *stream) { - ssize_t read_result; - size_t offset; - int error; + struct rtr_request *request; + bool eos; - request->nread = 0; - - for (offset = 0; offset < REQUEST_BUFFER_LEN; offset += read_result) { - read_result = read(fd, &request->buffer[offset], - REQUEST_BUFFER_LEN - offset); - if (read_result == -1) { - error = errno; - if (error == EAGAIN || error == EWOULDBLOCK) - return true; /* Ok, we have the full packet. */ - - pr_op_err("Client socket read interrupted: %s", - strerror(error)); - return false; - } - - if (read_result == 0) { - if (offset == 0) { - pr_op_debug("Client closed the socket."); - return false; - } - - return true; /* Ok, we have the last packet. */ - } - - request->nread += read_result; - } - - pr_op_warn("Peer's request is too big (>= %u bytes). Peer does not look like an RTR client; closing connection.", - REQUEST_BUFFER_LEN); - return false; -} - -static bool -__handle_client_request(struct rtr_client *client) -{ - struct client_request *request; - - request = pmalloc(sizeof(struct client_request)); + if (pdustream_next(stream, &request) != 0) + return false; - request->client = client; - if (!read_until_block(client->fd, request)) { + if (STAILQ_EMPTY(&request->pdus)) { + eos = request->eos; free(request); - return false; + return !eos; } - pr_op_debug("Client sent %zu bytes.", request->nread); thread_pool_push(request_handlers, "RTR request", handle_client_request, request); return true; @@ -512,7 +476,7 @@ delete_dead_clients(void) unsigned int dst; for (src = 0, dst = 0; src < clients.len; src++) { - if (clients.array[src].fd != -1) { + if (clients.array[src] != NULL) { clients.array[dst] = clients.array[src]; dst++; } @@ -526,7 +490,7 @@ apply_pollfds(struct pollfd *pollfds, unsigned int nclients) { struct pollfd *pfd; struct rtr_server *server; - struct rtr_client *client; + struct pdu_stream *client; unsigned int i; for (i = 0; i < servers.len; i++) { @@ -544,14 +508,14 @@ apply_pollfds(struct pollfd *pollfds, unsigned int nclients) for (i = 0; i < nclients; i++) { pfd = &pollfds[servers.len + i]; - client = &clients.array[i]; + client = clients.array[i]; /* PR_DEBUG_MSG("pfd:%d client:%d", pfd->fd, client->fd); */ - if ((pfd->fd == -1) && (client->fd != -1)) { - close(client->fd); - client->fd = -1; - print_poll_failure(pfd, "Client", client->addr); + if ((pfd->fd == -1) && (pdustream_fd(client) != -1)) { + pdustream_destroy(&client); + clients.array[i] = NULL; + print_poll_failure(pfd, "Client", pdustream_addr(client)); } } @@ -572,7 +536,7 @@ fddb_poll(void) ARRAYLIST_FOREACH_IDX(&servers, i) init_pollfd(&pollfds[i], servers.array[i].fd); ARRAYLIST_FOREACH_IDX(&clients, i) - init_pollfd(&pollfds[servers.len + i], clients.array[i].fd); + init_pollfd(&pollfds[servers.len + i], pdustream_fd(clients.array[i])); error = poll(pollfds, servers.len + clients.len, 1000); @@ -635,13 +599,13 @@ fddb_poll(void) /* PR_DEBUG_MSG("Client %u: fd:%d revents:%x", i, fd->fd, fd->revents); */ - if (fd->fd == -1) - continue; +// if (fd->fd == -1) +// continue; if (fd->revents & (POLLHUP | POLLERR | POLLNVAL)) { fd->fd = -1; } else if (fd->revents & POLLIN) { - if (!__handle_client_request(&clients.array[i])) + if (!__handle_client_request(clients.array[i])) fd->fd = -1; } } @@ -728,14 +692,16 @@ void rtr_stop(void) int rtr_foreach_client(rtr_foreach_client_cb cb, void *arg) { - struct rtr_client *client; + struct pdu_stream **client; + int fd; int error = 0; mutex_lock(&lock); ARRAYLIST_FOREACH(&clients, client) { - if (client->fd != -1) { - error = cb(client, arg); + fd = pdustream_fd(*client); + if (fd != -1) { + error = cb(fd, pdustream_version(*client), arg); if (error) break; } diff --git a/src/rtr/rtr.h b/src/rtr/rtr.h index 8575bb8d..6e38187b 100644 --- a/src/rtr/rtr.h +++ b/src/rtr/rtr.h @@ -1,25 +1,10 @@ #ifndef RTR_RTR_H_ #define RTR_RTR_H_ -#include -#include - -struct rtr_server { - int fd; - /* Printable address to which the server was bound. */ - char *addr; -}; - -struct rtr_client { - int fd; - char addr[INET6_ADDRSTRLEN]; /* Printable address of the client. */ - int rtr_version; /* -1: unset; > 0: version number */ -}; - int rtr_start(void); void rtr_stop(void); -typedef int (*rtr_foreach_client_cb)(struct rtr_client const *, void *arg); +typedef int (*rtr_foreach_client_cb)(int, int, void *); int rtr_foreach_client(rtr_foreach_client_cb, void *); #endif /* RTR_RTR_H_ */ diff --git a/test/Makefile.am b/test/Makefile.am index cffea9d2..40cb1e8c 100644 --- a/test/Makefile.am +++ b/test/Makefile.am @@ -28,9 +28,8 @@ check_PROGRAMS += db_table.test check_PROGRAMS += deltas_array.test check_PROGRAMS += line_file.test check_PROGRAMS += pb.test -check_PROGRAMS += pdu.test check_PROGRAMS += pdu_handler.test -check_PROGRAMS += primitive_reader.test +check_PROGRAMS += pdu_stream.test check_PROGRAMS += rrdp_objects.test check_PROGRAMS += serial.test check_PROGRAMS += tal.test @@ -60,14 +59,11 @@ line_file_test_LDADD = ${MY_LDADD} pb_test_SOURCES = data_structure/path_builder_test.c pb_test_LDADD = ${MY_LDADD} -pdu_test_SOURCES = rtr/pdu_test.c -pdu_test_LDADD = ${MY_LDADD} - pdu_handler_test_SOURCES = rtr/pdu_handler_test.c pdu_handler_test_LDADD = ${MY_LDADD} ${JANSSON_LIBS} -primitive_reader_test_SOURCES = rtr/primitive_reader_test.c -primitive_reader_test_LDADD = ${MY_LDADD} +pdu_stream_test_SOURCES = rtr/pdu_stream_test.c +pdu_stream_test_LDADD = ${MY_LDADD} ${JANSSON_LIBS} rrdp_objects_test_SOURCES = rrdp_objects_test.c rrdp_objects_test_LDADD = ${MY_LDADD} ${JANSSON_LIBS} ${XML2_LIBS} diff --git a/test/rtr/db/rtr_db_mock.c b/test/rtr/db/rtr_db_mock.c index b621bf1d..825cf073 100644 --- a/test/rtr/db/rtr_db_mock.c +++ b/test/rtr/db/rtr_db_mock.c @@ -68,15 +68,15 @@ __handle_router_key(unsigned char const *ski, struct asn_range const *range, unsigned char const *spk, void *arg) { uint64_t as; - int error = 0; + int error; for (as = range->min; as <= range->max; as++) { error = rtrhandler_handle_router_key(arg, ski, as, spk); if (error) - break; + return error; } - return error; + return 0; } int diff --git a/test/rtr/pdu_handler_test.c b/test/rtr/pdu_handler_test.c index 2215192a..48b9c921 100644 --- a/test/rtr/pdu_handler_test.c +++ b/test/rtr/pdu_handler_test.c @@ -11,7 +11,6 @@ #include "types/serial.c" #include "types/vrp.c" #include "rtr/pdu_handler.c" -#include "rtr/primitive_writer.c" #include "rtr/err_pdu.c" #include "rtr/db/delta.c" #include "rtr/db/deltas_array.c" @@ -22,8 +21,22 @@ /* Mocks */ -MOCK_ABORT_INT(read_int32, struct pdu_reader *reader, uint32_t *result) -MOCK_ABORT_INT(read_int8, struct pdu_reader *reader, uint8_t *result) +struct rtr_buffer const * +pdustream_last_pdu_raw(struct pdu_stream *s) +{ + static unsigned char bytes[] = { + /* header */ + 1, 1, 7, 8, 0, 0, 0, 12, + /* serial number */ + 14, 15, 16, 17, + }; + static struct rtr_buffer buf = { + .bytes = bytes, + .bytes_len = sizeof(bytes), + }; + + return &buf; +} MOCK_INT(cache_prepare, 0, void) @@ -92,33 +105,48 @@ init_db_full(void) } static void -init_reset_query(struct rtr_request *request, struct reset_query_pdu *query) +init_request(struct rtr_request *request, struct rtr_pdu *pdu) +{ + request->fd = 0; + strcpy(request->client_addr, "192.0.2.1"); + STAILQ_INIT(&request->pdus); + if (pdu != NULL) + STAILQ_INSERT_TAIL(&request->pdus, pdu, hook); + request->eos = false; +} + +static void +init_reset_query(struct rtr_pdu *pdu) { - request->pdu = query; - request->bytes_len = 0; - query->header.protocol_version = RTR_V1; - query->header.pdu_type = PDU_TYPE_RESET_QUERY; - query->header.m.reserved = 0; - query->header.length = 8; + static unsigned char raw[] = { 1, 2, 0, 0, 0, 0, 0, 8 }; + + pdu->obj.rq.header.version = RTR_V1; + pdu->obj.rq.header.type = PDU_TYPE_RESET_QUERY; + pdu->obj.rq.header.m.reserved = 0; + pdu->obj.rq.header.length = 8; + pdu->raw.bytes = raw; + pdu->raw.bytes_len = sizeof(raw); + memset(&pdu->hook, 0, sizeof(pdu->hook)); } static void -init_serial_query(struct rtr_request *request, struct serial_query_pdu *query, - uint32_t serial) +init_serial_query(struct rtr_pdu *pdu, uint32_t serial) { - request->pdu = query; - request->bytes_len = 0; - query->header.protocol_version = RTR_V1; - query->header.pdu_type = PDU_TYPE_SERIAL_QUERY; - query->header.m.session_id = get_current_session_id(RTR_V1); - query->header.length = 12; - query->serial_number = serial; + static unsigned char raw[] = { 1, 1, 0, 0, 0, 0, 0, 12, 0, 0, 0, 0 }; + + pdu->obj.sq.header.version = RTR_V1; + pdu->obj.sq.header.type = PDU_TYPE_SERIAL_QUERY; + pdu->obj.sq.header.m.session_id = get_current_session_id(RTR_V1); + pdu->obj.sq.header.length = 12; + pdu->obj.sq.serial_number = serial; + pdu->raw.bytes = raw; + pdu->raw.bytes_len = sizeof(raw); + memset(&pdu->hook, 0, sizeof(pdu->hook)); } /* Mocks */ MOCK_UINT(config_get_deltas_lifetime, 5, void) -MOCK_INT(clients_set_rtr_version, 0, int f, uint8_t v) int clients_get_rtr_version_set(int fd, bool *is_set, uint8_t *rtr_version) @@ -212,7 +240,7 @@ send_end_of_data_pdu(int fd, uint8_t version, serial_t end_serial) int send_error_report_pdu(int fd, uint8_t version, uint16_t code, - struct rtr_request const *request, char *message) + struct rtr_buffer const *request, char *message) { pr_op_info(" Server sent Error Report %u: '%s'", code, message); ck_assert_int_eq(pop_expected_pdu(), PDU_TYPE_ERROR_REPORT); @@ -225,15 +253,14 @@ send_error_report_pdu(int fd, uint8_t version, uint16_t code, START_TEST(test_start_or_restart) { struct rtr_request request; - struct reset_query_pdu client_pdu; + struct rtr_pdu client_pdu; pr_op_info("-- Start or Restart --"); - /* Prepare DB */ + /* Init */ init_db_full(); - - /* Init client request */ - init_reset_query(&request, &client_pdu); + init_reset_query(&client_pdu); + init_request(&request, &client_pdu); /* Define expected server response */ expected_pdu_add(PDU_TYPE_CACHE_RESPONSE); @@ -243,7 +270,7 @@ START_TEST(test_start_or_restart) expected_pdu_add(PDU_TYPE_END_OF_DATA); /* Run and validate */ - ck_assert_int_eq(0, handle_reset_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_reset_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* Clean up */ @@ -255,26 +282,25 @@ END_TEST START_TEST(test_typical_exchange) { struct rtr_request request; - struct serial_query_pdu client_pdu; + struct rtr_pdu client_pdu; pr_op_info("-- Typical Exchange --"); - /* Prepare DB */ + /* Init */ init_db_full(); - - /* From serial 0: Init client request */ - init_serial_query(&request, &client_pdu, 0); + init_serial_query(&client_pdu, 0); + init_request(&request, &client_pdu); /* From serial 0: Define expected server response */ /* Server doesn't have serial 0. */ expected_pdu_add(PDU_TYPE_CACHE_RESET); /* From serial 0: Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* From serial 1: Init client request */ - init_serial_query(&request, &client_pdu, 1); + init_serial_query(&client_pdu, 1); /* From serial 1: Define expected server response */ expected_pdu_add(PDU_TYPE_CACHE_RESPONSE); @@ -287,11 +313,11 @@ START_TEST(test_typical_exchange) expected_pdu_add(PDU_TYPE_END_OF_DATA); /* From serial 1: Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* From serial 2: Init client request */ - init_serial_query(&request, &client_pdu, 2); + init_serial_query(&client_pdu, 2); /* From serial 2: Define expected server response */ expected_pdu_add(PDU_TYPE_CACHE_RESPONSE); @@ -301,18 +327,18 @@ START_TEST(test_typical_exchange) expected_pdu_add(PDU_TYPE_END_OF_DATA); /* From serial 2: Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* From serial 3: Init client request */ - init_serial_query(&request, &client_pdu, 3); + init_serial_query(&client_pdu, 3); /* From serial 3: Define expected server response */ expected_pdu_add(PDU_TYPE_CACHE_RESPONSE); expected_pdu_add(PDU_TYPE_END_OF_DATA); /* From serial 3: Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* Clean up */ @@ -324,21 +350,20 @@ END_TEST START_TEST(test_no_incremental_update_available) { struct rtr_request request; - struct serial_query_pdu serial_query; + struct rtr_pdu client_pdu; pr_op_info("-- No Incremental Update Available --"); - /* Prepare DB */ + /* Init */ init_db_full(); - - /* Init client request */ - init_serial_query(&request, &serial_query, 10000); + init_serial_query(&client_pdu, 10000); + init_request(&request, &client_pdu); /* Define expected server response */ expected_pdu_add(PDU_TYPE_CACHE_RESET); /* Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* The Reset Query is already tested in start_or_restart. */ @@ -352,33 +377,38 @@ END_TEST START_TEST(test_cache_has_no_data_available) { struct rtr_request request; - struct serial_query_pdu serial_query; - struct reset_query_pdu reset_query; + struct rtr_pdu serial_query; + struct rtr_pdu reset_query; pr_op_info("-- Cache Has No Data Available --"); - /* Prepare DB */ + /* Init */ ck_assert_int_eq(0, vrps_init()); + init_request(&request, NULL); /* Serial Query: Init client request */ - init_serial_query(&request, &serial_query, 0); + init_serial_query(&serial_query, 0); + STAILQ_INSERT_TAIL(&request.pdus, &serial_query, hook); /* Serial Query: Define expected server response */ expected_pdu_add(PDU_TYPE_ERROR_REPORT); /* Serial Query: Run and validate */ - ck_assert_int_eq(0, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_serial_query_pdu(&request, &serial_query)); ck_assert_uint_eq(false, has_expected_pdus()); + STAILQ_REMOVE_HEAD(&request.pdus, hook); /* Reset Query: Init client request */ - init_reset_query(&request, &reset_query); + init_reset_query(&reset_query); + STAILQ_INSERT_TAIL(&request.pdus, &reset_query, hook); /* Reset Query: Define expected server response */ expected_pdu_add(PDU_TYPE_ERROR_REPORT); /* Reset Query: Run and validate */ - ck_assert_int_eq(0, handle_reset_query_pdu(0, &request)); + ck_assert_int_eq(0, handle_reset_query_pdu(&request, &reset_query)); ck_assert_uint_eq(false, has_expected_pdus()); + STAILQ_REMOVE_HEAD(&request.pdus, hook); /* Clean up */ vrps_destroy(); @@ -388,22 +418,21 @@ END_TEST START_TEST(test_bad_session_id) { struct rtr_request request; - struct serial_query_pdu client_pdu; + struct rtr_pdu client_pdu; pr_op_info("-- Bad Session ID --"); - /* Prepare DB */ + /* Init */ init_db_full(); - - /* From serial 0: Init client request */ - init_serial_query(&request, &client_pdu, 0); - client_pdu.header.m.session_id++; + init_serial_query(&client_pdu, 0); + client_pdu.obj.sq.header.m.session_id++; + init_request(&request, &client_pdu); /* From serial 0: Define expected server response */ expected_pdu_add(PDU_TYPE_ERROR_REPORT); /* From serial 0: Run and validate */ - ck_assert_int_eq(-EINVAL, handle_serial_query_pdu(0, &request)); + ck_assert_int_eq(-EINVAL, handle_serial_query_pdu(&request, &client_pdu)); ck_assert_uint_eq(false, has_expected_pdus()); /* Clean up */ @@ -411,21 +440,6 @@ START_TEST(test_bad_session_id) } END_TEST -size_t -serialize_serial_query_pdu(struct serial_query_pdu *pdu, unsigned char *buf) -{ - unsigned char *ptr; - - ptr = buf; - ptr = write_int8(ptr, pdu->header.protocol_version); - ptr = write_int8(ptr, pdu->header.pdu_type); - ptr = write_int16(ptr, pdu->header.m.session_id); - ptr = write_int32(ptr, pdu->header.length); - ptr = write_int32(ptr, pdu->serial_number); - - return ptr - buf; -} - Suite *pdu_suite(void) { Suite *suite; diff --git a/test/rtr/pdu_stream_test.c b/test/rtr/pdu_stream_test.c new file mode 100644 index 00000000..a3b52713 --- /dev/null +++ b/test/rtr/pdu_stream_test.c @@ -0,0 +1,521 @@ +#include +#include +#include + +#include "alloc.c" +#include "mock.c" +#include "rtr/pdu_stream.c" + +/* Mocks */ + +MOCK_ABORT_INT(err_pdu_send_invalid_request, int fd, uint8_t version, + struct rtr_buffer const *request, char const *msg) +MOCK_ABORT_INT(err_pdu_send_unsupported_proto_version, int fd, uint8_t version, + struct rtr_buffer const *request, char const *msg) +MOCK_ABORT_INT(err_pdu_send_unsupported_pdu_type, int fd, uint8_t version, + struct rtr_buffer const *request) +MOCK_ABORT_INT(err_pdu_send_unexpected_proto_version, int fd, uint8_t version, + struct rtr_buffer const *request, char const *msg) + +/* End of mocks */ + +static void +setup_pipes(int *pipes) +{ + int fl; + + ck_assert_int_eq(0, pipe(pipes)); + fl = fcntl(pipes[0], F_GETFL); + ck_assert_int_ne(-1, fl); + ck_assert_int_eq(0, fcntl(pipes[0], F_SETFL, fl | O_NONBLOCK)); +} + +static struct pdu_stream * +create_stream(unsigned char const *buf, size_t bufsize) +{ + struct pdu_stream *result = pdustream_create(-1, "192.0.2.1"); + memcpy(result->buffer, buf, bufsize); + result->end = result->buffer + bufsize; + return result; +} + +static struct pdu_stream * +create_stream_fd(unsigned char *data, size_t datalen, int rtr_version) +{ + struct pdu_stream *result; + int pipes[2]; + + setup_pipes(pipes); + ck_assert_int_eq(datalen, write(pipes[1], data, datalen)); + close(pipes[1]); + + result = pdustream_create(pipes[0], "192.0.2.1"); + result->rtr_version = rtr_version; + return result; +} + +static void +assert_pdu_count(unsigned int expected, struct rtr_request *request) +{ + struct rtr_pdu *pdu; + unsigned int npdu; + + npdu = 0; + STAILQ_FOREACH(pdu, &request->pdus, hook) + npdu++; + ck_assert_uint_eq(expected, npdu); +} + +START_TEST(test_pdu_header_from_stream) +{ + unsigned char input[] = { 0, 1, 2, 3, 4, 5, 6, 7 }; + struct pdu_stream *stream; + struct pdu_header hdr; + + stream = create_stream(input, sizeof(input)); + + read_hdr(stream, &hdr); + ck_assert_uint_eq(hdr.version, 0); + ck_assert_uint_eq(hdr.type, 1); + ck_assert_uint_eq(hdr.m.reserved, 0x0203); + ck_assert_uint_eq(hdr.length, 0x04050607); + + free(stream); +} +END_TEST + +START_TEST(test_serial_query_from_stream) +{ + unsigned char input[] = { + /* header */ + 1, 1, 7, 8, 0, 0, 0, 12, + /* serial number */ + 14, 15, 16, 17, + }; + struct pdu_stream *stream; + struct rtr_request *request; + struct rtr_pdu *pdu; + struct serial_query_pdu *sq; + + stream = create_stream_fd(input, sizeof(input), RTR_V1); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(1, request->eos); + assert_pdu_count(1, request); + + pdu = STAILQ_FIRST(&request->pdus); + sq = &pdu->obj.sq; + + ck_assert_uint_eq(sq->header.version, RTR_V1); + ck_assert_uint_eq(sq->header.type, PDU_TYPE_SERIAL_QUERY); + ck_assert_uint_eq(sq->header.m.reserved, 0x0708); + ck_assert_uint_eq(sq->header.length, 12); + ck_assert_uint_eq(sq->serial_number, 0x0e0f1011); + + rtreq_destroy(request); + pdustream_destroy(&stream); +} +END_TEST + +START_TEST(test_reset_query_from_stream) +{ + unsigned char input[] = { + /* Header */ 0, 2, 12, 13, 0, 0, 0, 8, + /* Garbage */ 18, 19, + }; + struct pdu_stream *stream; + struct rtr_request *request; + struct rtr_pdu *pdu; + struct reset_query_pdu *rq; + + stream = create_stream_fd(input, sizeof(input), RTR_V0); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(1, request->eos); + assert_pdu_count(1, request); + + pdu = STAILQ_FIRST(&request->pdus); + rq = &pdu->obj.rq; + + ck_assert_uint_eq(rq->header.version, RTR_V0); + ck_assert_uint_eq(rq->header.type, PDU_TYPE_RESET_QUERY); + ck_assert_uint_eq(rq->header.m.reserved, 0x0c0d); + ck_assert_uint_eq(rq->header.length, 8); + + ck_assert_uint_eq(8, pdu->raw.bytes_len); + ck_assert(memcmp(input, pdu->raw.bytes, 8) == 0); + + rtreq_destroy(request); + pdustream_destroy(&stream); +} +END_TEST + +START_TEST(test_error_report_from_stream) +{ + unsigned char input[] = { + /* header */ + 1, 10, 22, 23, 0, 0, 0, 33, + /* Sub-pdu length */ + 0, 0, 0, 12, + /* Sub-pdu with header*/ + 1, 0, 2, 3, 0, 0, 0, 12, 1, 2, 3, 4, + /* Error msg length */ + 0, 0, 0, 5, + /* Error msg */ + 'h', 'e', 'l', 'l', 'o', + /* Garbage */ + 1, 2, 3, 4, + }; + struct pdu_stream *stream; + struct rtr_request *request; + struct rtr_pdu *pdu; + struct error_report_pdu *er; + + stream = create_stream_fd(input, sizeof(input), RTR_V1); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(1, request->eos); + assert_pdu_count(1, request); + + pdu = STAILQ_FIRST(&request->pdus); + er = &pdu->obj.er; + + ck_assert_uint_eq(er->header.version, RTR_V1); + ck_assert_uint_eq(er->header.type, PDU_TYPE_ERROR_REPORT); + ck_assert_uint_eq(er->header.m.reserved, 0x1617); + ck_assert_uint_eq(er->header.length, 33); + ck_assert_uint_eq(er->errpdu_len, 12); + ck_assert_uint_eq(er->errpdu[0], 1); + ck_assert_uint_eq(er->errpdu[1], 0); + ck_assert_uint_eq(er->errpdu[2], 2); + ck_assert_uint_eq(er->errpdu[3], 3); + ck_assert_uint_eq(er->errpdu[4], 0); + ck_assert_uint_eq(er->errpdu[5], 0); + ck_assert_uint_eq(er->errpdu[6], 0); + ck_assert_uint_eq(er->errpdu[7], 12); + ck_assert_uint_eq(er->errpdu[8], 1); + ck_assert_uint_eq(er->errpdu[9], 2); + ck_assert_uint_eq(er->errpdu[10], 3); + ck_assert_uint_eq(er->errpdu[11], 4); + ck_assert_uint_eq(er->errmsg_len, 5); + ck_assert_str_eq(er->errmsg, "hello"); + + ck_assert_uint_eq(33, pdu->raw.bytes_len); + ck_assert(memcmp(input, pdu->raw.bytes, 33) == 0); + + rtreq_destroy(request); + pdustream_destroy(&stream); +} +END_TEST + +#define ASSERT_RQ(_rq, _version, _type, _reserved, _length) \ + ck_assert_uint_eq(_rq.header.version, _version); \ + ck_assert_uint_eq(_rq.header.type, _type); \ + ck_assert_uint_eq(_rq.header.m.reserved, _reserved); \ + ck_assert_uint_eq(_rq.header.length, _length); + +#define ASSERT_SQ(_sq, _version, _type, _reserved, _length, _serial) \ + ck_assert_uint_eq(_sq.header.version, _version); \ + ck_assert_uint_eq(_sq.header.type, _type); \ + ck_assert_uint_eq(_sq.header.m.reserved, _reserved); \ + ck_assert_uint_eq(_sq.header.length, _length); \ + ck_assert_uint_eq(_sq.serial_number, _serial); + +START_TEST(test_multiple_pdus) +{ + unsigned char input1[] = { + /* reset query */ 1, 2, 0, 0, 0, 0, 0, 8, + /* serial query */ 1, 1, 0, 0, 0, 0, 0, 12, 1, 2, 3, 4, + /* reset query */ 1, 2, 3, 4, 0, 0, 0, 8, + /* reset query start */ 1, 2, 3, 4, + }; + unsigned char input2[] = { + /* reset query end */ 0, 0, 0, 8, + /* reset query */ 1, 2, 6, 7, 0, 0, 0, 8, + }; + struct pdu_stream *stream; + struct rtr_request *request; + struct rtr_pdu *pdu; + int pipes[2]; + + setup_pipes(pipes); + + stream = pdustream_create(pipes[0], "192.0.2.1"); + + /* Input 1 */ + + ck_assert_int_eq(32, write(pipes[1], input1, sizeof(input1))); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(0, request->eos); + assert_pdu_count(3, request); + + pdu = STAILQ_FIRST(&request->pdus); + ASSERT_RQ(pdu->obj.rq, RTR_V1, PDU_TYPE_RESET_QUERY, 0, 8); + ck_assert_uint_eq(8, pdu->raw.bytes_len); + ck_assert(memcmp(input1 + 0, pdu->raw.bytes, 8) == 0); + + pdu = STAILQ_NEXT(pdu, hook); + ASSERT_SQ(pdu->obj.sq, RTR_V1, PDU_TYPE_SERIAL_QUERY, 0, 12, 0x1020304); + ck_assert_uint_eq(12, pdu->raw.bytes_len); + ck_assert(memcmp(input1 + 8, pdu->raw.bytes, 12) == 0); + + pdu = STAILQ_NEXT(pdu, hook); + ASSERT_RQ(pdu->obj.rq, RTR_V1, PDU_TYPE_RESET_QUERY, 0x304, 8); + ck_assert_uint_eq(8, pdu->raw.bytes_len); + ck_assert(memcmp(input1 + 20, pdu->raw.bytes, 8) == 0); + + rtreq_destroy(request); + + /* Input 2 */ + + ck_assert_int_eq(12, write(pipes[1], input2, sizeof(input2))); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(0, request->eos); + assert_pdu_count(2, request); + + pdu = STAILQ_FIRST(&request->pdus); + ASSERT_RQ(pdu->obj.rq, RTR_V1, PDU_TYPE_RESET_QUERY, 0x304, 8); + ck_assert_uint_eq(8, pdu->raw.bytes_len); + ck_assert(memcmp(input1 + 28, &pdu->raw.bytes[0], 4) == 0); + ck_assert(memcmp(input2 + 0, &pdu->raw.bytes[4], 4) == 0); + + pdu = STAILQ_NEXT(pdu, hook); + ASSERT_RQ(pdu->obj.rq, RTR_V1, PDU_TYPE_RESET_QUERY, 0x607, 8); + ck_assert_uint_eq(8, pdu->raw.bytes_len); + ck_assert(memcmp(input2 + 4, pdu->raw.bytes, 8) == 0); + + rtreq_destroy(request); + + /* Input 3 */ + + close(pipes[1]); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(1, request->eos); + assert_pdu_count(0, request); + + rtreq_destroy(request); + + /* Clean up */ + + pdustream_destroy(&stream); +} +END_TEST + +START_TEST(test_interrupted) +{ + unsigned char input[] = { 0, 1 }; + struct pdu_stream *stream; + struct rtr_request *request; + + stream = create_stream_fd(input, sizeof(input), RTR_V1); + ck_assert_int_eq(0, pdustream_next(stream, &request)); + + ck_assert_int_eq(stream->fd, request->fd); + ck_assert_str_eq(stream->addr, request->client_addr); + ck_assert_uint_eq(1, request->eos); + assert_pdu_count(0, request); + + rtreq_destroy(request); + pdustream_destroy(&stream); +} +END_TEST + +static void +test_read_string_success(unsigned char *input, size_t length, char *expected) +{ + struct pdu_stream *stream; + char *actual; + + stream = create_stream(input, length); + + actual = read_string(stream, length); + ck_assert_pstr_eq(expected, actual); + + free(actual); + free(stream); +} + +START_TEST(read_string_ascii) +{ + unsigned char input[] = { 'a', 'b', 'c', 'd' }; + test_read_string_success(input, sizeof(input), "abcd"); +} +END_TEST + +START_TEST(read_string_unicode) +{ + unsigned char input0[] = { 's', 'a', 'n', 'd', 0xc3, 0xad, 'a' }; + test_read_string_success(input0, sizeof(input0), "sandía"); + + unsigned char input1[] = { 0xe1, 0x88, 0x90, 0xe1, 0x89, 0xa5, 0xe1, + 0x88, 0x90, 0xe1, 0x89, 0xa5 }; + test_read_string_success(input1, sizeof(input1), "ሐብሐብ"); + + unsigned char input2[] = { 0xd8, 0xa7, 0xd9, 0x84, 0xd8, 0xa8, 0xd8, + 0xb7, 0xd9, 0x8a, 0xd8, 0xae }; + test_read_string_success(input2, sizeof(input2), "البطيخ"); + + unsigned char input3[] = { + 0xd5, 0xb1, 0xd5, 0xb4, 0xd5, 0xa5, 0xd6, 0x80, 0xd5, 0xb8, 0xd6, + 0x82, 0xd5, 0xaf, 0x20, 0xd0, 0xba, 0xd0, 0xb0, 0xd0, 0xb2, 0xd1, + 0x83, 0xd0, 0xbd }; + test_read_string_success(input3, sizeof(input3), "ձմերուկ кавун"); + + unsigned char input4[] = { + 0xe0, 0xa6, 0xa4, 0xe0, 0xa6, 0xb0, 0xe0, 0xa6, 0xae, 0xe0, 0xa7, + 0x81, 0xe0, 0xa6, 0x9c, 0x20, 0xd0, 0xb4, 0xd0, 0xb8, 0xd0, 0xbd, + 0xd1, 0x8f, 0x20, 0xe8, 0xa5, 0xbf, 0xe7, 0x93, 0x9c, 0x20, 0xf0, + 0x9f, 0x8d, 0x89 }; + test_read_string_success(input4, sizeof(input4), "তরমুজ диня 西瓜 🍉"); +} +END_TEST + +START_TEST(read_string_empty) +{ + unsigned char input[] = { 0, 0, 0, 0 }; + test_read_string_success(input, sizeof(input), ""); +} +END_TEST + +struct thread_param { + int fd; + uint32_t msg_size; + int err; +}; + +/* + * Sends @full_string_length characters to the fd, validates the parsed string + * contains the first @return_length characters. + */ +START_TEST(read_string_max) +{ + static char const *STR = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 52 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 104 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 156 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 208 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 260 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 312 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 364 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 416 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 468 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 520 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 572 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 624 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 676 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 728 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 780 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 832 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 884 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 936 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" /* 988 */ + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJ"; /* 1024 */ + struct pdu_stream *stream; + char *result_string; + + stream = create_stream((unsigned char *)STR, RTRPDU_MAX_LEN2); + + result_string = read_string(stream, RTRPDU_MAX_LEN2); + ck_assert_int_eq(0, strcmp(STR, result_string)); + + free(result_string); + free(stream); +} +END_TEST + +START_TEST(read_string_null) +{ + test_read_string_success(NULL, 0, NULL); +} +END_TEST + +START_TEST(read_string_unicode_mix) +{ + /* One octet failure */ + unsigned char input0[] = { 'a', 0x80, 'z' }; + test_read_string_success(input0, sizeof(input0), "a"); + + /* Two octets success */ + unsigned char input1[] = { 'a', 0xdf, 0x9a, 'z' }; + test_read_string_success(input1, sizeof(input1), "aߚz"); + /* Two octets failure */ + unsigned char input2[] = { 'a', 0xdf, 0xda, 'z' }; + test_read_string_success(input2, sizeof(input2), "a"); + + /* Three characters success */ + unsigned char input3[] = { 'a', 0xe2, 0x82, 0xac, 'z' }; + test_read_string_success(input3, sizeof(input3), "a€z"); + /* Three characters failure */ + unsigned char input4[] = { 'a', 0xe2, 0x82, 0x2c, 'z' }; + test_read_string_success(input4, sizeof(input4), "a"); + + /* Four characters success */ + unsigned char i5[] = { 'a', 0xf0, 0x90, 0x86, 0x97, 'z' }; + test_read_string_success(i5, sizeof(i5), "a𐆗z"); + /* Four characters failure */ + unsigned char i6[] = { 'a', 0xf0, 0x90, 0x90, 0x17, 'z' }; + test_read_string_success(i6, sizeof(i6), "a"); +} +END_TEST + +Suite *pdu_suite(void) +{ + Suite *suite; + TCase *core, *errors, *string; + + core = tcase_create("Core"); + tcase_add_test(core, test_pdu_header_from_stream); + tcase_add_test(core, test_serial_query_from_stream); + tcase_add_test(core, test_reset_query_from_stream); + tcase_add_test(core, test_error_report_from_stream); + tcase_add_test(core, test_multiple_pdus); + + errors = tcase_create("Errors"); + tcase_add_test(errors, test_interrupted); + /* FIXME (RTR) test more errors */ + + string = tcase_create("String"); + tcase_add_test(string, read_string_ascii); + tcase_add_test(string, read_string_unicode); + tcase_add_test(string, read_string_empty); + tcase_add_test(string, read_string_max); + tcase_add_test(string, read_string_null); + tcase_add_test(string, read_string_unicode_mix); + + suite = suite_create("PDU stream"); + suite_add_tcase(suite, core); + suite_add_tcase(suite, errors); + suite_add_tcase(suite, string); + return suite; +} + +int main(void) +{ + Suite *suite; + SRunner *runner; + int tests_failed; + + suite = pdu_suite(); + + runner = srunner_create(suite); + srunner_run_all(runner, CK_NORMAL); + tests_failed = srunner_ntests_failed(runner); + srunner_free(runner); + + return (tests_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/rtr/pdu_test.c b/test/rtr/pdu_test.c deleted file mode 100644 index fd622643..00000000 --- a/test/rtr/pdu_test.c +++ /dev/null @@ -1,292 +0,0 @@ -#include -#include -#include - -#include "alloc.c" -#include "common.c" -#include "mock.c" -#include "rtr/err_pdu.c" -#include "rtr/pdu.c" -#include "rtr/primitive_reader.c" -#include "rtr/db/rtr_db_mock.c" - -/* Mocks */ - -MOCK(get_current_session_id, uint16_t, 12345, uint8_t v) -MOCK_INT(clients_set_rtr_version, 0, int f, uint8_t v) - -int -clients_get_rtr_version_set(int fd, bool *is_set, uint8_t *rtr_version) -{ - (*is_set) = true; - (*rtr_version) = RTR_V0; - return 0; -} - -#define MOCK_HANDLER(n) MOCK(n, int, 0, int f, struct rtr_request const *r) - -MOCK_HANDLER(handle_serial_notify_pdu) -MOCK_HANDLER(handle_serial_query_pdu) -MOCK_HANDLER(handle_reset_query_pdu) -MOCK_HANDLER(handle_cache_response_pdu) -MOCK_HANDLER(handle_ipv4_prefix_pdu) -MOCK_HANDLER(handle_ipv6_prefix_pdu) -MOCK_HANDLER(handle_end_of_data_pdu) -MOCK_HANDLER(handle_cache_reset_pdu) -MOCK_HANDLER(handle_router_key_pdu) -MOCK_HANDLER(handle_error_report_pdu) - -int -send_error_report_pdu(int fd, uint8_t version, uint16_t code, - struct rtr_request const *request, char *message) -{ - pr_op_info(" Server sent Error Report %u: '%s'", code, - /* gcc is complaining about logging NULL messages. WTF */ - (message != NULL) ? message : ""); - return 0; -} - -MOCK_INT(rtrhandler_handle_roa_v4, 0, struct db_table *table, uint32_t asn, - struct ipv4_prefix const *prefix4, uint8_t max_length) -MOCK_INT(rtrhandler_handle_roa_v6, 0, struct db_table *table, uint32_t asn, - struct ipv6_prefix const *prefix6, uint8_t max_length) -MOCK_INT(rtrhandler_handle_router_key, 0, struct db_table *table, - unsigned char const *ski, uint32_t as, unsigned char const *spk) - -/* End of mocks */ - -/* -* Used to be a wrapper for `buffer2fd()`, but that's no longer necessary. -* -* Converts the @buffer buffer into PDU @obj, using the @cb function. -* Also takes care of the header validation. -*/ -#define BUFFER2FD(buffer, cb, obj) { \ - struct pdu_header header; \ - struct pdu_reader reader; \ - \ - pdu_reader_init(&reader, buffer, sizeof(buffer)); \ - init_pdu_header(&header); \ - ck_assert_int_eq(0, cb(&header, &reader, obj)); \ - assert_pdu_header(&(obj)->header); \ -} - -static void -init_pdu_header(struct pdu_header *header) -{ - header->protocol_version = RTR_V0; - header->pdu_type = 22; - header->m.reserved = get_current_session_id(RTR_V0); - header->length = 0x00000020; -} - -static void -assert_pdu_header(struct pdu_header *header) -{ - ck_assert_uint_eq(header->protocol_version, 0); - ck_assert_uint_eq(header->pdu_type, 22); - ck_assert_uint_eq(header->m.reserved, get_current_session_id(RTR_V0)); - ck_assert_uint_eq(header->length, 0x00000020); -} - -START_TEST(test_pdu_header_from_stream) -{ - unsigned char input[] = { 0, 1, 2, 3, 4, 5, 6, 7 }; - struct pdu_reader reader; - struct pdu_header header; - - pdu_reader_init(&reader, input, ARRAY_LEN(input)); - /* Read the header into its buffer. */ - ck_assert_int_eq(0, pdu_header_from_reader(&reader, &header)); - - ck_assert_uint_eq(header.protocol_version, 0); - ck_assert_uint_eq(header.pdu_type, 1); - ck_assert_uint_eq(header.m.reserved, 0x0203); - ck_assert_uint_eq(header.length, 0x04050607); -} -END_TEST - -START_TEST(test_serial_notify_from_stream) -{ - unsigned char input[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }; - struct serial_notify_pdu pdu; - - BUFFER2FD(input, serial_notify_from_stream, &pdu); - ck_assert_uint_eq(pdu.serial_number, 0x010203); -} -END_TEST - -START_TEST(test_serial_query_from_stream) -{ - unsigned char input[] = { 13, 14, 15, 16, 17 }; - struct serial_query_pdu pdu; - - BUFFER2FD(input, serial_query_from_stream, &pdu); - ck_assert_uint_eq(pdu.serial_number, 0x0d0e0f10); -} -END_TEST - -START_TEST(test_reset_query_from_stream) -{ - unsigned char input[] = { 18, 19 }; - struct reset_query_pdu pdu; - - BUFFER2FD(input, reset_query_from_stream, &pdu); -} -END_TEST - -START_TEST(test_cache_response_from_stream) -{ - unsigned char input[] = { 18, 19 }; - struct cache_response_pdu pdu; - - BUFFER2FD(input, cache_response_from_stream, &pdu); -} -END_TEST - -START_TEST(test_ipv4_prefix_from_stream) -{ - unsigned char input[] = { 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, - 29, 30, 31, 32 }; - struct ipv4_prefix_pdu pdu; - - BUFFER2FD(input, ipv4_prefix_from_stream, &pdu); - ck_assert_uint_eq(pdu.flags, 18); - ck_assert_uint_eq(pdu.prefix_length, 19); - ck_assert_uint_eq(pdu.max_length, 20); - ck_assert_uint_eq(pdu.zero, 21); - ck_assert_uint_eq(pdu.ipv4_prefix.s_addr, 0x16171819); - ck_assert_uint_eq(pdu.asn, 0x1a1b1c1d); -} -END_TEST - -START_TEST(test_ipv6_prefix_from_stream) -{ - unsigned char input[] = { 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, - 58, 59, 60 }; - struct ipv6_prefix_pdu pdu; - struct in6_addr tmp; - - BUFFER2FD(input, ipv6_prefix_from_stream, &pdu); - ck_assert_uint_eq(pdu.flags, 33); - ck_assert_uint_eq(pdu.prefix_length, 34); - ck_assert_uint_eq(pdu.max_length, 35); - ck_assert_uint_eq(pdu.zero, 36); - in6_addr_init(&tmp, 0x25262728, 0x292a2b2c, 0x2d2e2f30, 0x31323334); - ck_assert(addr6_equals(&tmp, &pdu.ipv6_prefix)); - ck_assert_uint_eq(pdu.asn, 0x35363738); -} -END_TEST - -START_TEST(test_end_of_data_from_stream) -{ - unsigned char input[] = { 61, 62, 63, 64 }; - struct end_of_data_pdu pdu; - - BUFFER2FD(input, end_of_data_from_stream, &pdu); - ck_assert_uint_eq(pdu.serial_number, 0x3d3e3f40); -} -END_TEST - -START_TEST(test_cache_reset_from_stream) -{ - unsigned char input[] = { 65, 66, 67 }; - struct cache_reset_pdu pdu; - - BUFFER2FD(input, cache_reset_from_stream, &pdu); -} -END_TEST - -START_TEST(test_error_report_from_stream) -{ - unsigned char input[] = { - /* Sub-pdu length */ - 0, 0, 0, 12, - /* Sub-pdu with header*/ - 1, 0, 2, 3, 0, 0, 0, 12, 1, 2, 3, 4, - /* Error msg length */ - 0, 0, 0, 5, - /* Error msg */ - 'h', 'e', 'l', 'l', 'o', - /* Garbage */ - 1, 2, 3, 4, - }; - struct error_report_pdu pdu; - struct serial_notify_pdu sub_pdu; - struct pdu_header sub_pdu_header; - struct pdu_reader reader; - - BUFFER2FD(input, error_report_from_stream, &pdu); - - /* Get the erroneous PDU as a serial notify */ - pdu_reader_init(&reader, pdu.erroneous_pdu, pdu.error_pdu_length); - - ck_assert_int_eq(0, pdu_header_from_reader(&reader, &sub_pdu_header)); - ck_assert_int_eq(0, serial_notify_from_stream(&sub_pdu_header, &reader, - &sub_pdu)); - - ck_assert_uint_eq(sub_pdu.header.protocol_version, 1); - ck_assert_uint_eq(sub_pdu.header.pdu_type, 0); - ck_assert_uint_eq(sub_pdu.header.m.reserved, 0x0203); - ck_assert_uint_eq(sub_pdu.header.length, 12); - ck_assert_uint_eq(sub_pdu.serial_number, 0x01020304); - ck_assert_str_eq(pdu.error_message, "hello"); - - free(pdu.error_message); -} -END_TEST - -START_TEST(test_interrupted) -{ - unsigned char input[] = { 0, 1 }; - struct pdu_reader reader; - struct pdu_header header; - - pdu_reader_init(&reader, input, ARRAY_LEN(input)); - ck_assert_int_eq(-EPIPE, pdu_header_from_reader(&reader, &header)); -} -END_TEST - -Suite *pdu_suite(void) -{ - Suite *suite; - TCase *core, *errors; - - core = tcase_create("Core"); - tcase_add_test(core, test_pdu_header_from_stream); - tcase_add_test(core, test_serial_notify_from_stream); - tcase_add_test(core, test_serial_notify_from_stream); - tcase_add_test(core, test_serial_query_from_stream); - tcase_add_test(core, test_reset_query_from_stream); - tcase_add_test(core, test_cache_response_from_stream); - tcase_add_test(core, test_ipv4_prefix_from_stream); - tcase_add_test(core, test_ipv6_prefix_from_stream); - tcase_add_test(core, test_end_of_data_from_stream); - tcase_add_test(core, test_cache_reset_from_stream); - tcase_add_test(core, test_error_report_from_stream); - - errors = tcase_create("Errors"); - tcase_add_test(errors, test_interrupted); - - suite = suite_create("PDU"); - suite_add_tcase(suite, core); - suite_add_tcase(suite, errors); - return suite; -} - -int main(void) -{ - Suite *suite; - SRunner *runner; - int tests_failed; - - suite = pdu_suite(); - - runner = srunner_create(suite); - srunner_run_all(runner, CK_NORMAL); - tests_failed = srunner_ntests_failed(runner); - srunner_free(runner); - - return (tests_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE; -} diff --git a/test/rtr/primitive_reader_test.c b/test/rtr/primitive_reader_test.c deleted file mode 100644 index d757fdc9..00000000 --- a/test/rtr/primitive_reader_test.c +++ /dev/null @@ -1,256 +0,0 @@ -#include -#include -#include -#include -#include -#include - -#include "alloc.c" -#include "mock.c" -#include "rtr/primitive_reader.c" - -/* - * Wrapper for `read_string()`, for easy testing. - */ -static int -__read_string(unsigned char *input, size_t size, rtr_char **result) -{ - struct pdu_reader reader; - pdu_reader_init(&reader, input, size); - return read_string(&reader, size & 0xFFFF, result); -} - -static void -test_read_string_success(unsigned char *input, size_t length, - rtr_char *expected) -{ - rtr_char *actual; - int err; - - err = __read_string(input, length, &actual); - ck_assert_int_eq(0, err); - if (!err) { - ck_assert_str_eq(expected, actual); - free(actual); - } -} - -START_TEST(read_string_ascii) -{ - unsigned char input[] = { 'a', 'b', 'c', 'd' }; - test_read_string_success(input, sizeof(input), "abcd"); -} -END_TEST - -START_TEST(read_string_unicode) -{ - unsigned char input0[] = { 's', 'a', 'n', 'd', 0xc3, 0xad, 'a' }; - test_read_string_success(input0, sizeof(input0), "sandía"); - - unsigned char input1[] = { 0xe1, 0x88, 0x90, 0xe1, 0x89, 0xa5, 0xe1, - 0x88, 0x90, 0xe1, 0x89, 0xa5 }; - test_read_string_success(input1, sizeof(input1), "ሐብሐብ"); - - unsigned char input2[] = { 0xd8, 0xa7, 0xd9, 0x84, 0xd8, 0xa8, 0xd8, - 0xb7, 0xd9, 0x8a, 0xd8, 0xae }; - test_read_string_success(input2, sizeof(input2), "البطيخ"); - - unsigned char input3[] = { - 0xd5, 0xb1, 0xd5, 0xb4, 0xd5, 0xa5, 0xd6, 0x80, 0xd5, 0xb8, 0xd6, - 0x82, 0xd5, 0xaf, 0x20, 0xd0, 0xba, 0xd0, 0xb0, 0xd0, 0xb2, 0xd1, - 0x83, 0xd0, 0xbd }; - test_read_string_success(input3, sizeof(input3), "ձմերուկ кавун"); - - unsigned char input4[] = { - 0xe0, 0xa6, 0xa4, 0xe0, 0xa6, 0xb0, 0xe0, 0xa6, 0xae, 0xe0, 0xa7, - 0x81, 0xe0, 0xa6, 0x9c, 0x20, 0xd0, 0xb4, 0xd0, 0xb8, 0xd0, 0xbd, - 0xd1, 0x8f, 0x20, 0xe8, 0xa5, 0xbf, 0xe7, 0x93, 0x9c, 0x20, 0xf0, - 0x9f, 0x8d, 0x89 }; - test_read_string_success(input4, sizeof(input4), "তরমুজ диня 西瓜 🍉"); -} -END_TEST - -START_TEST(read_string_empty) -{ - unsigned char input[] = { 0, 0, 0, 0 }; - test_read_string_success(input, sizeof(input), ""); -} -END_TEST - -struct thread_param { - int fd; - uint32_t msg_size; - int err; -}; - -#define WRITER_PATTERN "abcdefghijklmnopqrstuvwxyz0123456789" - -/* - * Checks that the string @str is made up of @expected_len characters composed - * of the @WRITER_PATTERN pattern repeatedly. - */ -static void -validate_massive_string(uint32_t expected_len, rtr_char *str) -{ - size_t actual_len; - rtr_char *pattern; - size_t pattern_len; - rtr_char *cursor; - rtr_char *end; - - actual_len = strlen(str); - if (expected_len != actual_len) { - free(str); - ck_abort_msg("Expected length %u != Actual length %zu", - expected_len, actual_len); - } - - pattern = WRITER_PATTERN; - pattern_len = strlen(pattern); - end = str + expected_len; - for (cursor = str; cursor + pattern_len < end; cursor += pattern_len) { - if (strncmp(pattern, cursor, pattern_len) != 0) { - free(str); - ck_abort_msg("String does not match expected pattern"); - } - } - - if (strncmp(pattern, cursor, strlen(cursor)) != 0) { - free(str); - ck_abort_msg("String end does not match expected pattern"); - } - - free(str); - /* Success */ -} - -/* - * Sends @full_string_length characters to the fd, validates the parsed string - * contains the first @return_length characters. - */ -static void -test_massive_string(uint32_t return_length, uint32_t full_string_length) -{ - unsigned char *buffer; - rtr_char *pattern; - size_t pattern_len; - - size_t written; - size_t w; - - struct pdu_reader reader; - rtr_char *result_string; - - buffer = malloc(full_string_length); - if (buffer == NULL) - ck_abort_msg("Out of memory."); - - pattern = WRITER_PATTERN; - pattern_len = strlen(pattern); - for (written = 0; written < full_string_length; written += w) { - w = (full_string_length - written > pattern_len) - ? pattern_len - : (full_string_length - written); - memcpy(&buffer[written], pattern, w); - } - - pdu_reader_init(&reader, buffer, full_string_length); - ck_assert_int_eq(0, read_string(&reader, full_string_length, - &result_string)); - - validate_massive_string(return_length, result_string); - - free(buffer); -} - -START_TEST(read_string_massive) -{ - test_massive_string(2000, 2000); - test_massive_string(4000, 4000); - test_massive_string(4094, 4094); - test_massive_string(4095, 4095); - test_massive_string(4096, 4096); - test_massive_string(4097, 4097); - test_massive_string(8000, 8000); - test_massive_string(16000, 16000); -} -END_TEST - -START_TEST(read_string_null) -{ - test_read_string_success(NULL, 0, ""); -} -END_TEST - -START_TEST(read_string_unicode_mix) -{ - /* One octet failure */ - unsigned char input0[] = { 'a', 0x80, 'z' }; - test_read_string_success(input0, sizeof(input0), "a"); - - /* Two octets success */ - unsigned char input1[] = { 'a', 0xdf, 0x9a, 'z' }; - test_read_string_success(input1, sizeof(input1), "aߚz"); - /* Two octets failure */ - unsigned char input2[] = { 'a', 0xdf, 0xda, 'z' }; - test_read_string_success(input2, sizeof(input2), "a"); - - /* Three characters success */ - unsigned char input3[] = { 'a', 0xe2, 0x82, 0xac, 'z' }; - test_read_string_success(input3, sizeof(input3), "a€z"); - /* Three characters failure */ - unsigned char input4[] = { 'a', 0xe2, 0x82, 0x2c, 'z' }; - test_read_string_success(input4, sizeof(input4), "a"); - - /* Four characters success */ - unsigned char i5[] = { 'a', 0xf0, 0x90, 0x86, 0x97, 'z' }; - test_read_string_success(i5, sizeof(i5), "a𐆗z"); - /* Four characters failure */ - unsigned char i6[] = { 'a', 0xf0, 0x90, 0x90, 0x17, 'z' }; - test_read_string_success(i6, sizeof(i6), "a"); -} -END_TEST - -Suite *read_string_suite(void) -{ - Suite *suite; - TCase *core, *limits, *errors; - - core = tcase_create("Core"); - tcase_add_test(core, read_string_ascii); - tcase_add_test(core, read_string_unicode); - - limits = tcase_create("Limits"); - tcase_add_test(limits, read_string_empty); - tcase_add_test(limits, read_string_massive); - - errors = tcase_create("Errors"); - tcase_add_test(errors, read_string_null); - tcase_add_test(errors, read_string_unicode_mix); - - suite = suite_create("read_string()"); - suite_add_tcase(suite, core); - suite_add_tcase(suite, limits); - suite_add_tcase(suite, errors); - return suite; -} - -int main(void) -{ - Suite *suite; - SRunner *runner; - int tests_failed; - - /* - * This is it. We won't test the other functions because they are - * already reasonably manhandled in the PDU units. - */ - suite = read_string_suite(); - - runner = srunner_create(suite); - srunner_run_all(runner, CK_NORMAL); - tests_failed = srunner_ntests_failed(runner); - srunner_free(runner); - - return (tests_failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE; -}