]> mj.ucw.cz Git - eval.git/blob - submit/submitd.c
211cc1b12a4b5ba104dbd2531e61b1c933112efc
[eval.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 /*
8  *  FIXME:
9  *      - competition timeout & per-contestant exceptions
10  */
11
12 #undef LOCAL_DEBUG
13
14 #include "lib/lib.h"
15 #include "lib/conf.h"
16 #include "lib/getopt.h"
17
18 #include <string.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 #include <signal.h>
22 #include <errno.h>
23 #include <sys/types.h>
24 #include <sys/socket.h>
25 #include <sys/wait.h>
26 #include <arpa/inet.h>
27 #include <netinet/in.h>
28
29 #include "submitd.h"
30
31 /*** CONFIGURATION ***/
32
33 static uns port = 8888;
34 static uns dh_bits = 1024;
35 static uns max_conn = 10;
36 static uns session_timeout;
37 static byte *ca_cert_name = "?";
38 static byte *server_cert_name = "?";
39 static byte *server_key_name = "?";
40 static clist access_rules;
41 static uns trace_tls;
42 uns max_request_size;
43 uns max_attachment_size;
44 uns trace_commands;
45
46 static struct cf_section access_conf = {
47   CF_TYPE(struct access_rule),
48   CF_ITEMS {
49     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
50     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
51     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
52     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
53     CF_END
54   }
55 };
56
57 static byte *
58 config_init(void)
59 {
60   clist_init(&access_rules);
61   return NULL;
62 }
63
64 static struct cf_section submitd_conf = {
65   CF_INIT(config_init),
66   CF_ITEMS {
67     CF_UNS("Port", &port),
68     CF_UNS("DHBits", &dh_bits),
69     CF_UNS("MaxConn", &max_conn),
70     CF_UNS("SessionTimeout", &session_timeout),
71     CF_UNS("MaxRequestSize", &max_request_size),
72     CF_UNS("MaxAttachSize", &max_attachment_size),
73     CF_STRING("CACert", &ca_cert_name),
74     CF_STRING("ServerCert", &server_cert_name),
75     CF_STRING("ServerKey", &server_key_name),
76     CF_LIST("Access", &access_rules, &access_conf),
77     CF_UNS("TraceTLS", &trace_tls),
78     CF_UNS("TraceCommands", &trace_commands),
79     CF_END
80   }
81 };
82
83 /*** CONNECTIONS ***/
84
85 static clist connections;
86 static uns last_conn_id;
87 static uns num_conn;
88
89 static void
90 conn_init(void)
91 {
92   clist_init(&connections);
93 }
94
95 static struct conn *
96 conn_new(void)
97 {
98   struct conn *c = xmalloc_zero(sizeof(*c));
99   c->id = ++last_conn_id;
100   clist_add_tail(&connections, &c->n);
101   num_conn++;
102   return c;
103 }
104
105 static void
106 conn_free(struct conn *c)
107 {
108   xfree(c->ip_string);
109   xfree(c->cert_name);
110   clist_remove(&c->n);
111   num_conn--;
112   xfree(c);
113 }
114
115 static struct access_rule *
116 lookup_rule(u32 ip)
117 {
118   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
119     if (ip_addrmask_match(&r->addrmask, ip))
120       return r;
121   return NULL;
122 }
123
124 static uns
125 conn_count(u32 ip)
126 {
127   uns cnt = 0;
128   CLIST_FOR_EACH(struct conn *, c, connections)
129     if (c->ip == ip)
130       cnt++;
131   return cnt;
132 }
133
134 /*** TLS ***/
135
136 static gnutls_certificate_credentials_t cert_cred;
137 static gnutls_dh_params_t dh_params;
138
139 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
140
141 static void
142 tls_init(void)
143 {
144   int err;
145
146   gnutls_global_init();
147   err = gnutls_certificate_allocate_credentials(&cert_cred);
148   TLS_CHECK(gnutls_certificate_allocate_credentials);
149   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
150   if (!err)
151     die("No CA certificate found");
152   if (err < 0)
153     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
154   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
155   if (err < 0)
156     die("Unable to load X509 key file: %s", gnutls_strerror(err));
157
158   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
159   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
160   gnutls_certificate_set_dh_params(cert_cred, dh_params);
161 }
162
163 static gnutls_session_t
164 tls_new_session(int sk)
165 {
166   gnutls_session_t s;
167   int err;
168
169   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
170   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
171   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
172   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
173   gnutls_dh_set_prime_bits(s, dh_bits);
174   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
175   return s;
176 }
177
178 static const char *
179 tls_verify_cert(struct conn *c)
180 {
181   gnutls_session_t s = c->tls;
182   uns status, num_certs;
183   int err;
184   gnutls_x509_crt_t cert;
185   const gnutls_datum_t *certs;
186
187   DBG("Verifying peer certificates");
188   err = gnutls_certificate_verify_peers2(s, &status);
189   if (err < 0)
190     return gnutls_strerror(err);
191   DBG("Verify status: %04x", status);
192   if (status & GNUTLS_CERT_INVALID)
193     return "Certificate is invalid";
194   /* XXX: We do not handle revokation. */
195   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
196     return "Certificate is not X509";
197
198   err = gnutls_x509_crt_init(&cert);
199   if (err < 0)
200     return "gnutls_x509_crt_init() failed";
201   certs = gnutls_certificate_get_peers(s, &num_certs);
202   if (!certs)
203     return "No peer certificate found";
204   DBG("Got certificate list with %d peers", num_certs);
205
206   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
207   if (err < 0)
208     return "Cannot import certificate";
209   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
210
211   byte dn[256];
212   size_t dn_len = sizeof(dn);
213   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
214   if (err < 0)
215     return "Cannot retrieve common name";
216   if (trace_tls)
217     log(L_INFO, "Cert CN: %s", dn);
218   c->cert_name = xstrdup(dn);
219
220   /* Check certificate purpose */
221   byte purp[256];
222   int purpi = 0;
223   do
224     {
225       size_t purp_len = sizeof(purp);
226       uns crit;
227       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
228       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
229         return "Not a client certificate";
230       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
231     }
232   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
233
234   DBG("Verified OK");
235   return NULL;
236 }
237
238 static void
239 tls_log_params(struct conn *c)
240 {
241   if (!trace_tls)
242     return;
243   gnutls_session_t s = c->tls;
244   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
245   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
246   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
247   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
248   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
249   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
250   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
251     proto, kx, cert, comp, cipher, mac);
252 }
253
254 /*** FASTBUFS OVER SOCKETS AND TLS ***/
255
256 void NONRET                             // Fatal protocol violation
257 client_error(char *msg, ...)
258 {
259   va_list args;
260   va_start(args, msg);
261   vlog_msg(L_ERROR_R, msg, args);
262   exit(0);
263 }
264
265 static int
266 sk_fb_refill(struct fastbuf *f)
267 {
268   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
269   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
270   if (cnt < 0)
271     client_error("Read error: %m");
272   f->bptr = f->buffer;
273   f->bstop = f->buffer + cnt;
274   return cnt;
275 }
276
277 static void
278 sk_fb_spout(struct fastbuf *f)
279 {
280   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
281   int len = f->bptr - f->buffer;
282   if (!len)
283     return;
284   int cnt = careful_write(c->sk, f->buffer, len);
285   if (cnt <= 0)
286     client_error("Write error");
287   f->bptr = f->buffer;
288 }
289
290 static void
291 init_sk_fastbufs(struct conn *c)
292 {
293   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
294
295   rf->buffer = xmalloc(1024);
296   rf->bufend = rf->buffer + 1024;
297   rf->bptr = rf->bstop = rf->buffer;
298   rf->name = "socket";
299   rf->refill = sk_fb_refill;
300
301   tf->buffer = xmalloc(1024);
302   tf->bufend = tf->buffer + 1024;
303   tf->bptr = tf->bstop = tf->buffer;
304   tf->name = rf->name;
305   tf->spout = sk_fb_spout;
306 }
307
308 static int
309 tls_fb_refill(struct fastbuf *f)
310 {
311   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
312   DBG("TLS: Refill");
313   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
314   DBG("TLS: Received %d bytes", cnt);
315   if (cnt < 0)
316     client_error("TLS read error: %s", gnutls_strerror(cnt));
317   f->bptr = f->buffer;
318   f->bstop = f->buffer + cnt;
319   return cnt;
320 }
321
322 static void
323 tls_fb_spout(struct fastbuf *f)
324 {
325   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
326   int len = f->bptr - f->buffer;
327   if (!len)
328     return;
329   int cnt = gnutls_record_send(c->tls, f->buffer, len);
330   DBG("TLS: Sent %d bytes", cnt);
331   if (cnt <= 0)
332     client_error("TLS write error: %s", gnutls_strerror(cnt));
333   f->bptr = f->buffer;
334 }
335
336 static void
337 init_tls_fastbufs(struct conn *c)
338 {
339   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
340
341   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
342   rf->refill = tls_fb_refill;
343   tf->spout = tls_fb_spout;
344 }
345
346 /*** CLIENT LOOP (runs in a child process) ***/
347
348 static void
349 sigalrm_handler(int sig UNUSED)
350 {
351   // We do not try to do any gracious shutdown to avoid races
352   client_error("Timed out");
353 }
354
355 static void
356 client_loop(struct conn *c)
357 {
358   setproctitle("submitd: client %s", c->ip_string);
359   log_pid = c->id;
360   init_sk_fastbufs(c);
361
362   signal(SIGPIPE, SIG_IGN);
363   struct sigaction sa = {
364     .sa_handler = sigalrm_handler
365   };
366   if (sigaction(SIGALRM, &sa, NULL) < 0)
367     die("Cannot setup SIGALRM handler: %m");
368
369   if (c->rule->plain_text)
370     {
371       bputsn(&c->tx_fb, "+OK");
372       bflush(&c->tx_fb);
373     }
374   else
375     {
376       bputsn(&c->tx_fb, "+TLS");
377       bflush(&c->tx_fb);
378       c->tls = tls_new_session(c->sk);
379       int err = gnutls_handshake(c->tls);
380       if (err < 0)
381         client_error("TLS handshake failed: %s", gnutls_strerror(err));
382       tls_log_params(c);
383       const char *cert_err = tls_verify_cert(c);
384       if (cert_err)
385         client_error("TLS certificate failure: %s", cert_err);
386       init_tls_fastbufs(c);
387     }
388
389   alarm(session_timeout);
390   if (!process_init(c))
391     log(L_ERROR, "Protocol handshake failed");
392   else
393     {
394       setproctitle("submitd: client %s (%s)", c->ip_string, c->user);
395       for (;;)
396         {
397           alarm(session_timeout);
398           if (!process_command(c))
399             break;
400         }
401     }
402
403   if (c->tls)
404     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
405   close(c->sk);
406   if (c->tls)
407     gnutls_deinit(c->tls);
408 }
409
410 /*** MAIN LOOP ***/
411
412 static void
413 sigchld_handler(int sig UNUSED)
414 {
415   /* We do not need to do anything, just interrupt the accept syscall */
416 }
417
418 static void
419 reap_child(pid_t pid, int status)
420 {
421   byte msg[EXIT_STATUS_MSG_SIZE];
422   if (format_exit_status(msg, status))
423     log(L_ERROR, "Child %d %s", (int)pid, msg);
424
425   CLIST_FOR_EACH(struct conn *, c, connections)
426     if (c->pid == pid)
427       {
428         log(L_INFO, "Connection %d closed", c->id);
429         conn_free(c);
430         return;
431       }
432   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
433 }
434
435 static int listen_sk;
436
437 static void
438 sk_init(void)
439 {
440   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
441   if (listen_sk < 0)
442     die("socket: %m");
443   int one = 1;
444   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
445     die("setsockopt(SO_REUSEADDR): %m");
446
447   struct sockaddr_in sa;
448   bzero(&sa, sizeof(sa));
449   sa.sin_family = AF_INET;
450   sa.sin_addr.s_addr = INADDR_ANY;
451   sa.sin_port = htons(port);
452   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
453     die("Cannot bind to port %d: %m", port);
454   if (listen(listen_sk, 1024) < 0)
455     die("Cannot listen on port %d: %m", port);
456 }
457
458 static void
459 sk_accept(void)
460 {
461   struct sockaddr_in sa;
462   int salen = sizeof(sa);
463   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
464   if (sk < 0)
465     {
466       if (errno == EINTR)
467         return;
468       die("accept: %m");
469     }
470
471   byte ipbuf[INET_ADDRSTRLEN];
472   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
473   u32 addr = ntohl(sa.sin_addr.s_addr);
474   uns port = ntohs(sa.sin_port);
475   char *err;
476
477   struct access_rule *rule = lookup_rule(addr);
478   if (!rule)
479     {
480       err = "Unauthorized";
481       goto reject;
482     }
483
484   if (num_conn >= max_conn)
485     {
486       err = "Too many connections";
487       goto reject;
488     }
489
490   if (conn_count(addr) >= rule->max_conn)
491     {
492       err = "Too many connections from this address";
493       goto reject;
494     }
495
496   struct conn *c = conn_new();
497   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
498         ipbuf, port, c->id,
499         (rule->plain_text ? "plain-text" : "TLS"),
500         (rule->allow_admin ? "admin" : "user"));
501   c->ip = addr;
502   c->ip_string = xstrdup(ipbuf);
503   c->sk = sk;
504   c->rule = rule;
505
506   c->pid = fork();
507   if (c->pid < 0)
508     {
509       conn_free(c);
510       err = "Server overloaded";
511       log(L_ERROR, "Fork failed: %m");
512       goto reject2;
513     }
514   if (!c->pid)
515     {
516       close(listen_sk);
517       client_loop(c);
518       exit(0);
519     }
520   close(sk);
521   return;
522
523 reject:
524   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
525 reject2: ;
526   // Write an error message to the socket, but do not allow it to slow us down
527   struct linger ling = { .l_onoff=0 };
528   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
529     log(L_ERROR, "Cannot set SO_LINGER: %m");
530   write(sk, "-", 1);
531   write(sk, err, strlen(err));
532   write(sk, "\n", 1);
533   close(sk);
534 }
535
536 int main(int argc, char **argv)
537 {
538   setproctitle_init(argc, argv);
539   cf_def_file = "config";
540   cf_declare_section("SubmitD", &submitd_conf, 0);
541
542   int opt;
543   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
544     die("This program has no options");
545
546   log(L_INFO, "Initializing TLS");
547   tls_init();
548
549   conn_init();
550   sk_init();
551   log(L_INFO, "Listening on port %d", port);
552
553   struct sigaction sa = {
554     .sa_handler = sigchld_handler
555   };
556   if (sigaction(SIGCHLD, &sa, NULL) < 0)
557     die("Cannot setup SIGCHLD handler: %m");
558
559   for (;;)
560     {
561       setproctitle("submitd: %d connections", num_conn);
562       int status;
563       pid_t pid = waitpid(-1, &status, WNOHANG);
564       if (pid > 0)
565         reap_child(pid, status);
566       else
567         sk_accept();
568     }
569 }