summaryrefslogtreecommitdiffstats
path: root/contrib/openbsm/bin/auditdistd/proto_tls.c
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/openbsm/bin/auditdistd/proto_tls.c')
-rw-r--r--contrib/openbsm/bin/auditdistd/proto_tls.c1076
1 files changed, 1076 insertions, 0 deletions
diff --git a/contrib/openbsm/bin/auditdistd/proto_tls.c b/contrib/openbsm/bin/auditdistd/proto_tls.c
new file mode 100644
index 0000000..faeb3d8
--- /dev/null
+++ b/contrib/openbsm/bin/auditdistd/proto_tls.c
@@ -0,0 +1,1076 @@
+/*-
+ * Copyright (c) 2011 The FreeBSD Foundation
+ * All rights reserved.
+ *
+ * This software was developed by Pawel Jakub Dawidek under sponsorship from
+ * the FreeBSD Foundation.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ * 1. Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+ * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
+ * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+ * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
+ * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
+ * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+ * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
+ * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
+ * SUCH DAMAGE.
+ *
+ * $P4: //depot/projects/trustedbsd/openbsm/bin/auditdistd/proto_tls.c#2 $
+ */
+
+#include <config/config.h>
+
+#include <sys/param.h> /* MAXHOSTNAMELEN */
+#include <sys/socket.h>
+
+#include <arpa/inet.h>
+
+#include <netinet/in.h>
+#include <netinet/tcp.h>
+
+#include <errno.h>
+#include <fcntl.h>
+#include <netdb.h>
+#include <signal.h>
+#include <stdbool.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h>
+#include <unistd.h>
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+#include <compat/compat.h>
+#ifndef HAVE_CLOSEFROM
+#include <compat/closefrom.h>
+#endif
+#ifndef HAVE_STRLCPY
+#include <compat/strlcpy.h>
+#endif
+
+#include "pjdlog.h"
+#include "proto_impl.h"
+#include "sandbox.h"
+#include "subr.h"
+
+#define TLS_CTX_MAGIC 0x715c7
+struct tls_ctx {
+ int tls_magic;
+ struct proto_conn *tls_sock;
+ struct proto_conn *tls_tcp;
+ char tls_laddr[256];
+ char tls_raddr[256];
+ int tls_side;
+#define TLS_SIDE_CLIENT 0
+#define TLS_SIDE_SERVER_LISTEN 1
+#define TLS_SIDE_SERVER_WORK 2
+ bool tls_wait_called;
+};
+
+#define TLS_DEFAULT_TIMEOUT 30
+
+static int tls_connect_wait(void *ctx, int timeout);
+static void tls_close(void *ctx);
+
+static void
+block(int fd)
+{
+ int flags;
+
+ flags = fcntl(fd, F_GETFL);
+ if (flags == -1)
+ pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
+ flags &= ~O_NONBLOCK;
+ if (fcntl(fd, F_SETFL, flags) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
+}
+
+static void
+nonblock(int fd)
+{
+ int flags;
+
+ flags = fcntl(fd, F_GETFL);
+ if (flags == -1)
+ pjdlog_exit(EX_TEMPFAIL, "fcntl(F_GETFL) failed");
+ flags |= O_NONBLOCK;
+ if (fcntl(fd, F_SETFL, flags) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "fcntl(F_SETFL) failed");
+}
+
+static int
+wait_for_fd(int fd, int timeout)
+{
+ struct timeval tv;
+ fd_set fdset;
+ int error, ret;
+
+ error = 0;
+
+ for (;;) {
+ FD_ZERO(&fdset);
+ FD_SET(fd, &fdset);
+
+ tv.tv_sec = timeout;
+ tv.tv_usec = 0;
+
+ ret = select(fd + 1, NULL, &fdset, NULL,
+ timeout == -1 ? NULL : &tv);
+ if (ret == 0) {
+ error = ETIMEDOUT;
+ break;
+ } else if (ret == -1) {
+ if (errno == EINTR)
+ continue;
+ error = errno;
+ break;
+ }
+ PJDLOG_ASSERT(ret > 0);
+ PJDLOG_ASSERT(FD_ISSET(fd, &fdset));
+ break;
+ }
+
+ return (error);
+}
+
+static void
+ssl_log_errors(void)
+{
+ unsigned long error;
+
+ while ((error = ERR_get_error()) != 0)
+ pjdlog_error("SSL error: %s", ERR_error_string(error, NULL));
+}
+
+static int
+ssl_check_error(SSL *ssl, int ret)
+{
+ int error;
+
+ error = SSL_get_error(ssl, ret);
+
+ switch (error) {
+ case SSL_ERROR_NONE:
+ return (0);
+ case SSL_ERROR_WANT_READ:
+ pjdlog_debug(2, "SSL_ERROR_WANT_READ");
+ return (-1);
+ case SSL_ERROR_WANT_WRITE:
+ pjdlog_debug(2, "SSL_ERROR_WANT_WRITE");
+ return (-1);
+ case SSL_ERROR_ZERO_RETURN:
+ pjdlog_exitx(EX_OK, "Connection closed.");
+ case SSL_ERROR_SYSCALL:
+ ssl_log_errors();
+ pjdlog_exitx(EX_TEMPFAIL, "SSL I/O error.");
+ case SSL_ERROR_SSL:
+ ssl_log_errors();
+ pjdlog_exitx(EX_TEMPFAIL, "SSL protocol error.");
+ default:
+ ssl_log_errors();
+ pjdlog_exitx(EX_TEMPFAIL, "Unknown SSL error (%d).", error);
+ }
+}
+
+static void
+tcp_recv_ssl_send(int recvfd, SSL *sendssl)
+{
+ static unsigned char buf[65536];
+ ssize_t tcpdone;
+ int sendfd, ssldone;
+
+ sendfd = SSL_get_fd(sendssl);
+ PJDLOG_ASSERT(sendfd >= 0);
+ pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
+ for (;;) {
+ tcpdone = recv(recvfd, buf, sizeof(buf), 0);
+ pjdlog_debug(2, "%s: recv() returned %zd", __func__, tcpdone);
+ if (tcpdone == 0) {
+ pjdlog_debug(1, "Connection terminated.");
+ exit(0);
+ } else if (tcpdone == -1) {
+ if (errno == EINTR)
+ continue;
+ else if (errno == EAGAIN)
+ break;
+ pjdlog_exit(EX_TEMPFAIL, "recv() failed");
+ }
+ for (;;) {
+ ssldone = SSL_write(sendssl, buf, (int)tcpdone);
+ pjdlog_debug(2, "%s: send() returned %d", __func__,
+ ssldone);
+ if (ssl_check_error(sendssl, ssldone) == -1) {
+ (void)wait_for_fd(sendfd, -1);
+ continue;
+ }
+ PJDLOG_ASSERT(ssldone == tcpdone);
+ break;
+ }
+ }
+ pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
+}
+
+static void
+ssl_recv_tcp_send(SSL *recvssl, int sendfd)
+{
+ static unsigned char buf[65536];
+ unsigned char *ptr;
+ ssize_t tcpdone;
+ size_t todo;
+ int recvfd, ssldone;
+
+ recvfd = SSL_get_fd(recvssl);
+ PJDLOG_ASSERT(recvfd >= 0);
+ pjdlog_debug(2, "%s: start %d -> %d", __func__, recvfd, sendfd);
+ for (;;) {
+ ssldone = SSL_read(recvssl, buf, sizeof(buf));
+ pjdlog_debug(2, "%s: SSL_read() returned %d", __func__,
+ ssldone);
+ if (ssl_check_error(recvssl, ssldone) == -1)
+ break;
+ todo = (size_t)ssldone;
+ ptr = buf;
+ do {
+ tcpdone = send(sendfd, ptr, todo, MSG_NOSIGNAL);
+ pjdlog_debug(2, "%s: send() returned %zd", __func__,
+ tcpdone);
+ if (tcpdone == 0) {
+ pjdlog_debug(1, "Connection terminated.");
+ exit(0);
+ } else if (tcpdone == -1) {
+ if (errno == EINTR || errno == ENOBUFS)
+ continue;
+ if (errno == EAGAIN) {
+ (void)wait_for_fd(sendfd, -1);
+ continue;
+ }
+ pjdlog_exit(EX_TEMPFAIL, "send() failed");
+ }
+ todo -= tcpdone;
+ ptr += tcpdone;
+ } while (todo > 0);
+ }
+ pjdlog_debug(2, "%s: done %d -> %d", __func__, recvfd, sendfd);
+}
+
+static void
+tls_loop(int sockfd, SSL *tcpssl)
+{
+ fd_set fds;
+ int maxfd, tcpfd;
+
+ tcpfd = SSL_get_fd(tcpssl);
+ PJDLOG_ASSERT(tcpfd >= 0);
+
+ for (;;) {
+ FD_ZERO(&fds);
+ FD_SET(sockfd, &fds);
+ FD_SET(tcpfd, &fds);
+ maxfd = MAX(sockfd, tcpfd);
+
+ PJDLOG_ASSERT(maxfd + 1 <= (int)FD_SETSIZE);
+ if (select(maxfd + 1, &fds, NULL, NULL, NULL) == -1) {
+ if (errno == EINTR)
+ continue;
+ pjdlog_exit(EX_TEMPFAIL, "select() failed");
+ }
+ if (FD_ISSET(sockfd, &fds))
+ tcp_recv_ssl_send(sockfd, tcpssl);
+ if (FD_ISSET(tcpfd, &fds))
+ ssl_recv_tcp_send(tcpssl, sockfd);
+ }
+}
+
+static void
+tls_certificate_verify(SSL *ssl, const char *fingerprint)
+{
+ unsigned char md[EVP_MAX_MD_SIZE];
+ char mdstr[sizeof("SHA256=") - 1 + EVP_MAX_MD_SIZE * 3];
+ char *mdstrp;
+ unsigned int i, mdsize;
+ X509 *cert;
+
+ if (fingerprint[0] == '\0') {
+ pjdlog_debug(1, "No fingerprint verification requested.");
+ return;
+ }
+
+ cert = SSL_get_peer_certificate(ssl);
+ if (cert == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "No peer certificate received.");
+
+ if (X509_digest(cert, EVP_sha256(), md, &mdsize) != 1)
+ pjdlog_exitx(EX_TEMPFAIL, "X509_digest() failed.");
+ PJDLOG_ASSERT(mdsize <= EVP_MAX_MD_SIZE);
+
+ X509_free(cert);
+
+ (void)strlcpy(mdstr, "SHA256=", sizeof(mdstr));
+ mdstrp = mdstr + strlen(mdstr);
+ for (i = 0; i < mdsize; i++) {
+ PJDLOG_VERIFY(mdstrp + 3 <= mdstr + sizeof(mdstr));
+ (void)sprintf(mdstrp, "%02hhX:", md[i]);
+ mdstrp += 3;
+ }
+ /* Clear last colon. */
+ mdstrp[-1] = '\0';
+ if (strcasecmp(mdstr, fingerprint) != 0) {
+ pjdlog_exitx(EX_NOPERM,
+ "Finger print doesn't match. Received \"%s\", expected \"%s\"",
+ mdstr, fingerprint);
+ }
+}
+
+static void
+tls_exec_client(const char *user, int startfd, const char *srcaddr,
+ const char *dstaddr, const char *fingerprint, const char *defport,
+ int timeout, int debuglevel)
+{
+ struct proto_conn *tcp;
+ char *saddr, *daddr;
+ SSL_CTX *sslctx;
+ SSL *ssl;
+ long ret;
+ int sockfd, tcpfd;
+ uint8_t connected;
+
+ pjdlog_debug_set(debuglevel);
+ pjdlog_prefix_set("[TLS sandbox] (client) ");
+#ifdef HAVE_SETPROCTITLE
+ setproctitle("[TLS sandbox] (client) ");
+#endif
+ proto_set("tcp:port", defport);
+
+ sockfd = startfd;
+
+ /* Change tls:// to tcp://. */
+ if (srcaddr == NULL) {
+ saddr = NULL;
+ } else {
+ saddr = strdup(srcaddr);
+ if (saddr == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
+ bcopy("tcp://", saddr, 6);
+ }
+ daddr = strdup(dstaddr);
+ if (daddr == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "Unable to allocate memory.");
+ bcopy("tcp://", daddr, 6);
+
+ /* Establish TCP connection. */
+ if (proto_connect(saddr, daddr, timeout, &tcp) == -1)
+ exit(EX_TEMPFAIL);
+
+ SSL_load_error_strings();
+ SSL_library_init();
+
+ /*
+ * TODO: On FreeBSD we could move this below sandbox() once libc and
+ * libcrypto use sysctl kern.arandom to obtain random data
+ * instead of /dev/urandom and friends.
+ */
+ sslctx = SSL_CTX_new(TLSv1_client_method());
+ if (sslctx == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
+
+ if (sandbox(user, true, "proto_tls client: %s", dstaddr) != 0)
+ pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS client.");
+ pjdlog_debug(1, "Privileges successfully dropped.");
+
+ SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+
+ /* Load CA certs. */
+ /* TODO */
+ //SSL_CTX_load_verify_locations(sslctx, cacerts_file, NULL);
+
+ ssl = SSL_new(sslctx);
+ if (ssl == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
+
+ tcpfd = proto_descriptor(tcp);
+
+ block(tcpfd);
+
+ if (SSL_set_fd(ssl, tcpfd) != 1)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
+
+ ret = SSL_connect(ssl);
+ ssl_check_error(ssl, (int)ret);
+
+ nonblock(sockfd);
+ nonblock(tcpfd);
+
+ tls_certificate_verify(ssl, fingerprint);
+
+ /*
+ * The following byte is send to make proto_connect_wait() to work.
+ */
+ connected = 1;
+ for (;;) {
+ switch (send(sockfd, &connected, sizeof(connected), 0)) {
+ case -1:
+ if (errno == EINTR || errno == ENOBUFS)
+ continue;
+ if (errno == EAGAIN) {
+ (void)wait_for_fd(sockfd, -1);
+ continue;
+ }
+ pjdlog_exit(EX_TEMPFAIL, "send() failed");
+ case 0:
+ pjdlog_debug(1, "Connection terminated.");
+ exit(0);
+ case 1:
+ break;
+ }
+ break;
+ }
+
+ tls_loop(sockfd, ssl);
+}
+
+static void
+tls_call_exec_client(struct proto_conn *sock, const char *srcaddr,
+ const char *dstaddr, int timeout)
+{
+ char *timeoutstr, *startfdstr, *debugstr;
+ int startfd;
+
+ /* Declare that we are receiver. */
+ proto_recv(sock, NULL, 0);
+
+ if (pjdlog_mode_get() == PJDLOG_MODE_STD)
+ startfd = 3;
+ else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
+ startfd = 0;
+
+ if (proto_descriptor(sock) != startfd) {
+ /* Move socketpair descriptor to descriptor number startfd. */
+ if (dup2(proto_descriptor(sock), startfd) == -1)
+ pjdlog_exit(EX_OSERR, "dup2() failed");
+ proto_close(sock);
+ } else {
+ /*
+ * The FD_CLOEXEC is cleared by dup2(2), so when we not
+ * call it, we have to clear it by hand in case it is set.
+ */
+ if (fcntl(startfd, F_SETFD, 0) == -1)
+ pjdlog_exit(EX_OSERR, "fcntl() failed");
+ }
+
+ closefrom(startfd + 1);
+
+ if (asprintf(&startfdstr, "%d", startfd) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
+ if (timeout == -1)
+ timeout = TLS_DEFAULT_TIMEOUT;
+ if (asprintf(&timeoutstr, "%d", timeout) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
+ if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
+
+ execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
+ proto_get("user"), "client", startfdstr,
+ srcaddr == NULL ? "" : srcaddr, dstaddr,
+ proto_get("tls:fingerprint"), proto_get("tcp:port"), timeoutstr,
+ debugstr, NULL);
+ pjdlog_exit(EX_SOFTWARE, "execl() failed");
+}
+
+static int
+tls_connect(const char *srcaddr, const char *dstaddr, int timeout, void **ctxp)
+{
+ struct tls_ctx *tlsctx;
+ struct proto_conn *sock;
+ pid_t pid;
+ int error;
+
+ PJDLOG_ASSERT(srcaddr == NULL || srcaddr[0] != '\0');
+ PJDLOG_ASSERT(dstaddr != NULL);
+ PJDLOG_ASSERT(timeout >= -1);
+ PJDLOG_ASSERT(ctxp != NULL);
+
+ if (strncmp(dstaddr, "tls://", 6) != 0)
+ return (-1);
+ if (srcaddr != NULL && strncmp(srcaddr, "tls://", 6) != 0)
+ return (-1);
+
+ if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
+ return (errno);
+
+#if 0
+ /*
+ * We use rfork() with the following flags to disable SIGCHLD
+ * delivery upon the sandbox process exit.
+ */
+ pid = rfork(RFFDG | RFPROC | RFTSIGZMB | RFTSIGFLAGS(0));
+#else
+ /*
+ * We don't use rfork() to be able to log information about sandbox
+ * process exiting.
+ */
+ pid = fork();
+#endif
+ switch (pid) {
+ case -1:
+ /* Failure. */
+ error = errno;
+ proto_close(sock);
+ return (error);
+ case 0:
+ /* Child. */
+ pjdlog_prefix_set("[TLS sandbox] (client) ");
+#ifdef HAVE_SETPROCTITLE
+ setproctitle("[TLS sandbox] (client) ");
+#endif
+ tls_call_exec_client(sock, srcaddr, dstaddr, timeout);
+ /* NOTREACHED */
+ default:
+ /* Parent. */
+ tlsctx = calloc(1, sizeof(*tlsctx));
+ if (tlsctx == NULL) {
+ error = errno;
+ proto_close(sock);
+ (void)kill(pid, SIGKILL);
+ return (error);
+ }
+ proto_send(sock, NULL, 0);
+ tlsctx->tls_sock = sock;
+ tlsctx->tls_tcp = NULL;
+ tlsctx->tls_side = TLS_SIDE_CLIENT;
+ tlsctx->tls_wait_called = false;
+ tlsctx->tls_magic = TLS_CTX_MAGIC;
+ if (timeout >= 0) {
+ error = tls_connect_wait(tlsctx, timeout);
+ if (error != 0) {
+ (void)kill(pid, SIGKILL);
+ tls_close(tlsctx);
+ return (error);
+ }
+ }
+ *ctxp = tlsctx;
+ return (0);
+ }
+}
+
+static int
+tls_connect_wait(void *ctx, int timeout)
+{
+ struct tls_ctx *tlsctx = ctx;
+ int error, sockfd;
+ uint8_t connected;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT);
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+ PJDLOG_ASSERT(!tlsctx->tls_wait_called);
+ PJDLOG_ASSERT(timeout >= 0);
+
+ sockfd = proto_descriptor(tlsctx->tls_sock);
+ error = wait_for_fd(sockfd, timeout);
+ if (error != 0)
+ return (error);
+
+ for (;;) {
+ switch (recv(sockfd, &connected, sizeof(connected),
+ MSG_WAITALL)) {
+ case -1:
+ if (errno == EINTR || errno == ENOBUFS)
+ continue;
+ error = errno;
+ break;
+ case 0:
+ pjdlog_debug(1, "Connection terminated.");
+ error = ENOTCONN;
+ break;
+ case 1:
+ tlsctx->tls_wait_called = true;
+ break;
+ }
+ break;
+ }
+
+ return (error);
+}
+
+static int
+tls_server(const char *lstaddr, void **ctxp)
+{
+ struct proto_conn *tcp;
+ struct tls_ctx *tlsctx;
+ char *laddr;
+ int error;
+
+ if (strncmp(lstaddr, "tls://", 6) != 0)
+ return (-1);
+
+ tlsctx = malloc(sizeof(*tlsctx));
+ if (tlsctx == NULL) {
+ pjdlog_warning("Unable to allocate memory.");
+ return (ENOMEM);
+ }
+
+ laddr = strdup(lstaddr);
+ if (laddr == NULL) {
+ free(tlsctx);
+ pjdlog_warning("Unable to allocate memory.");
+ return (ENOMEM);
+ }
+ bcopy("tcp://", laddr, 6);
+
+ if (proto_server(laddr, &tcp) == -1) {
+ error = errno;
+ free(tlsctx);
+ free(laddr);
+ return (error);
+ }
+ free(laddr);
+
+ tlsctx->tls_sock = NULL;
+ tlsctx->tls_tcp = tcp;
+ tlsctx->tls_side = TLS_SIDE_SERVER_LISTEN;
+ tlsctx->tls_wait_called = true;
+ tlsctx->tls_magic = TLS_CTX_MAGIC;
+ *ctxp = tlsctx;
+
+ return (0);
+}
+
+static void
+tls_exec_server(const char *user, int startfd, const char *privkey,
+ const char *cert, int debuglevel)
+{
+ SSL_CTX *sslctx;
+ SSL *ssl;
+ int sockfd, tcpfd, ret;
+
+ pjdlog_debug_set(debuglevel);
+ pjdlog_prefix_set("[TLS sandbox] (server) ");
+#ifdef HAVE_SETPROCTITLE
+ setproctitle("[TLS sandbox] (server) ");
+#endif
+
+ sockfd = startfd;
+ tcpfd = startfd + 1;
+
+ SSL_load_error_strings();
+ SSL_library_init();
+
+ sslctx = SSL_CTX_new(TLSv1_server_method());
+ if (sslctx == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_CTX_new() failed.");
+
+ SSL_CTX_set_options(sslctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3);
+
+ ssl = SSL_new(sslctx);
+ if (ssl == NULL)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_new() failed.");
+
+ if (SSL_use_RSAPrivateKey_file(ssl, privkey, SSL_FILETYPE_PEM) != 1) {
+ ssl_log_errors();
+ pjdlog_exitx(EX_CONFIG,
+ "SSL_use_RSAPrivateKey_file(%s) failed.", privkey);
+ }
+
+ if (SSL_use_certificate_file(ssl, cert, SSL_FILETYPE_PEM) != 1) {
+ ssl_log_errors();
+ pjdlog_exitx(EX_CONFIG, "SSL_use_certificate_file(%s) failed.",
+ cert);
+ }
+
+ if (sandbox(user, true, "proto_tls server") != 0)
+ pjdlog_exitx(EX_CONFIG, "Unable to sandbox TLS server.");
+ pjdlog_debug(1, "Privileges successfully dropped.");
+
+ nonblock(sockfd);
+ nonblock(tcpfd);
+
+ if (SSL_set_fd(ssl, tcpfd) != 1)
+ pjdlog_exitx(EX_TEMPFAIL, "SSL_set_fd() failed.");
+
+ ret = SSL_accept(ssl);
+ ssl_check_error(ssl, ret);
+
+ tls_loop(sockfd, ssl);
+}
+
+static void
+tls_call_exec_server(struct proto_conn *sock, struct proto_conn *tcp)
+{
+ int startfd, sockfd, tcpfd, safefd;
+ char *startfdstr, *debugstr;
+
+ if (pjdlog_mode_get() == PJDLOG_MODE_STD)
+ startfd = 3;
+ else /* if (pjdlog_mode_get() == PJDLOG_MODE_SYSLOG) */
+ startfd = 0;
+
+ /* Declare that we are receiver. */
+ proto_send(sock, NULL, 0);
+
+ sockfd = proto_descriptor(sock);
+ tcpfd = proto_descriptor(tcp);
+
+ safefd = MAX(sockfd, tcpfd);
+ safefd = MAX(safefd, startfd);
+ safefd++;
+
+ /* Move sockfd and tcpfd to safe numbers first. */
+ if (dup2(sockfd, safefd) == -1)
+ pjdlog_exit(EX_OSERR, "dup2() failed");
+ proto_close(sock);
+ sockfd = safefd;
+ if (dup2(tcpfd, safefd + 1) == -1)
+ pjdlog_exit(EX_OSERR, "dup2() failed");
+ proto_close(tcp);
+ tcpfd = safefd + 1;
+
+ /* Move socketpair descriptor to descriptor number startfd. */
+ if (dup2(sockfd, startfd) == -1)
+ pjdlog_exit(EX_OSERR, "dup2() failed");
+ (void)close(sockfd);
+ /* Move tcp descriptor to descriptor number startfd + 1. */
+ if (dup2(tcpfd, startfd + 1) == -1)
+ pjdlog_exit(EX_OSERR, "dup2() failed");
+ (void)close(tcpfd);
+
+ closefrom(startfd + 2);
+
+ /*
+ * Even if FD_CLOEXEC was set on descriptors before dup2(), it should
+ * have been cleared on dup2(), but better be safe than sorry.
+ */
+ if (fcntl(startfd, F_SETFD, 0) == -1)
+ pjdlog_exit(EX_OSERR, "fcntl() failed");
+ if (fcntl(startfd + 1, F_SETFD, 0) == -1)
+ pjdlog_exit(EX_OSERR, "fcntl() failed");
+
+ if (asprintf(&startfdstr, "%d", startfd) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
+ if (asprintf(&debugstr, "%d", pjdlog_debug_get()) == -1)
+ pjdlog_exit(EX_TEMPFAIL, "asprintf() failed");
+
+ execl(proto_get("execpath"), proto_get("execpath"), "proto", "tls",
+ proto_get("user"), "server", startfdstr, proto_get("tls:keyfile"),
+ proto_get("tls:certfile"), debugstr, NULL);
+ pjdlog_exit(EX_SOFTWARE, "execl() failed");
+}
+
+static int
+tls_accept(void *ctx, void **newctxp)
+{
+ struct tls_ctx *tlsctx = ctx;
+ struct tls_ctx *newtlsctx;
+ struct proto_conn *sock, *tcp;
+ pid_t pid;
+ int error;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_SERVER_LISTEN);
+
+ if (proto_connect(NULL, "socketpair://", -1, &sock) == -1)
+ return (errno);
+
+ /* Accept TCP connection. */
+ if (proto_accept(tlsctx->tls_tcp, &tcp) == -1) {
+ error = errno;
+ proto_close(sock);
+ return (error);
+ }
+
+ pid = fork();
+ switch (pid) {
+ case -1:
+ /* Failure. */
+ error = errno;
+ proto_close(sock);
+ return (error);
+ case 0:
+ /* Child. */
+ pjdlog_prefix_set("[TLS sandbox] (server) ");
+#ifdef HAVE_SETPROCTITLE
+ setproctitle("[TLS sandbox] (server) ");
+#endif
+ /* Close listen socket. */
+ proto_close(tlsctx->tls_tcp);
+ tls_call_exec_server(sock, tcp);
+ /* NOTREACHED */
+ PJDLOG_ABORT("Unreachable.");
+ default:
+ /* Parent. */
+ newtlsctx = calloc(1, sizeof(*tlsctx));
+ if (newtlsctx == NULL) {
+ error = errno;
+ proto_close(sock);
+ proto_close(tcp);
+ (void)kill(pid, SIGKILL);
+ return (error);
+ }
+ proto_local_address(tcp, newtlsctx->tls_laddr,
+ sizeof(newtlsctx->tls_laddr));
+ PJDLOG_ASSERT(strncmp(newtlsctx->tls_laddr, "tcp://", 6) == 0);
+ bcopy("tls://", newtlsctx->tls_laddr, 6);
+ *strrchr(newtlsctx->tls_laddr, ':') = '\0';
+ proto_remote_address(tcp, newtlsctx->tls_raddr,
+ sizeof(newtlsctx->tls_raddr));
+ PJDLOG_ASSERT(strncmp(newtlsctx->tls_raddr, "tcp://", 6) == 0);
+ bcopy("tls://", newtlsctx->tls_raddr, 6);
+ *strrchr(newtlsctx->tls_raddr, ':') = '\0';
+ proto_close(tcp);
+ proto_recv(sock, NULL, 0);
+ newtlsctx->tls_sock = sock;
+ newtlsctx->tls_tcp = NULL;
+ newtlsctx->tls_wait_called = true;
+ newtlsctx->tls_side = TLS_SIDE_SERVER_WORK;
+ newtlsctx->tls_magic = TLS_CTX_MAGIC;
+ *newctxp = newtlsctx;
+ return (0);
+ }
+}
+
+static int
+tls_wrap(int fd, bool client, void **ctxp)
+{
+ struct tls_ctx *tlsctx;
+ struct proto_conn *sock;
+ int error;
+
+ tlsctx = calloc(1, sizeof(*tlsctx));
+ if (tlsctx == NULL)
+ return (errno);
+
+ if (proto_wrap("socketpair", client, fd, &sock) == -1) {
+ error = errno;
+ free(tlsctx);
+ return (error);
+ }
+
+ tlsctx->tls_sock = sock;
+ tlsctx->tls_tcp = NULL;
+ tlsctx->tls_wait_called = (client ? false : true);
+ tlsctx->tls_side = (client ? TLS_SIDE_CLIENT : TLS_SIDE_SERVER_WORK);
+ tlsctx->tls_magic = TLS_CTX_MAGIC;
+ *ctxp = tlsctx;
+
+ return (0);
+}
+
+static int
+tls_send(void *ctx, const unsigned char *data, size_t size, int fd)
+{
+ struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
+ tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_wait_called);
+ PJDLOG_ASSERT(fd == -1);
+
+ if (proto_send(tlsctx->tls_sock, data, size) == -1)
+ return (errno);
+
+ return (0);
+}
+
+static int
+tls_recv(void *ctx, unsigned char *data, size_t size, int *fdp)
+{
+ struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_side == TLS_SIDE_CLIENT ||
+ tlsctx->tls_side == TLS_SIDE_SERVER_WORK);
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_wait_called);
+ PJDLOG_ASSERT(fdp == NULL);
+
+ if (proto_recv(tlsctx->tls_sock, data, size) == -1)
+ return (errno);
+
+ return (0);
+}
+
+static int
+tls_descriptor(const void *ctx)
+{
+ const struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+
+ switch (tlsctx->tls_side) {
+ case TLS_SIDE_CLIENT:
+ case TLS_SIDE_SERVER_WORK:
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+
+ return (proto_descriptor(tlsctx->tls_sock));
+ case TLS_SIDE_SERVER_LISTEN:
+ PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
+
+ return (proto_descriptor(tlsctx->tls_tcp));
+ default:
+ PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
+ }
+}
+
+static bool
+tcp_address_match(const void *ctx, const char *addr)
+{
+ const struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+
+ return (strcmp(tlsctx->tls_raddr, addr) == 0);
+}
+
+static void
+tls_local_address(const void *ctx, char *addr, size_t size)
+{
+ const struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_wait_called);
+
+ switch (tlsctx->tls_side) {
+ case TLS_SIDE_CLIENT:
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+
+ PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
+ break;
+ case TLS_SIDE_SERVER_WORK:
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+
+ PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_laddr, size) < size);
+ break;
+ case TLS_SIDE_SERVER_LISTEN:
+ PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
+
+ proto_local_address(tlsctx->tls_tcp, addr, size);
+ PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
+ /* Replace tcp:// prefix with tls:// */
+ bcopy("tls://", addr, 6);
+ break;
+ default:
+ PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
+ }
+}
+
+static void
+tls_remote_address(const void *ctx, char *addr, size_t size)
+{
+ const struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+ PJDLOG_ASSERT(tlsctx->tls_wait_called);
+
+ switch (tlsctx->tls_side) {
+ case TLS_SIDE_CLIENT:
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+
+ PJDLOG_VERIFY(strlcpy(addr, "tls://N/A", size) < size);
+ break;
+ case TLS_SIDE_SERVER_WORK:
+ PJDLOG_ASSERT(tlsctx->tls_sock != NULL);
+
+ PJDLOG_VERIFY(strlcpy(addr, tlsctx->tls_raddr, size) < size);
+ break;
+ case TLS_SIDE_SERVER_LISTEN:
+ PJDLOG_ASSERT(tlsctx->tls_tcp != NULL);
+
+ proto_remote_address(tlsctx->tls_tcp, addr, size);
+ PJDLOG_ASSERT(strncmp(addr, "tcp://", 6) == 0);
+ /* Replace tcp:// prefix with tls:// */
+ bcopy("tls://", addr, 6);
+ break;
+ default:
+ PJDLOG_ABORT("Invalid side (%d).", tlsctx->tls_side);
+ }
+}
+
+static void
+tls_close(void *ctx)
+{
+ struct tls_ctx *tlsctx = ctx;
+
+ PJDLOG_ASSERT(tlsctx != NULL);
+ PJDLOG_ASSERT(tlsctx->tls_magic == TLS_CTX_MAGIC);
+
+ if (tlsctx->tls_sock != NULL) {
+ proto_close(tlsctx->tls_sock);
+ tlsctx->tls_sock = NULL;
+ }
+ if (tlsctx->tls_tcp != NULL) {
+ proto_close(tlsctx->tls_tcp);
+ tlsctx->tls_tcp = NULL;
+ }
+ tlsctx->tls_side = 0;
+ tlsctx->tls_magic = 0;
+ free(tlsctx);
+}
+
+static int
+tls_exec(int argc, char *argv[])
+{
+
+ PJDLOG_ASSERT(argc > 3);
+ PJDLOG_ASSERT(strcmp(argv[0], "tls") == 0);
+
+ pjdlog_init(atoi(argv[3]) == 0 ? PJDLOG_MODE_SYSLOG : PJDLOG_MODE_STD);
+
+ if (strcmp(argv[2], "client") == 0) {
+ if (argc != 10)
+ return (EINVAL);
+ tls_exec_client(argv[1], atoi(argv[3]),
+ argv[4][0] == '\0' ? NULL : argv[4], argv[5], argv[6],
+ argv[7], atoi(argv[8]), atoi(argv[9]));
+ } else if (strcmp(argv[2], "server") == 0) {
+ if (argc != 7)
+ return (EINVAL);
+ tls_exec_server(argv[1], atoi(argv[3]), argv[4], argv[5],
+ atoi(argv[6]));
+ }
+ return (EINVAL);
+}
+
+static struct proto tls_proto = {
+ .prt_name = "tls",
+ .prt_connect = tls_connect,
+ .prt_connect_wait = tls_connect_wait,
+ .prt_server = tls_server,
+ .prt_accept = tls_accept,
+ .prt_wrap = tls_wrap,
+ .prt_send = tls_send,
+ .prt_recv = tls_recv,
+ .prt_descriptor = tls_descriptor,
+ .prt_address_match = tcp_address_match,
+ .prt_local_address = tls_local_address,
+ .prt_remote_address = tls_remote_address,
+ .prt_close = tls_close,
+ .prt_exec = tls_exec
+};
+
+static __constructor void
+tls_ctor(void)
+{
+
+ proto_register(&tls_proto, false);
+}
OpenPOWER on IntegriCloud