]> mj.ucw.cz Git - moe.git/blob - submit/submitd.c
First parts of submit command.
[moe.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 /*
8  *  FIXME:
9  *      - competition timeout & per-contestant exceptions
10  *      - open-data problems
11  */
12
13 #undef LOCAL_DEBUG
14
15 #include "lib/lib.h"
16 #include "lib/conf.h"
17 #include "lib/getopt.h"
18
19 #include <string.h>
20 #include <stdlib.h>
21 #include <unistd.h>
22 #include <signal.h>
23 #include <errno.h>
24 #include <sys/types.h>
25 #include <sys/socket.h>
26 #include <sys/wait.h>
27 #include <arpa/inet.h>
28 #include <netinet/in.h>
29
30 #include "submitd.h"
31
32 /*** CONFIGURATION ***/
33
34 static uns port = 8888;
35 static uns dh_bits = 1024;
36 static uns max_conn = 10;
37 static uns session_timeout;
38 static byte *ca_cert_name = "?";
39 static byte *server_cert_name = "?";
40 static byte *server_key_name = "?";
41 static clist access_rules;
42 static uns trace_tls;
43 uns max_request_size;
44 uns max_attachment_size;
45 uns trace_commands;
46
47 static struct cf_section access_conf = {
48   CF_TYPE(struct access_rule),
49   CF_ITEMS {
50     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
51     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
52     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
53     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
54     CF_END
55   }
56 };
57
58 static byte *
59 config_init(void)
60 {
61   clist_init(&access_rules);
62   return NULL;
63 }
64
65 static struct cf_section submitd_conf = {
66   CF_INIT(config_init),
67   CF_ITEMS {
68     CF_UNS("Port", &port),
69     CF_UNS("DHBits", &dh_bits),
70     CF_UNS("MaxConn", &max_conn),
71     CF_UNS("SessionTimeout", &session_timeout),
72     CF_UNS("MaxRequestSize", &max_request_size),
73     CF_UNS("MaxAttachSize", &max_attachment_size),
74     CF_STRING("CACert", &ca_cert_name),
75     CF_STRING("ServerCert", &server_cert_name),
76     CF_STRING("ServerKey", &server_key_name),
77     CF_LIST("Access", &access_rules, &access_conf),
78     CF_UNS("TraceTLS", &trace_tls),
79     CF_UNS("TraceCommands", &trace_commands),
80     CF_END
81   }
82 };
83
84 /*** CONNECTIONS ***/
85
86 static clist connections;
87 static uns last_conn_id;
88 static uns num_conn;
89
90 static void
91 conn_init(void)
92 {
93   clist_init(&connections);
94 }
95
96 static struct conn *
97 conn_new(void)
98 {
99   struct conn *c = xmalloc_zero(sizeof(*c));
100   c->id = ++last_conn_id;
101   clist_add_tail(&connections, &c->n);
102   num_conn++;
103   return c;
104 }
105
106 static void
107 conn_free(struct conn *c)
108 {
109   xfree(c->ip_string);
110   xfree(c->cert_name);
111   clist_remove(&c->n);
112   num_conn--;
113   xfree(c);
114 }
115
116 static struct access_rule *
117 lookup_rule(u32 ip)
118 {
119   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
120     if (ip_addrmask_match(&r->addrmask, ip))
121       return r;
122   return NULL;
123 }
124
125 static uns
126 conn_count(u32 ip)
127 {
128   uns cnt = 0;
129   CLIST_FOR_EACH(struct conn *, c, connections)
130     if (c->ip == ip)
131       cnt++;
132   return cnt;
133 }
134
135 /*** TLS ***/
136
137 static gnutls_certificate_credentials_t cert_cred;
138 static gnutls_dh_params_t dh_params;
139
140 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
141
142 static void
143 tls_init(void)
144 {
145   int err;
146
147   gnutls_global_init();
148   err = gnutls_certificate_allocate_credentials(&cert_cred);
149   TLS_CHECK(gnutls_certificate_allocate_credentials);
150   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
151   if (!err)
152     die("No CA certificate found");
153   if (err < 0)
154     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
155   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
156   if (err < 0)
157     die("Unable to load X509 key file: %s", gnutls_strerror(err));
158
159   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
160   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
161   gnutls_certificate_set_dh_params(cert_cred, dh_params);
162 }
163
164 static gnutls_session_t
165 tls_new_session(int sk)
166 {
167   gnutls_session_t s;
168   int err;
169
170   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
171   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
172   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
173   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
174   gnutls_dh_set_prime_bits(s, dh_bits);
175   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
176   return s;
177 }
178
179 static const char *
180 tls_verify_cert(struct conn *c)
181 {
182   gnutls_session_t s = c->tls;
183   uns status, num_certs;
184   int err;
185   gnutls_x509_crt_t cert;
186   const gnutls_datum_t *certs;
187
188   DBG("Verifying peer certificates");
189   err = gnutls_certificate_verify_peers2(s, &status);
190   if (err < 0)
191     return gnutls_strerror(err);
192   DBG("Verify status: %04x", status);
193   if (status & GNUTLS_CERT_INVALID)
194     return "Certificate is invalid";
195   /* XXX: We do not handle revokation. */
196   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
197     return "Certificate is not X509";
198
199   err = gnutls_x509_crt_init(&cert);
200   if (err < 0)
201     return "gnutls_x509_crt_init() failed";
202   certs = gnutls_certificate_get_peers(s, &num_certs);
203   if (!certs)
204     return "No peer certificate found";
205   DBG("Got certificate list with %d peers", num_certs);
206
207   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
208   if (err < 0)
209     return "Cannot import certificate";
210   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
211
212   byte dn[256];
213   size_t dn_len = sizeof(dn);
214   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
215   if (err < 0)
216     return "Cannot retrieve common name";
217   if (trace_tls)
218     log(L_INFO, "Cert CN: %s", dn);
219   c->cert_name = xstrdup(dn);
220
221   /* Check certificate purpose */
222   byte purp[256];
223   int purpi = 0;
224   do
225     {
226       size_t purp_len = sizeof(purp);
227       uns crit;
228       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
229       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
230         return "Not a client certificate";
231       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
232     }
233   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
234
235   DBG("Verified OK");
236   return NULL;
237 }
238
239 static void
240 tls_log_params(struct conn *c)
241 {
242   if (!trace_tls)
243     return;
244   gnutls_session_t s = c->tls;
245   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
246   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
247   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
248   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
249   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
250   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
251   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
252     proto, kx, cert, comp, cipher, mac);
253 }
254
255 /*** FASTBUFS OVER SOCKETS AND TLS ***/
256
257 void NONRET                             // Fatal protocol violation
258 client_error(char *msg, ...)
259 {
260   va_list args;
261   va_start(args, msg);
262   vlog_msg(L_ERROR_R, msg, args);
263   exit(0);
264 }
265
266 static int
267 sk_fb_refill(struct fastbuf *f)
268 {
269   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
270   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
271   if (cnt < 0)
272     client_error("Read error: %m");
273   f->bptr = f->buffer;
274   f->bstop = f->buffer + cnt;
275   return cnt;
276 }
277
278 static void
279 sk_fb_spout(struct fastbuf *f)
280 {
281   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
282   int len = f->bptr - f->buffer;
283   if (!len)
284     return;
285   int cnt = careful_write(c->sk, f->buffer, len);
286   if (cnt <= 0)
287     client_error("Write error");
288   f->bptr = f->buffer;
289 }
290
291 static void
292 init_sk_fastbufs(struct conn *c)
293 {
294   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
295
296   rf->buffer = xmalloc(1024);
297   rf->bufend = rf->buffer + 1024;
298   rf->bptr = rf->bstop = rf->buffer;
299   rf->name = "socket";
300   rf->refill = sk_fb_refill;
301
302   tf->buffer = xmalloc(1024);
303   tf->bufend = tf->buffer + 1024;
304   tf->bptr = tf->bstop = tf->buffer;
305   tf->name = rf->name;
306   tf->spout = sk_fb_spout;
307 }
308
309 static int
310 tls_fb_refill(struct fastbuf *f)
311 {
312   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
313   DBG("TLS: Refill");
314   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
315   DBG("TLS: Received %d bytes", cnt);
316   if (cnt < 0)
317     client_error("TLS read error: %s", gnutls_strerror(cnt));
318   f->bptr = f->buffer;
319   f->bstop = f->buffer + cnt;
320   return cnt;
321 }
322
323 static void
324 tls_fb_spout(struct fastbuf *f)
325 {
326   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
327   int len = f->bptr - f->buffer;
328   if (!len)
329     return;
330   int cnt = gnutls_record_send(c->tls, f->buffer, len);
331   DBG("TLS: Sent %d bytes", cnt);
332   if (cnt <= 0)
333     client_error("TLS write error: %s", gnutls_strerror(cnt));
334   f->bptr = f->buffer;
335 }
336
337 static void
338 init_tls_fastbufs(struct conn *c)
339 {
340   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
341
342   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
343   rf->refill = tls_fb_refill;
344   tf->spout = tls_fb_spout;
345 }
346
347 /*** CLIENT LOOP (runs in a child process) ***/
348
349 static void
350 sigalrm_handler(int sig UNUSED)
351 {
352   // We do not try to do any gracious shutdown to avoid races
353   client_error("Timed out");
354 }
355
356 static void
357 client_loop(struct conn *c)
358 {
359   setproctitle("submitd: client %s", c->ip_string);
360   log_pid = c->id;
361   init_sk_fastbufs(c);
362
363   signal(SIGPIPE, SIG_IGN);
364   struct sigaction sa = {
365     .sa_handler = sigalrm_handler
366   };
367   if (sigaction(SIGALRM, &sa, NULL) < 0)
368     die("Cannot setup SIGALRM handler: %m");
369
370   if (c->rule->plain_text)
371     {
372       bputsn(&c->tx_fb, "+OK");
373       bflush(&c->tx_fb);
374     }
375   else
376     {
377       bputsn(&c->tx_fb, "+TLS");
378       bflush(&c->tx_fb);
379       c->tls = tls_new_session(c->sk);
380       int err = gnutls_handshake(c->tls);
381       if (err < 0)
382         client_error("TLS handshake failed: %s", gnutls_strerror(err));
383       tls_log_params(c);
384       const char *cert_err = tls_verify_cert(c);
385       if (cert_err)
386         client_error("TLS certificate failure: %s", cert_err);
387       init_tls_fastbufs(c);
388     }
389
390   alarm(session_timeout);
391   if (!process_init(c))
392     log(L_ERROR, "Protocol handshake failed");
393   else
394     {
395       setproctitle("submitd: client %s (%s)", c->ip_string, c->user);
396       for (;;)
397         {
398           alarm(session_timeout);
399           if (!process_command(c))
400             break;
401         }
402     }
403
404   if (c->tls)
405     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
406   close(c->sk);
407   if (c->tls)
408     gnutls_deinit(c->tls);
409 }
410
411 /*** MAIN LOOP ***/
412
413 static void
414 sigchld_handler(int sig UNUSED)
415 {
416   /* We do not need to do anything, just interrupt the accept syscall */
417 }
418
419 static void
420 reap_child(pid_t pid, int status)
421 {
422   byte msg[EXIT_STATUS_MSG_SIZE];
423   if (format_exit_status(msg, status))
424     log(L_ERROR, "Child %d %s", (int)pid, msg);
425
426   CLIST_FOR_EACH(struct conn *, c, connections)
427     if (c->pid == pid)
428       {
429         log(L_INFO, "Connection %d closed", c->id);
430         conn_free(c);
431         return;
432       }
433   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
434 }
435
436 static int listen_sk;
437
438 static void
439 sk_init(void)
440 {
441   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
442   if (listen_sk < 0)
443     die("socket: %m");
444   int one = 1;
445   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
446     die("setsockopt(SO_REUSEADDR): %m");
447
448   struct sockaddr_in sa;
449   bzero(&sa, sizeof(sa));
450   sa.sin_family = AF_INET;
451   sa.sin_addr.s_addr = INADDR_ANY;
452   sa.sin_port = htons(port);
453   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
454     die("Cannot bind to port %d: %m", port);
455   if (listen(listen_sk, 1024) < 0)
456     die("Cannot listen on port %d: %m", port);
457 }
458
459 static void
460 sk_accept(void)
461 {
462   struct sockaddr_in sa;
463   int salen = sizeof(sa);
464   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
465   if (sk < 0)
466     {
467       if (errno == EINTR)
468         return;
469       die("accept: %m");
470     }
471
472   byte ipbuf[INET_ADDRSTRLEN];
473   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
474   u32 addr = ntohl(sa.sin_addr.s_addr);
475   uns port = ntohs(sa.sin_port);
476   char *err;
477
478   struct access_rule *rule = lookup_rule(addr);
479   if (!rule)
480     {
481       err = "Unauthorized";
482       goto reject;
483     }
484
485   if (num_conn >= max_conn)
486     {
487       err = "Too many connections";
488       goto reject;
489     }
490
491   if (conn_count(addr) >= rule->max_conn)
492     {
493       err = "Too many connections from this address";
494       goto reject;
495     }
496
497   struct conn *c = conn_new();
498   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
499         ipbuf, port, c->id,
500         (rule->plain_text ? "plain-text" : "TLS"),
501         (rule->allow_admin ? "admin" : "user"));
502   c->ip = addr;
503   c->ip_string = xstrdup(ipbuf);
504   c->sk = sk;
505   c->rule = rule;
506
507   c->pid = fork();
508   if (c->pid < 0)
509     {
510       conn_free(c);
511       err = "Server overloaded";
512       log(L_ERROR, "Fork failed: %m");
513       goto reject2;
514     }
515   if (!c->pid)
516     {
517       close(listen_sk);
518       client_loop(c);
519       exit(0);
520     }
521   close(sk);
522   return;
523
524 reject:
525   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
526 reject2: ;
527   // Write an error message to the socket, but do not allow it to slow us down
528   struct linger ling = { .l_onoff=0 };
529   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
530     log(L_ERROR, "Cannot set SO_LINGER: %m");
531   write(sk, "-", 1);
532   write(sk, err, strlen(err));
533   write(sk, "\n", 1);
534   close(sk);
535 }
536
537 int main(int argc, char **argv)
538 {
539   setproctitle_init(argc, argv);
540   cf_def_file = "config";
541   cf_declare_section("SubmitD", &submitd_conf, 0);
542   cf_declare_section("Tasks", &tasks_conf, 0);
543
544   int opt;
545   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
546     die("This program has no options");
547
548   log(L_INFO, "Initializing TLS");
549   tls_init();
550
551   conn_init();
552   sk_init();
553   log(L_INFO, "Listening on port %d", port);
554
555   struct sigaction sa = {
556     .sa_handler = sigchld_handler
557   };
558   if (sigaction(SIGCHLD, &sa, NULL) < 0)
559     die("Cannot setup SIGCHLD handler: %m");
560
561   for (;;)
562     {
563       setproctitle("submitd: %d connections", num_conn);
564       int status;
565       pid_t pid = waitpid(-1, &status, WNOHANG);
566       if (pid > 0)
567         reap_child(pid, status);
568       else
569         sk_accept();
570     }
571 }