*/
#include "lib/lib.h"
+#include "lib/threads.h"
#include "lib/lizard.h"
-#include <stdlib.h>
#include <sys/mman.h>
-#include <sys/user.h>
#include <fcntl.h>
#include <signal.h>
#include <setjmp.h>
#include <errno.h>
+struct lizard_buffer {
+ uns len;
+ void *ptr;
+};
+
struct lizard_buffer *
-lizard_alloc(uns max_len)
+lizard_alloc(void)
{
- static byte *zero = "/dev/zero";
- int fd = open(zero, O_RDWR);
- if (fd < 0)
- die("open(%s): %m", zero);
struct lizard_buffer *buf = xmalloc(sizeof(struct lizard_buffer));
- buf->len = ALIGN(max_len + PAGE_SIZE, PAGE_SIZE);
- buf->ptr = mmap(NULL, buf->len, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, 0);
- if (buf->ptr == MAP_FAILED)
- die("mmap(%s): %m", zero);
+ buf->len = 0;
+ buf->ptr = NULL;
+ handle_signal(SIGSEGV);
return buf;
}
void
lizard_free(struct lizard_buffer *buf)
{
- munmap(buf->ptr, buf->len);
+ unhandle_signal(SIGSEGV);
+ if (buf->ptr)
+ munmap(buf->ptr, buf->len + CPU_PAGE_SIZE);
xfree(buf);
}
-static jmp_buf safe_decompress_jump;
static void
-sigsegv_handler(int UNUSED whatsit)
+lizard_realloc(struct lizard_buffer *buf, uns max_len)
+ /* max_len needs to be aligned to CPU_PAGE_SIZE */
+{
+ if (max_len <= buf->len)
+ return;
+ if (max_len < 2*buf->len) // to ensure logarithmic cost
+ max_len = 2*buf->len;
+
+ if (buf->ptr)
+ munmap(buf->ptr, buf->len + CPU_PAGE_SIZE);
+ buf->len = max_len;
+ buf->ptr = mmap(NULL, buf->len + CPU_PAGE_SIZE, PROT_READ | PROT_WRITE, MAP_ANON | MAP_PRIVATE, -1, 0);
+ if (buf->ptr == MAP_FAILED)
+ die("mmap(anonymous, %d bytes): %m", (uns)(buf->len + CPU_PAGE_SIZE));
+ if (mprotect(buf->ptr + buf->len, CPU_PAGE_SIZE, PROT_NONE) < 0)
+ die("mprotect: %m");
+}
+
+static jmp_buf safe_decompress_jump;
+static int
+sigsegv_handler(int signal UNUSED)
{
- log(L_ERROR, "SIGSEGV caught in lizard_decompress()");
longjmp(safe_decompress_jump, 1);
+ return 1;
}
-int
-lizard_decompress_safe(byte *in, struct lizard_buffer *buf, uns expected_length)
- /* Decompresses into buf->ptr and returns the length of the uncompressed
- * file. If an error has occured, -1 is returned and errno is set. SIGSEGV
- * is caught in the case of buffer-overflow. The function is not re-entrant
- * because of a static longjmp handler. */
+byte *
+lizard_decompress_safe(const byte *in, struct lizard_buffer *buf, uns expected_length)
+ /* Decompresses in into buf, sets *ptr to the data, and returns the
+ * uncompressed length. If an error has occured, -1 is returned and errno is
+ * set. The buffer buf is automatically reallocated. SIGSEGV is caught in
+ * case of buffer-overflow. The function is not re-entrant because of a
+ * static longjmp handler. */
{
- volatile uns lock_offset = ALIGN(expected_length, PAGE_SIZE);
- if (lock_offset + PAGE_SIZE > buf->len)
- {
- errno = EFBIG;
- return -1;
- }
- mprotect(buf->ptr + lock_offset, PAGE_SIZE, PROT_NONE);
- volatile sighandler_t old_handler = signal(SIGSEGV, sigsegv_handler);
- int len, err;
+ uns lock_offset = ALIGN_TO(expected_length + 3, CPU_PAGE_SIZE); // +3 due to the unaligned access
+ if (lock_offset > buf->len)
+ lizard_realloc(buf, lock_offset);
+ volatile sh_sighandler_t old_handler = set_signal_handler(SIGSEGV, sigsegv_handler);
+ byte *ptr;
if (!setjmp(safe_decompress_jump))
{
- len = lizard_decompress(in, buf->ptr);
- err = errno;
+ ptr = buf->ptr + buf->len - lock_offset;
+ int len = lizard_decompress(in, ptr);
+ if (len != (int) expected_length)
+ {
+ ptr = NULL;
+ errno = EINVAL;
+ }
}
else
{
- len = -1;
- err = EFAULT;
+ msg(L_ERROR, "SIGSEGV caught in lizard_decompress()");
+ ptr = NULL;
+ errno = EFAULT;
}
- signal(SIGSEGV, old_handler);
- mprotect(buf->ptr + lock_offset, PAGE_SIZE, PROT_READ | PROT_WRITE);
- errno = err;
- return len;
+ set_signal_handler(SIGSEGV, old_handler);
+ return ptr;
}