]> mj.ucw.cz Git - eval.git/blob - submit/submitd.c
a58ae69255bb4ff914c5bab27a5853ca4f3f7bed
[eval.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 /*
8  *  FIXME:
9  *      - competition timeout & per-contestant exceptions
10  */
11
12 #undef LOCAL_DEBUG
13
14 #include "lib/lib.h"
15 #include "lib/conf.h"
16 #include "lib/getopt.h"
17 #include "lib/ipaccess.h"
18 #include "lib/fastbuf.h"
19
20 #include <string.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23 #include <signal.h>
24 #include <errno.h>
25 #include <sys/types.h>
26 #include <sys/socket.h>
27 #include <sys/wait.h>
28 #include <arpa/inet.h>
29 #include <netinet/in.h>
30 #include <gnutls/gnutls.h>
31 #include <gnutls/x509.h>
32
33 /*** CONFIGURATION ***/
34
35 static uns port = 8888;
36 static uns dh_bits = 1024;
37 static uns max_conn = 10;
38 static uns session_timeout;
39 static byte *ca_cert_name = "?";
40 static byte *server_cert_name = "?";
41 static byte *server_key_name = "?";
42 static clist access_rules;
43
44 struct access_rule {
45   cnode n;
46   struct ip_addrmask addrmask;
47   uns allow_admin;
48   uns plain_text;
49   uns max_conn;
50 };
51
52 static struct cf_section access_conf = {
53   CF_TYPE(struct access_rule),
54   CF_ITEMS {
55     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
56     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
57     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
58     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
59     CF_END
60   }
61 };
62
63 static byte *
64 config_init(void)
65 {
66   clist_init(&access_rules);
67   return NULL;
68 }
69
70 static struct cf_section submitd_conf = {
71   CF_INIT(config_init),
72   CF_ITEMS {
73     CF_UNS("Port", &port),
74     CF_UNS("DHBits", &dh_bits),
75     CF_UNS("MaxConn", &max_conn),
76     CF_UNS("SessionTimeout", &session_timeout),
77     CF_STRING("CACert", &ca_cert_name),
78     CF_STRING("ServerCert", &server_cert_name),
79     CF_STRING("ServerKey", &server_key_name),
80     CF_LIST("Access", &access_rules, &access_conf),
81     CF_END
82   }
83 };
84
85 /*** CONNECTIONS ***/
86
87 struct conn {
88   cnode n;
89   u32 ip;                               // Used by the main loop connection logic
90   pid_t pid;
91   uns id;
92   struct access_rule *rule;             // Rule matched by this connection
93   int sk;                               // Client socket
94   gnutls_session_t tls;                 // TLS session
95   struct fastbuf rx_fb, tx_fb;          // Fastbufs for communication with the client
96 };
97
98 static clist connections;
99 static uns last_conn_id;
100 static uns num_conn;
101
102 static void
103 conn_init(void)
104 {
105   clist_init(&connections);
106 }
107
108 static struct conn *
109 conn_new(void)
110 {
111   struct conn *c = xmalloc(sizeof(*c));
112   c->id = ++last_conn_id;
113   clist_add_tail(&connections, &c->n);
114   num_conn++;
115   return c;
116 }
117
118 static void
119 conn_free(struct conn *c)
120 {
121   clist_remove(&c->n);
122   num_conn--;
123   xfree(c);
124 }
125
126 static struct access_rule *
127 lookup_rule(u32 ip)
128 {
129   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
130     if (ip_addrmask_match(&r->addrmask, ip))
131       return r;
132   return NULL;
133 }
134
135 static uns
136 conn_count(u32 ip)
137 {
138   uns cnt = 0;
139   CLIST_FOR_EACH(struct conn *, c, connections)
140     if (c->ip == ip)
141       cnt++;
142   return cnt;
143 }
144
145 /*** TLS ***/
146
147 static gnutls_certificate_credentials_t cert_cred;
148 static gnutls_dh_params_t dh_params;
149
150 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
151
152 static void
153 tls_init(void)
154 {
155   int err;
156
157   gnutls_global_init();
158   err = gnutls_certificate_allocate_credentials(&cert_cred);
159   TLS_CHECK(gnutls_certificate_allocate_credentials);
160   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
161   if (!err)
162     die("No CA certificate found");
163   if (err < 0)
164     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
165   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
166   if (err < 0)
167     die("Unable to load X509 key file: %s", gnutls_strerror(err));
168
169   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
170   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
171   gnutls_certificate_set_dh_params(cert_cred, dh_params);
172 }
173
174 static gnutls_session_t
175 tls_new_session(int sk)
176 {
177   gnutls_session_t s;
178   int err;
179
180   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
181   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
182   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
183   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
184   gnutls_dh_set_prime_bits(s, dh_bits);
185   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
186   return s;
187 }
188
189 static const char *
190 tls_verify_cert(gnutls_session_t s)
191 {
192   uns status, num_certs;
193   int err;
194   gnutls_x509_crt_t cert;
195   const gnutls_datum_t *certs;
196
197   DBG("Verifying peer certificates");
198   err = gnutls_certificate_verify_peers2(s, &status);
199   if (err < 0)
200     return gnutls_strerror(err);
201   DBG("Verify status: %04x", status);
202   if (status & GNUTLS_CERT_INVALID)
203     return "Certificate is invalid";
204   /* XXX: We do not handle revokation. */
205   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
206     return "Certificate is not X509";
207
208   err = gnutls_x509_crt_init(&cert);
209   if (err < 0)
210     return "gnutls_x509_crt_init() failed";
211   certs = gnutls_certificate_get_peers(s, &num_certs);
212   if (!certs)
213     return "No peer certificate found";
214   DBG("Got certificate list with %d peers", num_certs);
215
216   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
217   if (err < 0)
218     return "Cannot import certificate";
219   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
220
221   byte dn[256];
222   size_t dn_len = sizeof(dn);
223   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
224   if (err < 0)
225     return "Cannot retrieve common name";
226   log(L_INFO, "Cert CN: %s", dn);
227
228   /* Check certificate purpose */
229   byte purp[256];
230   int purpi = 0;
231   do
232     {
233       size_t purp_len = sizeof(purp);
234       uns crit;
235       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
236       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
237         return "Not a client certificate";
238       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
239     }
240   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
241
242   DBG("Verified OK");
243   return NULL;
244 }
245
246 static void
247 tls_log_params(gnutls_session_t s)
248 {
249   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
250   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
251   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
252   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
253   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
254   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
255   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
256     proto, kx, cert, comp, cipher, mac);
257 }
258
259 /*** SOCKET FASTBUFS ***/
260
261 static void NONRET
262 client_error(char *msg, ...)
263 {
264   va_list args;
265   va_start(args, msg);
266   vlog_msg(L_ERROR_R, msg, args);
267   exit(0);
268 }
269
270 static int
271 sk_fb_refill(struct fastbuf *f)
272 {
273   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
274   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
275   if (cnt < 0)
276     client_error("Read error: %m");
277   f->bptr = f->buffer;
278   f->bstop = f->buffer + cnt;
279   return cnt;
280 }
281
282 static void
283 sk_fb_spout(struct fastbuf *f)
284 {
285   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
286   int len = f->bptr - f->buffer;
287   if (!len)
288     return;
289   int cnt = careful_write(c->sk, f->buffer, len);
290   if (cnt <= 0)
291     client_error("Write error");
292   f->bptr = f->buffer;
293 }
294
295 static void
296 init_sk_fastbufs(struct conn *c)
297 {
298   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
299
300   rf->buffer = xmalloc(1024);
301   rf->bufend = rf->buffer + 1024;
302   rf->bptr = rf->bstop = rf->buffer;
303   rf->name = "socket";
304   rf->refill = sk_fb_refill;
305
306   tf->buffer = xmalloc(1024);
307   tf->bufend = tf->buffer + 1024;
308   tf->bptr = tf->bstop = tf->buffer;
309   tf->name = rf->name;
310   tf->spout = sk_fb_spout;
311 }
312
313 static int
314 tls_fb_refill(struct fastbuf *f)
315 {
316   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
317   DBG("TLS: Refill");
318   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
319   DBG("TLS: Received %d bytes", cnt);
320   if (cnt < 0)
321     client_error("TLS read error: %s", gnutls_strerror(cnt));
322   f->bptr = f->buffer;
323   f->bstop = f->buffer + cnt;
324   return cnt;
325 }
326
327 static void
328 tls_fb_spout(struct fastbuf *f)
329 {
330   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
331   int len = f->bptr - f->buffer;
332   if (!len)
333     return;
334   int cnt = gnutls_record_send(c->tls, f->buffer, len);
335   DBG("TLS: Sent %d bytes", cnt);
336   if (cnt <= 0)
337     client_error("TLS write error: %s", gnutls_strerror(cnt));
338   f->bptr = f->buffer;
339 }
340
341 static void
342 init_tls_fastbufs(struct conn *c)
343 {
344   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
345
346   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
347   rf->refill = tls_fb_refill;
348   tf->spout = tls_fb_spout;
349 }
350
351 /*** CLIENT LOOP (runs in a child process) ***/
352
353 static void
354 sigalrm_handler(int sig UNUSED)
355 {
356   // We do not try to do any gracious shutdown to avoid races
357   client_error("Timed out");
358 }
359
360 static void
361 client_loop(struct conn *c)
362 {
363   log_pid = c->id;
364   init_sk_fastbufs(c);
365
366   signal(SIGPIPE, SIG_IGN);
367   struct sigaction sa = {
368     .sa_handler = sigalrm_handler
369   };
370   if (sigaction(SIGALRM, &sa, NULL) < 0)
371     die("Cannot setup SIGALRM handler: %m");
372
373   if (c->rule->plain_text)
374     {
375       bputsn(&c->tx_fb, "+OK");
376       bflush(&c->tx_fb);
377     }
378   else
379     {
380       bputsn(&c->tx_fb, "+TLS");
381       bflush(&c->tx_fb);
382       c->tls = tls_new_session(c->sk);
383       int err = gnutls_handshake(c->tls);
384       if (err < 0)
385         client_error("TLS handshake failed: %s", gnutls_strerror(err));
386       tls_log_params(c->tls);
387       const char *cert_err = tls_verify_cert(c->tls);
388       if (cert_err)
389         client_error("TLS certificate failure: %s", cert_err);
390       init_tls_fastbufs(c);
391     }
392
393   for (;;)
394     {
395       alarm(session_timeout);
396       byte buf[1024];
397       if (!bgets(&c->rx_fb, buf, sizeof(buf)))
398         break;
399       bputsn(&c->tx_fb, buf);
400       bflush(&c->tx_fb);
401     }
402
403   if (c->tls)
404     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
405   close(c->sk);
406   if (c->tls)
407     gnutls_deinit(c->tls);
408 }
409
410 /*** MAIN LOOP ***/
411
412 static void
413 sigchld_handler(int sig UNUSED)
414 {
415   /* We do not need to do anything, just interrupt the accept syscall */
416 }
417
418 static void
419 reap_child(pid_t pid, int status)
420 {
421   byte msg[EXIT_STATUS_MSG_SIZE];
422   if (format_exit_status(msg, status))
423     log(L_ERROR, "Child %d %s", (int)pid, msg);
424
425   CLIST_FOR_EACH(struct conn *, c, connections)
426     if (c->pid == pid)
427       {
428         log(L_INFO, "Connection %d closed", c->id);
429         conn_free(c);
430         return;
431       }
432   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
433 }
434
435 static int listen_sk;
436
437 static void
438 sk_init(void)
439 {
440   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
441   if (listen_sk < 0)
442     die("socket: %m");
443   int one = 1;
444   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
445     die("setsockopt(SO_REUSEADDR): %m");
446
447   struct sockaddr_in sa;
448   bzero(&sa, sizeof(sa));
449   sa.sin_family = AF_INET;
450   sa.sin_addr.s_addr = INADDR_ANY;
451   sa.sin_port = htons(port);
452   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
453     die("Cannot bind to port %d: %m", port);
454   if (listen(listen_sk, 1024) < 0)
455     die("Cannot listen on port %d: %m", port);
456 }
457
458 static void
459 sk_accept(void)
460 {
461   struct sockaddr_in sa;
462   int salen = sizeof(sa);
463   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
464   if (sk < 0)
465     {
466       if (errno == EINTR)
467         return;
468       die("accept: %m");
469     }
470
471   byte ipbuf[INET_ADDRSTRLEN];
472   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
473   u32 addr = ntohl(sa.sin_addr.s_addr);
474   uns port = ntohs(sa.sin_port);
475   char *err;
476
477   struct access_rule *rule = lookup_rule(addr);
478   if (!rule)
479     {
480       err = "Unauthorized";
481       goto reject;
482     }
483
484   if (num_conn >= max_conn)
485     {
486       err = "Too many connections";
487       goto reject;
488     }
489
490   if (conn_count(addr) >= rule->max_conn)
491     {
492       err = "Too many connections from this address";
493       goto reject;
494     }
495
496   struct conn *c = conn_new();
497   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
498         ipbuf, port, c->id,
499         (rule->plain_text ? "plain-text" : "TLS"),
500         (rule->allow_admin ? "admin" : "user"));
501   c->ip = addr;
502   c->sk = sk;
503   c->rule = rule;
504
505   c->pid = fork();
506   if (c->pid < 0)
507     {
508       conn_free(c);
509       err = "Server overloaded";
510       log(L_ERROR, "Fork failed: %m");
511       goto reject2;
512     }
513   if (!c->pid)
514     {
515       close(listen_sk);
516       client_loop(c);
517       exit(0);
518     }
519   close(sk);
520   return;
521
522 reject:
523   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
524 reject2: ;
525   // Write an error message to the socket, but do not allow it to slow us down
526   struct linger ling = { .l_onoff=0 };
527   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
528     log(L_ERROR, "Cannot set SO_LINGER: %m");
529   write(sk, "-", 1);
530   write(sk, err, strlen(err));
531   write(sk, "\n", 1);
532   close(sk);
533 }
534
535 int main(int argc, char **argv)
536 {
537   setproctitle_init(argc, argv);
538   cf_def_file = "config";
539   cf_declare_section("SubmitD", &submitd_conf, 0);
540
541   int opt;
542   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
543     die("This program has no options");
544
545   log(L_INFO, "Initializing TLS");
546   tls_init();
547
548   conn_init();
549   sk_init();
550   log(L_INFO, "Listening on port %d", port);
551
552   struct sigaction sa = {
553     .sa_handler = sigchld_handler
554   };
555   if (sigaction(SIGCHLD, &sa, NULL) < 0)
556     die("Cannot setup SIGCHLD handler: %m");
557
558   for (;;)
559     {
560       int status;
561       pid_t pid = waitpid(-1, &status, WNOHANG);
562       if (pid > 0)
563         reap_child(pid, status);
564       else
565         sk_accept();
566     }
567 }