diff --git a/include/proxysql_utils.h b/include/proxysql_utils.h index df2a467be..b3a550df6 100644 --- a/include/proxysql_utils.h +++ b/include/proxysql_utils.h @@ -263,7 +263,21 @@ inline void replace_checksum_zeros(char* checksum) { */ std::string get_checksum_from_hash(uint64_t hash); -void close_all_non_term_fd(std::vector excludeFDs); +/** + * @brief Closes all open file descriptors except stdin (0), stdout (1), stderr (2), and a specified exclusion list + * + * This function is typically called after fork() in the child process before exec() to ensure that + * the child process does not inherit unintended file descriptors from the parent. + * + * CRITICAL: This function is designed to be called between fork() and execve() in the child process. + * To avoid deadlocks in multi-threaded programs, it must NOT allocate on the heap. + * + * @param excludeFDs Vector of file descriptors to preserve (in addition to 0, 1, 2) + * Passed by const reference to avoid heap allocation during copy. + * + * Thread-safety: Safe to call in child process after fork() before execve() + */ +void close_all_non_term_fd(const std::vector& excludeFDs); /** * @brief Returns the expected error for query 'SELECT $$'. diff --git a/lib/proxysql_utils.cpp b/lib/proxysql_utils.cpp index deb3bd109..15e9ce5fd 100644 --- a/lib/proxysql_utils.cpp +++ b/lib/proxysql_utils.cpp @@ -587,23 +587,31 @@ std::string get_checksum_from_hash(uint64_t hash) { * @param excludeFDs Vector of file descriptors to preserve (in addition to 0, 1, 2) * @return void */ -void close_all_non_term_fd(std::vector excludeFDs) { +void close_all_non_term_fd(const std::vector& excludeFDs) { // Try close_range() syscall first (Linux 5.9+) - most efficient and safe // We use syscall directly with runtime detection to avoid hard dependency on kernel version // close_range() can ONLY be used when excludeFDs is empty, because it closes all fds >= 3 #ifdef __NR_close_range if (excludeFDs.empty()) { static int close_range_available = -1; // -1 = unknown, 0 = not available, 1 = available - if (close_range_available == -1) { - // First call: check if close_range is available - long ret = syscall(__NR_close_range, 3, ~0U, 0); - close_range_available = (ret == 0 || errno != ENOSYS) ? 1 : 0; - } if (close_range_available == 1) { // close_range is available, use it to close all fds >= 3 syscall(__NR_close_range, 3, ~0U, 0); return; } + if (close_range_available == -1) { + // First call: check if close_range is available + long ret = syscall(__NR_close_range, 3, ~0U, 0); + if (ret == 0) { + close_range_available = 1; + return; + } + // Only cache as "not available" on ENOSYS + // For other errors (EBADF, EINVAL, etc.), don't cache - might be transient + if (errno == ENOSYS) { + close_range_available = 0; + } + } } #endif @@ -642,11 +650,14 @@ void close_all_non_term_fd(std::vector excludeFDs) { struct rlimit nlimit; int rc = getrlimit(RLIMIT_NOFILE, &nlimit); if (rc == 0) { - for (unsigned int fd = 3; fd < nlimit.rlim_cur; fd++) { + // Use rlim_t for the loop variable to avoid infinite loop when rlim_cur > UINT_MAX + // Cap at INT_MAX since file descriptors are signed ints + for (rlim_t fd_rlim = 3; fd_rlim < nlimit.rlim_cur && fd_rlim <= INT_MAX; fd_rlim++) { + int fd = (int)fd_rlim; // Check if fd is in exclusion list bool exclude = false; for (size_t i = 0; i < excludeFDs.size(); i++) { - if (excludeFDs[i] == (int)fd) { + if (excludeFDs[i] == fd) { exclude = true; break; }