]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-109461: Update logging module lock to use context manager (#109462)
authorDale Collison <92315623+dcollison@users.noreply.github.com>
Wed, 27 Sep 2023 16:26:41 +0000 (17:26 +0100)
committerGitHub <noreply@github.com>
Wed, 27 Sep 2023 16:26:41 +0000 (18:26 +0200)
Co-authored-by: Victor Stinner <vstinner@python.org>
Lib/logging/__init__.py
Lib/logging/config.py
Lib/logging/handlers.py
Lib/multiprocessing/util.py
Lib/test/test_logging.py
Misc/NEWS.d/next/Library/2023-09-15-17-12-53.gh-issue-109461.VNFPTK.rst [new file with mode: 0644]

index 2d228e563094c8d20333114149aecf4565ad7a93..eb7e020d1edfc040439479e04de4bbb64ccc8b12 100644 (file)
@@ -159,12 +159,9 @@ def addLevelName(level, levelName):
 
     This is used when converting levels to text during message formatting.
     """
-    _acquireLock()
-    try:    #unlikely to cause an exception, but you never know...
+    with _lock:
         _levelToName[level] = levelName
         _nameToLevel[levelName] = level
-    finally:
-        _releaseLock()
 
 if hasattr(sys, "_getframe"):
     currentframe = lambda: sys._getframe(1)
@@ -231,25 +228,27 @@ def _checkLevel(level):
 #
 _lock = threading.RLock()
 
-def _acquireLock():
+def _prepareFork():
     """
-    Acquire the module-level lock for serializing access to shared data.
+    Prepare to fork a new child process by acquiring the module-level lock.
 
-    This should be released with _releaseLock().
+    This should be used in conjunction with _afterFork().
     """
-    if _lock:
-        try:
-            _lock.acquire()
-        except BaseException:
-            _lock.release()
-            raise
+    # Wrap the lock acquisition in a try-except to prevent the lock from being
+    # abandoned in the event of an asynchronous exception. See gh-106238.
+    try:
+        _lock.acquire()
+    except BaseException:
+        _lock.release()
+        raise
 
-def _releaseLock():
+def _afterFork():
     """
-    Release the module-level lock acquired by calling _acquireLock().
+    After a new child process has been forked, release the module-level lock.
+
+    This should be used in conjunction with _prepareFork().
     """
-    if _lock:
-        _lock.release()
+    _lock.release()
 
 
 # Prevent a held logging lock from blocking a child from logging.
@@ -264,23 +263,20 @@ else:
     _at_fork_reinit_lock_weakset = weakref.WeakSet()
 
     def _register_at_fork_reinit_lock(instance):
-        _acquireLock()
-        try:
+        with _lock:
             _at_fork_reinit_lock_weakset.add(instance)
-        finally:
-            _releaseLock()
 
     def _after_at_fork_child_reinit_locks():
         for handler in _at_fork_reinit_lock_weakset:
             handler._at_fork_reinit()
 
-        # _acquireLock() was called in the parent before forking.
+        # _prepareFork() was called in the parent before forking.
         # The lock is reinitialized to unlocked state.
         _lock._at_fork_reinit()
 
-    os.register_at_fork(before=_acquireLock,
+    os.register_at_fork(before=_prepareFork,
                         after_in_child=_after_at_fork_child_reinit_locks,
-                        after_in_parent=_releaseLock)
+                        after_in_parent=_afterFork)
 
 
 #---------------------------------------------------------------------------
@@ -883,25 +879,20 @@ def _removeHandlerRef(wr):
     # set to None. It can also be called from another thread. So we need to
     # pre-emptively grab the necessary globals and check if they're None,
     # to prevent race conditions and failures during interpreter shutdown.
-    acquire, release, handlers = _acquireLock, _releaseLock, _handlerList
-    if acquire and release and handlers:
-        acquire()
-        try:
-            handlers.remove(wr)
-        except ValueError:
-            pass
-        finally:
-            release()
+    handlers, lock = _handlerList, _lock
+    if lock and handlers:
+        with lock:
+            try:
+                handlers.remove(wr)
+            except ValueError:
+                pass
 
 def _addHandlerRef(handler):
     """
     Add a handler to the internal cleanup list using a weak reference.
     """
-    _acquireLock()
-    try:
+    with _lock:
         _handlerList.append(weakref.ref(handler, _removeHandlerRef))
-    finally:
-        _releaseLock()
 
 
 def getHandlerByName(name):
@@ -946,15 +937,12 @@ class Handler(Filterer):
         return self._name
 
     def set_name(self, name):
-        _acquireLock()
-        try:
+        with _lock:
             if self._name in _handlers:
                 del _handlers[self._name]
             self._name = name
             if name:
                 _handlers[name] = self
-        finally:
-            _releaseLock()
 
     name = property(get_name, set_name)
 
@@ -1026,11 +1014,8 @@ class Handler(Filterer):
         if isinstance(rv, LogRecord):
             record = rv
         if rv:
-            self.acquire()
-            try:
+            with self.lock:
                 self.emit(record)
-            finally:
-                self.release()
         return rv
 
     def setFormatter(self, fmt):
@@ -1058,13 +1043,10 @@ class Handler(Filterer):
         methods.
         """
         #get the module data lock, as we're updating a shared structure.
-        _acquireLock()
-        try:    #unlikely to raise an exception, but you never know...
+        with _lock:
             self._closed = True
             if self._name and self._name in _handlers:
                 del _handlers[self._name]
-        finally:
-            _releaseLock()
 
     def handleError(self, record):
         """
@@ -1141,12 +1123,9 @@ class StreamHandler(Handler):
         """
         Flushes the stream.
         """
-        self.acquire()
-        try:
+        with self.lock:
             if self.stream and hasattr(self.stream, "flush"):
                 self.stream.flush()
-        finally:
-            self.release()
 
     def emit(self, record):
         """
@@ -1182,12 +1161,9 @@ class StreamHandler(Handler):
             result = None
         else:
             result = self.stream
-            self.acquire()
-            try:
+            with self.lock:
                 self.flush()
                 self.stream = stream
-            finally:
-                self.release()
         return result
 
     def __repr__(self):
@@ -1237,8 +1213,7 @@ class FileHandler(StreamHandler):
         """
         Closes the stream.
         """
-        self.acquire()
-        try:
+        with self.lock:
             try:
                 if self.stream:
                     try:
@@ -1254,8 +1229,6 @@ class FileHandler(StreamHandler):
                 # Also see Issue #42378: we also rely on
                 # self._closed being set to True there
                 StreamHandler.close(self)
-        finally:
-            self.release()
 
     def _open(self):
         """
@@ -1391,8 +1364,7 @@ class Manager(object):
         rv = None
         if not isinstance(name, str):
             raise TypeError('A logger name must be a string')
-        _acquireLock()
-        try:
+        with _lock:
             if name in self.loggerDict:
                 rv = self.loggerDict[name]
                 if isinstance(rv, PlaceHolder):
@@ -1407,8 +1379,6 @@ class Manager(object):
                 rv.manager = self
                 self.loggerDict[name] = rv
                 self._fixupParents(rv)
-        finally:
-            _releaseLock()
         return rv
 
     def setLoggerClass(self, klass):
@@ -1471,12 +1441,11 @@ class Manager(object):
         Called when level changes are made
         """
 
-        _acquireLock()
-        for logger in self.loggerDict.values():
-            if isinstance(logger, Logger):
-                logger._cache.clear()
-        self.root._cache.clear()
-        _releaseLock()
+        with _lock:
+            for logger in self.loggerDict.values():
+                if isinstance(logger, Logger):
+                    logger._cache.clear()
+            self.root._cache.clear()
 
 #---------------------------------------------------------------------------
 #   Logger classes and functions
@@ -1701,23 +1670,17 @@ class Logger(Filterer):
         """
         Add the specified handler to this logger.
         """
-        _acquireLock()
-        try:
+        with _lock:
             if not (hdlr in self.handlers):
                 self.handlers.append(hdlr)
-        finally:
-            _releaseLock()
 
     def removeHandler(self, hdlr):
         """
         Remove the specified handler from this logger.
         """
-        _acquireLock()
-        try:
+        with _lock:
             if hdlr in self.handlers:
                 self.handlers.remove(hdlr)
-        finally:
-            _releaseLock()
 
     def hasHandlers(self):
         """
@@ -1795,16 +1758,13 @@ class Logger(Filterer):
         try:
             return self._cache[level]
         except KeyError:
-            _acquireLock()
-            try:
+            with _lock:
                 if self.manager.disable >= level:
                     is_enabled = self._cache[level] = False
                 else:
                     is_enabled = self._cache[level] = (
                         level >= self.getEffectiveLevel()
                     )
-            finally:
-                _releaseLock()
             return is_enabled
 
     def getChild(self, suffix):
@@ -1834,16 +1794,13 @@ class Logger(Filterer):
             return 1 + logger.name.count('.')
 
         d = self.manager.loggerDict
-        _acquireLock()
-        try:
+        with _lock:
             # exclude PlaceHolders - the last check is to ensure that lower-level
             # descendants aren't returned - if there are placeholders, a logger's
             # parent field might point to a grandparent or ancestor thereof.
             return set(item for item in d.values()
                        if isinstance(item, Logger) and item.parent is self and
                        _hierlevel(item) == 1 + _hierlevel(item.parent))
-        finally:
-            _releaseLock()
 
     def __repr__(self):
         level = getLevelName(self.getEffectiveLevel())
@@ -2102,8 +2059,7 @@ def basicConfig(**kwargs):
     """
     # Add thread safety in case someone mistakenly calls
     # basicConfig() from multiple threads
-    _acquireLock()
-    try:
+    with _lock:
         force = kwargs.pop('force', False)
         encoding = kwargs.pop('encoding', None)
         errors = kwargs.pop('errors', 'backslashreplace')
@@ -2152,8 +2108,6 @@ def basicConfig(**kwargs):
             if kwargs:
                 keys = ', '.join(kwargs.keys())
                 raise ValueError('Unrecognised argument(s): %s' % keys)
-    finally:
-        _releaseLock()
 
 #---------------------------------------------------------------------------
 # Utility functions at module level.
index 41283f4d62726704eb576afeca8bd0b3d42ed266..951bba73913cb30ee768bc6bbd58e4ab179b2c93 100644 (file)
@@ -83,15 +83,12 @@ def fileConfig(fname, defaults=None, disable_existing_loggers=True, encoding=Non
     formatters = _create_formatters(cp)
 
     # critical section
-    logging._acquireLock()
-    try:
+    with logging._lock:
         _clearExistingHandlers()
 
         # Handlers add themselves to logging._handlers
         handlers = _install_handlers(cp, formatters)
         _install_loggers(cp, handlers, disable_existing_loggers)
-    finally:
-        logging._releaseLock()
 
 
 def _resolve(name):
@@ -516,8 +513,7 @@ class DictConfigurator(BaseConfigurator):
             raise ValueError("Unsupported version: %s" % config['version'])
         incremental = config.pop('incremental', False)
         EMPTY_DICT = {}
-        logging._acquireLock()
-        try:
+        with logging._lock:
             if incremental:
                 handlers = config.get('handlers', EMPTY_DICT)
                 for name in handlers:
@@ -661,8 +657,6 @@ class DictConfigurator(BaseConfigurator):
                     except Exception as e:
                         raise ValueError('Unable to configure root '
                                          'logger') from e
-        finally:
-            logging._releaseLock()
 
     def configure_formatter(self, config):
         """Configure a formatter from a dictionary."""
@@ -988,9 +982,8 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None):
         def __init__(self, host='localhost', port=DEFAULT_LOGGING_CONFIG_PORT,
                      handler=None, ready=None, verify=None):
             ThreadingTCPServer.__init__(self, (host, port), handler)
-            logging._acquireLock()
-            self.abort = 0
-            logging._releaseLock()
+            with logging._lock:
+                self.abort = 0
             self.timeout = 1
             self.ready = ready
             self.verify = verify
@@ -1004,9 +997,8 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None):
                                            self.timeout)
                 if rd:
                     self.handle_request()
