* (c) 2007 Martin Mares <mj@ucw.cz>
*/
-#define LOCAL_DEBUG
+/*
+ * FIXME:
+ * - competition timeout & per-contestant exceptions
+ */
+
+#undef LOCAL_DEBUG
#include "lib/lib.h"
+#include "lib/conf.h"
+#include "lib/getopt.h"
+#include "lib/ipaccess.h"
+#include "lib/fastbuf.h"
#include <string.h>
+#include <stdlib.h>
#include <unistd.h>
+#include <signal.h>
+#include <errno.h>
+#include <sys/types.h>
#include <sys/socket.h>
+#include <sys/wait.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
-static int port = 8888;
+/*** CONFIGURATION ***/
+
+static uns port = 8888;
+static uns dh_bits = 1024;
+static uns max_conn = 10;
+static uns session_timeout;
+static byte *ca_cert_name = "?";
+static byte *server_cert_name = "?";
+static byte *server_key_name = "?";
+static clist access_rules;
+
+struct access_rule {
+ cnode n;
+ struct ip_addrmask addrmask;
+ uns allow_admin;
+ uns plain_text;
+ uns max_conn;
+};
+
+static struct cf_section access_conf = {
+ CF_TYPE(struct access_rule),
+ CF_ITEMS {
+ CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
+ CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
+ CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
+ CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
+ CF_END
+ }
+};
+
+static byte *
+config_init(void)
+{
+ clist_init(&access_rules);
+ return NULL;
+}
+
+static struct cf_section submitd_conf = {
+ CF_INIT(config_init),
+ CF_ITEMS {
+ CF_UNS("Port", &port),
+ CF_UNS("DHBits", &dh_bits),
+ CF_UNS("MaxConn", &max_conn),
+ CF_UNS("SessionTimeout", &session_timeout),
+ CF_STRING("CACert", &ca_cert_name),
+ CF_STRING("ServerCert", &server_cert_name),
+ CF_STRING("ServerKey", &server_key_name),
+ CF_LIST("Access", &access_rules, &access_conf),
+ CF_END
+ }
+};
+
+/*** CONNECTIONS ***/
+
+struct conn {
+ cnode n;
+ u32 ip; // Used by the main loop connection logic
+ pid_t pid;
+ uns id;
+ struct access_rule *rule; // Rule matched by this connection
+ int sk; // Client socket
+ gnutls_session_t tls; // TLS session
+ struct fastbuf rx_fb, tx_fb; // Fastbufs for communication with the client
+};
+
+static clist connections;
+static uns last_conn_id;
+static uns num_conn;
+
+static void
+conn_init(void)
+{
+ clist_init(&connections);
+}
+
+static struct conn *
+conn_new(void)
+{
+ struct conn *c = xmalloc(sizeof(*c));
+ c->id = ++last_conn_id;
+ clist_add_tail(&connections, &c->n);
+ num_conn++;
+ return c;
+}
+
+static void
+conn_free(struct conn *c)
+{
+ clist_remove(&c->n);
+ num_conn--;
+ xfree(c);
+}
+
+static struct access_rule *
+lookup_rule(u32 ip)
+{
+ CLIST_FOR_EACH(struct access_rule *, r, access_rules)
+ if (ip_addrmask_match(&r->addrmask, ip))
+ return r;
+ return NULL;
+}
+
+static uns
+conn_count(u32 ip)
+{
+ uns cnt = 0;
+ CLIST_FOR_EACH(struct conn *, c, connections)
+ if (c->ip == ip)
+ cnt++;
+ return cnt;
+}
+
+/*** TLS ***/
static gnutls_certificate_credentials_t cert_cred;
static gnutls_dh_params_t dh_params;
-#define DH_BITS 1024
#define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
static void
{
int err;
- log(L_INFO, "Initializing TLS");
gnutls_global_init();
err = gnutls_certificate_allocate_credentials(&cert_cred);
TLS_CHECK(gnutls_certificate_allocate_credentials);
- err = gnutls_certificate_set_x509_trust_file(cert_cred, "ca-cert.pem", GNUTLS_X509_FMT_PEM);
+ err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
if (!err)
die("No CA certificate found");
if (err < 0)
die("Unable to load X509 trust file: %s", gnutls_strerror(err));
- err = gnutls_certificate_set_x509_key_file(cert_cred, "server-cert.pem", "server-key.pem", GNUTLS_X509_FMT_PEM);
+ err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
if (err < 0)
die("Unable to load X509 key file: %s", gnutls_strerror(err));
- log(L_INFO, "Setting up DH parameters");
err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
- err = gnutls_dh_params_generate2(dh_params, DH_BITS); TLS_CHECK(gnutls_dh_params_generate2);
+ err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
gnutls_certificate_set_dh_params(cert_cred, dh_params);
}
err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority); // FIXME
gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
- gnutls_dh_set_prime_bits(s, DH_BITS);
+ gnutls_dh_set_prime_bits(s, dh_bits);
gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
return s;
}
proto, kx, cert, comp, cipher, mac);
}
-int main(int argc UNUSED, char **argv UNUSED)
+/*** SOCKET FASTBUFS ***/
+
+static void NONRET
+client_error(char *msg, ...)
{
- tls_init();
+ va_list args;
+ va_start(args, msg);
+ vlog_msg(L_ERROR_R, msg, args);
+ exit(0);
+}
- int sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
- if (sk < 0)
+static int
+sk_fb_refill(struct fastbuf *f)
+{
+ struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
+ int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
+ if (cnt < 0)
+ client_error("Read error: %m");
+ f->bptr = f->buffer;
+ f->bstop = f->buffer + cnt;
+ return cnt;
+}
+
+static void
+sk_fb_spout(struct fastbuf *f)
+{
+ struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
+ int len = f->bptr - f->buffer;
+ if (!len)
+ return;
+ int cnt = careful_write(c->sk, f->buffer, len);
+ if (cnt <= 0)
+ client_error("Write error");
+ f->bptr = f->buffer;
+}
+
+static void
+init_sk_fastbufs(struct conn *c)
+{
+ struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
+
+ rf->buffer = xmalloc(1024);
+ rf->bufend = rf->buffer + 1024;
+ rf->bptr = rf->bstop = rf->buffer;
+ rf->name = "socket";
+ rf->refill = sk_fb_refill;
+
+ tf->buffer = xmalloc(1024);
+ tf->bufend = tf->buffer + 1024;
+ tf->bptr = tf->bstop = tf->buffer;
+ tf->name = rf->name;
+ tf->spout = sk_fb_spout;
+}
+
+static int
+tls_fb_refill(struct fastbuf *f)
+{
+ struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
+ DBG("TLS: Refill");
+ int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
+ DBG("TLS: Received %d bytes", cnt);
+ if (cnt < 0)
+ client_error("TLS read error: %s", gnutls_strerror(cnt));
+ f->bptr = f->buffer;
+ f->bstop = f->buffer + cnt;
+ return cnt;
+}
+
+static void
+tls_fb_spout(struct fastbuf *f)
+{
+ struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
+ int len = f->bptr - f->buffer;
+ if (!len)
+ return;
+ int cnt = gnutls_record_send(c->tls, f->buffer, len);
+ DBG("TLS: Sent %d bytes", cnt);
+ if (cnt <= 0)
+ client_error("TLS write error: %s", gnutls_strerror(cnt));
+ f->bptr = f->buffer;
+}
+
+static void
+init_tls_fastbufs(struct conn *c)
+{
+ struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
+
+ ASSERT(rf->buffer && tf->buffer); // Already set up for the plaintext connection
+ rf->refill = tls_fb_refill;
+ tf->spout = tls_fb_spout;
+}
+
+/*** CLIENT LOOP (runs in a child process) ***/
+
+static void
+sigalrm_handler(int sig UNUSED)
+{
+ // We do not try to do any gracious shutdown to avoid races
+ client_error("Timed out");
+}
+
+static void
+client_loop(struct conn *c)
+{
+ log_pid = c->id;
+ init_sk_fastbufs(c);
+
+ signal(SIGPIPE, SIG_IGN);
+ struct sigaction sa = {
+ .sa_handler = sigalrm_handler
+ };
+ if (sigaction(SIGALRM, &sa, NULL) < 0)
+ die("Cannot setup SIGALRM handler: %m");
+
+ if (c->rule->plain_text)
+ {
+ bputsn(&c->tx_fb, "+OK");
+ bflush(&c->tx_fb);
+ }
+ else
+ {
+ bputsn(&c->tx_fb, "+TLS");
+ bflush(&c->tx_fb);
+ c->tls = tls_new_session(c->sk);
+ int err = gnutls_handshake(c->tls);
+ if (err < 0)
+ client_error("TLS handshake failed: %s", gnutls_strerror(err));
+ tls_log_params(c->tls);
+ const char *cert_err = tls_verify_cert(c->tls);
+ if (cert_err)
+ client_error("TLS certificate failure: %s", cert_err);
+ init_tls_fastbufs(c);
+ }
+
+ for (;;)
+ {
+ alarm(session_timeout);
+ byte buf[1024];
+ if (!bgets(&c->rx_fb, buf, sizeof(buf)))
+ break;
+ bputsn(&c->tx_fb, buf);
+ bflush(&c->tx_fb);
+ }
+
+ if (c->tls)
+ gnutls_bye(c->tls, GNUTLS_SHUT_WR);
+ close(c->sk);
+ if (c->tls)
+ gnutls_deinit(c->tls);
+}
+
+/*** MAIN LOOP ***/
+
+static void
+sigchld_handler(int sig UNUSED)
+{
+ /* We do not need to do anything, just interrupt the accept syscall */
+}
+
+static void
+reap_child(pid_t pid, int status)
+{
+ byte msg[EXIT_STATUS_MSG_SIZE];
+ if (format_exit_status(msg, status))
+ log(L_ERROR, "Child %d %s", (int)pid, msg);
+
+ CLIST_FOR_EACH(struct conn *, c, connections)
+ if (c->pid == pid)
+ {
+ log(L_INFO, "Connection %d closed", c->id);
+ conn_free(c);
+ return;
+ }
+ log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
+}
+
+static int listen_sk;
+
+static void
+sk_init(void)
+{
+ listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+ if (listen_sk < 0)
die("socket: %m");
int one = 1;
- if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
+ if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
die("setsockopt(SO_REUSEADDR): %m");
struct sockaddr_in sa;
sa.sin_family = AF_INET;
sa.sin_addr.s_addr = INADDR_ANY;
sa.sin_port = htons(port);
- if (bind(sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
+ if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
die("Cannot bind to port %d: %m", port);
- if (listen(sk, 1024) < 0)
+ if (listen(listen_sk, 1024) < 0)
die("Cannot listen on port %d: %m", port);
- log(L_INFO, "Listening on port %d", port);
+}
- for (;;)
+static void
+sk_accept(void)
+{
+ struct sockaddr_in sa;
+ int salen = sizeof(sa);
+ int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
+ if (sk < 0)
{
- struct sockaddr_in sa2;
- int sa2len = sizeof(sa2);
- int sk2 = accept(sk, (struct sockaddr *) &sa2, &sa2len);
- if (sk2 < 0)
- die("accept: %m");
-
- byte ipbuf[INET_ADDRSTRLEN];
- inet_ntop(AF_INET, &sa2.sin_addr, ipbuf, sizeof(ipbuf));
- log(L_INFO, "Connection from %s port %d", ipbuf, ntohs(sa2.sin_port));
-
- gnutls_session_t sess = tls_new_session(sk2);
- int err = gnutls_handshake(sess);
- if (err < 0)
- {
- log(L_ERROR_R, "Handshake failed: %s", gnutls_strerror(err));
- goto shut;
- }
- tls_log_params(sess);
+ if (errno == EINTR)
+ return;
+ die("accept: %m");
+ }
- const char *cert_err = tls_verify_cert(sess);
- if (cert_err)
- {
- log(L_ERROR_R, "Certificate verification failed: %s", cert_err);
- goto shut;
- }
-
- for (;;)
- {
- byte buf[1024];
- int ret = gnutls_record_recv(sess, buf, sizeof(buf));
- if (ret < 0)
- {
- log(L_ERROR_R, "Connection broken: %s", gnutls_strerror(ret));
- break;
- }
- if (!ret)
- {
- log(L_INFO, "Client closed connection");
- break;
- }
- log(L_DEBUG, "Received %d bytes", ret);
- gnutls_record_send(sess, buf, ret);
- }
-
- gnutls_bye(sess, GNUTLS_SHUT_WR);
-shut:
- close(sk2);
- gnutls_deinit(sess);
+ byte ipbuf[INET_ADDRSTRLEN];
+ inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
+ u32 addr = ntohl(sa.sin_addr.s_addr);
+ uns port = ntohs(sa.sin_port);
+ char *err;
+
+ struct access_rule *rule = lookup_rule(addr);
+ if (!rule)
+ {
+ err = "Unauthorized";
+ goto reject;
+ }
+
+ if (num_conn >= max_conn)
+ {
+ err = "Too many connections";
+ goto reject;
+ }
+
+ if (conn_count(addr) >= rule->max_conn)
+ {
+ err = "Too many connections from this address";
+ goto reject;
+ }
+
+ struct conn *c = conn_new();
+ log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
+ ipbuf, port, c->id,
+ (rule->plain_text ? "plain-text" : "TLS"),
+ (rule->allow_admin ? "admin" : "user"));
+ c->ip = addr;
+ c->sk = sk;
+ c->rule = rule;
+
+ c->pid = fork();
+ if (c->pid < 0)
+ {
+ conn_free(c);
+ err = "Server overloaded";
+ log(L_ERROR, "Fork failed: %m");
+ goto reject2;
+ }
+ if (!c->pid)
+ {
+ close(listen_sk);
+ client_loop(c);
+ exit(0);
}
+ close(sk);
+ return;
+
+reject:
+ log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
+reject2: ;
+ // Write an error message to the socket, but do not allow it to slow us down
+ struct linger ling = { .l_onoff=0 };
+ if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
+ log(L_ERROR, "Cannot set SO_LINGER: %m");
+ write(sk, "-", 1);
+ write(sk, err, strlen(err));
+ write(sk, "\n", 1);
+ close(sk);
+}
- return 0;
+int main(int argc, char **argv)
+{
+ setproctitle_init(argc, argv);
+ cf_def_file = "config";
+ cf_declare_section("SubmitD", &submitd_conf, 0);
+
+ int opt;
+ if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
+ die("This program has no options");
+
+ log(L_INFO, "Initializing TLS");
+ tls_init();
+
+ conn_init();
+ sk_init();
+ log(L_INFO, "Listening on port %d", port);
+
+ struct sigaction sa = {
+ .sa_handler = sigchld_handler
+ };
+ if (sigaction(SIGCHLD, &sa, NULL) < 0)
+ die("Cannot setup SIGCHLD handler: %m");
+
+ for (;;)
+ {
+ int status;
+ pid_t pid = waitpid(-1, &status, WNOHANG);
+ if (pid > 0)
+ reap_child(pid, status);
+ else
+ sk_accept();
+ }
}