From 4d6c0022e700db3c7584ad5b58bbda69090ecf0d Mon Sep 17 00:00:00 2001 From: Martin Mares Date: Mon, 4 Jun 2007 00:58:52 +0200 Subject: [PATCH] Implemented the connection logic. --- submit/Makefile | 3 + submit/connect.c | 18 ++ submit/submitd.c | 488 ++++++++++++++++++++++++++++++++++++++++------- submit/test.pl | 35 ++++ 4 files changed, 478 insertions(+), 66 deletions(-) create mode 100755 submit/test.pl diff --git a/submit/Makefile b/submit/Makefile index ccac81a..847f15e 100644 --- a/submit/Makefile +++ b/submit/Makefile @@ -4,6 +4,9 @@ TLSLF:=$(shell libgnutls-config --libs) CFLAGS=-O2 -Iinclude -g -Wall -W -Wno-parentheses -Wstrict-prototypes -Wmissing-prototypes -Wundef -Wredundant-decls -std=gnu99 $(TLSCF) LDFLAGS=$(TLSLF) +CC=gcc-4.1.1 +CFLAGS+=-Wno-pointer-sign -Wdisabled-optimization -Wno-missing-field-initializers + all: submitd connect submitd: submitd.o lib/libucw.a lib/libsh.a diff --git a/submit/connect.c b/submit/connect.c index 7f5709e..d0c83ff 100644 --- a/submit/connect.c +++ b/submit/connect.c @@ -121,6 +121,24 @@ int main(int argc UNUSED, char **argv UNUSED) if (connect(sk, (struct sockaddr *) &sa, sizeof(sa)) < 0) die("Cannot connect: %m"); + log(L_INFO, "Waiting for initial message"); + byte msg[256]; + int i = 0; + do + { + if (i >= (int)sizeof(msg)) + die("Response too long"); + int c = read(sk, msg+i, sizeof(msg)-i); + if (c <= 0) + die("Connection broken"); + i += c; + } + while (msg[i-1] != '\n'); + msg[i-1] = 0; + if (msg[0] != '+') + die("%s", msg); + log(L_INFO, "%s", msg); + gnutls_session_t s; gnutls_init(&s, GNUTLS_CLIENT); gnutls_set_default_priority(s); diff --git a/submit/submitd.c b/submit/submitd.c index 09bf581..a58ae69 100644 --- a/submit/submitd.c +++ b/submit/submitd.c @@ -4,24 +4,149 @@ * (c) 2007 Martin Mares */ -#define LOCAL_DEBUG +/* + * FIXME: + * - competition timeout & per-contestant exceptions + */ + +#undef LOCAL_DEBUG #include "lib/lib.h" +#include "lib/conf.h" +#include "lib/getopt.h" +#include "lib/ipaccess.h" +#include "lib/fastbuf.h" #include +#include #include +#include +#include +#include #include +#include #include #include #include #include -static int port = 8888; +/*** CONFIGURATION ***/ + +static uns port = 8888; +static uns dh_bits = 1024; +static uns max_conn = 10; +static uns session_timeout; +static byte *ca_cert_name = "?"; +static byte *server_cert_name = "?"; +static byte *server_key_name = "?"; +static clist access_rules; + +struct access_rule { + cnode n; + struct ip_addrmask addrmask; + uns allow_admin; + uns plain_text; + uns max_conn; +}; + +static struct cf_section access_conf = { + CF_TYPE(struct access_rule), + CF_ITEMS { + CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type), + CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)), + CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)), + CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)), + CF_END + } +}; + +static byte * +config_init(void) +{ + clist_init(&access_rules); + return NULL; +} + +static struct cf_section submitd_conf = { + CF_INIT(config_init), + CF_ITEMS { + CF_UNS("Port", &port), + CF_UNS("DHBits", &dh_bits), + CF_UNS("MaxConn", &max_conn), + CF_UNS("SessionTimeout", &session_timeout), + CF_STRING("CACert", &ca_cert_name), + CF_STRING("ServerCert", &server_cert_name), + CF_STRING("ServerKey", &server_key_name), + CF_LIST("Access", &access_rules, &access_conf), + CF_END + } +}; + +/*** CONNECTIONS ***/ + +struct conn { + cnode n; + u32 ip; // Used by the main loop connection logic + pid_t pid; + uns id; + struct access_rule *rule; // Rule matched by this connection + int sk; // Client socket + gnutls_session_t tls; // TLS session + struct fastbuf rx_fb, tx_fb; // Fastbufs for communication with the client +}; + +static clist connections; +static uns last_conn_id; +static uns num_conn; + +static void +conn_init(void) +{ + clist_init(&connections); +} + +static struct conn * +conn_new(void) +{ + struct conn *c = xmalloc(sizeof(*c)); + c->id = ++last_conn_id; + clist_add_tail(&connections, &c->n); + num_conn++; + return c; +} + +static void +conn_free(struct conn *c) +{ + clist_remove(&c->n); + num_conn--; + xfree(c); +} + +static struct access_rule * +lookup_rule(u32 ip) +{ + CLIST_FOR_EACH(struct access_rule *, r, access_rules) + if (ip_addrmask_match(&r->addrmask, ip)) + return r; + return NULL; +} + +static uns +conn_count(u32 ip) +{ + uns cnt = 0; + CLIST_FOR_EACH(struct conn *, c, connections) + if (c->ip == ip) + cnt++; + return cnt; +} + +/*** TLS ***/ static gnutls_certificate_credentials_t cert_cred; static gnutls_dh_params_t dh_params; -#define DH_BITS 1024 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err)) static void @@ -29,22 +154,20 @@ tls_init(void) { int err; - log(L_INFO, "Initializing TLS"); gnutls_global_init(); err = gnutls_certificate_allocate_credentials(&cert_cred); TLS_CHECK(gnutls_certificate_allocate_credentials); - err = gnutls_certificate_set_x509_trust_file(cert_cred, "ca-cert.pem", GNUTLS_X509_FMT_PEM); + err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM); if (!err) die("No CA certificate found"); if (err < 0) die("Unable to load X509 trust file: %s", gnutls_strerror(err)); - err = gnutls_certificate_set_x509_key_file(cert_cred, "server-cert.pem", "server-key.pem", GNUTLS_X509_FMT_PEM); + err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM); if (err < 0) die("Unable to load X509 key file: %s", gnutls_strerror(err)); - log(L_INFO, "Setting up DH parameters"); err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init); - err = gnutls_dh_params_generate2(dh_params, DH_BITS); TLS_CHECK(gnutls_dh_params_generate2); + err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2); gnutls_certificate_set_dh_params(cert_cred, dh_params); } @@ -58,7 +181,7 @@ tls_new_session(int sk) err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority); // FIXME gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred); gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST); - gnutls_dh_set_prime_bits(s, DH_BITS); + gnutls_dh_set_prime_bits(s, dh_bits); gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk); return s; } @@ -133,15 +256,192 @@ tls_log_params(gnutls_session_t s) proto, kx, cert, comp, cipher, mac); } -int main(int argc UNUSED, char **argv UNUSED) +/*** SOCKET FASTBUFS ***/ + +static void NONRET +client_error(char *msg, ...) { - tls_init(); + va_list args; + va_start(args, msg); + vlog_msg(L_ERROR_R, msg, args); + exit(0); +} - int sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); - if (sk < 0) +static int +sk_fb_refill(struct fastbuf *f) +{ + struct conn *c = SKIP_BACK(struct conn, rx_fb, f); + int cnt = read(c->sk, f->buffer, f->bufend - f->buffer); + if (cnt < 0) + client_error("Read error: %m"); + f->bptr = f->buffer; + f->bstop = f->buffer + cnt; + return cnt; +} + +static void +sk_fb_spout(struct fastbuf *f) +{ + struct conn *c = SKIP_BACK(struct conn, tx_fb, f); + int len = f->bptr - f->buffer; + if (!len) + return; + int cnt = careful_write(c->sk, f->buffer, len); + if (cnt <= 0) + client_error("Write error"); + f->bptr = f->buffer; +} + +static void +init_sk_fastbufs(struct conn *c) +{ + struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb; + + rf->buffer = xmalloc(1024); + rf->bufend = rf->buffer + 1024; + rf->bptr = rf->bstop = rf->buffer; + rf->name = "socket"; + rf->refill = sk_fb_refill; + + tf->buffer = xmalloc(1024); + tf->bufend = tf->buffer + 1024; + tf->bptr = tf->bstop = tf->buffer; + tf->name = rf->name; + tf->spout = sk_fb_spout; +} + +static int +tls_fb_refill(struct fastbuf *f) +{ + struct conn *c = SKIP_BACK(struct conn, rx_fb, f); + DBG("TLS: Refill"); + int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer); + DBG("TLS: Received %d bytes", cnt); + if (cnt < 0) + client_error("TLS read error: %s", gnutls_strerror(cnt)); + f->bptr = f->buffer; + f->bstop = f->buffer + cnt; + return cnt; +} + +static void +tls_fb_spout(struct fastbuf *f) +{ + struct conn *c = SKIP_BACK(struct conn, tx_fb, f); + int len = f->bptr - f->buffer; + if (!len) + return; + int cnt = gnutls_record_send(c->tls, f->buffer, len); + DBG("TLS: Sent %d bytes", cnt); + if (cnt <= 0) + client_error("TLS write error: %s", gnutls_strerror(cnt)); + f->bptr = f->buffer; +} + +static void +init_tls_fastbufs(struct conn *c) +{ + struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb; + + ASSERT(rf->buffer && tf->buffer); // Already set up for the plaintext connection + rf->refill = tls_fb_refill; + tf->spout = tls_fb_spout; +} + +/*** CLIENT LOOP (runs in a child process) ***/ + +static void +sigalrm_handler(int sig UNUSED) +{ + // We do not try to do any gracious shutdown to avoid races + client_error("Timed out"); +} + +static void +client_loop(struct conn *c) +{ + log_pid = c->id; + init_sk_fastbufs(c); + + signal(SIGPIPE, SIG_IGN); + struct sigaction sa = { + .sa_handler = sigalrm_handler + }; + if (sigaction(SIGALRM, &sa, NULL) < 0) + die("Cannot setup SIGALRM handler: %m"); + + if (c->rule->plain_text) + { + bputsn(&c->tx_fb, "+OK"); + bflush(&c->tx_fb); + } + else + { + bputsn(&c->tx_fb, "+TLS"); + bflush(&c->tx_fb); + c->tls = tls_new_session(c->sk); + int err = gnutls_handshake(c->tls); + if (err < 0) + client_error("TLS handshake failed: %s", gnutls_strerror(err)); + tls_log_params(c->tls); + const char *cert_err = tls_verify_cert(c->tls); + if (cert_err) + client_error("TLS certificate failure: %s", cert_err); + init_tls_fastbufs(c); + } + + for (;;) + { + alarm(session_timeout); + byte buf[1024]; + if (!bgets(&c->rx_fb, buf, sizeof(buf))) + break; + bputsn(&c->tx_fb, buf); + bflush(&c->tx_fb); + } + + if (c->tls) + gnutls_bye(c->tls, GNUTLS_SHUT_WR); + close(c->sk); + if (c->tls) + gnutls_deinit(c->tls); +} + +/*** MAIN LOOP ***/ + +static void +sigchld_handler(int sig UNUSED) +{ + /* We do not need to do anything, just interrupt the accept syscall */ +} + +static void +reap_child(pid_t pid, int status) +{ + byte msg[EXIT_STATUS_MSG_SIZE]; + if (format_exit_status(msg, status)) + log(L_ERROR, "Child %d %s", (int)pid, msg); + + CLIST_FOR_EACH(struct conn *, c, connections) + if (c->pid == pid) + { + log(L_INFO, "Connection %d closed", c->id); + conn_free(c); + return; + } + log(L_ERROR, "Cannot find connection for child process %d", (int)pid); +} + +static int listen_sk; + +static void +sk_init(void) +{ + listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP); + if (listen_sk < 0) die("socket: %m"); int one = 1; - if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) + if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0) die("setsockopt(SO_REUSEADDR): %m"); struct sockaddr_in sa; @@ -149,63 +449,119 @@ int main(int argc UNUSED, char **argv UNUSED) sa.sin_family = AF_INET; sa.sin_addr.s_addr = INADDR_ANY; sa.sin_port = htons(port); - if (bind(sk, (struct sockaddr *) &sa, sizeof(sa)) < 0) + if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0) die("Cannot bind to port %d: %m", port); - if (listen(sk, 1024) < 0) + if (listen(listen_sk, 1024) < 0) die("Cannot listen on port %d: %m", port); - log(L_INFO, "Listening on port %d", port); +} - for (;;) +static void +sk_accept(void) +{ + struct sockaddr_in sa; + int salen = sizeof(sa); + int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen); + if (sk < 0) { - struct sockaddr_in sa2; - int sa2len = sizeof(sa2); - int sk2 = accept(sk, (struct sockaddr *) &sa2, &sa2len); - if (sk2 < 0) - die("accept: %m"); - - byte ipbuf[INET_ADDRSTRLEN]; - inet_ntop(AF_INET, &sa2.sin_addr, ipbuf, sizeof(ipbuf)); - log(L_INFO, "Connection from %s port %d", ipbuf, ntohs(sa2.sin_port)); - - gnutls_session_t sess = tls_new_session(sk2); - int err = gnutls_handshake(sess); - if (err < 0) - { - log(L_ERROR_R, "Handshake failed: %s", gnutls_strerror(err)); - goto shut; - } - tls_log_params(sess); + if (errno == EINTR) + return; + die("accept: %m"); + } - const char *cert_err = tls_verify_cert(sess); - if (cert_err) - { - log(L_ERROR_R, "Certificate verification failed: %s", cert_err); - goto shut; - } - - for (;;) - { - byte buf[1024]; - int ret = gnutls_record_recv(sess, buf, sizeof(buf)); - if (ret < 0) - { - log(L_ERROR_R, "Connection broken: %s", gnutls_strerror(ret)); - break; - } - if (!ret) - { - log(L_INFO, "Client closed connection"); - break; - } - log(L_DEBUG, "Received %d bytes", ret); - gnutls_record_send(sess, buf, ret); - } - - gnutls_bye(sess, GNUTLS_SHUT_WR); -shut: - close(sk2); - gnutls_deinit(sess); + byte ipbuf[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf)); + u32 addr = ntohl(sa.sin_addr.s_addr); + uns port = ntohs(sa.sin_port); + char *err; + + struct access_rule *rule = lookup_rule(addr); + if (!rule) + { + err = "Unauthorized"; + goto reject; + } + + if (num_conn >= max_conn) + { + err = "Too many connections"; + goto reject; + } + + if (conn_count(addr) >= rule->max_conn) + { + err = "Too many connections from this address"; + goto reject; + } + + struct conn *c = conn_new(); + log(L_INFO, "Connection from %s:%d (id %d, %s, %s)", + ipbuf, port, c->id, + (rule->plain_text ? "plain-text" : "TLS"), + (rule->allow_admin ? "admin" : "user")); + c->ip = addr; + c->sk = sk; + c->rule = rule; + + c->pid = fork(); + if (c->pid < 0) + { + conn_free(c); + err = "Server overloaded"; + log(L_ERROR, "Fork failed: %m"); + goto reject2; + } + if (!c->pid) + { + close(listen_sk); + client_loop(c); + exit(0); } + close(sk); + return; + +reject: + log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err); +reject2: ; + // Write an error message to the socket, but do not allow it to slow us down + struct linger ling = { .l_onoff=0 }; + if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0) + log(L_ERROR, "Cannot set SO_LINGER: %m"); + write(sk, "-", 1); + write(sk, err, strlen(err)); + write(sk, "\n", 1); + close(sk); +} - return 0; +int main(int argc, char **argv) +{ + setproctitle_init(argc, argv); + cf_def_file = "config"; + cf_declare_section("SubmitD", &submitd_conf, 0); + + int opt; + if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0) + die("This program has no options"); + + log(L_INFO, "Initializing TLS"); + tls_init(); + + conn_init(); + sk_init(); + log(L_INFO, "Listening on port %d", port); + + struct sigaction sa = { + .sa_handler = sigchld_handler + }; + if (sigaction(SIGCHLD, &sa, NULL) < 0) + die("Cannot setup SIGCHLD handler: %m"); + + for (;;) + { + int status; + pid_t pid = waitpid(-1, &status, WNOHANG); + if (pid > 0) + reap_child(pid, status); + else + sk_accept(); + } } diff --git a/submit/test.pl b/submit/test.pl new file mode 100755 index 0000000..9948a70 --- /dev/null +++ b/submit/test.pl @@ -0,0 +1,35 @@ +#!/usr/bin/perl + +use strict; +use warnings; + +use IO::Socket::INET; +use IO::Socket::SSL; # qw(debug3); + +my $sk = new IO::Socket::INET( +# PeerAddr => "nikam.ms.mff.cuni.cz:443", + PeerAddr => "localhost:8888", + Proto => "tcp", +) or die "Cannot connect to server: $!"; + +my $z = <$sk>; +defined $z or die "Server failed to send welcome message\n"; +$z =~ /^\+/ or die "Server reported error: $z"; +print $z; + +if ($z =~ /TLS/) { + $sk = IO::Socket::SSL->start_SSL( + $sk, + SSL_version => 'TLSv1', + SSL_use_cert => 1, + SSL_key_file => "client-key.pem", + SSL_cert_file => "client-cert.pem", + SSL_ca_file => "ca-cert.pem", + SSL_verify_mode => 3, + ) or die "Cannot establish TLS connection: " . IO::Socket::SSL::errstr() . "\n"; +} + +print $sk "Hello, world!\n"; +my $y = <$sk>; +print $y; +close $sk; -- 2.39.2