-                logging._acquireLock()
-                abort = self.abort
-                logging._releaseLock()
+                with logging._lock:
+                    abort = self.abort
             self.server_close()
 
     class Server(threading.Thread):
@@ -1027,9 +1019,8 @@ def listen(port=DEFAULT_LOGGING_CONFIG_PORT, verify=None):
                 self.port = server.server_address[1]
             self.ready.set()
             global _listener
-            logging._acquireLock()
-            _listener = server
-            logging._releaseLock()
+            with logging._lock:
+                _listener = server
             server.serve_until_stopped()
 
     return Server(ConfigSocketReceiver, ConfigStreamHandler, port, verify)
@@ -1039,10 +1030,7 @@ def stopListening():
     Stop the listening server which was created with a call to listen().
     """
     global _listener
-    logging._acquireLock()
-    try:
+    with logging._lock:
         if _listener:
             _listener.abort = 1
             _listener = None
-    finally:
-        logging._releaseLock()
index 671cc9596b02dd4068eb5de1aade13d23593dbcc..e75da9b7b1de64f75dc813a31ebf64724ab77071 100644 (file)
@@ -683,15 +683,12 @@ class SocketHandler(logging.Handler):
         """
         Closes the socket.
         """
-        self.acquire()
-        try:
+        with self.lock:
             sock = self.sock
             if sock:
                 self.sock = None
                 sock.close()
             logging.Handler.close(self)
-        finally:
-            self.release()
 
 class DatagramHandler(SocketHandler):
     """
