]> mj.ucw.cz Git - eval.git/blob - submit/submitd.c
b3c39280614d0e12ae8399b40ae0dd83dbace72e
[eval.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 /*
8  *  FIXME:
9  *      - competition timeout & per-contestant exceptions
10  *      - open-data problems
11  */
12
13 #undef LOCAL_DEBUG
14
15 #include "lib/lib.h"
16 #include "lib/conf.h"
17 #include "lib/getopt.h"
18
19 #include <string.h>
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <signal.h>
23 #include <errno.h>
24 #include <sys/types.h>
25 #include <sys/socket.h>
26 #include <sys/wait.h>
27 #include <arpa/inet.h>
28 #include <netinet/in.h>
29
30 #include "submitd.h"
31
32 /*** CONFIGURATION ***/
33
34 static uns port = 8888;
35 static uns dh_bits = 1024;
36 static uns max_conn = 10;
37 static uns session_timeout;
38 static byte *ca_cert_name = "?";
39 static byte *server_cert_name = "?";
40 static byte *server_key_name = "?";
41 static clist access_rules;
42 static uns trace_tls;
43 uns max_request_size;
44 uns max_attachment_size;
45 uns trace_commands;
46
47 static struct cf_section access_conf = {
48   CF_TYPE(struct access_rule),
49   CF_ITEMS {
50     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
51     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
52     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
53     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
54     CF_END
55   }
56 };
57
58 static struct cf_section submitd_conf = {
59   CF_ITEMS {
60     CF_UNS("Port", &port),
61     CF_UNS("DHBits", &dh_bits),
62     CF_UNS("MaxConn", &max_conn),
63     CF_UNS("SessionTimeout", &session_timeout),
64     CF_UNS("MaxRequestSize", &max_request_size),
65     CF_UNS("MaxAttachSize", &max_attachment_size),
66     CF_STRING("CACert", &ca_cert_name),
67     CF_STRING("ServerCert", &server_cert_name),
68     CF_STRING("ServerKey", &server_key_name),
69     CF_LIST("Access", &access_rules, &access_conf),
70     CF_UNS("TraceTLS", &trace_tls),
71     CF_UNS("TraceCommands", &trace_commands),
72     CF_END
73   }
74 };
75
76 /*** CONNECTIONS ***/
77
78 static clist connections;
79 static uns last_conn_id;
80 static uns num_conn;
81
82 static void
83 conn_init(void)
84 {
85   clist_init(&connections);
86 }
87
88 static struct conn *
89 conn_new(void)
90 {
91   struct conn *c = xmalloc_zero(sizeof(*c));
92   c->id = ++last_conn_id;
93   clist_add_tail(&connections, &c->n);
94   num_conn++;
95   return c;
96 }
97
98 static void
99 conn_free(struct conn *c)
100 {
101   xfree(c->ip_string);
102   xfree(c->cert_name);
103   clist_remove(&c->n);
104   num_conn--;
105   xfree(c);
106 }
107
108 static struct access_rule *
109 lookup_rule(u32 ip)
110 {
111   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
112     if (ip_addrmask_match(&r->addrmask, ip))
113       return r;
114   return NULL;
115 }
116
117 static uns
118 conn_count(u32 ip)
119 {
120   uns cnt = 0;
121   CLIST_FOR_EACH(struct conn *, c, connections)
122     if (c->ip == ip)
123       cnt++;
124   return cnt;
125 }
126
127 /*** TLS ***/
128
129 static gnutls_certificate_credentials_t cert_cred;
130 static gnutls_dh_params_t dh_params;
131
132 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
133
134 static void
135 tls_init(void)
136 {
137   int err;
138
139   gnutls_global_init();
140   err = gnutls_certificate_allocate_credentials(&cert_cred);
141   TLS_CHECK(gnutls_certificate_allocate_credentials);
142   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
143   if (!err)
144     die("No CA certificate found");
145   if (err < 0)
146     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
147   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
148   if (err < 0)
149     die("Unable to load X509 key file: %s", gnutls_strerror(err));
150
151   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
152   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
153   gnutls_certificate_set_dh_params(cert_cred, dh_params);
154 }
155
156 static gnutls_session_t
157 tls_new_session(int sk)
158 {
159   gnutls_session_t s;
160   int err;
161
162   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
163   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
164   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
165   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
166   gnutls_dh_set_prime_bits(s, dh_bits);
167   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
168   return s;
169 }
170
171 static const char *
172 tls_verify_cert(struct conn *c)
173 {
174   gnutls_session_t s = c->tls;
175   uns status, num_certs;
176   int err;
177   gnutls_x509_crt_t cert;
178   const gnutls_datum_t *certs;
179
180   DBG("Verifying peer certificates");
181   err = gnutls_certificate_verify_peers2(s, &status);
182   if (err < 0)
183     return gnutls_strerror(err);
184   DBG("Verify status: %04x", status);
185   if (status & GNUTLS_CERT_INVALID)
186     return "Certificate is invalid";
187   /* XXX: We do not handle revokation. */
188   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
189     return "Certificate is not X509";
190
191   err = gnutls_x509_crt_init(&cert);
192   if (err < 0)
193     return "gnutls_x509_crt_init() failed";
194   certs = gnutls_certificate_get_peers(s, &num_certs);
195   if (!certs)
196     return "No peer certificate found";
197   DBG("Got certificate list with %d peers", num_certs);
198
199   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
200   if (err < 0)
201     return "Cannot import certificate";
202   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
203
204   byte dn[256];
205   size_t dn_len = sizeof(dn);
206   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
207   if (err < 0)
208     return "Cannot retrieve common name";
209   if (trace_tls)
210     log(L_INFO, "Cert CN: %s", dn);
211   c->cert_name = xstrdup(dn);
212
213   /* Check certificate purpose */
214   byte purp[256];
215   int purpi = 0;
216   do
217     {
218       size_t purp_len = sizeof(purp);
219       uns crit;
220       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
221       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
222         return "Not a client certificate";
223       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
224     }
225   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
226
227   DBG("Verified OK");
228   return NULL;
229 }
230
231 static void
232 tls_log_params(struct conn *c)
233 {
234   if (!trace_tls)
235     return;
236   gnutls_session_t s = c->tls;
237   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
238   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
239   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
240   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
241   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
242   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
243   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
244     proto, kx, cert, comp, cipher, mac);
245 }
246
247 /*** FASTBUFS OVER SOCKETS AND TLS ***/
248
249 void NONRET                             // Fatal protocol violation
250 client_error(char *msg, ...)
251 {
252   va_list args;
253   va_start(args, msg);
254   vlog_msg(L_ERROR_R, msg, args);
255   exit(0);
256 }
257
258 static int
259 sk_fb_refill(struct fastbuf *f)
260 {
261   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
262   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
263   if (cnt < 0)
264     client_error("Read error: %m");
265   f->bptr = f->buffer;
266   f->bstop = f->buffer + cnt;
267   return cnt;
268 }
269
270 static void
271 sk_fb_spout(struct fastbuf *f)
272 {
273   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
274   int len = f->bptr - f->buffer;
275   if (!len)
276     return;
277   int cnt = careful_write(c->sk, f->buffer, len);
278   if (cnt <= 0)
279     client_error("Write error");
280   f->bptr = f->buffer;
281 }
282
283 static void
284 init_sk_fastbufs(struct conn *c)
285 {
286   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
287
288   rf->buffer = xmalloc(1024);
289   rf->bufend = rf->buffer + 1024;
290   rf->bptr = rf->bstop = rf->buffer;
291   rf->name = "socket";
292   rf->refill = sk_fb_refill;
293
294   tf->buffer = xmalloc(1024);
295   tf->bufend = tf->buffer + 1024;
296   tf->bptr = tf->bstop = tf->buffer;
297   tf->name = rf->name;
298   tf->spout = sk_fb_spout;
299 }
300
301 static int
302 tls_fb_refill(struct fastbuf *f)
303 {
304   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
305   DBG("TLS: Refill");
306   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
307   DBG("TLS: Received %d bytes", cnt);
308   if (cnt < 0)
309     client_error("TLS read error: %s", gnutls_strerror(cnt));
310   f->bptr = f->buffer;
311   f->bstop = f->buffer + cnt;
312   return cnt;
313 }
314
315 static void
316 tls_fb_spout(struct fastbuf *f)
317 {
318   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
319   int len = f->bptr - f->buffer;
320   if (!len)
321     return;
322   int cnt = gnutls_record_send(c->tls, f->buffer, len);
323   DBG("TLS: Sent %d bytes", cnt);
324   if (cnt <= 0)
325     client_error("TLS write error: %s", gnutls_strerror(cnt));
326   f->bptr = f->buffer;
327 }
328
329 static void
330 init_tls_fastbufs(struct conn *c)
331 {
332   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
333
334   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
335   rf->refill = tls_fb_refill;
336   tf->spout = tls_fb_spout;
337 }
338
339 /*** CLIENT LOOP (runs in a child process) ***/
340
341 static void
342 sigalrm_handler(int sig UNUSED)
343 {
344   // We do not try to do any gracious shutdown to avoid races
345   client_error("Timed out");
346 }
347
348 static void
349 client_loop(struct conn *c)
350 {
351   setproctitle("submitd: client %s", c->ip_string);
352   log_pid = c->id;
353   init_sk_fastbufs(c);
354
355   signal(SIGPIPE, SIG_IGN);
356   struct sigaction sa = {
357     .sa_handler = sigalrm_handler
358   };
359   if (sigaction(SIGALRM, &sa, NULL) < 0)
360     die("Cannot setup SIGALRM handler: %m");
361
362   if (c->rule->plain_text)
363     {
364       bputsn(&c->tx_fb, "+OK");
365       bflush(&c->tx_fb);
366     }
367   else
368     {
369       bputsn(&c->tx_fb, "+TLS");
370       bflush(&c->tx_fb);
371       c->tls = tls_new_session(c->sk);
372       int err = gnutls_handshake(c->tls);
373       if (err < 0)
374         client_error("TLS handshake failed: %s", gnutls_strerror(err));
375       tls_log_params(c);
376       const char *cert_err = tls_verify_cert(c);
377       if (cert_err)
378         client_error("TLS certificate failure: %s", cert_err);
379       init_tls_fastbufs(c);
380     }
381
382   alarm(session_timeout);
383   if (!process_init(c))
384     log(L_ERROR, "Protocol handshake failed");
385   else
386     {
387       setproctitle("submitd: client %s (%s)", c->ip_string, c->user);
388       for (;;)
389         {
390           alarm(session_timeout);
391           if (!process_command(c))
392             break;
393         }
394     }
395
396   if (c->tls)
397     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
398   close(c->sk);
399   if (c->tls)
400     gnutls_deinit(c->tls);
401 }
402
403 /*** MAIN LOOP ***/
404
405 static void
406 sigchld_handler(int sig UNUSED)
407 {
408   /* We do not need to do anything, just interrupt the accept syscall */
409 }
410
411 static void
412 reap_child(pid_t pid, int status)
413 {
414   byte msg[EXIT_STATUS_MSG_SIZE];
415   if (format_exit_status(msg, status))
416     log(L_ERROR, "Child %d %s", (int)pid, msg);
417
418   CLIST_FOR_EACH(struct conn *, c, connections)
419     if (c->pid == pid)
420       {
421         log(L_INFO, "Connection %d closed", c->id);
422         conn_free(c);
423         return;
424       }
425   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
426 }
427
428 static int listen_sk;
429
430 static void
431 sk_init(void)
432 {
433   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
434   if (listen_sk < 0)
435     die("socket: %m");
436   int one = 1;
437   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
438     die("setsockopt(SO_REUSEADDR): %m");
439
440   struct sockaddr_in sa;
441   bzero(&sa, sizeof(sa));
442   sa.sin_family = AF_INET;
443   sa.sin_addr.s_addr = INADDR_ANY;
444   sa.sin_port = htons(port);
445   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
446     die("Cannot bind to port %d: %m", port);
447   if (listen(listen_sk, 1024) < 0)
448     die("Cannot listen on port %d: %m", port);
449 }
450
451 static void
452 sk_accept(void)
453 {
454   struct sockaddr_in sa;
455   int salen = sizeof(sa);
456   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
457   if (sk < 0)
458     {
459       if (errno == EINTR)
460         return;
461       die("accept: %m");
462     }
463
464   byte ipbuf[INET_ADDRSTRLEN];
465   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
466   u32 addr = ntohl(sa.sin_addr.s_addr);
467   uns port = ntohs(sa.sin_port);
468   char *err;
469
470   struct access_rule *rule = lookup_rule(addr);
471   if (!rule)
472     {
473       err = "Unauthorized";
474       goto reject;
475     }
476
477   if (num_conn >= max_conn)
478     {
479       err = "Too many connections";
480       goto reject;
481     }
482
483   if (conn_count(addr) >= rule->max_conn)
484     {
485       err = "Too many connections from this address";
486       goto reject;
487     }
488
489   struct conn *c = conn_new();
490   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
491         ipbuf, port, c->id,
492         (rule->plain_text ? "plain-text" : "TLS"),
493         (rule->allow_admin ? "admin" : "user"));
494   c->ip = addr;
495   c->ip_string = xstrdup(ipbuf);
496   c->sk = sk;
497   c->rule = rule;
498
499   c->pid = fork();
500   if (c->pid < 0)
501     {
502       conn_free(c);
503       err = "Server overloaded";
504       log(L_ERROR, "Fork failed: %m");
505       goto reject2;
506     }
507   if (!c->pid)
508     {
509       close(listen_sk);
510       client_loop(c);
511       exit(0);
512     }
513   close(sk);
514   return;
515
516 reject:
517   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
518 reject2: ;
519   // Write an error message to the socket, but do not allow it to slow us down
520   struct linger ling = { .l_onoff=0 };
521   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
522     log(L_ERROR, "Cannot set SO_LINGER: %m");
523   write(sk, "-", 1);
524   write(sk, err, strlen(err));
525   write(sk, "\n", 1);
526   close(sk);
527 }
528
529 int main(int argc, char **argv)
530 {
531   setproctitle_init(argc, argv);
532   cf_def_file = "config";
533   cf_declare_section("SubmitD", &submitd_conf, 0);
534   cf_declare_section("Tasks", &tasks_conf, 0);
535
536   int opt;
537   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
538     die("This program has no options");
539
540   log(L_INFO, "Initializing TLS");
541   tls_init();
542
543   conn_init();
544   sk_init();
545   log(L_INFO, "Listening on port %d", port);
546
547   struct sigaction sa = {
548     .sa_handler = sigchld_handler
549   };
550   if (sigaction(SIGCHLD, &sa, NULL) < 0)
551     die("Cannot setup SIGCHLD handler: %m");
552
553   for (;;)
554     {
555       setproctitle("submitd: %d connections", num_conn);
556       int status;
557       pid_t pid = waitpid(-1, &status, WNOHANG);
558       if (pid > 0)
559         reap_child(pid, status);
560       else
561         sk_accept();
562     }
563 }