]> mj.ucw.cz Git - moe.git/blob - submit/submitd.c
Distinguish between die() and err().
[moe.git] / submit / submitd.c
1 /*
2  *  The Submit Daemon
3  *
4  *  (c) 2007 Martin Mares <mj@ucw.cz>
5  */
6
7 #undef LOCAL_DEBUG
8
9 #include "lib/lib.h"
10 #include "lib/conf.h"
11 #include "lib/getopt.h"
12
13 #include <string.h>
14 #include <stdlib.h>
15 #include <unistd.h>
16 #include <signal.h>
17 #include <errno.h>
18 #include <sys/types.h>
19 #include <sys/socket.h>
20 #include <sys/wait.h>
21 #include <arpa/inet.h>
22 #include <netinet/in.h>
23
24 #include "submitd.h"
25
26 /*** CONFIGURATION ***/
27
28 static char *log_name;
29 static uns port = 8888;
30 static uns dh_bits = 1024;
31 static uns max_conn = 10;
32 static uns session_timeout;
33 uns max_versions;
34 static char *ca_cert_name = "?";
35 static char *server_cert_name = "?";
36 static char *server_key_name = "?";
37 char *history_format;
38 static clist access_rules;
39 static uns trace_tls;
40 uns max_request_size;
41 uns max_attachment_size;
42 uns trace_commands;
43
44 static struct cf_section ip_node_conf = {
45   CF_TYPE(struct ip_node),
46   CF_ITEMS {
47     CF_USER("IP", PTR_TO(struct ip_node, addrmask), &ip_addrmask_type),
48     CF_END
49   }
50 };
51
52 static struct cf_section access_conf = {
53   CF_TYPE(struct access_rule),
54   CF_ITEMS {
55     CF_LIST("IP", PTR_TO(struct access_rule, ip_list), &ip_node_conf),
56     CF_UNS("Admin", PTR_TO(struct access_rule, allow_admin)),
57     CF_UNS("PlainText", PTR_TO(struct access_rule, plain_text)),
58     CF_UNS("MaxConn", PTR_TO(struct access_rule, max_conn)),
59     CF_END
60   }
61 };
62
63 static struct cf_section submitd_conf = {
64   CF_ITEMS {
65     CF_STRING("LogFile", &log_name),
66     CF_UNS("Port", &port),
67     CF_UNS("DHBits", &dh_bits),
68     CF_UNS("MaxConn", &max_conn),
69     CF_UNS("SessionTimeout", &session_timeout),
70     CF_UNS("MaxRequestSize", &max_request_size),
71     CF_UNS("MaxAttachSize", &max_attachment_size),
72     CF_UNS("MaxVersions", &max_versions),
73     CF_STRING("CACert", &ca_cert_name),
74     CF_STRING("ServerCert", &server_cert_name),
75     CF_STRING("ServerKey", &server_key_name),
76     CF_STRING("History", &history_format),
77     CF_LIST("Access", &access_rules, &access_conf),
78     CF_UNS("TraceTLS", &trace_tls),
79     CF_UNS("TraceCommands", &trace_commands),
80     CF_END
81   }
82 };
83
84 /*** CONNECTIONS ***/
85
86 static clist connections;
87 static uns last_conn_id;
88 static uns num_conn;
89
90 static void
91 conn_init(void)
92 {
93   clist_init(&connections);
94 }
95
96 static struct conn *
97 conn_new(void)
98 {
99   struct conn *c = xmalloc_zero(sizeof(*c));
100   c->id = ++last_conn_id;
101   clist_add_tail(&connections, &c->n);
102   num_conn++;
103   return c;
104 }
105
106 static void
107 conn_free(struct conn *c)
108 {
109   xfree(c->ip_string);
110   xfree(c->cert_name);
111   clist_remove(&c->n);
112   num_conn--;
113   xfree(c);
114 }
115
116 static struct access_rule *
117 lookup_rule(u32 ip)
118 {
119   CLIST_FOR_EACH(struct access_rule *, r, access_rules)
120     CLIST_FOR_EACH(struct ip_node *, n, r->ip_list)
121       if (ip_addrmask_match(&n->addrmask, ip))
122         return r;
123   return NULL;
124 }
125
126 static uns
127 conn_count(u32 ip)
128 {
129   uns cnt = 0;
130   CLIST_FOR_EACH(struct conn *, c, connections)
131     if (c->ip == ip)
132       cnt++;
133   return cnt;
134 }
135
136 /*** TLS ***/
137
138 static gnutls_certificate_credentials_t cert_cred;
139 static gnutls_dh_params_t dh_params;
140
141 #define TLS_CHECK(name) if (err < 0) die(#name " failed: %s", gnutls_strerror(err))
142
143 static void
144 tls_init(void)
145 {
146   int err;
147
148   gnutls_global_init();
149   err = gnutls_certificate_allocate_credentials(&cert_cred);
150   TLS_CHECK(gnutls_certificate_allocate_credentials);
151   err = gnutls_certificate_set_x509_trust_file(cert_cred, ca_cert_name, GNUTLS_X509_FMT_PEM);
152   if (!err)
153     die("No CA certificate found");
154   if (err < 0)
155     die("Unable to load X509 trust file: %s", gnutls_strerror(err));
156   err = gnutls_certificate_set_x509_key_file(cert_cred, server_cert_name, server_key_name, GNUTLS_X509_FMT_PEM);
157   if (err < 0)
158     die("Unable to load X509 key file: %s", gnutls_strerror(err));
159
160   err = gnutls_dh_params_init(&dh_params); TLS_CHECK(gnutls_dh_params_init);
161   err = gnutls_dh_params_generate2(dh_params, dh_bits); TLS_CHECK(gnutls_dh_params_generate2);
162   gnutls_certificate_set_dh_params(cert_cred, dh_params);
163 }
164
165 static gnutls_session_t
166 tls_new_session(int sk)
167 {
168   gnutls_session_t s;
169   int err;
170
171   err = gnutls_init(&s, GNUTLS_SERVER); TLS_CHECK(gnutls_init);
172   err = gnutls_set_default_priority(s); TLS_CHECK(gnutls_set_default_priority);
173   gnutls_credentials_set(s, GNUTLS_CRD_CERTIFICATE, cert_cred);
174   gnutls_certificate_server_set_request(s, GNUTLS_CERT_REQUEST);
175   gnutls_dh_set_prime_bits(s, dh_bits);
176   gnutls_transport_set_ptr(s, (gnutls_transport_ptr_t) sk);
177   return s;
178 }
179
180 static const char *
181 tls_verify_cert(struct conn *c)
182 {
183   gnutls_session_t s = c->tls;
184   uns status, num_certs;
185   int err;
186   gnutls_x509_crt_t cert;
187   const gnutls_datum_t *certs;
188
189   DBG("Verifying peer certificates");
190   err = gnutls_certificate_verify_peers2(s, &status);
191   if (err < 0)
192     return gnutls_strerror(err);
193   DBG("Verify status: %04x", status);
194   if (status & GNUTLS_CERT_INVALID)
195     return "Certificate is invalid";
196   /* XXX: We do not handle revokation. */
197   if (gnutls_certificate_type_get(s) != GNUTLS_CRT_X509)
198     return "Certificate is not X509";
199
200   err = gnutls_x509_crt_init(&cert);
201   if (err < 0)
202     return "gnutls_x509_crt_init() failed";
203   certs = gnutls_certificate_get_peers(s, &num_certs);
204   if (!certs)
205     return "No peer certificate found";
206   DBG("Got certificate list with %d peers", num_certs);
207
208   err = gnutls_x509_crt_import(cert, &certs[0], GNUTLS_X509_FMT_DER);
209   if (err < 0)
210     return "Cannot import certificate";
211   /* XXX: We do not check expiration and activation since the keys are generated for a single contest only anyway. */
212
213   char dn[256];
214   size_t dn_len = sizeof(dn);
215   err = gnutls_x509_crt_get_dn_by_oid(cert, GNUTLS_OID_X520_COMMON_NAME, 0, 0, dn, &dn_len);
216   if (err < 0)
217     return "Cannot retrieve common name";
218   if (trace_tls)
219     msg(L_INFO, "Cert CN: %s", dn);
220   c->cert_name = xstrdup(dn);
221
222   /* Check certificate purpose */
223   char purp[256];
224   int purpi = 0;
225   do
226     {
227       size_t purp_len = sizeof(purp);
228       uns crit;
229       err = gnutls_x509_crt_get_key_purpose_oid(cert, purpi++, purp, &purp_len, &crit);
230       if (err == GNUTLS_E_REQUESTED_DATA_NOT_AVAILABLE)
231         return "Not a client certificate";
232       TLS_CHECK(gnutls_x509_crt_get_key_purpose_oid);
233     }
234   while (strcmp(purp, GNUTLS_KP_TLS_WWW_CLIENT));
235
236   DBG("Verified OK");
237   return NULL;
238 }
239
240 static void
241 tls_log_params(struct conn *c)
242 {
243   if (!trace_tls)
244     return;
245   gnutls_session_t s = c->tls;
246   const char *proto = gnutls_protocol_get_name(gnutls_protocol_get_version(s));
247   const char *kx = gnutls_kx_get_name(gnutls_kx_get(s));
248   const char *cert = gnutls_certificate_type_get_name(gnutls_certificate_type_get(s));
249   const char *comp = gnutls_compression_get_name(gnutls_compression_get(s));
250   const char *cipher = gnutls_cipher_get_name(gnutls_cipher_get(s));
251   const char *mac = gnutls_mac_get_name(gnutls_mac_get(s));
252   msg(L_DEBUG, "TLS params: proto=%s kx=%s cert=%s comp=%s cipher=%s mac=%s",
253     proto, kx, cert, comp, cipher, mac);
254 }
255
256 /*** FASTBUFS OVER SOCKETS AND TLS ***/
257
258 void NONRET                             // Fatal protocol violation
259 client_error(char *msg, ...)
260 {
261   va_list args;
262   va_start(args, msg);
263   vmsg(L_ERROR_R, msg, args);
264   exit(0);
265 }
266
267 static int
268 sk_fb_refill(struct fastbuf *f)
269 {
270   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
271   int cnt = read(c->sk, f->buffer, f->bufend - f->buffer);
272   if (cnt < 0)
273     client_error("Read error: %m");
274   f->bptr = f->buffer;
275   f->bstop = f->buffer + cnt;
276   return cnt;
277 }
278
279 static void
280 sk_fb_spout(struct fastbuf *f)
281 {
282   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
283   int len = f->bptr - f->buffer;
284   if (!len)
285     return;
286   int cnt = careful_write(c->sk, f->buffer, len);
287   if (cnt <= 0)
288     client_error("Write error");
289   f->bptr = f->buffer;
290 }
291
292 static void
293 init_sk_fastbufs(struct conn *c)
294 {
295   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
296
297   rf->buffer = xmalloc(1024);
298   rf->bufend = rf->buffer + 1024;
299   rf->bptr = rf->bstop = rf->buffer;
300   rf->name = "socket";
301   rf->refill = sk_fb_refill;
302
303   tf->buffer = xmalloc(1024);
304   tf->bufend = tf->buffer + 1024;
305   tf->bptr = tf->bstop = tf->buffer;
306   tf->name = rf->name;
307   tf->spout = sk_fb_spout;
308 }
309
310 static int
311 tls_fb_refill(struct fastbuf *f)
312 {
313   struct conn *c = SKIP_BACK(struct conn, rx_fb, f);
314   DBG("TLS: Refill");
315   int cnt = gnutls_record_recv(c->tls, f->buffer, f->bufend - f->buffer);
316   DBG("TLS: Received %d bytes", cnt);
317   if (cnt < 0)
318     client_error("TLS read error: %s", gnutls_strerror(cnt));
319   f->bptr = f->buffer;
320   f->bstop = f->buffer + cnt;
321   return cnt;
322 }
323
324 static void
325 tls_fb_spout(struct fastbuf *f)
326 {
327   struct conn *c = SKIP_BACK(struct conn, tx_fb, f);
328   int len = f->bptr - f->buffer;
329   if (!len)
330     return;
331   int cnt = gnutls_record_send(c->tls, f->buffer, len);
332   DBG("TLS: Sent %d bytes", cnt);
333   if (cnt <= 0)
334     client_error("TLS write error: %s", gnutls_strerror(cnt));
335   f->bptr = f->buffer;
336 }
337
338 static void
339 init_tls_fastbufs(struct conn *c)
340 {
341   struct fastbuf *rf = &c->rx_fb, *tf = &c->tx_fb;
342
343   ASSERT(rf->buffer && tf->buffer);     // Already set up for the plaintext connection
344   rf->refill = tls_fb_refill;
345   tf->spout = tls_fb_spout;
346 }
347
348 /*** CLIENT LOOP (runs in a child process) ***/
349
350 static void
351 sigalrm_handler(int sig UNUSED)
352 {
353   // We do not try to do any gracious shutdown to avoid races
354   client_error("Timed out");
355 }
356
357 static void
358 client_loop(struct conn *c)
359 {
360   setproctitle("submitd: client %s", c->ip_string);
361   log_pid = c->id;
362   init_sk_fastbufs(c);
363
364   signal(SIGPIPE, SIG_IGN);
365   struct sigaction sa = {
366     .sa_handler = sigalrm_handler
367   };
368   if (sigaction(SIGALRM, &sa, NULL) < 0)
369     die("Cannot setup SIGALRM handler: %m");
370
371   if (c->rule->plain_text)
372     {
373       bputsn(&c->tx_fb, "+OK");
374       bflush(&c->tx_fb);
375     }
376   else
377     {
378       bputsn(&c->tx_fb, "+TLS");
379       bflush(&c->tx_fb);
380       c->tls = tls_new_session(c->sk);
381       int err = gnutls_handshake(c->tls);
382       if (err < 0)
383         client_error("TLS handshake failed: %s", gnutls_strerror(err));
384       tls_log_params(c);
385       const char *cert_err = tls_verify_cert(c);
386       if (cert_err)
387         client_error("TLS certificate failure: %s", cert_err);
388       init_tls_fastbufs(c);
389     }
390
391   alarm(session_timeout);
392   if (!process_init(c))
393     msg(L_ERROR, "Protocol handshake failed");
394   else
395     {
396       setproctitle("submitd: client %s (%s)", c->ip_string, c->user);
397       for (;;)
398         {
399           alarm(session_timeout);
400           if (!process_command(c))
401             break;
402         }
403     }
404
405   if (c->tls)
406     gnutls_bye(c->tls, GNUTLS_SHUT_WR);
407   close(c->sk);
408   if (c->tls)
409     gnutls_deinit(c->tls);
410 }
411
412 /*** MAIN LOOP ***/
413
414 static void
415 sigchld_handler(int sig UNUSED)
416 {
417   /* We do not need to do anything, just interrupt the accept syscall */
418 }
419
420 static void
421 reap_child(pid_t pid, int status)
422 {
423   char buf[EXIT_STATUS_MSG_SIZE];
424   if (format_exit_status(buf, status))
425     msg(L_ERROR, "Child %d %s", (int)pid, buf);
426
427   CLIST_FOR_EACH(struct conn *, c, connections)
428     if (c->pid == pid)
429       {
430         msg(L_INFO, "Connection %d closed", c->id);
431         conn_free(c);
432         return;
433       }
434   msg(L_ERROR, "Cannot find connection for child process %d", (int)pid);
435 }
436
437 static int listen_sk;
438
439 static void
440 sk_init(void)
441 {
442   listen_sk = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);
443   if (listen_sk < 0)
444     die("socket: %m");
445   int one = 1;
446   if (setsockopt(listen_sk, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) < 0)
447     die("setsockopt(SO_REUSEADDR): %m");
448
449   struct sockaddr_in sa;
450   bzero(&sa, sizeof(sa));
451   sa.sin_family = AF_INET;
452   sa.sin_addr.s_addr = INADDR_ANY;
453   sa.sin_port = htons(port);
454   if (bind(listen_sk, (struct sockaddr *) &sa, sizeof(sa)) < 0)
455     die("Cannot bind to port %d: %m", port);
456   if (listen(listen_sk, 1024) < 0)
457     die("Cannot listen on port %d: %m", port);
458 }
459
460 static void
461 sk_accept(void)
462 {
463   struct sockaddr_in sa;
464   int salen = sizeof(sa);
465   int sk = accept(listen_sk, (struct sockaddr *) &sa, &salen);
466   if (sk < 0)
467     {
468       if (errno == EINTR)
469         return;
470       die("accept: %m");
471     }
472
473   char ipbuf[INET_ADDRSTRLEN];
474   inet_ntop(AF_INET, &sa.sin_addr, ipbuf, sizeof(ipbuf));
475   u32 addr = ntohl(sa.sin_addr.s_addr);
476   uns port = ntohs(sa.sin_port);
477   char *err;
478
479   struct access_rule *rule = lookup_rule(addr);
480   if (!rule)
481     {
482       err = "Unauthorized";
483       goto reject;
484     }
485
486   if (num_conn >= max_conn)
487     {
488       err = "Too many connections";
489       goto reject;
490     }
491
492   if (conn_count(addr) >= rule->max_conn)
493     {
494       err = "Too many connections from this address";
495       goto reject;
496     }
497
498   struct conn *c = conn_new();
499   msg(L_INFO, "Connection from %s:%d (id %d, %s, %s)",
500         ipbuf, port, c->id,
501         (rule->plain_text ? "plain-text" : "TLS"),
502         (rule->allow_admin ? "admin" : "user"));
503   c->ip = addr;
504   c->ip_string = xstrdup(ipbuf);
505   c->sk = sk;
506   c->rule = rule;
507
508   c->pid = fork();
509   if (c->pid < 0)
510     {
511       conn_free(c);
512       err = "Server overloaded";
513       msg(L_ERROR, "Fork failed: %m");
514       goto reject2;
515     }
516   if (!c->pid)
517     {
518       close(listen_sk);
519       client_loop(c);
520       exit(0);
521     }
522   close(sk);
523   return;
524
525 reject:
526   msg(L_ERROR_R, "Connection from %s:%d rejected (%s)", ipbuf, port, err);
527 reject2: ;
528   // Write an error message to the socket, but do not allow it to slow us down
529   struct linger ling = { .l_onoff=0 };
530   if (setsockopt(sk, SOL_SOCKET, SO_LINGER, &ling, sizeof(ling)) < 0)
531     msg(L_ERROR, "Cannot set SO_LINGER: %m");
532   write(sk, "-", 1);
533   write(sk, err, strlen(err));
534   write(sk, "\n", 1);
535   close(sk);
536 }
537
538 int main(int argc, char **argv)
539 {
540   setproctitle_init(argc, argv);
541   cf_def_file = "cf/submitd";
542   cf_declare_section("SubmitD", &submitd_conf, 0);
543   cf_declare_section("Tasks", &tasks_conf, 0);
544
545   int opt;
546   if ((opt = cf_getopt(argc, argv, CF_SHORT_OPTS, CF_NO_LONG_OPTS, NULL)) >= 0)
547     die("This program has no options");
548
549   log_file(log_name);
550
551   msg(L_INFO, "Initializing TLS");
552   tls_init();
553
554   conn_init();
555   sk_init();
556   msg(L_INFO, "Listening on port %d", port);
557
558   struct sigaction sa = {
559     .sa_handler = sigchld_handler
560   };
561   if (sigaction(SIGCHLD, &sa, NULL) < 0)
562     die("Cannot setup SIGCHLD handler: %m");
563
564   for (;;)
565     {
566       setproctitle("submitd: %d connections", num_conn);
567       int status;
568       pid_t pid = waitpid(-1, &status, WNOHANG);
569       if (pid > 0)
570         reap_child(pid, status);
571       else
572         sk_accept();
573     }
574 }