* eventdns and libevent.) */
static int n_sockets_open = 0;
+/** Mutex to protect open_sockets, max_socket, and n_sockets_open. */
+static tor_mutex_t *socket_accounting_mutex = NULL;
+
+static INLINE void
+socket_accounting_lock(void)
+{
+ if (PREDICT_UNLIKELY(!socket_accounting_mutex))
+ socket_accounting_mutex = tor_mutex_new();
+ tor_mutex_acquire(socket_accounting_mutex);
+}
+
+static INLINE void
+socket_accounting_unlock(void)
+{
+ tor_mutex_release(socket_accounting_mutex);
+}
+
/** As close(), but guaranteed to work for sockets across platforms (including
* Windows, where close()ing a socket doesn't work. Returns 0 on success, -1
* on failure. */
tor_close_socket(int s)
{
int r = 0;
-#ifdef DEBUG_SOCKET_COUNTING
- if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
- log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
- "socket(), or that was already closed or something.", s);
- } else {
- tor_assert(open_sockets && s <= max_socket);
- bitarray_clear(open_sockets, s);
- }
-#endif
+
/* On Windows, you have to call close() on fds returned by open(),
* and closesocket() on fds returned by socket(). On Unix, everything
* gets close()'d. We abstract this difference by always using
#else
r = close(s);
#endif
+
+ socket_accounting_lock();
+#ifdef DEBUG_SOCKET_COUNTING
+ if (s > max_socket || ! bitarray_is_set(open_sockets, s)) {
+ log_warn(LD_BUG, "Closing a socket (%d) that wasn't returned by tor_open_"
+ "socket(), or that was already closed or something.", s);
+ } else {
+ tor_assert(open_sockets && s <= max_socket);
+ bitarray_clear(open_sockets, s);
+ }
+#endif
if (r == 0) {
--n_sockets_open;
} else {
#endif
r = -1;
}
+
if (n_sockets_open < 0)
log_warn(LD_BUG, "Our socket count is below zero: %d. Please submit a "
"bug report.", n_sockets_open);
+ socket_accounting_unlock();
return r;
}
{
int s = socket(domain, type, protocol);
if (s >= 0) {
+ socket_accounting_lock();
++n_sockets_open;
mark_socket_open(s);
+ socket_accounting_unlock();
}
return s;
}
{
int s = accept(sockfd, addr, len);
if (s >= 0) {
+ socket_accounting_lock();
++n_sockets_open;
mark_socket_open(s);
+ socket_accounting_unlock();
}
return s;
}
int
get_n_open_sockets(void)
{
- return n_sockets_open;
+ int n;
+ socket_accounting_lock();
+ n = n_sockets_open;
+ socket_accounting_unlock();
+ return n;
}
/** Turn <b>socket</b> into a nonblocking socket.
int r;
r = socketpair(family, type, protocol, fd);
if (r == 0) {
+ socket_accounting_lock();
if (fd[0] >= 0) {
++n_sockets_open;
mark_socket_open(fd[0]);
++n_sockets_open;
mark_socket_open(fd[1]);
}
+ socket_accounting_unlock();
}
return r < 0 ? -errno : r;
#else