#define TLS_RSA_WITH_AES_128_CBC_SHA 0x002f
#define TLS_RSA_WITH_AES_256_CBC_SHA 0x0035
+/* TLS extension types */
+#define TLS_SERVER_NAME 0
+#define TLS_SERVER_NAME_HOST_NAME 0
+
/** TLS RX state machine state */
enum tls_rx_state {
TLS_RX_HEADER = 0,
/** Reference counter */
struct refcnt refcnt;
+ /** Server name */
+ const char *name;
/** Plaintext stream */
struct interface plainstream;
/** Ciphertext stream */
void *rx_data;
};
-extern int add_tls ( struct interface *xfer,
+extern int add_tls ( struct interface *xfer, const char *name,
struct interface **next );
#endif /* _IPXE_TLS_H */
int http_open_filter ( struct interface *xfer, struct uri *uri,
unsigned int default_port,
int ( * filter ) ( struct interface *xfer,
+ const char *name,
struct interface **next ) ) {
struct http_request *http;
struct sockaddr_tcpip server;
server.st_port = htons ( uri_port ( http->uri, default_port ) );
socket = &http->socket;
if ( filter ) {
- if ( ( rc = filter ( socket, &socket ) ) != 0 )
+ if ( ( rc = filter ( socket, uri->host, &socket ) ) != 0 )
goto err;
}
if ( ( rc = xfer_open_named_socket ( socket, SOCK_STREAM,
uint16_t cipher_suites[2];
uint8_t compression_methods_len;
uint8_t compression_methods[1];
+ uint16_t extensions_len;
+ struct {
+ uint16_t server_name_type;
+ uint16_t server_name_len;
+ struct {
+ uint16_t len;
+ struct {
+ uint8_t type;
+ uint16_t len;
+ uint8_t name[ strlen ( tls->name ) ];
+ } __attribute__ (( packed )) list[1];
+ } __attribute__ (( packed )) server_name;
+ } __attribute__ (( packed )) extensions;
} __attribute__ (( packed )) hello;
memset ( &hello, 0, sizeof ( hello ) );
hello.cipher_suites[0] = htons ( TLS_RSA_WITH_AES_128_CBC_SHA );
hello.cipher_suites[1] = htons ( TLS_RSA_WITH_AES_256_CBC_SHA );
hello.compression_methods_len = sizeof ( hello.compression_methods );
+ hello.extensions_len = htons ( sizeof ( hello.extensions ) );
+ hello.extensions.server_name_type = htons ( TLS_SERVER_NAME );
+ hello.extensions.server_name_len
+ = htons ( sizeof ( hello.extensions.server_name ) );
+ hello.extensions.server_name.len
+ = htons ( sizeof ( hello.extensions.server_name.list ) );
+ hello.extensions.server_name.list[0].type = TLS_SERVER_NAME_HOST_NAME;
+ hello.extensions.server_name.list[0].len
+ = htons ( sizeof ( hello.extensions.server_name.list[0].name ));
+ memcpy ( hello.extensions.server_name.list[0].name, tls->name,
+ sizeof ( hello.extensions.server_name.list[0].name ) );
return tls_send_handshake ( tls, &hello, sizeof ( hello ) );
}
int rc;
/* Sanity check */
- if ( end != ( data + len ) ) {
- DBGC ( tls, "TLS %p received overlength Server Hello\n", tls );
+ if ( end > ( data + len ) ) {
+ DBGC ( tls, "TLS %p received underlength Server Hello\n", tls );
DBGC_HD ( tls, data, len );
return -EINVAL;
}
******************************************************************************
*/
-int add_tls ( struct interface *xfer, struct interface **next ) {
+int add_tls ( struct interface *xfer, const char *name,
+ struct interface **next ) {
struct tls_session *tls;
int rc;
}
memset ( tls, 0, sizeof ( *tls ) );
ref_init ( &tls->refcnt, free_tls );
+ tls->name = name;
intf_init ( &tls->plainstream, &tls_plainstream_desc, &tls->refcnt );
intf_init ( &tls->cipherstream, &tls_cipherstream_desc, &tls->refcnt );
tls->version = TLS_VERSION_TLS_1_1;