free(s);
}
+/** print fixed line over the ssl connection */
+static int
+ssl_print_text(SSL* ssl, const char* text)
+{
+ int r;
+ ERR_clear_error();
+ if((r=SSL_write(ssl, text, (int)strlen(text))) <= 0) {
+ if(SSL_get_error(ssl, r) == SSL_ERROR_ZERO_RETURN) {
+ verbose(VERB_QUERY, "warning, in SSL_write, peer "
+ "closed connection");
+ return 0;
+ }
+ log_crypto_err("could not SSL_write");
+ return 0;
+ }
+ return 1;
+}
+
+/** print text over the ssl connection */
+static int
+ssl_print_vmsg(SSL* ssl, const char* format, va_list args)
+{
+ char msg[1024];
+ vsnprintf(msg, sizeof(msg), format, args);
+ return ssl_print_text(ssl, msg);
+}
+
+/** declare for printf format checking by gcc
+ * @param ssl: the SSL connection to print to. Blocking.
+ * @param format: printf style format string.
+ * @return success or false on a network failure.
+ */
+static int ssl_printf(SSL* ssl, const char* format, ...)
+ ATTR_FORMAT(printf, 2, 3);
+
+/** printf style printing to the ssl connection */
+static int ssl_printf(SSL* ssl, const char* format, ...)
+{
+ va_list args;
+ int ret;
+ va_start(args, format);
+ ret = ssl_print_vmsg(ssl, format, args);
+ va_end(args);
+ return ret;
+}
+
+/** read until \n */
+static int
+ssl_read_line(SSL* ssl, char* buf, size_t max)
+{
+ int r;
+ size_t len = 0;
+ while(len < max) {
+ ERR_clear_error();
+ if((r=SSL_read(ssl, buf+len, 1)) <= 0) {
+ if(SSL_get_error(ssl, r) == SSL_ERROR_ZERO_RETURN) {
+ buf[len] = 0;
+ return 1;
+ }
+ log_crypto_err("could not SSL_read");
+ return 0;
+ }
+ if(buf[len] == '\n') {
+ /* return string without \n */
+ buf[len] = 0;
+ return 1;
+ }
+ len++;
+ }
+ buf[max-1] = 0;
+ log_err("control line too long (%d): %s", (int)max, buf);
+ return 0;
+}
+
+/** send the OK to the control client */
+static void send_ok(SSL* ssl)
+{
+ (void)ssl_printf(ssl, "ok\n");
+}
+
+/** do the stop command */
+static void
+do_stop(struct daemon_remote* rc, SSL* ssl)
+{
+ rc->worker->need_to_exit = 1;
+ comm_base_exit(rc->worker->base);
+ send_ok(ssl);
+}
+
+/** do the reload command */
+static void
+do_reload(struct daemon_remote* rc, SSL* ssl)
+{
+ rc->worker->need_to_exit = 0;
+ comm_base_exit(rc->worker->base);
+ send_ok(ssl);
+}
+
+/** execute a remote control command */
+static void
+execute_cmd(struct daemon_remote* rc, SSL* ssl, char* cmd)
+{
+ char* p = cmd;
+ /* skip whitespace */
+ while( isspace(*p) ) p++;
+ /* compare command - check longer strings first */
+ if(strncmp(p, "stop", 4) == 0) {
+ do_stop(rc, ssl);
+ } else if(strncmp(p, "reload", 6) == 0) {
+ do_reload(rc, ssl);
+ } else {
+ (void)ssl_printf(ssl, "error unknown command '%s'\n", p);
+ }
+}
+
/** handle remote control request */
static void
handle_req(struct daemon_remote* rc, struct rc_state* s, SSL* ssl)
{
- char* msg = "ok\n";
int r;
+ char magic[5];
char buf[1024];
fd_set_block(s->c->fd);
+ /* try to read magic UBCT string */
ERR_clear_error();
- if((r=SSL_read(ssl, buf, (int)sizeof(buf)-1)) <= 0) {
+ if((r=SSL_read(ssl, magic, (int)sizeof(magic)-1)) <= 0) {
if(SSL_get_error(ssl, r) == SSL_ERROR_ZERO_RETURN)
return;
log_crypto_err("could not SSL_read");
return;
}
- buf[r] = 0;
- log_info("got '%s'", buf);
+ magic[4] = 0;
+ if( r != 4 || strcmp(magic, "UBCT") != 0) {
+ verbose(VERB_QUERY, "control connection has bad magic string");
+ return;
+ }
- ERR_clear_error();
- if((r=SSL_write(ssl, msg, (int)strlen(msg))) <= 0) {
- if(SSL_get_error(ssl, r) == SSL_ERROR_ZERO_RETURN)
- return;
- log_crypto_err("could not SSL_write");
+ /* read the command line */
+ if(!ssl_read_line(ssl, buf, sizeof(buf))) {
return;
}
+ verbose(VERB_DETAIL, "control cmd: %s", buf);
+
+ /* figure out what to do */
+ execute_cmd(rc, ssl, buf);
}
int remote_control_callback(struct comm_point* c, void* arg, int err,