]> mj.ucw.cz Git - moe.git/commitdiff
Implemented the connection logic.
authorMartin Mares <mj@ucw.cz>
Sun, 3 Jun 2007 22:58:52 +0000 (00:58 +0200)
committerMartin Mares <mj@ucw.cz>
Sun, 3 Jun 2007 22:58:52 +0000 (00:58 +0200)
submit/Makefile
submit/connect.c
submit/submitd.c
submit/test.pl [new file with mode: 0755]

index ccac81a243c93f75b0e1d10c0ee2419f7ba88e3e..847f15e007212a93e4a068d6099dd279b26b600b 100644 (file)
@@ -4,6 +4,9 @@ TLSLF:=$(shell libgnutls-config --libs)
 CFLAGS=-O2 -Iinclude -g -Wall -W -Wno-parentheses -Wstrict-prototypes -Wmissing-prototypes -Wundef -Wredundant-decls -std=gnu99 $(TLSCF)
 LDFLAGS=$(TLSLF)
 
+CC=gcc-4.1.1
+CFLAGS+=-Wno-pointer-sign -Wdisabled-optimization -Wno-missing-field-initializers
+
 all: submitd connect
 
 submitd: submitd.o lib/libucw.a lib/libsh.a
index 7f5709ea2616b7d5ee1992143418e8ee953ae85f..d0c83ffe47dac808d3a1a2c17d6e10c1223acd7c 100644 (file)
@@ -121,6 +121,24 @@ int main(int argc UNUSED, char **argv UNUSED)
   if (connect(sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
     die("Cannot connect: %m");
 
+  log(L_INFO, "Waiting for initial message");
+  byte msg[256];
+  int i = 0;
+  do
+    {
+      if (i >= (int)sizeof(msg))
+       die("Response too long");
+      int c = read(sk, msg+i, sizeof(msg)-i);
+      if (c <= 0)
+       die("Connection broken");
+      i += c;
+    }
+  while (msg[i-1] != '\n');
+  msg[i-1] = 0;
+  if (msg[0] != '+')
+    die("%s", msg);
+  log(L_INFO, "%s", msg);
+
   gnutls_session_t s;
   gnutls_init(&s, GNUTLS_CLIENT);
   gnutls_set_default_priority(s);
index 09bf58163a812c806c05728b21e2bb3f1422ffd2..a58ae69255bb4ff914c5bab27a5853ca4f3f7bed 100644 (file)
  *  (c) 2007 Martin Mares <mj@ucw.cz>
  */
 
-#define LOCAL_DEBUG
+/*
+ *  FIXME:
+ *     - competition timeout & per-contestant exceptions
+ */
+
+#undef LOCAL_DEBUG
 
 #include "lib/lib.h"
+#include "lib/conf.h"
+#include "lib/getopt.h"
+#include "lib/ipaccess.h"
+#include "lib/fastbuf.h"
 
 #include <string.h>
+#include <stdlib.h>
 #include <unistd.h>
+#include <signal.h>
+#include <errno.h>
+#include <sys/types.h>
 #include <sys/socket.h>
+#include <sys/wait.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
 #include <gnutls/gnutls.h>
 #include <gnutls/x509.h>
 
-static int port = 8888;
+/*** CONFIGURATION ***/
+
+static uns port = 8888;
+static uns dh_bits = 1024;
+static uns max_conn = 10;
+static uns session_timeout;
+static byte *ca_cert_name = "?";
+static byte *server_cert_name = "?";
+static byte *server_key_name = "?";
+static clist access_rules;
+
+struct access_rule {
+  cnode n;
+  struct ip_addrmask addrmask;
+  uns allow_admin;
+  uns plain_text;
+  uns max_conn;
+};
+
+static struct cf_section access_conf = {
+  CF_TYPE(struct access_rule),
+  CF_ITEMS {
+    CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
+    CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
+    CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
+    CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
+    CF_END
+  }
+};
+
+static byte *
+config_init(void)
+{
+  clist_init(&access_rules);
+  return NULL;
+}
+
+static struct cf_section submitd_conf = {
+  CF_INIT(config_init),
+  CF_ITEMS {
+    CF_UNS("Port", &port),
+    CF_UNS("DHBits", &dh_bits),
+    CF_UNS("MaxConn", &max_conn),
+    CF_UNS("SessionTimeout", &session_timeout),
+    CF_STRING("CACert", &ca_cert_name),
+    CF_STRING("ServerCert", &server_cert_name),
+    CF_STRING("ServerKey", &server_key_name),
+    CF_LIST("Access", &access_rules, &access_conf),
+    CF_END
+  }
+};
+
+/*** CONNECTIONS ***/
+
+struct conn {
+  cnode n;
+  u32 ip;                              // Used by the main loop connection logic
+  pid_t pid;
+  uns id;
+  struct access_rule *rule;            // Rule matched by this connection
+  int sk;                              // Client socket
+  gnutls_session_t tls;                        // TLS session
+  struct fastbuf rx_fb, tx_fb;         // Fastbufs for communication with the client
+};
+
+static clist connections;
+static uns last_conn_id;
+static uns num_conn;
+
+static void
+conn_init(void)
+{
+  clist_init(&connections);
+}
+
+static struct conn *
+conn_new(void)
+{
+  struct conn *c = xmalloc(sizeof(*c));
+  c->id = ++last_conn_id;
+  clist_add_tail(&connections, &c->n);
+  num_conn++;
+  return c;
+}
+
+static void
+conn_free(struct conn *c)
+{
+  clist_remove(&c->n);
+  num_conn--;
+  xfree(c);
+}
+
+static struct access_rule *
+lookup_rule(u32 ip)
+{
+  CLIST_FOR_EACH(struct access_rule *, r, access_rules)
+    if (ip_addrmask_match(&r->addrmask, ip))
+      return r;
+  return NULL;
+}
+
+static uns
+conn_count(u32 ip)
+{
+  uns cnt = 0;
+  CLIST_FOR_EACH(struct conn *, c, connections)
+    if (c->ip == ip)
+      cnt++;
+  return cnt;
+}
+
+/*** TLS ***/
 
 static gnutls_certificate_credentials_t cert_cred;
 static gnutls_dh_params_t dh_params;
 
-#define DH_BITS 1024
 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
 
 static void
@@ -29,22 +154,20 @@ tls_init(void)
 {
   int err;
 
-  log(L_INFO, "Initializing TLS");
   gnutls_global_init();
   err = gnutls_certificate_allocate_credentials(&cert_cred);
   TLS_CHECK(gnutls_certificate_allocate_credentials);
-  err = gnutls_certificate_set_x509_trust_file(cert_cred, "ca-cert.pem", GNUTLS_X509_FMT_PEM);
+  err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
   if (!err)
     die("No CA certificate found");
   if (err < 0)
     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
-  err = gnutls_certificate_set_x509_key_file(cert_cred, "server-cert.pem", "server-key.pem", GNUTLS_X509_FMT_PEM);
+  err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
   if (err < 0)
     die("Unable to load X509 key file: %s", gnutls_strerror(err));
 
-  log(L_INFO, "Setting up DH parameters");
   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
-  err = gnutls_dh_params_generate2(dh_params, DH_BITS); TLS_CHECK(gnutls_dh_params_generate2);
+  err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
   gnutls_certificate_set_dh_params(cert_cred, dh_params);
 }
 
@@ -58,7 +181,7 @@ tls_new_session(int sk)
   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                        // FIXME
   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
-  gnutls_dh_set_prime_bits(s, DH_BITS);
+  gnutls_dh_set_prime_bits(s, dh_bits);
   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
   return s;
 }
@@ -133,15 +256,192 @@ tls_log_params(gnutls_session_t s)
     proto, kx, cert, comp, cipher, mac);
 }
 
