]> mj.ucw.cz Git - moe.git/blob - submit/submitd.c
22be143b6aad69f635c0dec767fc7306fa5493d1
[moe.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 /*
8  *  FIXME:
9  *      - competition timeout & per-contestant exceptions
10  */
11
12 #undef LOCAL_DEBUG
13
14 #include "lib/lib.h"
15 #include "lib/conf.h"
16 #include "lib/getopt.h"
17
18 #include <string.h>
19 #include <stdlib.h>
20 #include <unistd.h>
21 #include <signal.h>
22 #include <errno.h>
23 #include <sys/types.h>
24 #include <sys/socket.h>
25 #include <sys/wait.h>
26 #include <arpa/inet.h>
27 #include <netinet/in.h>
28
29 #include "submitd.h"
30
31 /*** CONFIGURATION ***/
32
33 static uns port = 8888;
34 static uns dh_bits = 1024;
35 static uns max_conn = 10;
36 static uns session_timeout;
37 static byte *ca_cert_name = "?";
38 static byte *server_cert_name = "?";
39 static byte *server_key_name = "?";
40 static clist access_rules;
41 static uns trace_tls;
42 uns max_request_size;
43 uns max_attachment_size;
44 uns trace_commands;
45
46 static struct cf_section access_conf = {
47   CF_TYPE(struct access_rule),
48   CF_ITEMS {
49     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
50     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
51     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
52     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
53     CF_END
54   }
55 };
56
57 static byte *
58 config_init(void)
59 {
60   clist_init(&access_rules);
61   return NULL;
62 }
63
64 static struct cf_section submitd_conf = {
65   CF_INIT(config_init),
66   CF_ITEMS {
67     CF_UNS("Port", &port),
68     CF_UNS("DHBits", &dh_bits),
69     CF_UNS("MaxConn", &max_conn),
70     CF_UNS("SessionTimeout", &session_timeout),
71     CF_UNS("MaxRequestSize", &max_request_size),
72     CF_UNS("MaxAttachSize", &max_attachment_size),
73     CF_STRING("CACert", &ca_cert_name),
74     CF_STRING("ServerCert", &server_cert_name),
75     CF_STRING("ServerKey", &server_key_name),
76     CF_LIST("Access", &access_rules, &access_conf),
77     CF_UNS("TraceTLS", &trace_tls),
78     CF_UNS("TraceCommands", &trace_commands),
79     CF_END
80   }
81 };
82
83 /*** CONNECTIONS ***/
84
85 static clist connections;
86 static uns last_conn_id;
87 static uns num_conn;
88
89 static void
90 conn_init(void)
91 {
92   clist_init(&connections);
93 }
94
95 static struct conn *
96 conn_new(void)
97 {
98   struct conn *c = xmalloc_zero(sizeof(*c));
99   c->id = ++last_conn_id;
100   clist_add_tail(&connections, &c->n);
101   num_conn++;
102   return c;
103 }
104
105 static void
106 conn_free(struct conn *c)
107 {
108   clist_remove(&c->n);
109   num_conn--;
110   xfree(c);
111 }
112
113 static struct access_rule *
114 lookup_rule(u32 ip)
115 {
116   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
117     if (ip_addrmask_match(&r->addrmask, ip))
118       return r;
119   return NULL;
120 }
121
122 static uns
123 conn_count(u32 ip)
124 {
125   uns cnt = 0;
126   CLIST_FOR_EACH(struct conn *, c, connections)
127     if (c->ip == ip)
128       cnt++;
129   return cnt;
130 }
131
132 /*** TLS ***/
133
134 static gnutls_certificate_credentials_t cert_cred;
135 static gnutls_dh_params_t dh_params;
136
137 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
138
139 static void
140 tls_init(void)
141 {
142   int err;
143
144   gnutls_global_init();
145   err = gnutls_certificate_allocate_credentials(&cert_cred);
146   TLS_CHECK(gnutls_certificate_allocate_credentials);
147   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
148   if (!err)
149     die("No CA certificate found");
150   if (err < 0)
151     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
152   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
153   if (err < 0)
154     die("Unable to load X509 key file: %s", gnutls_strerror(err));
155
156   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
157   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
158   gnutls_certificate_set_dh_params(cert_cred, dh_params);
159 }
160
161 static gnutls_session_t
162 tls_new_session(int sk)
163 {
164   gnutls_session_t s;
165   int err;
166
167   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
168   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
169   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
170   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
171   gnutls_dh_set_prime_bits(s, dh_bits);
172   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
173   return s;
174 }
175
176 static const char *
177 tls_verify_cert(struct conn *c)
178 {
179   gnutls_session_t s = c->tls;
180   uns status, num_certs;
181   int err;
182   gnutls_x509_crt_t cert;
183   const gnutls_datum_t *certs;
184
185   DBG("Verifying peer certificates");
186   err = gnutls_certificate_verify_peers2(s, &status);
187   if (err < 0)
188     return gnutls_strerror(err);
189   DBG("Verify status: %04x", status);
190   if (status & GNUTLS_CERT_INVALID)
191     return "Certificate is invalid";
192   /* XXX: We do not handle revokation. */
193   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
194     return "Certificate is not X509";
195
196   err = gnutls_x509_crt_init(&cert);
197   if (err < 0)
198     return "gnutls_x509_crt_init() failed";
199   certs = gnutls_certificate_get_peers(s, &num_certs);
200   if (!certs)
201     return "No peer certificate found";
202   DBG("Got certificate list with %d peers", num_certs);
203
204   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
205   if (err < 0)
206     return "Cannot import certificate";
207   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
208
209   byte dn[256];
210   size_t dn_len = sizeof(dn);
211   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
212   if (err < 0)
213     return "Cannot retrieve common name";
214   if (trace_tls)
215     log(L_INFO, "Cert CN: %s", dn);
216   c->cert_name = xstrdup(dn);
217
218   /* Check certificate purpose */
219   byte purp[256];
220   int purpi = 0;
221   do
222     {
223       size_t purp_len = sizeof(purp);
224       uns crit;
225       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
226       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
227         return "Not a client certificate";
228       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
229     }
230   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
231
232   DBG("Verified OK");
233   return NULL;
234 }
235
236 static void
237 tls_log_params(struct conn *c)
238 {
239   if (!trace_tls)
240     return;
241   gnutls_session_t s = c->tls;
242   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
243   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
244   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
245   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
246   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
247   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
248   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
249     proto, kx, cert, comp, cipher, mac);
250 }
251
252 /*** SOCKET FASTBUFS ***/
253
254 void NONRET
255 client_error(char *msg, ...)
256 {
257   va_list args;
258   va_start(args, msg);
259   vlog_msg(L_ERROR_R, msg, args);
260   exit(0);
261 }
262
263 static int
264 sk_fb_refill(struct fastbuf *f)
265 {
266   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
267   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
268   if (cnt < 0)
269     client_error("Read error: %m");
270   f->bptr = f->buffer;
271   f->bstop = f->buffer + cnt;
272   return cnt;
273 }
274
275 static void
276 sk_fb_spout(struct fastbuf *f)
277 {
278   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
279   int len = f->bptr - f->buffer;
280   if (!len)
281     return;
282   int cnt = careful_write(c->sk, f->buffer, len);
283   if (cnt <= 0)
284     client_error("Write error");
285   f->bptr = f->buffer;
286 }
287
288 static void
289 init_sk_fastbufs(struct conn *c)
290 {
291   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
292
293   rf->buffer = xmalloc(1024);
294   rf->bufend = rf->buffer + 1024;
295   rf->bptr = rf->bstop = rf->buffer;
296   rf->name = "socket";
297   rf->refill = sk_fb_refill;
298
299   tf->buffer = xmalloc(1024);
300   tf->bufend = tf->buffer + 1024;
301   tf->bptr = tf->bstop = tf->buffer;
302   tf->name = rf->name;
303   tf->spout = sk_fb_spout;
304 }
305
306 static int
307 tls_fb_refill(struct fastbuf *f)
308 {
309   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
310   DBG("TLS: Refill");
311   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
312   DBG("TLS: Received %d bytes", cnt);
313   if (cnt < 0)
314     client_error("TLS read error: %s", gnutls_strerror(cnt));
315   f->bptr = f->buffer;
316   f->bstop = f->buffer + cnt;
317   return cnt;
318 }
319
320 static void
321 tls_fb_spout(struct fastbuf *f)
322 {
323   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
324   int len = f->bptr - f->buffer;
325   if (!len)
326     return;
327   int cnt = gnutls_record_send(c->tls, f->buffer, len);
328   DBG("TLS: Sent %d bytes", cnt);
329   if (cnt <= 0)
330     client_error("TLS write error: %s", gnutls_strerror(cnt));
331   f->bptr = f->buffer;
332 }
333
334 static void
335 init_tls_fastbufs(struct conn *c)
336 {
337   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
338
339   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
340   rf->refill = tls_fb_refill;
341   tf->spout = tls_fb_spout;
342 }
343
344 /*** CLIENT LOOP (runs in a child process) ***/
345
346 static void
347 sigalrm_handler(int sig UNUSED)
348 {
349   // We do not try to do any gracious shutdown to avoid races
350   client_error("Timed out");
351 }
352
353 static void
354 client_loop(struct conn *c)
355 {
356   log_pid = c->id;
357   init_sk_fastbufs(c);
358
359   signal(SIGPIPE, SIG_IGN);
360   struct sigaction sa = {
361     .sa_handler = sigalrm_handler
362   };
363   if (sigaction(SIGALRM, &sa, NULL) < 0)
364     die("Cannot setup SIGALRM handler: %m");
365
366   if (c->rule->plain_text)
367     {
368       bputsn(&c->tx_fb, "+OK");
369       bflush(&c->tx_fb);
370     }
371   else
372     {
373       bputsn(&c->tx_fb, "+TLS");
374       bflush(&c->tx_fb);
375       c->tls = tls_new_session(c->sk);
376       int err = gnutls_handshake(c->tls);
377       if (err < 0)
378         client_error("TLS handshake failed: %s", gnutls_strerror(err));
379       tls_log_params(c);
380       const char *cert_err = tls_verify_cert(c);
381       if (cert_err)
382         client_error("TLS certificate failure: %s", cert_err);
383       init_tls_fastbufs(c);
384     }
385
386   alarm(session_timeout);
387   if (!process_init(c))
388     log(L_ERROR, "Protocol handshake failed");
389   else
390     for (;;)
391       {
392         alarm(session_timeout);
393         if (!process_command(c))
394           break;
395       }
396
397   if (c->tls)
398     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
399   close(c->sk);
400   if (c->tls)
401     gnutls_deinit(c->tls);
402 }
403
404 /*** MAIN LOOP ***/
405
406 static void
407 sigchld_handler(int sig UNUSED)
408 {
409   /* We do not need to do anything, just interrupt the accept syscall */
410 }
411
412 static void
413 reap_child(pid_t pid, int status)
414 {
415   byte msg[EXIT_STATUS_MSG_SIZE];
416   if (format_exit_status(msg, status))
417     log(L_ERROR, "Child %d %s", (int)pid, msg);
418
419   CLIST_FOR_EACH(struct conn *, c, connections)
420     if (c->pid == pid)
421       {
422         log(L_INFO, "Connection %d closed", c->id);
423         conn_free(c);
424         return;
425       }
426   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
427 }
428
429 static int listen_sk;
430
431 static void
432 sk_init(void)
433 {
434   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
435   if (listen_sk < 0)
436     die("socket: %m");
437   int one = 1;
438   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
439     die("setsockopt(SO_REUSEADDR): %m");
440
441   struct sockaddr_in sa;
442   bzero(&sa, sizeof(sa));
443   sa.sin_family = AF_INET;
444   sa.sin_addr.s_addr = INADDR_ANY;
445   sa.sin_port = htons(port);
446   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
447     die("Cannot bind to port %d: %m", port);
448   if (listen(listen_sk, 1024) < 0)
449     die("Cannot listen on port %d: %m", port);
450 }
451
452 static void
453 sk_accept(void)
454 {
455   struct sockaddr_in sa;
456   int salen = sizeof(sa);
457   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
458   if (sk < 0)
459     {
460       if (errno == EINTR)
461         return;
462       die("accept: %m");
463     }
464
465   byte ipbuf[INET_ADDRSTRLEN];
466   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
467   u32 addr = ntohl(sa.sin_addr.s_addr);
468   uns port = ntohs(sa.sin_port);
469   char *err;
470
471   struct access_rule *rule = lookup_rule(addr);
472   if (!rule)
473     {
474       err = "Unauthorized";
475       goto reject;
476     }
477
478   if (num_conn >= max_conn)
479     {
480       err = "Too many connections";
481       goto reject;
482     }
483
484   if (conn_count(addr) >= rule->max_conn)
485     {
486       err = "Too many connections from this address";
487       goto reject;
488     }
489
490   struct conn *c = conn_new();
491   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
492         ipbuf, port, c->id,
493         (rule->plain_text ? "plain-text" : "TLS"),
494         (rule->allow_admin ? "admin" : "user"));
495   c->ip = addr;
496   c->sk = sk;
497   c->rule = rule;
498
499   c->pid = fork();
500   if (c->pid < 0)
501     {
502       conn_free(c);
503       err = "Server overloaded";
504       log(L_ERROR, "Fork failed: %m");
505       goto reject2;
506     }
507   if (!c->pid)
508     {
509       close(listen_sk);
510       client_loop(c);
511       exit(0);
512     }
513   close(sk);
514   return;
515
516 reject:
517   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
518 reject2: ;
519   // Write an error message to the socket, but do not allow it to slow us down
520   struct linger ling = { .l_onoff=0 };
521   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
522     log(L_ERROR, "Cannot set SO_LINGER: %m");
523   write(sk, "-", 1);
524   write(sk, err, strlen(err));
525   write(sk, "\n", 1);
526   close(sk);
527 }
528
529 int main(int argc, char **argv)
530 {
531   setproctitle_init(argc, argv);
532   cf_def_file = "config";
533   cf_declare_section("SubmitD", &submitd_conf, 0);
534
535   int opt;
536   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
537     die("This program has no options");
538
539   log(L_INFO, "Initializing TLS");
540   tls_init();
541
542   conn_init();
543   sk_init();
544   log(L_INFO, "Listening on port %d", port);
545
546   struct sigaction sa = {
547     .sa_handler = sigchld_handler
548   };
549   if (sigaction(SIGCHLD, &sa, NULL) < 0)
550     die("Cannot setup SIGCHLD handler: %m");
551
552   for (;;)
553     {
554       int status;
555       pid_t pid = waitpid(-1, &status, WNOHANG);
556       if (pid > 0)
557         reap_child(pid, status);
558       else
559         sk_accept();
560     }
561 }