@@ -953,15 +950,12 @@ class SysLogHandler(logging.Handler):
         """
         Closes the socket.
         """
-        self.acquire()
-        try:
+        with self.lock:
             sock = self.socket
             if sock:
                 self.socket = None
                 sock.close()
             logging.Handler.close(self)
-        finally:
-            self.release()
 
     def mapPriority(self, levelName):
         """
@@ -1333,11 +1327,8 @@ class BufferingHandler(logging.Handler):
 
         This version just zaps the buffer to empty.
         """
-        self.acquire()
-        try:
+        with self.lock:
             self.buffer.clear()
-        finally:
-            self.release()
 
     def close(self):
         """
@@ -1387,11 +1378,8 @@ class MemoryHandler(BufferingHandler):
         """
         Set the target handler for this handler.
         """
-        self.acquire()
-        try:
+        with self.lock:
             self.target = target
-        finally:
-            self.release()
 
     def flush(self):
         """
@@ -1401,14 +1389,11 @@ class MemoryHandler(BufferingHandler):
 
         The record buffer is only cleared if a target has been set.
         """
-        self.acquire()
-        try:
+        with self.lock:
             if self.target:
                 for record in self.buffer:
                     self.target.handle(record)
                 self.buffer.clear()
-        finally:
-            self.release()
 
     def close(self):
         """
