]> mj.ucw.cz Git - moe.git/blob - submit/submitd.c
More TODO bits.
[moe.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 #undef LOCAL_DEBUG
8
9 #include "lib/lib.h"
10 #include "lib/conf.h"
11 #include "lib/getopt.h"
12
13 #include <string.h>
14 #include <stdlib.h>
15 #include <unistd.h>
16 #include <signal.h>
17 #include <errno.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20 #include <sys/wait.h>
21 #include <arpa/inet.h>
22 #include <netinet/in.h>
23
24 #include "submitd.h"
25
26 /*** CONFIGURATION ***/
27
28 static uns port = 8888;
29 static uns dh_bits = 1024;
30 static uns max_conn = 10;
31 static uns session_timeout;
32 static byte *ca_cert_name = "?";
33 static byte *server_cert_name = "?";
34 static byte *server_key_name = "?";
35 static clist access_rules;
36 static uns trace_tls;
37 uns max_request_size;
38 uns max_attachment_size;
39 uns trace_commands;
40
41 static struct cf_section access_conf = {
42   CF_TYPE(struct access_rule),
43   CF_ITEMS {
44     CF_USER("IP", PTR_TO(struct access_rule, addrmask), &ip_addrmask_type),
45     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
46     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
47     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
48     CF_END
49   }
50 };
51
52 static struct cf_section submitd_conf = {
53   CF_ITEMS {
54     CF_UNS("Port", &port),
55     CF_UNS("DHBits", &dh_bits),
56     CF_UNS("MaxConn", &max_conn),
57     CF_UNS("SessionTimeout", &session_timeout),
58     CF_UNS("MaxRequestSize", &max_request_size),
59     CF_UNS("MaxAttachSize", &max_attachment_size),
60     CF_STRING("CACert", &ca_cert_name),
61     CF_STRING("ServerCert", &server_cert_name),
62     CF_STRING("ServerKey", &server_key_name),
63     CF_LIST("Access", &access_rules, &access_conf),
64     CF_UNS("TraceTLS", &trace_tls),
65     CF_UNS("TraceCommands", &trace_commands),
66     CF_END
67   }
68 };
69
70 /*** CONNECTIONS ***/
71
72 static clist connections;
73 static uns last_conn_id;
74 static uns num_conn;
75
76 static void
77 conn_init(void)
78 {
79   clist_init(&connections);
80 }
81
82 static struct conn *
83 conn_new(void)
84 {
85   struct conn *c = xmalloc_zero(sizeof(*c));
86   c->id = ++last_conn_id;
87   clist_add_tail(&connections, &c->n);
88   num_conn++;
89   return c;
90 }
91
92 static void
93 conn_free(struct conn *c)
94 {
95   xfree(c->ip_string);
96   xfree(c->cert_name);
97   clist_remove(&c->n);
98   num_conn--;
99   xfree(c);
100 }
101
102 static struct access_rule *
103 lookup_rule(u32 ip)
104 {
105   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
106     if (ip_addrmask_match(&r->addrmask, ip))
107       return r;
108   return NULL;
109 }
110
111 static uns
112 conn_count(u32 ip)
113 {
114   uns cnt = 0;
115   CLIST_FOR_EACH(struct conn *, c, connections)
116     if (c->ip == ip)
117       cnt++;
118   return cnt;
119 }
120
121 /*** TLS ***/
122
123 static gnutls_certificate_credentials_t cert_cred;
124 static gnutls_dh_params_t dh_params;
125
126 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
127
128 static void
129 tls_init(void)
130 {
131   int err;
132
133   gnutls_global_init();
134   err = gnutls_certificate_allocate_credentials(&cert_cred);
135   TLS_CHECK(gnutls_certificate_allocate_credentials);
136   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
137   if (!err)
138     die("No CA certificate found");
139   if (err < 0)
140     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
141   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
142   if (err < 0)
143     die("Unable to load X509 key file: %s", gnutls_strerror(err));
144
145   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
146   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
147   gnutls_certificate_set_dh_params(cert_cred, dh_params);
148 }
149
150 static gnutls_session_t
151 tls_new_session(int sk)
152 {
153   gnutls_session_t s;
154   int err;
155
156   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
157   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);                 // FIXME
158   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
159   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
160   gnutls_dh_set_prime_bits(s, dh_bits);
161   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
162   return s;
163 }
164
165 static const char *
166 tls_verify_cert(struct conn *c)
167 {
168   gnutls_session_t s = c->tls;
169   uns status, num_certs;
170   int err;
171   gnutls_x509_crt_t cert;
172   const gnutls_datum_t *certs;
173
174   DBG("Verifying peer certificates");
175   err = gnutls_certificate_verify_peers2(s, &status);
176   if (err < 0)
177     return gnutls_strerror(err);
178   DBG("Verify status: %04x", status);
179   if (status & GNUTLS_CERT_INVALID)
180     return "Certificate is invalid";
181   /* XXX: We do not handle revokation. */
182   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
183     return "Certificate is not X509";
184
185   err = gnutls_x509_crt_init(&cert);
186   if (err < 0)
187     return "gnutls_x509_crt_init() failed";
188   certs = gnutls_certificate_get_peers(s, &num_certs);
189   if (!certs)
190     return "No peer certificate found";
191   DBG("Got certificate list with %d peers", num_certs);
192
193   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
194   if (err < 0)
195     return "Cannot import certificate";
196   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
197
198   byte dn[256];
199   size_t dn_len = sizeof(dn);
200   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
201   if (err < 0)
202     return "Cannot retrieve common name";
203   if (trace_tls)
204     log(L_INFO, "Cert CN: %s", dn);
205   c->cert_name = xstrdup(dn);
206
207   /* Check certificate purpose */
208   byte purp[256];
209   int purpi = 0;
210   do
211     {
212       size_t purp_len = sizeof(purp);
213       uns crit;
214       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
215       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
216         return "Not a client certificate";
217       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
218     }
219   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
220
221   DBG("Verified OK");
222   return NULL;
223 }
224
225 static void
226 tls_log_params(struct conn *c)
227 {
228   if (!trace_tls)
229     return;
230   gnutls_session_t s = c->tls;
231   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
232   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
233   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
234   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
235   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
236   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
237   log(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
238     proto, kx, cert, comp, cipher, mac);
239 }
240
241 /*** FASTBUFS OVER SOCKETS AND TLS ***/
242
243 void NONRET                             // Fatal protocol violation
244 client_error(char *msg, ...)
245 {
246   va_list args;
247   va_start(args, msg);
248   vlog_msg(L_ERROR_R, msg, args);
249   exit(0);
250 }
251
252 static int
253 sk_fb_refill(struct fastbuf *f)
254 {
255   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
256   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
257   if (cnt < 0)
258     client_error("Read error: %m");
259   f->bptr = f->buffer;
260   f->bstop = f->buffer + cnt;
261   return cnt;
262 }
263
264 static void
265 sk_fb_spout(struct fastbuf *f)
266 {
267   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
268   int len = f->bptr - f->buffer;
269   if (!len)
270     return;
271   int cnt = careful_write(c->sk, f->buffer, len);
272   if (cnt <= 0)
273     client_error("Write error");
274   f->bptr = f->buffer;
275 }
276
277 static void
278 init_sk_fastbufs(struct conn *c)
279 {
280   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
281
282   rf->buffer = xmalloc(1024);
283   rf->bufend = rf->buffer + 1024;
284   rf->bptr = rf->bstop = rf->buffer;
285   rf->name = "socket";
286   rf->refill = sk_fb_refill;
287
288   tf->buffer = xmalloc(1024);
289   tf->bufend = tf->buffer + 1024;
290   tf->bptr = tf->bstop = tf->buffer;
291   tf->name = rf->name;
292   tf->spout = sk_fb_spout;
293 }
294
295 static int
296 tls_fb_refill(struct fastbuf *f)
297 {
298   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
299   DBG("TLS: Refill");
300   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
301   DBG("TLS: Received %d bytes", cnt);
302   if (cnt < 0)
303     client_error("TLS read error: %s", gnutls_strerror(cnt));
304   f->bptr = f->buffer;
305   f->bstop = f->buffer + cnt;
306   return cnt;
307 }
308
309 static void
310 tls_fb_spout(struct fastbuf *f)
311 {
312   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
313   int len = f->bptr - f->buffer;
314   if (!len)
315     return;
316   int cnt = gnutls_record_send(c->tls, f->buffer, len);
317   DBG("TLS: Sent %d bytes", cnt);
318   if (cnt <= 0)
319     client_error("TLS write error: %s", gnutls_strerror(cnt));
320   f->bptr = f->buffer;
321 }
322
323 static void
324 init_tls_fastbufs(struct conn *c)
325 {
326   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
327
328   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
329   rf->refill = tls_fb_refill;
330   tf->spout = tls_fb_spout;
331 }
332
333 /*** CLIENT LOOP (runs in a child process) ***/
334
335 static void
336 sigalrm_handler(int sig UNUSED)
337 {
338   // We do not try to do any gracious shutdown to avoid races
339   client_error("Timed out");
340 }
341
342 static void
343 client_loop(struct conn *c)
344 {
345   setproctitle("submitd: client %s", c->ip_string);
346   log_pid = c->id;
347   init_sk_fastbufs(c);
348
349   signal(SIGPIPE, SIG_IGN);
350   struct sigaction sa = {
351     .sa_handler = sigalrm_handler
352   };
353   if (sigaction(SIGALRM, &sa, NULL) < 0)
354     die("Cannot setup SIGALRM handler: %m");
355
356   if (c->rule->plain_text)
357     {
358       bputsn(&c->tx_fb, "+OK");
359       bflush(&c->tx_fb);
360     }
361   else
362     {
363       bputsn(&c->tx_fb, "+TLS");
364       bflush(&c->tx_fb);
365       c->tls = tls_new_session(c->sk);
366       int err = gnutls_handshake(c->tls);
367       if (err < 0)
368         client_error("TLS handshake failed: %s", gnutls_strerror(err));
369       tls_log_params(c);
370       const char *cert_err = tls_verify_cert(c);
371       if (cert_err)
372         client_error("TLS certificate failure: %s", cert_err);
373       init_tls_fastbufs(c);
374     }
375
376   alarm(session_timeout);
377   if (!process_init(c))
378     log(L_ERROR, "Protocol handshake failed");
379   else
380     {
381       setproctitle("submitd: client %s (%s)", c->ip_string, c->user);
382       for (;;)
383         {
384           alarm(session_timeout);
385           if (!process_command(c))
386             break;
387         }
388     }
389
390   if (c->tls)
391     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
392   close(c->sk);
393   if (c->tls)
394     gnutls_deinit(c->tls);
395 }
396
397 /*** MAIN LOOP ***/
398
399 static void
400 sigchld_handler(int sig UNUSED)
401 {
402   /* We do not need to do anything, just interrupt the accept syscall */
403 }
404
405 static void
406 reap_child(pid_t pid, int status)
407 {
408   byte msg[EXIT_STATUS_MSG_SIZE];
409   if (format_exit_status(msg, status))
410     log(L_ERROR, "Child %d %s", (int)pid, msg);
411
412   CLIST_FOR_EACH(struct conn *, c, connections)
413     if (c->pid == pid)
414       {
415         log(L_INFO, "Connection %d closed", c->id);
416         conn_free(c);
417         return;
418       }
419   log(L_ERROR, "Cannot find connection for child process %d", (int)pid);
420 }
421
422 static int listen_sk;
423
424 static void
425 sk_init(void)
426 {
427   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
428   if (listen_sk < 0)
429     die("socket: %m");
430   int one = 1;
431   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
432     die("setsockopt(SO_REUSEADDR): %m");
433
434   struct sockaddr_in sa;
435   bzero(&sa, sizeof(sa));
436   sa.sin_family = AF_INET;
437   sa.sin_addr.s_addr = INADDR_ANY;
438   sa.sin_port = htons(port);
439   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
440     die("Cannot bind to port %d: %m", port);
441   if (listen(listen_sk, 1024) < 0)
442     die("Cannot listen on port %d: %m", port);
443 }
444
445 static void
446 sk_accept(void)
447 {
448   struct sockaddr_in sa;
449   int salen = sizeof(sa);
450   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
451   if (sk < 0)
452     {
453       if (errno == EINTR)
454         return;
455       die("accept: %m");
456     }
457
458   byte ipbuf[INET_ADDRSTRLEN];
459   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
460   u32 addr = ntohl(sa.sin_addr.s_addr);
461   uns port = ntohs(sa.sin_port);
462   char *err;
463
464   struct access_rule *rule = lookup_rule(addr);
465   if (!rule)
466     {
467       err = "Unauthorized";
468       goto reject;
469     }
470
471   if (num_conn >= max_conn)
472     {
473       err = "Too many connections";
474       goto reject;
475     }
476
477   if (conn_count(addr) >= rule->max_conn)
478     {
479       err = "Too many connections from this address";
480       goto reject;
481     }
482
483   struct conn *c = conn_new();
484   log(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
485         ipbuf, port, c->id,
486         (rule->plain_text ? "plain-text" : "TLS"),
487         (rule->allow_admin ? "admin" : "user"));
488   c->ip = addr;
489   c->ip_string = xstrdup(ipbuf);
490   c->sk = sk;
491   c->rule = rule;
492
493   c->pid = fork();
494   if (c->pid < 0)
495     {
496       conn_free(c);
497       err = "Server overloaded";
498       log(L_ERROR, "Fork failed: %m");
499       goto reject2;
500     }
501   if (!c->pid)
502     {
503       close(listen_sk);
504       client_loop(c);
505       exit(0);
506     }
507   close(sk);
508   return;
509
510 reject:
511   log(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
512 reject2: ;
513   // Write an error message to the socket, but do not allow it to slow us down
514   struct linger ling = { .l_onoff=0 };
515   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
516     log(L_ERROR, "Cannot set SO_LINGER: %m");
517   write(sk, "-", 1);
518   write(sk, err, strlen(err));
519   write(sk, "\n", 1);
520   close(sk);
521 }
522
523 int main(int argc, char **argv)
524 {
525   setproctitle_init(argc, argv);
526   cf_def_file = "config";
527   cf_declare_section("SubmitD", &submitd_conf, 0);
528   cf_declare_section("Tasks", &tasks_conf, 0);
529
530   int opt;
531   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
532     die("This program has no options");
533
534   log(L_INFO, "Initializing TLS");
535   tls_init();
536
537   conn_init();
538   sk_init();
539   log(L_INFO, "Listening on port %d", port);
540
541   struct sigaction sa = {
542     .sa_handler = sigchld_handler
543   };
544   if (sigaction(SIGCHLD, &sa, NULL) < 0)
545     die("Cannot setup SIGCHLD handler: %m");
546
547   for (;;)
548     {
549       setproctitle("submitd: %d connections", num_conn);
550       int status;
551       pid_t pid = waitpid(-1, &status, WNOHANG);
552       if (pid > 0)
553         reap_child(pid, status);
554       else
555         sk_accept();
556     }
557 }