From 7a59a7fcb18a65633da119c46914ec7ec048eecc Mon Sep 17 00:00:00 2001 From: Stefan Metzmacher Date: Tue, 9 Feb 2021 08:56:42 +0100 Subject: [PATCH] third_party: Update socket_wrapper to version 1.3.2 This brings support for fd-passing of INET sockets. BUG: https://bugzilla.samba.org/show_bug.cgi?id=11899 Signed-off-by: Stefan Metzmacher Reviewed-by: Andreas Schneider (cherry picked from commit ab943babc3eb454186558f6e863996dfcf7a20ea) --- buildtools/wafsamba/samba_third_party.py | 2 +- third_party/socket_wrapper/socket_wrapper.c | 1605 ++++++++++++++++--- third_party/socket_wrapper/wscript | 7 +- 3 files changed, 1349 insertions(+), 265 deletions(-) diff --git a/buildtools/wafsamba/samba_third_party.py b/buildtools/wafsamba/samba_third_party.py index bc2b21f2a559..d4d8f7a8dd07 100644 --- a/buildtools/wafsamba/samba_third_party.py +++ b/buildtools/wafsamba/samba_third_party.py @@ -24,7 +24,7 @@ Build.BuildContext.CHECK_CMOCKA = CHECK_CMOCKA @conf def CHECK_SOCKET_WRAPPER(conf): - return conf.CHECK_BUNDLED_SYSTEM_PKG('socket_wrapper', minversion='1.2.5') + return conf.CHECK_BUNDLED_SYSTEM_PKG('socket_wrapper', minversion='1.3.2') Build.BuildContext.CHECK_SOCKET_WRAPPER = CHECK_SOCKET_WRAPPER @conf diff --git a/third_party/socket_wrapper/socket_wrapper.c b/third_party/socket_wrapper/socket_wrapper.c index ffdd31a51bfd..a950a0a0dbc8 100644 --- a/third_party/socket_wrapper/socket_wrapper.c +++ b/third_party/socket_wrapper/socket_wrapper.c @@ -178,24 +178,67 @@ enum swrap_dbglvl_e { # endif #endif +#define socket_wrapper_init_mutex(m) \ + _socket_wrapper_init_mutex(m, #m) + /* Add new global locks here please */ -# define SWRAP_LOCK_ALL \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ +# define SWRAP_REINIT_ALL do { \ + int ret; \ + ret = socket_wrapper_init_mutex(&sockets_mutex); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&socket_reset_mutex); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&first_free_mutex); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&sockets_si_global); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&autobind_start_mutex); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&pcap_dump_mutex); \ + if (ret != 0) exit(-1); \ + ret = socket_wrapper_init_mutex(&mtu_update_mutex); \ + if (ret != 0) exit(-1); \ +} while(0) + +# define SWRAP_LOCK_ALL do { \ + swrap_mutex_lock(&sockets_mutex); \ + swrap_mutex_lock(&socket_reset_mutex); \ + swrap_mutex_lock(&first_free_mutex); \ + swrap_mutex_lock(&sockets_si_global); \ + swrap_mutex_lock(&autobind_start_mutex); \ + swrap_mutex_lock(&pcap_dump_mutex); \ + swrap_mutex_lock(&mtu_update_mutex); \ +} while(0) -# define SWRAP_UNLOCK_ALL \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ +# define SWRAP_UNLOCK_ALL do { \ + swrap_mutex_unlock(&mtu_update_mutex); \ + swrap_mutex_unlock(&pcap_dump_mutex); \ + swrap_mutex_unlock(&autobind_start_mutex); \ + swrap_mutex_unlock(&sockets_si_global); \ + swrap_mutex_unlock(&first_free_mutex); \ + swrap_mutex_unlock(&socket_reset_mutex); \ + swrap_mutex_unlock(&sockets_mutex); \ +} while(0) #define SOCKET_INFO_CONTAINER(si) \ (struct socket_info_container *)(si) #define SWRAP_LOCK_SI(si) do { \ struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ - swrap_mutex_lock(&sic->meta.mutex); \ + if (sic != NULL) { \ + swrap_mutex_lock(&sockets_si_global); \ + } else { \ + abort(); \ + } \ } while(0) #define SWRAP_UNLOCK_SI(si) do { \ struct socket_info_container *sic = SOCKET_INFO_CONTAINER(si); \ - swrap_mutex_unlock(&sic->meta.mutex); \ + if (sic != NULL) { \ + swrap_mutex_unlock(&sockets_si_global); \ + } else { \ + abort(); \ + } \ } while(0) #if defined(HAVE_GETTIMEOFDAY_TZ) || defined(HAVE_GETTIMEOFDAY_TZ_VOID) @@ -253,10 +296,15 @@ struct swrap_address { } sa; }; -int first_free; +static int first_free; struct socket_info { + /* + * Remember to update swrap_unix_scm_right_magic + * on any change. + */ + int family; int type; int protocol; @@ -268,6 +316,7 @@ struct socket_info int pktinfo; int tcp_nodelay; int listening; + int fd_passed; /* The unix path so we can unlink it on close() */ struct sockaddr_un un_addr; @@ -286,7 +335,13 @@ struct socket_info_meta { unsigned int refcount; int next_free; - pthread_mutex_t mutex; + /* + * As long as we don't use shared memory + * for the sockets array, we use + * sockets_si_global as a single mutex. + * + * pthread_mutex_t mutex; + */ }; struct socket_info_container @@ -309,32 +364,44 @@ static size_t socket_fds_max = SOCKET_WRAPPER_MAX_SOCKETS_LIMIT; /* Hash table to map fds to corresponding socket_info index */ static int *socket_fds_idx; -/* Mutex to synchronize access to global libc.symbols */ -static pthread_mutex_t libc_symbol_binding_mutex = PTHREAD_MUTEX_INITIALIZER; - /* Mutex for syncronizing port selection during swrap_auto_bind() */ -static pthread_mutex_t autobind_start_mutex; +static pthread_mutex_t autobind_start_mutex = PTHREAD_MUTEX_INITIALIZER; /* Mutex to guard the initialization of array of socket_info structures */ -static pthread_mutex_t sockets_mutex; +static pthread_mutex_t sockets_mutex = PTHREAD_MUTEX_INITIALIZER; /* Mutex to guard the socket reset in swrap_close() and swrap_remove_stale() */ -static pthread_mutex_t socket_reset_mutex; +static pthread_mutex_t socket_reset_mutex = PTHREAD_MUTEX_INITIALIZER; /* Mutex to synchronize access to first free index in socket_info array */ -static pthread_mutex_t first_free_mutex; +static pthread_mutex_t first_free_mutex = PTHREAD_MUTEX_INITIALIZER; + +/* + * Mutex to synchronize access to to socket_info structures + * We use a single global mutex in order to avoid leaking + * ~ 38M copy on write memory per fork. + * max_sockets=65535 * sizeof(struct socket_info_container)=592 = 38796720 + */ +static pthread_mutex_t sockets_si_global = PTHREAD_MUTEX_INITIALIZER; /* Mutex to synchronize access to packet capture dump file */ -static pthread_mutex_t pcap_dump_mutex; +static pthread_mutex_t pcap_dump_mutex = PTHREAD_MUTEX_INITIALIZER; /* Mutex for synchronizing mtu value fetch*/ -static pthread_mutex_t mtu_update_mutex; +static pthread_mutex_t mtu_update_mutex = PTHREAD_MUTEX_INITIALIZER; /* Function prototypes */ bool socket_wrapper_enabled(void); +#if ! defined(HAVE_CONSTRUCTOR_ATTRIBUTE) && defined(HAVE_PRAGMA_INIT) +/* xlC and other oldschool compilers support (only) this */ +#pragma init (swrap_constructor) +#endif void swrap_constructor(void) CONSTRUCTOR_ATTRIBUTE; +#if ! defined(HAVE_DESTRUCTOR_ATTRIBUTE) && defined(HAVE_PRAGMA_FINI) +#pragma fini (swrap_destructor) +#endif void swrap_destructor(void) DESTRUCTOR_ATTRIBUTE; #ifndef HAVE_GETPROGNAME @@ -565,7 +632,6 @@ static char *socket_wrapper_dir(void); enum swrap_lib { SWRAP_LIBC, - SWRAP_LIBNSL, SWRAP_LIBSOCKET, }; @@ -574,8 +640,6 @@ static const char *swrap_str_lib(enum swrap_lib lib) switch (lib) { case SWRAP_LIBC: return "libc"; - case SWRAP_LIBNSL: - return "libnsl"; case SWRAP_LIBSOCKET: return "libsocket"; } @@ -613,7 +677,6 @@ static void *swrap_load_lib_handle(enum swrap_lib lib) #endif switch (lib) { - case SWRAP_LIBNSL: case SWRAP_LIBSOCKET: #ifdef HAVE_LIBSOCKET handle = swrap.libc.socket_handle; @@ -695,25 +758,29 @@ static void *_swrap_bind_symbol(enum swrap_lib lib, const char *fn_name) return func; } -static void swrap_mutex_lock(pthread_mutex_t *mutex) +#define swrap_mutex_lock(m) _swrap_mutex_lock(m, #m, __func__, __LINE__) +static void _swrap_mutex_lock(pthread_mutex_t *mutex, const char *name, const char *caller, unsigned line) { int ret; ret = pthread_mutex_lock(mutex); if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, "Couldn't lock pthread mutex - %s", - strerror(ret)); + SWRAP_LOG(SWRAP_LOG_ERROR, "PID(%d):PPID(%d): %s(%u): Couldn't lock pthread mutex(%s) - %s", + getpid(), getppid(), caller, line, name, strerror(ret)); + abort(); } } -static void swrap_mutex_unlock(pthread_mutex_t *mutex) +#define swrap_mutex_unlock(m) _swrap_mutex_unlock(m, #m, __func__, __LINE__) +static void _swrap_mutex_unlock(pthread_mutex_t *mutex, const char *name, const char *caller, unsigned line) { int ret; ret = pthread_mutex_unlock(mutex); if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, "Couldn't unlock pthread mutex - %s", - strerror(ret)); + SWRAP_LOG(SWRAP_LOG_ERROR, "PID(%d):PPID(%d): %s(%u): Couldn't unlock pthread mutex(%s) - %s", + getpid(), getppid(), caller, line, name, strerror(ret)); + abort(); } } @@ -723,35 +790,18 @@ static void swrap_mutex_unlock(pthread_mutex_t *mutex) * This is an optimization to avoid locking each time we check if the symbol is * bound. */ +#define _swrap_bind_symbol_generic(lib, sym_name) do { \ + swrap.libc.symbols._libc_##sym_name.obj = \ + _swrap_bind_symbol(lib, #sym_name); \ +} while(0); + #define swrap_bind_symbol_libc(sym_name) \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap.libc.symbols._libc_##sym_name.obj = \ - _swrap_bind_symbol(SWRAP_LIBC, #sym_name); \ - } \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ - } + _swrap_bind_symbol_generic(SWRAP_LIBC, sym_name) #define swrap_bind_symbol_libsocket(sym_name) \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap.libc.symbols._libc_##sym_name.obj = \ - _swrap_bind_symbol(SWRAP_LIBSOCKET, #sym_name); \ - } \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ - } + _swrap_bind_symbol_generic(SWRAP_LIBSOCKET, sym_name) -#define swrap_bind_symbol_libnsl(sym_name) \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap_mutex_lock(&libc_symbol_binding_mutex); \ - if (swrap.libc.symbols._libc_##sym_name.obj == NULL) { \ - swrap.libc.symbols._libc_##sym_name.obj = \ - _swrap_bind_symbol(SWRAP_LIBNSL, #sym_name); \ - } \ - swrap_mutex_unlock(&libc_symbol_binding_mutex); \ - } +static void swrap_bind_symbol_all(void); /**************************************************************************** * IMPORTANT @@ -770,7 +820,7 @@ static int libc_accept4(int sockfd, socklen_t *addrlen, int flags) { - swrap_bind_symbol_libsocket(accept4); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_accept4.f(sockfd, addr, addrlen, flags); } @@ -779,7 +829,7 @@ static int libc_accept4(int sockfd, static int libc_accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(accept); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_accept.f(sockfd, addr, addrlen); } @@ -789,14 +839,14 @@ static int libc_bind(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(bind); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_bind.f(sockfd, addr, addrlen); } static int libc_close(int fd) { - swrap_bind_symbol_libc(close); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_close.f(fd); } @@ -805,21 +855,21 @@ static int libc_connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(connect); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_connect.f(sockfd, addr, addrlen); } static int libc_dup(int fd) { - swrap_bind_symbol_libc(dup); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_dup.f(fd); } static int libc_dup2(int oldfd, int newfd) { - swrap_bind_symbol_libc(dup2); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_dup2.f(oldfd, newfd); } @@ -827,7 +877,7 @@ static int libc_dup2(int oldfd, int newfd) #ifdef HAVE_EVENTFD static int libc_eventfd(int count, int flags) { - swrap_bind_symbol_libc(eventfd); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_eventfd.f(count, flags); } @@ -839,7 +889,7 @@ static int libc_vfcntl(int fd, int cmd, va_list ap) void *arg; int rc; - swrap_bind_symbol_libc(fcntl); + swrap_bind_symbol_all(); arg = va_arg(ap, void *); @@ -852,7 +902,7 @@ static int libc_getpeername(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(getpeername); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getpeername.f(sockfd, addr, addrlen); } @@ -861,7 +911,7 @@ static int libc_getsockname(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(getsockname); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getsockname.f(sockfd, addr, addrlen); } @@ -872,7 +922,7 @@ static int libc_getsockopt(int sockfd, void *optval, socklen_t *optlen) { - swrap_bind_symbol_libsocket(getsockopt); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_getsockopt.f(sockfd, level, @@ -887,7 +937,7 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap) void *arg; int rc; - swrap_bind_symbol_libc(ioctl); + swrap_bind_symbol_all(); arg = va_arg(ap, void *); @@ -898,14 +948,14 @@ static int libc_vioctl(int d, unsigned long int request, va_list ap) static int libc_listen(int sockfd, int backlog) { - swrap_bind_symbol_libsocket(listen); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_listen.f(sockfd, backlog); } static FILE *libc_fopen(const char *name, const char *mode) { - swrap_bind_symbol_libc(fopen); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_fopen.f(name, mode); } @@ -913,7 +963,7 @@ static FILE *libc_fopen(const char *name, const char *mode) #ifdef HAVE_FOPEN64 static FILE *libc_fopen64(const char *name, const char *mode) { - swrap_bind_symbol_libc(fopen64); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_fopen64.f(name, mode); } @@ -924,7 +974,7 @@ static int libc_vopen(const char *pathname, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(open); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -952,7 +1002,7 @@ static int libc_vopen64(const char *pathname, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(open64); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -968,7 +1018,7 @@ static int libc_vopenat(int dirfd, const char *path, int flags, va_list ap) int mode = 0; int fd; - swrap_bind_symbol_libc(openat); + swrap_bind_symbol_all(); if (flags & O_CREAT) { mode = va_arg(ap, int); @@ -997,28 +1047,28 @@ static int libc_openat(int dirfd, const char *path, int flags, ...) static int libc_pipe(int pipefd[2]) { - swrap_bind_symbol_libsocket(pipe); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_pipe.f(pipefd); } static int libc_read(int fd, void *buf, size_t count) { - swrap_bind_symbol_libc(read); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_read.f(fd, buf, count); } static ssize_t libc_readv(int fd, const struct iovec *iov, int iovcnt) { - swrap_bind_symbol_libsocket(readv); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_readv.f(fd, iov, iovcnt); } static int libc_recv(int sockfd, void *buf, size_t len, int flags) { - swrap_bind_symbol_libsocket(recv); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recv.f(sockfd, buf, len, flags); } @@ -1030,7 +1080,7 @@ static int libc_recvfrom(int sockfd, struct sockaddr *src_addr, socklen_t *addrlen) { - swrap_bind_symbol_libsocket(recvfrom); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recvfrom.f(sockfd, buf, @@ -1042,21 +1092,21 @@ static int libc_recvfrom(int sockfd, static int libc_recvmsg(int sockfd, struct msghdr *msg, int flags) { - swrap_bind_symbol_libsocket(recvmsg); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_recvmsg.f(sockfd, msg, flags); } static int libc_send(int sockfd, const void *buf, size_t len, int flags) { - swrap_bind_symbol_libsocket(send); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_send.f(sockfd, buf, len, flags); } static int libc_sendmsg(int sockfd, const struct msghdr *msg, int flags) { - swrap_bind_symbol_libsocket(sendmsg); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_sendmsg.f(sockfd, msg, flags); } @@ -1068,7 +1118,7 @@ static int libc_sendto(int sockfd, const struct sockaddr *dst_addr, socklen_t addrlen) { - swrap_bind_symbol_libsocket(sendto); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_sendto.f(sockfd, buf, @@ -1084,7 +1134,7 @@ static int libc_setsockopt(int sockfd, const void *optval, socklen_t optlen) { - swrap_bind_symbol_libsocket(setsockopt); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_setsockopt.f(sockfd, level, @@ -1096,7 +1146,7 @@ static int libc_setsockopt(int sockfd, #ifdef HAVE_SIGNALFD static int libc_signalfd(int fd, const sigset_t *mask, int flags) { - swrap_bind_symbol_libsocket(signalfd); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_signalfd.f(fd, mask, flags); } @@ -1104,14 +1154,14 @@ static int libc_signalfd(int fd, const sigset_t *mask, int flags) static int libc_socket(int domain, int type, int protocol) { - swrap_bind_symbol_libsocket(socket); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_socket.f(domain, type, protocol); } static int libc_socketpair(int domain, int type, int protocol, int sv[2]) { - swrap_bind_symbol_libsocket(socketpair); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_socketpair.f(domain, type, protocol, sv); } @@ -1119,7 +1169,7 @@ static int libc_socketpair(int domain, int type, int protocol, int sv[2]) #ifdef HAVE_TIMERFD_CREATE static int libc_timerfd_create(int clockid, int flags) { - swrap_bind_symbol_libc(timerfd_create); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_timerfd_create.f(clockid, flags); } @@ -1127,20 +1177,20 @@ static int libc_timerfd_create(int clockid, int flags) static ssize_t libc_write(int fd, const void *buf, size_t count) { - swrap_bind_symbol_libc(write); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_write.f(fd, buf, count); } static ssize_t libc_writev(int fd, const struct iovec *iov, int iovcnt) { - swrap_bind_symbol_libsocket(writev); + swrap_bind_symbol_all(); return swrap.libc.symbols._libc_writev.f(fd, iov, iovcnt); } /* DO NOT call this function during library initialization! */ -static void swrap_bind_symbol_all(void) +static void __swrap_bind_symbol_all_once(void) { #ifdef HAVE_ACCEPT4 swrap_bind_symbol_libsocket(accept4); @@ -1192,6 +1242,13 @@ static void swrap_bind_symbol_all(void) swrap_bind_symbol_libsocket(writev); } +static void swrap_bind_symbol_all(void) +{ + static pthread_once_t all_symbol_binding_once = PTHREAD_ONCE_INIT; + + pthread_once(&all_symbol_binding_once, __swrap_bind_symbol_all_once); +} + /********************************************************* * SWRAP HELPER FUNCTIONS *********************************************************/ @@ -1438,7 +1495,7 @@ static char *socket_wrapper_dir(void) char *t; bool ok; - if (s == NULL) { + if (s == NULL || s[0] == '\0') { SWRAP_LOG(SWRAP_LOG_WARN, "SOCKET_WRAPPER_DIR not set"); return NULL; } @@ -1446,7 +1503,8 @@ static char *socket_wrapper_dir(void) swrap_dir = realpath(s, NULL); if (swrap_dir == NULL) { SWRAP_LOG(SWRAP_LOG_ERROR, - "Unable to resolve socket_wrapper dir path: %s", + "Unable to resolve socket_wrapper dir path: %s - %s", + s, strerror(errno)); abort(); } @@ -1524,26 +1582,31 @@ done: return max_mtu; } -static int socket_wrapper_init_mutex(pthread_mutex_t *m) +static int _socket_wrapper_init_mutex(pthread_mutex_t *m, const char *name) { pthread_mutexattr_t ma; - int ret; - - ret = pthread_mutexattr_init(&ma); - if (ret != 0) { - return ret; - } - - ret = pthread_mutexattr_settype(&ma, PTHREAD_MUTEX_ERRORCHECK); - if (ret != 0) { - goto done; - } - - ret = pthread_mutex_init(m, &ma); + bool need_destroy = false; + int ret = 0; + +#define __CHECK(cmd) do { \ + ret = cmd; \ + if (ret != 0) { \ + SWRAP_LOG(SWRAP_LOG_ERROR, \ + "%s: %s - failed %d", \ + name, #cmd, ret); \ + goto done; \ + } \ +} while(0) + *m = (pthread_mutex_t)PTHREAD_MUTEX_INITIALIZER; + __CHECK(pthread_mutexattr_init(&ma)); + need_destroy = true; + __CHECK(pthread_mutexattr_settype(&ma, PTHREAD_MUTEX_ERRORCHECK)); + __CHECK(pthread_mutex_init(m, &ma)); done: - pthread_mutexattr_destroy(&ma); - + if (need_destroy) { + pthread_mutexattr_destroy(&ma); + } return ret; } @@ -1618,7 +1681,9 @@ static void socket_wrapper_init_sockets(void) { size_t max_sockets; size_t i; - int ret; + int ret = 0; + + swrap_bind_symbol_all(); swrap_mutex_lock(&sockets_mutex); @@ -1627,6 +1692,10 @@ static void socket_wrapper_init_sockets(void) return; } + SWRAP_LOG(SWRAP_LOG_DEBUG, + "SOCKET_WRAPPER_PACKAGE[%s] SOCKET_WRAPPER_VERSION[%s]", + SOCKET_WRAPPER_PACKAGE, SOCKET_WRAPPER_VERSION); + /* * Intialize the static cache early before * any thread is able to start. @@ -1650,44 +1719,18 @@ static void socket_wrapper_init_sockets(void) } swrap_mutex_lock(&first_free_mutex); + swrap_mutex_lock(&sockets_si_global); first_free = 0; for (i = 0; i < max_sockets; i++) { swrap_set_next_free(&sockets[i].info, i+1); - ret = socket_wrapper_init_mutex(&sockets[i].meta.mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - goto done; - } } /* mark the end of the free list */ swrap_set_next_free(&sockets[max_sockets-1].info, -1); - ret = socket_wrapper_init_mutex(&autobind_start_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - goto done; - } - - ret = socket_wrapper_init_mutex(&pcap_dump_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - goto done; - } - - ret = socket_wrapper_init_mutex(&mtu_update_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - goto done; - } - -done: + swrap_mutex_unlock(&sockets_si_global); swrap_mutex_unlock(&first_free_mutex); swrap_mutex_unlock(&sockets_mutex); if (ret != 0) { @@ -1779,7 +1822,7 @@ static int find_socket_info_index(int fd) return socket_fds_idx[fd]; } -static int swrap_add_socket_info(struct socket_info *si_input) +static int swrap_add_socket_info(const struct socket_info *si_input) { struct socket_info *si = NULL; int si_index = -1; @@ -1822,6 +1865,7 @@ static int swrap_create_socket(struct socket_info *si, int fd) "trying to add %d", socket_fds_max, fd); + errno = EMFILE; return -1; } @@ -1846,31 +1890,40 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock if (p) p++; else p = un->sun_path; if (sscanf(p, SOCKET_FORMAT, &type, &iface, &prt) != 3) { + SWRAP_LOG(SWRAP_LOG_ERROR, "sun_path[%s] p[%s]", + un->sun_path, p); errno = EINVAL; return -1; } - SWRAP_LOG(SWRAP_LOG_TRACE, "type %c iface %u port %u", - type, iface, prt); - if (iface == 0 || iface > MAX_WRAPPED_INTERFACES) { + SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u", + type, iface, prt); errno = EINVAL; return -1; } if (prt > 0xFFFF) { + SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u", + type, iface, prt); errno = EINVAL; return -1; } + SWRAP_LOG(SWRAP_LOG_TRACE, "type %c iface %u port %u", + type, iface, prt); + switch(type) { case SOCKET_TYPE_CHAR_TCP: case SOCKET_TYPE_CHAR_UDP: { struct sockaddr_in *in2 = (struct sockaddr_in *)(void *)in; if ((*len) < sizeof(*in2)) { - errno = EINVAL; - return -1; + SWRAP_LOG(SWRAP_LOG_ERROR, + "V4: *len(%zu) < sizeof(*in2)=%zu", + (size_t)*len, sizeof(*in2)); + errno = EINVAL; + return -1; } memset(in2, 0, sizeof(*in2)); @@ -1887,6 +1940,10 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock struct sockaddr_in6 *in2 = (struct sockaddr_in6 *)(void *)in; if ((*len) < sizeof(*in2)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "V6: *len(%zu) < sizeof(*in2)=%zu", + (size_t)*len, sizeof(*in2)); + SWRAP_LOG(SWRAP_LOG_ERROR, "LINE:%d", __LINE__); errno = EINVAL; return -1; } @@ -1902,6 +1959,8 @@ static int convert_un_in(const struct sockaddr_un *un, struct sockaddr *in, sock } #endif default: + SWRAP_LOG(SWRAP_LOG_ERROR, "type %c iface %u port %u", + type, iface, prt); errno = EINVAL; return -1; } @@ -2932,7 +2991,7 @@ static int swrap_pcap_get_fd(const char *fname) file_hdr.link_type = 0x0065; /* 101 RAW IP */ if (write(fd, &file_hdr, sizeof(file_hdr)) != sizeof(file_hdr)) { - close(fd); + libc_close(fd); fd = -1; } return fd; @@ -3437,6 +3496,9 @@ static int swrap_socket(int family, int type, int protocol) ret = swrap_create_socket(si, fd); if (ret == -1) { + int saved_errno = errno; + libc_close(fd); + errno = saved_errno; return -1; } @@ -3585,10 +3647,12 @@ static int swrap_accept(int s, ret = libc_accept(s, &un_addr.sa.s, &un_addr.sa_socklen); #endif if (ret == -1) { - if (errno == ENOTSOCK) { + int saved_errno = errno; + if (saved_errno == ENOTSOCK) { /* Remove stale fds */ swrap_remove_stale(s); } + errno = saved_errno; return ret; } @@ -3597,6 +3661,50 @@ static int swrap_accept(int s, /* Check if we have a stale fd and remove it */ swrap_remove_stale(fd); + if (un_addr.sa.un.sun_path[0] == '\0') { + /* + * FreeBSD seems to have a problem where + * accept4() on the unix socket doesn't + * ECONNABORTED for already disconnected connections. + * + * Let's try libc_getpeername() to get the peer address + * as a fallback, but it'll likely return ENOTCONN, + * which we have to map to ECONNABORTED. + */ + un_addr.sa_socklen = sizeof(struct sockaddr_un), + ret = libc_getpeername(fd, &un_addr.sa.s, &un_addr.sa_socklen); + if (ret == -1) { + int saved_errno = errno; + libc_close(fd); + if (saved_errno == ENOTCONN) { + /* + * If the connection is already disconnected + * we should return ECONNABORTED. + */ + saved_errno = ECONNABORTED; + } + errno = saved_errno; + return ret; + } + } + + ret = libc_getsockname(fd, + &un_my_addr.sa.s, + &un_my_addr.sa_socklen); + if (ret == -1) { + int saved_errno = errno; + libc_close(fd); + if (saved_errno == ENOTCONN) { + /* + * If the connection is already disconnected + * we should return ECONNABORTED. + */ + saved_errno = ECONNABORTED; + } + errno = saved_errno; + return ret; + } + SWRAP_LOCK_SI(parent_si); ret = sockaddr_convert_from_un(parent_si, @@ -3606,8 +3714,10 @@ static int swrap_accept(int s, &in_addr.sa.s, &in_addr.sa_socklen); if (ret == -1) { + int saved_errno = errno; SWRAP_UNLOCK_SI(parent_si); - close(fd); + libc_close(fd); + errno = saved_errno; return ret; } @@ -3635,14 +3745,6 @@ static int swrap_accept(int s, *addrlen = in_addr.sa_socklen; } - ret = libc_getsockname(fd, - &un_my_addr.sa.s, - &un_my_addr.sa_socklen); - if (ret == -1) { - close(fd); - return ret; - } - ret = sockaddr_convert_from_un(child_si, &un_my_addr.sa.un, un_my_addr.sa_socklen, @@ -3650,7 +3752,9 @@ static int swrap_accept(int s, &in_my_addr.sa.s, &in_my_addr.sa_socklen); if (ret == -1) { - close(fd); + int saved_errno = errno; + libc_close(fd); + errno = saved_errno; return ret; } @@ -3665,7 +3769,9 @@ static int swrap_accept(int s, idx = swrap_create_socket(&new_si, fd); if (idx == -1) { - close (fd); + int saved_errno = errno; + libc_close(fd); + errno = saved_errno; return -1; } @@ -4959,16 +5065,21 @@ static int swrap_msghdr_add_socket_info(struct socket_info *si, return rc; } -static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, +static int swrap_sendmsg_copy_cmsg(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space); -static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, - uint8_t **cm_data, - size_t *cm_data_space); - -static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, +static int swrap_sendmsg_filter_cmsg_ipproto_ip(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space); +static int swrap_sendmsg_filter_cmsg_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space); + +static int swrap_sendmsg_filter_cmsghdr(const struct msghdr *_msg, uint8_t **cm_data, - size_t *cm_data_space) { + size_t *cm_data_space) +{ + struct msghdr *msg = discard_const_p(struct msghdr, _msg); struct cmsghdr *cmsg; int rc = -1; @@ -4982,9 +5093,14 @@ static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, cmsg = CMSG_NXTHDR(msg, cmsg)) { switch (cmsg->cmsg_level) { case IPPROTO_IP: - rc = swrap_sendmsg_filter_cmsg_socket(cmsg, - cm_data, - cm_data_space); + rc = swrap_sendmsg_filter_cmsg_ipproto_ip(cmsg, + cm_data, + cm_data_space); + break; + case SOL_SOCKET: + rc = swrap_sendmsg_filter_cmsg_sol_socket(cmsg, + cm_data, + cm_data_space); break; default: rc = swrap_sendmsg_copy_cmsg(cmsg, @@ -4992,12 +5108,19 @@ static int swrap_sendmsg_filter_cmsghdr(struct msghdr *msg, cm_data_space); break; } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(*cm_data); + *cm_data_space = 0; + errno = saved_errno; + return rc; + } } return rc; } -static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, +static int swrap_sendmsg_copy_cmsg(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space) { @@ -5020,14 +5143,14 @@ static int swrap_sendmsg_copy_cmsg(struct cmsghdr *cmsg, return 0; } -static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, +static int swrap_sendmsg_filter_cmsg_pktinfo(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space); -static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, - uint8_t **cm_data, - size_t *cm_data_space) +static int swrap_sendmsg_filter_cmsg_ipproto_ip(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) { int rc = -1; @@ -5053,7 +5176,7 @@ static int swrap_sendmsg_filter_cmsg_socket(struct cmsghdr *cmsg, return rc; } -static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, +static int swrap_sendmsg_filter_cmsg_pktinfo(const struct cmsghdr *cmsg, uint8_t **cm_data, size_t *cm_data_space) { @@ -5067,53 +5190,957 @@ static int swrap_sendmsg_filter_cmsg_pktinfo(struct cmsghdr *cmsg, */ return 0; } -#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ -static ssize_t swrap_sendmsg_before(int fd, - struct socket_info *si, - struct msghdr *msg, - struct iovec *tmp_iov, - struct sockaddr_un *tmp_un, - const struct sockaddr_un **to_un, - const struct sockaddr **to, - int *bcast) +static int swrap_sendmsg_filter_cmsg_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) { - size_t i, len = 0; - ssize_t ret = -1; + int rc = -1; - if (to_un) { - *to_un = NULL; - } - if (to) { - *to = NULL; - } - if (bcast) { - *bcast = 0; + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + SWRAP_LOG(SWRAP_LOG_TRACE, + "Ignoring SCM_RIGHTS on inet socket!"); + rc = 0; + break; +#ifdef SCM_CREDENTIALS + case SCM_CREDENTIALS: + SWRAP_LOG(SWRAP_LOG_TRACE, + "Ignoring SCM_CREDENTIALS on inet socket!"); + rc = 0; + break; +#endif /* SCM_CREDENTIALS */ + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; } - SWRAP_LOCK_SI(si); + return rc; +} - switch (si->type) { - case SOCK_STREAM: { - unsigned long mtu; +static const uint64_t swrap_unix_scm_right_magic = 0x8e0e13f27c42fc36; - if (!si->connected) { - errno = ENOTCONN; - goto out; +/* + * We only allow up to 6 fds at a time + * as that's more than enough for Samba + * and it means we can keep the logic simple + * and work with fixed size arrays. + * + * We also keep sizeof(struct swrap_unix_scm_rights) + * under PIPE_BUF (4096) in order to allow a non-blocking + * write into the pipe. + */ +#ifndef PIPE_BUF +#define PIPE_BUF 4096 +#endif +#define SWRAP_MAX_PASSED_FDS ((size_t)6) +#define SWRAP_MAX_PASSED_SOCKET_INFO SWRAP_MAX_PASSED_FDS +struct swrap_unix_scm_rights_payload { + uint8_t num_idxs; + int8_t idxs[SWRAP_MAX_PASSED_FDS]; + struct socket_info infos[SWRAP_MAX_PASSED_SOCKET_INFO]; +}; +struct swrap_unix_scm_rights { + uint64_t magic; + char package_name[sizeof(SOCKET_WRAPPER_PACKAGE)]; + char package_version[sizeof(SOCKET_WRAPPER_VERSION)]; + uint32_t full_size; + uint32_t payload_size; + struct swrap_unix_scm_rights_payload payload; +}; + +static void swrap_dec_fd_passed_array(size_t num, struct socket_info **array) +{ + int saved_errno = errno; + size_t i; + + for (i = 0; i < num; i++) { + struct socket_info *si = array[i]; + if (si == NULL) { + continue; } - if (msg->msg_iovlen == 0) { - break; + SWRAP_LOCK_SI(si); + swrap_dec_refcount(si); + if (si->fd_passed > 0) { + si->fd_passed -= 1; } + SWRAP_UNLOCK_SI(si); + array[i] = NULL; + } - mtu = socket_wrapper_mtu(); - for (i = 0; i < (size_t)msg->msg_iovlen; i++) { - size_t nlen; - nlen = len + msg->msg_iov[i].iov_len; - if (nlen < len) { - /* overflow */ - errno = EMSGSIZE; - goto out; + errno = saved_errno; +} + +static void swrap_undo_si_idx_array(size_t num, int *array) +{ + int saved_errno = errno; + size_t i; + + swrap_mutex_lock(&first_free_mutex); + + for (i = 0; i < num; i++) { + struct socket_info *si = NULL; + + if (array[i] == -1) { + continue; + } + + si = swrap_get_socket_info(array[i]); + if (si == NULL) { + continue; + } + + SWRAP_LOCK_SI(si); + swrap_dec_refcount(si); + SWRAP_UNLOCK_SI(si); + + swrap_set_next_free(si, first_free); + first_free = array[i]; + array[i] = -1; + } + + swrap_mutex_unlock(&first_free_mutex); + errno = saved_errno; +} + +static void swrap_close_fd_array(size_t num, const int *array) +{ + int saved_errno = errno; + size_t i; + + for (i = 0; i < num; i++) { + if (array[i] == -1) { + continue; + } + libc_close(array[i]); + } + + errno = saved_errno; +} + +union __swrap_fds { + const uint8_t *p; + int *fds; +}; + +union __swrap_cmsghdr { + const uint8_t *p; + struct cmsghdr *cmsg; +}; + +static int swrap_sendmsg_unix_scm_rights(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space, + int *scm_rights_pipe_fd) +{ + struct swrap_unix_scm_rights info; + struct swrap_unix_scm_rights_payload *payload = NULL; + int si_idx_array[SWRAP_MAX_PASSED_FDS]; + struct socket_info *si_array[SWRAP_MAX_PASSED_FDS] = { NULL, }; + size_t info_idx = 0; + size_t size_fds_in; + size_t num_fds_in; + union __swrap_fds __fds_in = { .p = NULL, }; + const int *fds_in = NULL; + size_t num_fds_out; + size_t size_fds_out; + union __swrap_fds __fds_out = { .p = NULL, }; + int *fds_out = NULL; + size_t cmsg_len; + size_t cmsg_space; + size_t new_cm_data_space; + union __swrap_cmsghdr __new_cmsg = { .p = NULL, }; + struct cmsghdr *new_cmsg = NULL; + uint8_t *p = NULL; + size_t i; + int pipefd[2] = { -1, -1 }; + int rc; + ssize_t sret; + + /* + * We pass this a buffer to the kernel make sure any padding + * is also cleared. + */ + ZERO_STRUCT(info); + info.magic = swrap_unix_scm_right_magic; + memcpy(info.package_name, + SOCKET_WRAPPER_PACKAGE, + sizeof(info.package_name)); + memcpy(info.package_version, + SOCKET_WRAPPER_VERSION, + sizeof(info.package_version)); + info.full_size = sizeof(info); + info.payload_size = sizeof(info.payload); + payload = &info.payload; + + if (*scm_rights_pipe_fd != -1) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "Two SCM_RIGHTS headers are not supported by socket_wrapper"); + errno = EINVAL; + return -1; + } + + if (cmsg->cmsg_len < CMSG_LEN(0)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu < CMSG_LEN(0)=%zu", + (size_t)cmsg->cmsg_len, + CMSG_LEN(0)); + errno = EINVAL; + return -1; + } + size_fds_in = cmsg->cmsg_len - CMSG_LEN(0); + if ((size_fds_in % sizeof(int)) != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu => (size_fds_in=%zu %% sizeof(int)=%zu) != 0", + (size_t)cmsg->cmsg_len, + size_fds_in, + sizeof(int)); + errno = EINVAL; + return -1; + } + num_fds_in = size_fds_in / sizeof(int); + if (num_fds_in > SWRAP_MAX_PASSED_FDS) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu > " + "SWRAP_MAX_PASSED_FDS(%zu)", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in, + SWRAP_MAX_PASSED_FDS); + errno = EINVAL; + return -1; + } + if (num_fds_in == 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in); + errno = EINVAL; + return -1; + } + __fds_in.p = CMSG_DATA(cmsg); + fds_in = __fds_in.fds; + num_fds_out = num_fds_in + 1; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "num_fds_in=%zu num_fds_out=%zu", + num_fds_in, num_fds_out); + + size_fds_out = sizeof(int) * num_fds_out; + cmsg_len = CMSG_LEN(size_fds_out); + cmsg_space = CMSG_SPACE(size_fds_out); + + new_cm_data_space = *cm_data_space + cmsg_space; + + p = realloc((*cm_data), new_cm_data_space); + if (p == NULL) { + return -1; + } + (*cm_data) = p; + p = (*cm_data) + (*cm_data_space); + memset(p, 0, cmsg_space); + __new_cmsg.p = p; + new_cmsg = __new_cmsg.cmsg; + *new_cmsg = *cmsg; + __fds_out.p = CMSG_DATA(new_cmsg); + fds_out = __fds_out.fds; + memcpy(fds_out, fds_in, size_fds_in); + new_cmsg->cmsg_len = cmsg->cmsg_len; + + for (i = 0; i < num_fds_in; i++) { + size_t j; + + payload->idxs[i] = -1; + payload->num_idxs++; + + si_idx_array[i] = find_socket_info_index(fds_in[i]); + if (si_idx_array[i] == -1) { + continue; + } + + si_array[i] = swrap_get_socket_info(si_idx_array[i]); + if (si_array[i] == NULL) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_in[%zu]=%d si_idx_array[%zu]=%d missing!", + i, fds_in[i], i, si_idx_array[i]); + errno = EINVAL; + return -1; + } + + for (j = 0; j < i; j++) { + if (si_array[j] == si_array[i]) { + payload->idxs[i] = payload->idxs[j]; + break; + } + } + if (payload->idxs[i] == -1) { + if (info_idx >= SWRAP_MAX_PASSED_SOCKET_INFO) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_in[%zu]=%d,si_idx_array[%zu]=%d: " + "info_idx=%zu >= SWRAP_MAX_PASSED_FDS(%zu)!", + i, fds_in[i], i, si_idx_array[i], + info_idx, + SWRAP_MAX_PASSED_SOCKET_INFO); + errno = EINVAL; + return -1; + } + payload->idxs[i] = info_idx; + info_idx += 1; + continue; + } + } + + for (i = 0; i < num_fds_in; i++) { + struct socket_info *si = si_array[i]; + + if (si == NULL) { + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d not an inet socket", + i, fds_in[i]); + continue; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d si_idx_array[%zu]=%d " + "passing as info.idxs[%zu]=%d!", + i, fds_in[i], + i, si_idx_array[i], + i, payload->idxs[i]); + + SWRAP_LOCK_SI(si); + si->fd_passed += 1; + payload->infos[payload->idxs[i]] = *si; + payload->infos[payload->idxs[i]].fd_passed = 0; + SWRAP_UNLOCK_SI(si); + } + + rc = pipe(pipefd); + if (rc == -1) { + int saved_errno = errno; + SWRAP_LOG(SWRAP_LOG_ERROR, + "pipe() failed - %d %s", + saved_errno, + strerror(saved_errno)); + swrap_dec_fd_passed_array(num_fds_in, si_array); + errno = saved_errno; + return -1; + } + + sret = write(pipefd[1], &info, sizeof(info)); + if (sret != sizeof(info)) { + int saved_errno = errno; + if (sret != -1) { + saved_errno = EINVAL; + } + SWRAP_LOG(SWRAP_LOG_ERROR, + "write() failed - sret=%zd - %d %s", + sret, saved_errno, + strerror(saved_errno)); + swrap_dec_fd_passed_array(num_fds_in, si_array); + libc_close(pipefd[1]); + libc_close(pipefd[0]); + errno = saved_errno; + return -1; + } + libc_close(pipefd[1]); + + /* + * Add the pipe read end to the end of the passed fd array + */ + fds_out[num_fds_in] = pipefd[0]; + new_cmsg->cmsg_len = cmsg_len; + + /* we're done ... */ + *scm_rights_pipe_fd = pipefd[0]; + *cm_data_space = new_cm_data_space; + + return 0; +} + +static int swrap_sendmsg_unix_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space, + int *scm_rights_pipe_fd) +{ + int rc = -1; + + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + rc = swrap_sendmsg_unix_scm_rights(cmsg, + cm_data, + cm_data_space, + scm_rights_pipe_fd); + break; + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; + } + + return rc; +} + +static int swrap_recvmsg_unix_scm_rights(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) +{ + int scm_rights_pipe_fd = -1; + struct swrap_unix_scm_rights info; + struct swrap_unix_scm_rights_payload *payload = NULL; + int si_idx_array[SWRAP_MAX_PASSED_FDS]; + size_t size_fds_in; + size_t num_fds_in; + union __swrap_fds __fds_in = { .p = NULL, }; + const int *fds_in = NULL; + size_t num_fds_out; + size_t size_fds_out; + union __swrap_fds __fds_out = { .p = NULL, }; + int *fds_out = NULL; + size_t cmsg_len; + size_t cmsg_space; + size_t new_cm_data_space; + union __swrap_cmsghdr __new_cmsg = { .p = NULL, }; + struct cmsghdr *new_cmsg = NULL; + uint8_t *p = NULL; + ssize_t sret; + size_t i; + int cmp; + + if (cmsg->cmsg_len < CMSG_LEN(0)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu < CMSG_LEN(0)=%zu", + (size_t)cmsg->cmsg_len, + CMSG_LEN(0)); + errno = EINVAL; + return -1; + } + size_fds_in = cmsg->cmsg_len - CMSG_LEN(0); + if ((size_fds_in % sizeof(int)) != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu => (size_fds_in=%zu %% sizeof(int)=%zu) != 0", + (size_t)cmsg->cmsg_len, + size_fds_in, + sizeof(int)); + errno = EINVAL; + return -1; + } + num_fds_in = size_fds_in / sizeof(int); + if (num_fds_in > (SWRAP_MAX_PASSED_FDS + 1)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu > SWRAP_MAX_PASSED_FDS+1(%zu)", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in, + SWRAP_MAX_PASSED_FDS+1); + errno = EINVAL; + return -1; + } + if (num_fds_in <= 1) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "cmsg->cmsg_len=%zu,size_fds_in=%zu => " + "num_fds_in=%zu", + (size_t)cmsg->cmsg_len, + size_fds_in, + num_fds_in); + errno = EINVAL; + return -1; + } + __fds_in.p = CMSG_DATA(cmsg); + fds_in = __fds_in.fds; + num_fds_out = num_fds_in - 1; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "num_fds_in=%zu num_fds_out=%zu", + num_fds_in, num_fds_out); + + for (i = 0; i < num_fds_in; i++) { + /* Check if we have a stale fd and remove it */ + swrap_remove_stale(fds_in[i]); + } + + scm_rights_pipe_fd = fds_in[num_fds_out]; + size_fds_out = sizeof(int) * num_fds_out; + cmsg_len = CMSG_LEN(size_fds_out); + cmsg_space = CMSG_SPACE(size_fds_out); + + new_cm_data_space = *cm_data_space + cmsg_space; + + p = realloc((*cm_data), new_cm_data_space); + if (p == NULL) { + swrap_close_fd_array(num_fds_in, fds_in); + return -1; + } + (*cm_data) = p; + p = (*cm_data) + (*cm_data_space); + memset(p, 0, cmsg_space); + __new_cmsg.p = p; + new_cmsg = __new_cmsg.cmsg; + *new_cmsg = *cmsg; + __fds_out.p = CMSG_DATA(new_cmsg); + fds_out = __fds_out.fds; + memcpy(fds_out, fds_in, size_fds_out); + new_cmsg->cmsg_len = cmsg_len; + + sret = read(scm_rights_pipe_fd, &info, sizeof(info)); + if (sret != sizeof(info)) { + int saved_errno = errno; + if (sret != -1) { + saved_errno = EINVAL; + } + SWRAP_LOG(SWRAP_LOG_ERROR, + "read() failed - sret=%zd - %d %s", + sret, saved_errno, + strerror(saved_errno)); + swrap_close_fd_array(num_fds_in, fds_in); + errno = saved_errno; + return -1; + } + libc_close(scm_rights_pipe_fd); + payload = &info.payload; + + if (info.magic != swrap_unix_scm_right_magic) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.magic=0x%llx != swrap_unix_scm_right_magic=0x%llx", + (unsigned long long)info.magic, + (unsigned long long)swrap_unix_scm_right_magic); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + cmp = memcmp(info.package_name, + SOCKET_WRAPPER_PACKAGE, + sizeof(info.package_name)); + if (cmp != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.package_name='%.*s' != '%s'", + (int)sizeof(info.package_name), + info.package_name, + SOCKET_WRAPPER_PACKAGE); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + cmp = memcmp(info.package_version, + SOCKET_WRAPPER_VERSION, + sizeof(info.package_version)); + if (cmp != 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.package_version='%.*s' != '%s'", + (int)sizeof(info.package_version), + info.package_version, + SOCKET_WRAPPER_VERSION); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if (info.full_size != sizeof(info)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.full_size=%zu != sizeof(info)=%zu", + (size_t)info.full_size, + sizeof(info)); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if (info.payload_size != sizeof(info.payload)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.payload_size=%zu != sizeof(info.payload)=%zu", + (size_t)info.payload_size, + sizeof(info.payload)); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if (payload->num_idxs != num_fds_out) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "info.num_idxs=%u != num_fds_out=%zu", + payload->num_idxs, num_fds_out); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + for (i = 0; i < num_fds_out; i++) { + size_t j; + + si_idx_array[i] = -1; + + if (payload->idxs[i] == -1) { + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_out[%zu]=%d not an inet socket", + i, fds_out[i]); + continue; + } + + if (payload->idxs[i] < 0) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_out[%zu]=%d info.idxs[%zu]=%d < 0!", + i, fds_out[i], i, payload->idxs[i]); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if (payload->idxs[i] >= payload->num_idxs) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "fds_out[%zu]=%d info.idxs[%zu]=%d >= %u!", + i, fds_out[i], i, payload->idxs[i], + payload->num_idxs); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EINVAL; + return -1; + } + + if ((size_t)fds_out[i] >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + fds_out[i]); + swrap_close_fd_array(num_fds_out, fds_out); + errno = EMFILE; + return -1; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "fds_in[%zu]=%d " + "received as info.idxs[%zu]=%d!", + i, fds_out[i], + i, payload->idxs[i]); + + for (j = 0; j < i; j++) { + if (payload->idxs[j] == -1) { + continue; + } + if (payload->idxs[j] == payload->idxs[i]) { + si_idx_array[i] = si_idx_array[j]; + } + } + if (si_idx_array[i] == -1) { + const struct socket_info *si = &payload->infos[payload->idxs[i]]; + + si_idx_array[i] = swrap_add_socket_info(si); + if (si_idx_array[i] == -1) { + int saved_errno = errno; + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + fds_out[i]); + swrap_undo_si_idx_array(i, si_idx_array); + swrap_close_fd_array(num_fds_out, fds_out); + errno = saved_errno; + return -1; + } + SWRAP_LOG(SWRAP_LOG_TRACE, + "Imported %s socket for protocol %s, fd=%d", + si->family == AF_INET ? "IPv4" : "IPv6", + si->type == SOCK_DGRAM ? "UDP" : "TCP", + fds_out[i]); + } + } + + for (i = 0; i < num_fds_out; i++) { + if (si_idx_array[i] == -1) { + continue; + } + set_socket_info_index(fds_out[i], si_idx_array[i]); + } + + /* we're done ... */ + *cm_data_space = new_cm_data_space; + + return 0; +} + +static int swrap_recvmsg_unix_sol_socket(const struct cmsghdr *cmsg, + uint8_t **cm_data, + size_t *cm_data_space) +{ + int rc = -1; + + switch (cmsg->cmsg_type) { + case SCM_RIGHTS: + rc = swrap_recvmsg_unix_scm_rights(cmsg, + cm_data, + cm_data_space); + break; + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + cm_data, + cm_data_space); + break; + } + + return rc; +} + +#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + +static int swrap_sendmsg_before_unix(const struct msghdr *_msg_in, + struct msghdr *msg_tmp, + int *scm_rights_pipe_fd) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + struct msghdr *msg_in = discard_const_p(struct msghdr, _msg_in); + struct cmsghdr *cmsg = NULL; + uint8_t *cm_data = NULL; + size_t cm_data_space = 0; + int rc = -1; + + *msg_tmp = *msg_in; + *scm_rights_pipe_fd = -1; + + /* Nothing to do */ + if (msg_in->msg_controllen == 0 || msg_in->msg_control == NULL) { + return 0; + } + + for (cmsg = CMSG_FIRSTHDR(msg_in); + cmsg != NULL; + cmsg = CMSG_NXTHDR(msg_in, cmsg)) { + switch (cmsg->cmsg_level) { + case SOL_SOCKET: + rc = swrap_sendmsg_unix_sol_socket(cmsg, + &cm_data, + &cm_data_space, + scm_rights_pipe_fd); + break; + + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + &cm_data, + &cm_data_space); + break; + } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(cm_data); + errno = saved_errno; + return rc; + } + } + + msg_tmp->msg_controllen = cm_data_space; + msg_tmp->msg_control = cm_data; + + return 0; +#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + *msg_tmp = *_msg_in; + return 0; +#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */ +} + +static ssize_t swrap_sendmsg_after_unix(struct msghdr *msg_tmp, + ssize_t ret, + int scm_rights_pipe_fd) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + int saved_errno = errno; + SAFE_FREE(msg_tmp->msg_control); + if (scm_rights_pipe_fd != -1) { + libc_close(scm_rights_pipe_fd); + } + errno = saved_errno; +#endif /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + return ret; +} + +static int swrap_recvmsg_before_unix(struct msghdr *msg_in, + struct msghdr *msg_tmp, + uint8_t **tmp_control) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + const size_t cm_extra_space = CMSG_SPACE(sizeof(int)); + uint8_t *cm_data = NULL; + size_t cm_data_space = 0; + + *msg_tmp = *msg_in; + *tmp_control = NULL; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "msg_in->msg_controllen=%zu", + (size_t)msg_in->msg_controllen); + + /* Nothing to do */ + if (msg_in->msg_controllen == 0 || msg_in->msg_control == NULL) { + return 0; + } + + /* + * We need to give the kernel a bit more space in order + * recv the pipe fd, added by swrap_sendmsg_before_unix()). + * swrap_recvmsg_after_unix() will hide it again. + */ + cm_data_space = msg_in->msg_controllen; + if (cm_data_space < (INT32_MAX - cm_extra_space)) { + cm_data_space += cm_extra_space; + } + cm_data = calloc(1, cm_data_space); + if (cm_data == NULL) { + return -1; + } + + msg_tmp->msg_controllen = cm_data_space; + msg_tmp->msg_control = cm_data; + *tmp_control = cm_data; + + SWRAP_LOG(SWRAP_LOG_TRACE, + "msg_tmp->msg_controllen=%zu", + (size_t)msg_tmp->msg_controllen); + return 0; +#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + *msg_tmp = *msg_in; + *tmp_control = NULL; + return 0; +#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */ +} + +static ssize_t swrap_recvmsg_after_unix(struct msghdr *msg_tmp, + uint8_t **tmp_control, + struct msghdr *msg_out, + ssize_t ret) +{ +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + struct cmsghdr *cmsg = NULL; + uint8_t *cm_data = NULL; + size_t cm_data_space = 0; + int rc = -1; + + if (ret < 0) { + int saved_errno = errno; + SWRAP_LOG(SWRAP_LOG_TRACE, "ret=%zd - %d - %s", ret, + saved_errno, strerror(saved_errno)); + SAFE_FREE(*tmp_control); + /* msg_out should not be touched on error */ + errno = saved_errno; + return ret; + } + + SWRAP_LOG(SWRAP_LOG_TRACE, + "msg_tmp->msg_controllen=%zu", + (size_t)msg_tmp->msg_controllen); + + /* Nothing to do */ + if (msg_tmp->msg_controllen == 0 || msg_tmp->msg_control == NULL) { + int saved_errno = errno; + *msg_out = *msg_tmp; + SAFE_FREE(*tmp_control); + errno = saved_errno; + return ret; + } + + for (cmsg = CMSG_FIRSTHDR(msg_tmp); + cmsg != NULL; + cmsg = CMSG_NXTHDR(msg_tmp, cmsg)) { + switch (cmsg->cmsg_level) { + case SOL_SOCKET: + rc = swrap_recvmsg_unix_sol_socket(cmsg, + &cm_data, + &cm_data_space); + break; + + default: + rc = swrap_sendmsg_copy_cmsg(cmsg, + &cm_data, + &cm_data_space); + break; + } + if (rc < 0) { + int saved_errno = errno; + SAFE_FREE(cm_data); + SAFE_FREE(*tmp_control); + errno = saved_errno; + return rc; + } + } + + /* + * msg_tmp->msg_control (*tmp_control) was created by + * swrap_recvmsg_before_unix() and msg_out->msg_control + * is still the buffer of the caller. + */ + msg_tmp->msg_control = msg_out->msg_control; + msg_tmp->msg_controllen = msg_out->msg_controllen; + *msg_out = *msg_tmp; + + cm_data_space = MIN(cm_data_space, msg_out->msg_controllen); + memcpy(msg_out->msg_control, cm_data, cm_data_space); + msg_out->msg_controllen = cm_data_space; + SAFE_FREE(cm_data); + SAFE_FREE(*tmp_control); + + SWRAP_LOG(SWRAP_LOG_TRACE, + "msg_out->msg_controllen=%zu", + (size_t)msg_out->msg_controllen); + return ret; +#else /* HAVE_STRUCT_MSGHDR_MSG_CONTROL */ + int saved_errno = errno; + *msg_out = *msg_tmp; + SAFE_FREE(*tmp_control); + errno = saved_errno; + return ret; +#endif /* ! HAVE_STRUCT_MSGHDR_MSG_CONTROL */ +} + +static ssize_t swrap_sendmsg_before(int fd, + struct socket_info *si, + struct msghdr *msg, + struct iovec *tmp_iov, + struct sockaddr_un *tmp_un, + const struct sockaddr_un **to_un, + const struct sockaddr **to, + int *bcast) +{ + size_t i, len = 0; + ssize_t ret = -1; + + if (to_un) { + *to_un = NULL; + } + if (to) { + *to = NULL; + } + if (bcast) { + *bcast = 0; + } + + SWRAP_LOCK_SI(si); + + switch (si->type) { + case SOCK_STREAM: { + unsigned long mtu; + + if (!si->connected) { + errno = ENOTCONN; + goto out; + } + + if (msg->msg_iovlen == 0) { + break; + } + + mtu = socket_wrapper_mtu(); + for (i = 0; i < (size_t)msg->msg_iovlen; i++) { + size_t nlen; + nlen = len + msg->msg_iov[i].iov_len; + if (nlen < len) { + /* overflow */ + errno = EMSGSIZE; + goto out; } if (nlen > mtu) { break; @@ -5215,28 +6242,6 @@ static ssize_t swrap_sendmsg_before(int fd, goto out; } -#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL - if (msg->msg_controllen > 0 && msg->msg_control != NULL) { - uint8_t *cmbuf = NULL; - size_t cmlen = 0; - - ret = swrap_sendmsg_filter_cmsghdr(msg, &cmbuf, &cmlen); - if (ret < 0) { - free(cmbuf); - goto out; - } - - if (cmlen == 0) { - msg->msg_controllen = 0; - msg->msg_control = NULL; - } else if (cmlen < msg->msg_controllen && cmbuf != NULL) { - memcpy(msg->msg_control, cmbuf, cmlen); - msg->msg_controllen = cmlen; - } - free(cmbuf); - } -#endif - ret = 0; out: SWRAP_UNLOCK_SI(si); @@ -6003,7 +7008,13 @@ static ssize_t swrap_recvmsg(int s, struct msghdr *omsg, int flags) si = find_socket_info(s); if (si == NULL) { - return libc_recvmsg(s, omsg, flags); + uint8_t *tmp_control = NULL; + rc = swrap_recvmsg_before_unix(omsg, &msg, &tmp_control); + if (rc < 0) { + return rc; + } + ret = libc_recvmsg(s, &msg, flags); + return swrap_recvmsg_after_unix(&msg, &tmp_control, omsg, ret); } tmp.iov_base = NULL; @@ -6126,7 +7137,15 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) int bcast = 0; if (!si) { - return libc_sendmsg(s, omsg, flags); + int scm_rights_pipe_fd = -1; + + rc = swrap_sendmsg_before_unix(omsg, &msg, + &scm_rights_pipe_fd); + if (rc < 0) { + return rc; + } + ret = libc_sendmsg(s, &msg, flags); + return swrap_sendmsg_after_unix(&msg, ret, scm_rights_pipe_fd); } ZERO_STRUCT(un_addr); @@ -6148,20 +7167,32 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) SWRAP_UNLOCK_SI(si); #ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL - if (msg.msg_controllen > 0 && msg.msg_control != NULL) { - /* omsg is a const so use a local buffer for modifications */ - uint8_t cmbuf[omsg->msg_controllen]; + if (omsg != NULL && omsg->msg_controllen > 0 && omsg->msg_control != NULL) { + uint8_t *cmbuf = NULL; + size_t cmlen = 0; - memcpy(cmbuf, omsg->msg_control, omsg->msg_controllen); + rc = swrap_sendmsg_filter_cmsghdr(omsg, &cmbuf, &cmlen); + if (rc < 0) { + return rc; + } - msg.msg_control = cmbuf; /* ancillary data, see below */ - msg.msg_controllen = omsg->msg_controllen; /* ancillary data buffer len */ + if (cmlen == 0) { + msg.msg_controllen = 0; + msg.msg_control = NULL; + } else { + msg.msg_control = cmbuf; + msg.msg_controllen = cmlen; + } } msg.msg_flags = omsg->msg_flags; /* flags on received message */ #endif - rc = swrap_sendmsg_before(s, si, &msg, &tmp, &un_addr, &to_un, &to, &bcast); if (rc < 0) { + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + errno = saved_errno; return -1; } @@ -6187,6 +7218,11 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) /* we capture it as one single packet */ buf = (uint8_t *)malloc(remain); if (!buf) { + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + errno = saved_errno; return -1; } @@ -6203,7 +7239,12 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) swrap_dir = socket_wrapper_dir(); if (swrap_dir == NULL) { - free(buf); + int saved_errno = errno; +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + SAFE_FREE(msg.msg_control); +#endif + SAFE_FREE(buf); + errno = saved_errno; return -1; } @@ -6234,6 +7275,14 @@ static ssize_t swrap_sendmsg(int s, const struct msghdr *omsg, int flags) swrap_sendmsg_after(s, si, &msg, to, ret); +#ifdef HAVE_STRUCT_MSGHDR_MSG_CONTROL + { + int saved_errno = errno; + SAFE_FREE(msg.msg_control); + errno = saved_errno; + } +#endif + return ret; } @@ -6385,6 +7434,10 @@ static int swrap_close(int fd) goto out; } + if (si->fd_passed) { + goto set_next_free; + } + if (si->myname.sa_socklen > 0 && si->peername.sa_socklen > 0) { swrap_pcap_dump_packet(si, NULL, SWRAP_CLOSE_SEND, NULL, 0); } @@ -6398,6 +7451,7 @@ static int swrap_close(int fd) unlink(si->un_addr.sun_path); } +set_next_free: swrap_set_next_free(si, first_free); first_free = si_index; @@ -6437,6 +7491,17 @@ static int swrap_dup(int fd) return -1; } + if ((size_t)dup_fd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + dup_fd); + libc_close(dup_fd); + errno = EMFILE; + return -1; + } + SWRAP_LOCK_SI(si); swrap_inc_refcount(si); @@ -6482,6 +7547,16 @@ static int swrap_dup2(int fd, int newfd) return newfd; } + if ((size_t)newfd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + newfd); + errno = EMFILE; + return -1; + } + if (find_socket_info(newfd)) { /* dup2() does an implicit close of newfd, which we * need to emulate */ @@ -6539,14 +7614,26 @@ static int swrap_vfcntl(int fd, int cmd, va_list va) return -1; } + /* Make sure we don't have an entry for the fd */ + swrap_remove_stale(dup_fd); + + if ((size_t)dup_fd >= socket_fds_max) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "The max socket index limit of %zu has been reached, " + "trying to add %d", + socket_fds_max, + dup_fd); + libc_close(dup_fd); + errno = EMFILE; + return -1; + } + SWRAP_LOCK_SI(si); swrap_inc_refcount(si); SWRAP_UNLOCK_SI(si); - /* Make sure we don't have an entry for the fd */ - swrap_remove_stale(dup_fd); set_socket_info_index(dup_fd, idx); @@ -6632,7 +7719,7 @@ static void swrap_thread_parent(void) static void swrap_thread_child(void) { - SWRAP_UNLOCK_ALL; + SWRAP_REINIT_ALL; } /**************************** @@ -6640,7 +7727,20 @@ static void swrap_thread_child(void) ***************************/ void swrap_constructor(void) { - int ret; + if (PIPE_BUF < sizeof(struct swrap_unix_scm_rights)) { + SWRAP_LOG(SWRAP_LOG_ERROR, + "PIPE_BUF=%zu < " + "sizeof(struct swrap_unix_scm_rights)=%zu\n" + "sizeof(struct swrap_unix_scm_rights_payload)=%zu " + "sizeof(struct socket_info)=%zu", + (size_t)PIPE_BUF, + sizeof(struct swrap_unix_scm_rights), + sizeof(struct swrap_unix_scm_rights_payload), + sizeof(struct socket_info)); + exit(-1); + } + + SWRAP_REINIT_ALL; /* * If we hold a lock and the application forks, then the child @@ -6650,27 +7750,6 @@ void swrap_constructor(void) pthread_atfork(&swrap_thread_prepare, &swrap_thread_parent, &swrap_thread_child); - - ret = socket_wrapper_init_mutex(&sockets_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - exit(-1); - } - - ret = socket_wrapper_init_mutex(&socket_reset_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - exit(-1); - } - - ret = socket_wrapper_init_mutex(&first_free_mutex); - if (ret != 0) { - SWRAP_LOG(SWRAP_LOG_ERROR, - "Failed to initialize pthread mutex"); - exit(-1); - } } /**************************** diff --git a/third_party/socket_wrapper/wscript b/third_party/socket_wrapper/wscript index 06eb3cd79ec6..5f237d30057f 100644 --- a/third_party/socket_wrapper/wscript +++ b/third_party/socket_wrapper/wscript @@ -2,7 +2,7 @@ import os -VERSION="1.2.4" +VERSION="1.3.2" def configure(conf): if conf.CHECK_SOCKET_WRAPPER(): @@ -94,6 +94,11 @@ def build(bld): # breaks preloading! bld.SAMBA_LIBRARY('socket_wrapper', source='socket_wrapper.c', + cflags='-D%s="%s" -D%s="%s"' % ( + "SOCKET_WRAPPER_PACKAGE", + "samba_socket_wrapper", + "SOCKET_WRAPPER_VERSION", + VERSION), deps='dl pthread tirpc', install=False, realname='libsocket-wrapper.so') -- 2.25.1