@@ -1419,12 +1404,9 @@ class MemoryHandler(BufferingHandler):
             if self.flushOnClose:
                 self.flush()
         finally:
-            self.acquire()
-            try:
+            with self.lock:
                 self.target = None
                 BufferingHandler.close(self)
-            finally:
-                self.release()
 
 
 class QueueHandler(logging.Handler):
index 6ee0d33e88a060a68c97717aee4ed2b6792b8e3f..28c77df1c32ea803506b585bf3d267f72ba1ab9f 100644 (file)
@@ -64,8 +64,7 @@ def get_logger():
     global _logger
     import logging
 
-    logging._acquireLock()
-    try:
+    with logging._lock:
         if not _logger:
 
             _logger = logging.getLogger(LOGGER_NAME)
@@ -79,9 +78,6 @@ def get_logger():
                 atexit._exithandlers.remove((_exit_function, (), {}))
                 atexit._exithandlers.append((_exit_function, (), {}))
 
-    finally:
-        logging._releaseLock()
-
     return _logger
 
 def log_to_stderr(level=None):
index 375f65f9d16182f765fcb38730c86283f3d27e08..cca02a010b80f4d1b41574b9e3e8d355595ff117 100644 (file)
@@ -90,8 +90,7 @@ class BaseTest(unittest.TestCase):
         self._threading_key = threading_helper.threading_setup()
 
         logger_dict = logging.getLogger().manager.loggerDict