-int main(int argc UNUSED, char **argv UNUSED)
+/*** SOCKET FASTBUFS ***/
+
+static void NONRET
+client_error(char *msg, ...)
 {
-  tls_init();
+  va_list args;
+  va_start(args, msg);
+  vlog_msg(L_ERROR_R, msg, args);
+  exit(0);
+}
 
-  int sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
-  if (sk < 0)
+static int
+sk_fb_refill(struct fastbuf *f)
+{
+  struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
+  int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
+  if (cnt < 0)
+    client_error("Read error: %m");
+  f->bptr = f->buffer;
+  f->bstop = f->buffer + cnt;
+  return cnt;
+}
+
+static void
+sk_fb_spout(struct fastbuf *f)
+{
+  struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
+  int len = f->bptr - f->buffer;
+  if (!len)
+    return;
+  int cnt = careful_write(c->sk, f->buffer, len);
+  if (cnt <= 0)
+    client_error("Write error");
+  f->bptr = f->buffer;
+}
+
+static void
+init_sk_fastbufs(struct conn *c)
+{
+  struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
+
+  rf->buffer = xmalloc(1024);
+  rf->bufend = rf->buffer + 1024;
+  rf->bptr = rf->bstop = rf->buffer;
+  rf->name = "socket";
+  rf->refill = sk_fb_refill;
+
+  tf->buffer = xmalloc(1024);
+  tf->bufend = tf->buffer + 1024;
+  tf->bptr = tf->bstop = tf->buffer;
+  tf->name = rf->name;
+  tf->spout = sk_fb_spout;
+}
+
+static int
+tls_fb_refill(struct fastbuf *f)
+{
+  struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
+  DBG("TLS: Refill");
+  int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
+  DBG("TLS: Received %d bytes", cnt);
+  if (cnt < 0)
+    client_error("TLS read error: %s", gnutls_strerror(cnt));
+  f->bptr = f->buffer;
+  f->bstop = f->buffer + cnt;
+  return cnt;
+}
+
+static void
+tls_fb_spout(struct fastbuf *f)
+{
+  struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
+  int len = f->bptr - f->buffer;
+  if (!len)
+    return;
+  int cnt = gnutls_record_send(c->tls, f->buffer, len);
+  DBG("TLS: Sent %d bytes", cnt);
+  if (cnt <= 0)
+    client_error("TLS write error: %s", gnutls_strerror(cnt));
+  f->bptr = f->buffer;
+}
+
+static void
+init_tls_fastbufs(struct conn *c)
+{
+  struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
+
+  ASSERT(rf->buffer && tf->buffer);    // Already set up for the plaintext connection
+  rf->refill = tls_fb_refill;
+  tf->spout = tls_fb_spout;
+}
+
+/*** CLIENT LOOP (runs in a child process) ***/
+
+static void
+sigalrm_handler(int sig UNUSED)
+{
+  // We do not try to do any gracious shutdown to avoid races
+  client_error("Timed out");
+}
+
+static void
+client_loop(struct conn *c)
+{
+  log_pid = c->id;
+  init_sk_fastbufs(c);
+
+  signal(SIGPIPE, SIG_IGN);
+  struct sigaction sa = {
+    .sa_handler = sigalrm_handler
+  };
+  if (sigaction(SIGALRM, &sa, NULL) < 0)
+    die("Cannot setup SIGALRM handler: %m");
+
+  if (c->rule->plain_text)
+    {
+      bputsn(&c->tx_fb, "+OK");
+      bflush(&c->tx_fb);
+    }
+  else
+    {
+      bputsn(&c->tx_fb, "+TLS");
+      bflush(&c->tx_fb);
+      c->tls = tls_new_session(c->sk);
+      int err = gnutls_handshake(c->tls);
+      if (err < 0)
+       client_error("TLS handshake failed: %s", gnutls_strerror(err));
+      tls_log_params(c->tls);
+      const char *cert_err = tls_verify_cert(c->tls);
+      if (cert_err)
+       client_error("TLS certificate failure: %s", cert_err);
+      init_tls_fastbufs(c);
+    }
+
+  for (;;)
+    {
+      alarm(session_timeout);
+      byte buf[1024];
+      if (!bgets(&c->rx_fb, buf, sizeof(buf)))
+       break;
+      bputsn(&c->tx_fb, buf);
+      bflush(&c->tx_fb);
+    }
+
+  if (c->tls)
+    gnutls_bye(c->tls, GNUTLS_SHUT_WR);
+  close(c->sk);
+  if (c->tls)
+    gnutls_deinit(c->tls);
+}
+
+/*** MAIN LOOP ***/
+
+static void
+sigchld_handler(int sig UNUSED)
+{
+  /* We do not need to do anything, just interrupt the accept syscall */
+}
+
+static void
+reap_child(pid_t pid, int status)
+{
+  byte msg[EXIT_STATUS_MSG_SIZE];
+  if (format_exit_status(msg, status))
+    log(L_ERROR, "Child %d %s", (int)pid, msg);
+
+  CLIST_FOR_EACH(struct conn *, c, connections)
+    if (c->pid == pid)
+      {
+       log(L_INFO, "Connection %d closed", c->id);
+       conn_free(c);
+       return;
+      }
+  log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
+}
+
+static int listen_sk;
+
+static void
+sk_init(void)
+{
+  listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
+  if (listen_sk < 0)
     die("socket: %m");
   int one = 1;
-  if (setsockopt(sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
+  if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
     die("setsockopt(SO_REUSEADDR): %m");
 
   struct sockaddr_in sa;
@@ -149,63 +449,119 @@ int main(int argc UNUSED, char **argv UNUSED)
   sa.sin_family = AF_INET;
   sa.sin_addr.s_addr = INADDR_ANY;
   sa.sin_port = htons(port);
-  if (bind(sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
+  if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
     die("Cannot bind to port %d: %m", port);
-  if (listen(sk, 1024) < 0)
+  if (listen(listen_sk, 1024) < 0)
     die("Cannot listen on port %d: %m", port);
-  log(L_INFO, "Listening on port %d", port);
+}
 
-  for (;;)
+static void
+sk_accept(void)
+{
+  struct sockaddr_in sa;
+  int salen = sizeof(sa);
+  int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
+  if (sk < 0)
     {
-      struct sockaddr_in sa2;
-      int sa2len = sizeof(sa2);
-      int sk2 = accept(sk, (struct sockaddr *) &sa2, &sa2len);
-      if (sk2 < 0)
-       die("accept: %m");
-
-      byte ipbuf[INET_ADDRSTRLEN];
-      inet_ntop(AF_INET, &sa2.sin_addr, ipbuf, sizeof(ipbuf));
-      log(L_INFO, "Connection from %s port %d", ipbuf, ntohs(sa2.sin_port));
-
-      gnutls_session_t sess = tls_new_session(sk2);
-      int err = gnutls_handshake(sess);
-      if (err < 0)
-       {
-         log(L_ERROR_R, "Handshake failed: %s", gnutls_strerror(err));
-         goto shut;
-       }
-      tls_log_params(sess);
+      if (errno == EINTR)
+       return;
+      die("accept: %m");
+    }
 
-      const char *cert_err = tls_verify_cert(sess);
-      if (cert_err)
-       {
-         log(L_ERROR_R, "Certificate verification failed: %s", cert_err);
-         goto shut;
-       }
-
-      for (;;)
-       {
-         byte buf[1024];
-         int ret = gnutls_record_recv(sess, buf, sizeof(buf));
-         if (ret < 0)
-           {
-             log(L_ERROR_R, "Connection broken: %s", gnutls_strerror(ret));
-             break;
-           }
-         if (!ret)
-           {
-             log(L_INFO, "Client closed connection");
-             break;
-           }
-         log(L_DEBUG, "Received %d bytes", ret);
-         gnutls_record_send(sess, buf, ret);
-       }
-
-      gnutls_bye(sess, GNUTLS_SHUT_WR);
-shut:
-      close(sk2);
-      gnutls_deinit(sess);
+  byte ipbuf[INET_ADDRSTRLEN];
+  inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
+  u32 addr = ntohl(sa.sin_addr.s_addr);
+  uns port = ntohs(sa.sin_port);
+  char *err;
+
+  struct access_rule *rule = lookup_rule(addr);
+  if (!rule)
+    {
+      err = "Unauthorized";
+      goto reject;
+    }
+
+  if (num_conn >= max_conn)
+    {
+      err = "Too many connections";
+      goto reject;
+    }
+
+  if (conn_count(addr) >= rule->max_conn)
+    {
+      err = "Too many connections from this address";
+      goto reject;
+    }
+
+  struct conn *c = conn_new();
+  log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
+       ipbuf, port, c->id,
+       (rule->plain_text ? "plain-text" : "TLS"),
+       (rule->allow_admin ? "admin" : "user"));
+  c->ip = addr;
+  c->sk = sk;
+  c->rule = rule;
+
+  c->pid = fork();
+  if (c->pid < 0)
+    {
+      conn_free(c);
+      err = "Server overloaded";
+      log(L_ERROR, "Fork failed: %m");
+      goto reject2;
+    }
+  if (!c->pid)
+    {
+      close(listen_sk);
+      client_loop(c);
+      exit(0);
     }
+  close(sk);
+  return;
+
+reject:
+  log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
+reject2: ;
+  // Write an error message to the socket, but do not allow it to slow us down
+  struct linger ling = { .l_onoff=0 };
+  if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
+    log(L_ERROR, "Cannot set SO_LINGER: %m");
+  write(sk, "-", 1);
+  write(sk, err, strlen(err));
+  write(sk, "\n", 1);
+  close(sk);
+}
 
-  return 0;
+int main(int argc, char **argv)
+{
+  setproctitle_init(argc, argv);
+  cf_def_file = "config";
+  cf_declare_section("SubmitD", &submitd_conf, 0);
+
+  int opt;
+  if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
+    die("This program has no options");
+
+  log(L_INFO, "Initializing TLS");
+  tls_init();
+
+  conn_init();
+  sk_init();
+  log(L_INFO, "Listening on port %d", port);
+
+  struct sigaction sa = {
+    .sa_handler = sigchld_handler
+  };
+  if (sigaction(SIGCHLD, &sa, NULL) < 0)
+    die("Cannot setup SIGCHLD handler: %m");
+
+  for (;;)
+    {
+      int status;
+      pid_t pid = waitpid(-1, &status, WNOHANG);
+      if (pid > 0)
+       reap_child(pid, status);
+      else
+       sk_accept();
+    }
 }
diff --git a/submit/test.pl b/submit/test.pl
new file mode 100755 (executable)
index 0000000..9948a70
--- /dev/null
@@ -0,0 +1,35 @@
+#!/usr/bin/perl
+
+use strict;
+use warnings;
+
+use IO::Socket::INET;
+use IO::Socket::SSL; # qw(debug3);
+
+my $sk = new IO::Socket::INET(
+#      PeerAddr => "nikam.ms.mff.cuni.cz:443",
+       PeerAddr => "localhost:8888",
+       Proto => "tcp",
+) or die "Cannot connect to server: $!";
+
+my $z = <$sk>;
+defined $z or die "Server failed to send welcome message\n";
+$z =~ /^\+/ or die "Server reported error: $z";
+print $z;
+
+if ($z =~ /TLS/) {
+       $sk = IO::Socket::SSL->start_SSL(
+               $sk,
+               SSL_version => 'TLSv1',
+               SSL_use_cert => 1,
+               SSL_key_file => "client-key.pem",
+               SSL_cert_file => "client-cert.pem",
+               SSL_ca_file => "ca-cert.pem",
+               SSL_verify_mode => 3,
+       ) or die "Cannot establish TLS connection: " . IO::Socket::SSL::errstr() . "\n";
+}
+
+print $sk "Hello, world!\n";
+my $y = <$sk>;
+print $y;
+close $sk;