-        logging._acquireLock()
-        try:
+        with logging._lock:
             self.saved_handlers = logging._handlers.copy()
             self.saved_handler_list = logging._handlerList[:]
             self.saved_loggers = saved_loggers = logger_dict.copy()
@@ -101,8 +100,6 @@ class BaseTest(unittest.TestCase):
             for name in saved_loggers:
                 logger_states[name] = getattr(saved_loggers[name],
                                               'disabled', None)
-        finally:
-            logging._releaseLock()
 
         # Set two unused loggers
         self.logger1 = logging.getLogger("\xab\xd7\xbb")
@@ -136,8 +133,7 @@ class BaseTest(unittest.TestCase):
             self.root_logger.removeHandler(h)
             h.close()
         self.root_logger.setLevel(self.original_logging_level)
-        logging._acquireLock()
-        try:
+        with logging._lock:
             logging._levelToName.clear()
             logging._levelToName.update(self.saved_level_to_name)
             logging._nameToLevel.clear()
@@ -154,8 +150,6 @@ class BaseTest(unittest.TestCase):
             for name in self.logger_states:
                 if logger_states[name] is not None:
                     self.saved_loggers[name].disabled = logger_states[name]
-        finally:
-            logging._releaseLock()
 
         self.doCleanups()
         threading_helper.threading_cleanup(*self._threading_key)
@@ -739,11 +733,8 @@ class HandlerTest(BaseTest):
                     stream=open('/dev/null', 'wt', encoding='utf-8'))
 
             def emit(self, record):
-                self.sub_handler.acquire()
-                try:
+                with self.sub_handler.lock:
                     self.sub_handler.emit(record)
-                finally:
-                    self.sub_handler.release()
 
         self.assertEqual(len(logging._handlers), 0)
         refed_h = _OurHandler()
@@ -759,29 +750,22 @@ class HandlerTest(BaseTest):
         fork_happened__release_locks_and_end_thread = threading.Event()
 
         def lock_holder_thread_fn():
-            logging._acquireLock()
-            try:
-                refed_h.acquire()
-                try:
-                    # Tell the main thread to do the fork.
-                    locks_held__ready_to_fork.set()
-
-                    # If the deadlock bug exists, the fork will happen
-                    # without dealing with the locks we hold, deadlocking
-                    # the child.
-
-                    # Wait for a successful fork or an unreasonable amount of
-                    # time before releasing our locks.  To avoid a timing based
-                    # test we'd need communication from os.fork() as to when it
-                    # has actually happened.  Given this is a regression test
-                    # for a fixed issue, potentially less reliably detecting
-                    # regression via timing is acceptable for simplicity.
-                    # The test will always take at least this long. :(
-                    fork_happened__release_locks_and_end_thread.wait(0.5)
-                finally:
-                    refed_h.release()
-            finally:
-                logging._releaseLock()
+            with logging._lock, refed_h.lock:
+                # Tell the main thread to do the fork.
+                locks_held__ready_to_fork.set()
+
+                # If the deadlock bug exists, the fork will happen
+                # without dealing with the locks we hold, deadlocking
+                # the child.
+
+                # Wait for a successful fork or an unreasonable amount of
+                # time before releasing our locks.  To avoid a timing based
+                # test we'd need communication from os.fork() as to when it
+                # has actually happened.  Given this is a regression test
+                # for a fixed issue, potentially less reliably detecting
+                # regression via timing is acceptable for simplicity.
+                # The test will always take at least this long. :(
+                fork_happened__release_locks_and_end_thread.wait(0.5)
 
         lock_holder_thread = threading.Thread(
                 target=lock_holder_thread_fn,
diff --git a/Misc/NEWS.d/next/Library/2023-09-15-17-12-53.gh-issue-109461.VNFPTK.rst b/Misc/NEWS.d/next/Library/2023-09-15-17-12-53.gh-issue-109461.VNFPTK.rst
new file mode 100644 (file)
index 0000000..28f0c16
--- /dev/null
@@ -0,0 +1 @@
+:mod:`logging`: Use a context manager for lock acquisition.