]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- New features added to support engine/pool plugins with advanced
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Apr 2015 21:51:14 +0000 (17:51 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Apr 2015 21:51:14 +0000 (17:51 -0400)
functionality.   Added a new "soft invalidate" feature to the
connection pool at the level of the checked out connection wrapper
as well as the :class:`._ConnectionRecord`.  This works similarly
to a modern pool invalidation in that connections aren't actively
closed, but are recycled only on next checkout; this is essentially
a per-connection version of that feature.  A new event
:class:`.PoolEvents.soft_invalidate` is added to complement it.
fixes #3379

- Added new flag
:attr:`.ExceptionContext.invalidate_pool_on_disconnect`.
Allows an error handler within :meth:`.ConnectionEvents.handle_error`
to maintain a "disconnect" condition, but to handle calling invalidate
on individual connections in a specific manner within the event.

- Added new event :class:`.DialectEvents.do_connect`, which allows
interception / replacement of when the :meth:`.Dialect.connect`
hook is called to create a DBAPI connection.  Also added
dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and
:meth:`.Dialect.engine_created` which allow external plugins to
add events to existing dialects using entry points.
fixes #3355

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/events.py
lib/sqlalchemy/pool.py
test/engine/test_execute.py
test/engine/test_parseconnect.py
test/engine/test_pool.py

index 43689160177d9cffbcfded1d798d156d154e7a97..fe5c7a744779f7da7e6d6f3ba9c61eb43bf66f3d 100644 (file)
 .. changelog::
     :version: 1.0.3
 
+    .. change::
+        :tags: feature, engine
+        :tickets: 3379
+
+        New features added to support engine/pool plugins with advanced
+        functionality.   Added a new "soft invalidate" feature to the
+        connection pool at the level of the checked out connection wrapper
+        as well as the :class:`._ConnectionRecord`.  This works similarly
+        to a modern pool invalidation in that connections aren't actively
+        closed, but are recycled only on next checkout; this is essentially
+        a per-connection version of that feature.  A new event
+        :class:`.PoolEvents.soft_invalidate` is added to complement it.
+
+        Also added new flag
+        :attr:`.ExceptionContext.invalidate_pool_on_disconnect`.
+        Allows an error handler within :meth:`.ConnectionEvents.handle_error`
+        to maintain a "disconnect" condition, but to handle calling invalidate
+        on individual connections in a specific manner within the event.
+
+    .. change::
+        :tags: feature, engine
+        :tickets: 3355
+
+        Added new event :class:`.DialectEvents.do_connect`, which allows
+        interception / replacement of when the :meth:`.Dialect.connect`
+        hook is called to create a DBAPI connection.  Also added
+        dialect plugin hooks :meth:`.Dialect.get_dialect_cls` and
+        :meth:`.Dialect.engine_created` which allow external plugins to
+        add events to existing dialects using entry points.
+
     .. change::
         :tags: bug, orm
         :tickets: 3403, 3320
index 5921ab9ba02cbde5773f4b1647e8fe32244bae62..af310c4506c7c5ff0c40eb7c6fc94de326972887 100644 (file)
@@ -1254,6 +1254,8 @@ class Connection(Connectable):
             if context:
                 context.is_disconnect = self._is_disconnect
 
+        invalidate_pool_on_disconnect = True
+
         if self._reentrant_error:
             util.raise_from_cause(
                 exc.DBAPIError.instance(statement,
@@ -1316,6 +1318,11 @@ class Connection(Connectable):
                     sqlalchemy_exception.connection_invalidated = \
                         self._is_disconnect = ctx.is_disconnect
 
+                # set up potentially user-defined value for
+                # invalidate pool.
+                invalidate_pool_on_disconnect = \
+                    ctx.invalidate_pool_on_disconnect
+
             if should_wrap and context:
                 context.handle_dbapi_exception(e)
 
@@ -1340,7 +1347,8 @@ class Connection(Connectable):
                 del self._is_disconnect
                 if not self.invalidated:
                     dbapi_conn_wrapper = self.__connection
-                    self.engine.pool._invalidate(dbapi_conn_wrapper, e)
+                    if invalidate_pool_on_disconnect:
+                        self.engine.pool._invalidate(dbapi_conn_wrapper, e)
                     self.invalidate(e)
             if self.should_close_with_result:
                 self.close()
index da8fa81eb4309c9a92d1fbf01a578a0fc8499b7b..2dd1921626aff43d28b4a10cd04727d99fc224b1 100644 (file)
@@ -733,6 +733,41 @@ class Dialect(object):
 
         raise NotImplementedError()
 
+    @classmethod
+    def get_dialect_cls(cls, url):
+        """Given a URL, return the :class:`.Dialect` that will be used.
+
+        This is a hook that allows an external plugin to provide functionality
+        around an existing dialect, by allowing the plugin to be loaded
+        from the url based on an entrypoint, and then the plugin returns
+        the actual dialect to be used.
+
+        By default this just returns the cls.
+
+        .. versionadded:: 1.0.3
+
+        """
+        return cls
+
+    @classmethod
+    def engine_created(cls, engine):
+        """A convenience hook called before returning the final :class:`.Engine`.
+
+        If the dialect returned a different class from the
+        :meth:`.get_dialect_cls`
+        method, then the hook is called on both classes, first on
+        the dialect class returned by the :meth:`.get_dialect_cls` method and
+        then on the class on which the method was called.
+
+        The hook should be used by dialects and/or wrappers to apply special
+        events to the engine or its components.   In particular, it allows
+        a dialect-wrapping class to apply dialect-level events.
+
+        .. versionadded:: 1.0.3
+
+        """
+        pass
+
 
 class ExecutionContext(object):
     """A messenger object for a Dialect that corresponds to a single
@@ -1085,3 +1120,21 @@ class ExceptionContext(object):
     changing this flag.
 
     """
+
+    invalidate_pool_on_disconnect = True
+    """Represent whether all connections in the pool should be invalidated
+    when a "disconnect" condition is in effect.
+
+    Setting this flag to False within the scope of the
+    :meth:`.ConnectionEvents.handle_error` event will have the effect such
+    that the full collection of connections in the pool will not be
+    invalidated during a disconnect; only the current connection that is the
+    subject of the error will actually be invalidated.
+
+    The purpose of this flag is for custom disconnect-handling schemes where
+    the invalidation of other connections in the pool is to be performed
+    based on other conditions, or even on a per-connection basis.
+
+    .. versionadded:: 1.0.3
+
+    """
\ No newline at end of file
index 1fd105d679b407bcc3a1448e5cdd918d4e567f78..a802e5d90eafdf1b83fceb77894d2ff3abeaea72 100644 (file)
@@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy):
         # create url.URL object
         u = url.make_url(name_or_url)
 
-        dialect_cls = u.get_dialect()
+        entrypoint = u.get_dialect()
+        dialect_cls = entrypoint.get_dialect_cls(u)
 
         if kwargs.pop('_coerce_config', False):
             def pop_kwarg(key, default=None):
@@ -81,11 +82,18 @@ class DefaultEngineStrategy(EngineStrategy):
         # assemble connection arguments
         (cargs, cparams) = dialect.create_connect_args(u)
         cparams.update(pop_kwarg('connect_args', {}))
+        cargs = list(cargs)  # allow mutability
 
         # look for existing pool or create
         pool = pop_kwarg('pool', None)
         if pool is None:
-            def connect():
+            def connect(connection_record=None):
+                if dialect._has_events:
+                    for fn in dialect.dispatch.do_connect:
+                        connection = fn(
+                            dialect, connection_record, cargs, cparams)
+                        if connection is not None:
+                            return connection
                 return dialect.connect(*cargs, **cparams)
 
             creator = pop_kwarg('creator', connect)
@@ -157,6 +165,10 @@ class DefaultEngineStrategy(EngineStrategy):
                 dialect.initialize(c)
             event.listen(pool, 'first_connect', first_connect, once=True)
 
+        dialect_cls.engine_created(engine)
+        if entrypoint is not dialect_cls:
+            entrypoint.engine_created(engine)
+
         return engine
 
 
index 22e066c88f386f4e00a2819f884f1720547998d3..b2d4b54a9df1ae352b0e99849a676e58acdf6f13 100644 (file)
@@ -371,7 +371,9 @@ class PoolEvents(event.Events):
         """Called when a DBAPI connection is to be "invalidated".
 
         This event is called any time the :meth:`._ConnectionRecord.invalidate`
-        method is invoked, either from API usage or via "auto-invalidation".
+        method is invoked, either from API usage or via "auto-invalidation",
+        without the ``soft`` flag.
+
         The event occurs before a final attempt to call ``.close()`` on the
         connection occurs.
 
@@ -392,6 +394,21 @@ class PoolEvents(event.Events):
 
         """
 
+    def soft_invalidate(self, dbapi_connection, connection_record, exception):
+        """Called when a DBAPI connection is to be "soft invalidated".
+
+        This event is called any time the :meth:`._ConnectionRecord.invalidate`
+        method is invoked with the ``soft`` flag.
+
+        Soft invalidation refers to when the connection record that tracks
+        this connection will force a reconnect after the current connection
+        is checked in.   It does not actively close the dbapi_connection
+        at the point at which it is called.
+
+        .. versionadded:: 1.0.3
+
+        """
+
 
 class ConnectionEvents(event.Events):
     """Available events for :class:`.Connectable`, which includes
@@ -707,6 +724,16 @@ class ConnectionEvents(event.Events):
                     "failed" in str(context.original_exception):
                     raise MySpecialException("failed operation")
 
+        .. warning::  Because the :meth:`.ConnectionEvents.handle_error`
+           event specifically provides for exceptions to be re-thrown as
+           the ultimate exception raised by the failed statement,
+           **stack traces will be misleading** if the user-defined event
+           handler itself fails and throws an unexpected exception;
+           the stack trace may not illustrate the actual code line that
+           failed!  It is advised to code carefully here and use
+           logging and/or inline debugging if unexpected exceptions are
+           occurring.
+
         Alternatively, a "chained" style of event handling can be
         used, by configuring the handler with the ``retval=True``
         modifier and returning the new exception instance from the
@@ -1007,6 +1034,23 @@ class DialectEvents(event.Events):
         else:
             return target
 
+    def do_connect(self, dialect, conn_rec, cargs, cparams):
+        """Receive connection arguments before a connection is made.
+
+        Return a DBAPI connection to halt further events from invoking;
+        the returned connection will be used.
+
+        Alternatively, the event can manipulate the cargs and/or cparams
+        collections; cargs will always be a Python list that can be mutated
+        in-place and cparams a Python dictionary.  Return None to
+        allow control to pass to the next event handler and ultimately
+        to allow the dialect to connect normally, given the updated
+        arguments.
+
+        .. versionadded:: 1.0.3
+
+        """
+
     def do_executemany(self, cursor, statement, parameters, context):
         """Receive a cursor to have executemany() called.
 
index 999cc11207cd41f0dc042379615ebd95ff84aa9c..902309d75977857b06c473068f160f51f6afb593 100644 (file)
@@ -219,6 +219,7 @@ class Pool(log.Identified):
         log.instance_logger(self, echoflag=echo)
         self._threadconns = threading.local()
         self._creator = creator
+        self._wrapped_creator = self._maybe_wrap_callable(creator)
         self._recycle = recycle
         self._invalidate_time = 0
         self._use_threadlocal = use_threadlocal
@@ -249,6 +250,32 @@ class Pool(log.Identified):
             for l in listeners:
                 self.add_listener(l)
 
+    def _maybe_wrap_callable(self, fn):
+        """Detect if creator accepts a single argument, or is sent
+        as a legacy style no-arg function.
+
+        """
+
+        try:
+            argspec = util.get_callable_argspec(fn, no_self=True)
+        except TypeError:
+            return lambda ctx: fn()
+
+        defaulted = argspec[3] is not None and len(argspec[3]) or 0
+        positionals = len(argspec[0]) - defaulted
+
+        # look for the exact arg signature that DefaultStrategy
+        # sends us
+        if (argspec[0], argspec[3]) == (['connection_record'], (None,)):
+            return fn
+        # or just a single positional
+        elif positionals == 1:
+            return fn
+        # all other cases, just wrap and assume legacy "creator" callable
+        # thing
+        else:
+            return lambda ctx: fn()
+
     def _close_connection(self, connection):
         self.logger.debug("Closing connection %r", connection)
         try:
@@ -428,6 +455,8 @@ class _ConnectionRecord(object):
 
     """
 
+    _soft_invalidate_time = 0
+
     @util.memoized_property
     def info(self):
         """The ``.info`` dictionary associated with the DBAPI connection.
@@ -476,7 +505,7 @@ class _ConnectionRecord(object):
         if self.connection is not None:
             self.__close()
 
-    def invalidate(self, e=None):
+    def invalidate(self, e=None, soft=False):
         """Invalidate the DBAPI connection held by this :class:`._ConnectionRecord`.
 
         This method is called for all connection invalidations, including
@@ -484,6 +513,13 @@ class _ConnectionRecord(object):
         :meth:`.Connection.invalidate` methods are called, as well as when any
         so-called "automatic invalidation" condition occurs.
 
+        :param e: an exception object indicating a reason for the invalidation.
+
+        :param soft: if True, the connection isn't closed; instead, this
+         connection will be recycled on next checkout.
+
+         .. versionadded:: 1.0.3
+
         .. seealso::
 
             :ref:`pool_connection_invalidation`
@@ -492,22 +528,31 @@ class _ConnectionRecord(object):
         # already invalidated
         if self.connection is None:
             return
-        self.__pool.dispatch.invalidate(self.connection, self, e)
+        if soft:
+            self.__pool.dispatch.soft_invalidate(self.connection, self, e)
+        else:
+            self.__pool.dispatch.invalidate(self.connection, self, e)
         if e is not None:
             self.__pool.logger.info(
-                "Invalidate connection %r (reason: %s:%s)",
+                "%sInvalidate connection %r (reason: %s:%s)",
+                "Soft " if soft else "",
                 self.connection, e.__class__.__name__, e)
         else:
             self.__pool.logger.info(
-                "Invalidate connection %r", self.connection)
-        self.__close()
-        self.connection = None
+                "%sInvalidate connection %r",
+                "Soft " if soft else "",
+                self.connection)
+        if soft:
+            self._soft_invalidate_time = time.time()
+        else:
+            self.__close()
+            self.connection = None
 
     def get_connection(self):
         recycle = False
         if self.connection is None:
-            self.connection = self.__connect()
             self.info.clear()
+            self.connection = self.__connect()
             if self.__pool.dispatch.connect:
                 self.__pool.dispatch.connect(self.connection, self)
         elif self.__pool._recycle > -1 and \
@@ -523,11 +568,18 @@ class _ConnectionRecord(object):
                 self.connection
             )
             recycle = True
+        elif self._soft_invalidate_time > self.starttime:
+            self.__pool.logger.info(
+                "Connection %r invalidated due to local soft invalidation; " +
+                "recycling",
+                self.connection
+            )
+            recycle = True
 
         if recycle:
             self.__close()
-            self.connection = self.__connect()
             self.info.clear()
+            self.connection = self.__connect()
             if self.__pool.dispatch.connect:
                 self.__pool.dispatch.connect(self.connection, self)
         return self.connection
@@ -539,7 +591,7 @@ class _ConnectionRecord(object):
     def __connect(self):
         try:
             self.starttime = time.time()
-            connection = self.__pool._creator()
+            connection = self.__pool._wrapped_creator(self)
             self.__pool.logger.debug("Created new connection %r", connection)
             return connection
         except Exception as e:
@@ -740,7 +792,7 @@ class _ConnectionFairy(object):
         """
         return self._connection_record.info
 
-    def invalidate(self, e=None):
+    def invalidate(self, e=None, soft=False):
         """Mark this connection as invalidated.
 
         This method can be called directly, and is also called as a result
@@ -749,6 +801,13 @@ class _ConnectionFairy(object):
         further use by the pool.  The invalidation mechanism proceeds
         via the :meth:`._ConnectionRecord.invalidate` internal method.
 
+        :param e: an exception object indicating a reason for the invalidation.
+
+        :param soft: if True, the connection isn't closed; instead, this
+         connection will be recycled on next checkout.
+
+         .. versionadded:: 1.0.3
+
         .. seealso::
 
             :ref:`pool_connection_invalidation`
@@ -759,9 +818,10 @@ class _ConnectionFairy(object):
             util.warn("Can't invalidate an already-closed connection.")
             return
         if self._connection_record:
-            self._connection_record.invalidate(e=e)
-        self.connection = None
-        self._checkin()
+            self._connection_record.invalidate(e=e, soft=soft)
+        if not soft:
+            self.connection = None
+            self._checkin()
 
     def cursor(self, *args, **kwargs):
         """Return a new DBAPI cursor for the underlying connection.
index b0256d325497967b7e5a7ae2ea86b6ec1350a97f..cba3972f62684dc7d905fbc4ec05fa0bf2e12f54 100644 (file)
@@ -1,7 +1,7 @@
 # coding: utf-8
 
 from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, \
-    config, is_
+    config, is_, is_not_
 import re
 from sqlalchemy.testing.util import picklers
 from sqlalchemy.interfaces import ConnectionProxy
@@ -1943,6 +1943,47 @@ class HandleErrorTest(fixtures.TestBase):
         self._test_alter_disconnect(True, False)
         self._test_alter_disconnect(False, False)
 
+    @testing.requires.independent_connections
+    def _test_alter_invalidate_pool_to_false(self, set_to_false):
+        orig_error = True
+
+        engine = engines.testing_engine()
+
+        @event.listens_for(engine, "handle_error")
+        def evt(ctx):
+            if set_to_false:
+                ctx.invalidate_pool_on_disconnect = False
+
+        c1, c2, c3 = engine.pool.connect(), \
+            engine.pool.connect(), engine.pool.connect()
+        crecs = [conn._connection_record for conn in (c1, c2, c3)]
+        c1.close()
+        c2.close()
+        c3.close()
+
+        with patch.object(engine.dialect, "is_disconnect",
+                          Mock(return_value=orig_error)):
+
+            with engine.connect() as c:
+                target_crec = c.connection._connection_record
+                try:
+                    c.execute("SELECT x FROM nonexistent")
+                    assert False
+                except tsa.exc.StatementError as st:
+                    eq_(st.connection_invalidated, True)
+
+        for crec in crecs:
+            if crec is target_crec or not set_to_false:
+                is_not_(crec.connection, crec.get_connection())
+            else:
+                is_(crec.connection, crec.get_connection())
+
+    def test_alter_invalidate_pool_to_false(self):
+        self._test_alter_invalidate_pool_to_false(True)
+
+    def test_alter_invalidate_pool_stays_true(self):
+        self._test_alter_invalidate_pool_to_false(False)
+
     def test_handle_error_event_connect_isolation_level(self):
         engine = engines.testing_engine()
 
@@ -2133,7 +2174,7 @@ class HandleInvalidatedOnConnectTest(fixtures.TestBase):
 
         conn.invalidate()
 
-        eng.pool._creator = Mock(
+        eng.pool._wrapped_creator = Mock(
             side_effect=self.ProgrammingError(
                 "Cannot operate on a closed database."))
 
@@ -2532,3 +2573,87 @@ class DialectEventTest(fixtures.TestBase):
 
     def test_cursor_execute_wo_replace(self):
         self._test_cursor_execute(False)
+
+    def test_connect_replace_params(self):
+        e = engines.testing_engine(options={"_initialize": False})
+
+        @event.listens_for(e, "do_connect")
+        def evt(dialect, conn_rec, cargs, cparams):
+            cargs[:] = ['foo', 'hoho']
+            cparams.clear()
+            cparams['bar'] = 'bat'
+            conn_rec.info['boom'] = "bap"
+
+        m1 = Mock()
+        e.dialect.connect = m1.real_connect
+
+        with e.connect() as conn:
+            eq_(m1.mock_calls, [call.real_connect('foo', 'hoho', bar='bat')])
+            eq_(conn.info['boom'], 'bap')
+
+    def test_connect_do_connect(self):
+        e = engines.testing_engine(options={"_initialize": False})
+
+        m1 = Mock()
+
+        @event.listens_for(e, "do_connect")
+        def evt1(dialect, conn_rec, cargs, cparams):
+            cargs[:] = ['foo', 'hoho']
+            cparams.clear()
+            cparams['bar'] = 'bat'
+            conn_rec.info['boom'] = "one"
+
+        @event.listens_for(e, "do_connect")
+        def evt2(dialect, conn_rec, cargs, cparams):
+            conn_rec.info['bap'] = "two"
+            return m1.our_connect(cargs, cparams)
+
+        with e.connect() as conn:
+            # called with args
+            eq_(
+                m1.mock_calls,
+                [call.our_connect(['foo', 'hoho'], {'bar': 'bat'})])
+
+            eq_(conn.info['boom'], "one")
+            eq_(conn.info['bap'], "two")
+
+            # returned our mock connection
+            is_(conn.connection.connection, m1.our_connect())
+
+    def test_connect_do_connect_info_there_after_recycle(self):
+        # test that info is maintained after the do_connect()
+        # event for a soft invalidation.
+
+        e = engines.testing_engine(options={"_initialize": False})
+
+        @event.listens_for(e, "do_connect")
+        def evt1(dialect, conn_rec, cargs, cparams):
+            conn_rec.info['boom'] = "one"
+
+        conn = e.connect()
+        eq_(conn.info['boom'], "one")
+
+        conn.connection.invalidate(soft=True)
+        conn.close()
+        conn = e.connect()
+        eq_(conn.info['boom'], "one")
+
+    def test_connect_do_connect_info_there_after_invalidate(self):
+        # test that info is maintained after the do_connect()
+        # event for a hard invalidation.
+
+        e = engines.testing_engine(options={"_initialize": False})
+
+        @event.listens_for(e, "do_connect")
+        def evt1(dialect, conn_rec, cargs, cparams):
+            assert not conn_rec.info
+            conn_rec.info['boom'] = "one"
+
+        conn = e.connect()
+        eq_(conn.info['boom'], "one")
+
+        conn.connection.invalidate()
+        conn = e.connect()
+        eq_(conn.info['boom'], "one")
+
+
index 9f1b5cebada243e251f1efe96964b770b01d04af..fb1f338e65efe4a18a81a2120361c56ed99dd5c9 100644 (file)
@@ -5,7 +5,7 @@ from sqlalchemy.engine.default import DefaultDialect
 import sqlalchemy as tsa
 from sqlalchemy.testing import fixtures
 from sqlalchemy import testing
-from sqlalchemy.testing.mock import Mock, MagicMock
+from sqlalchemy.testing.mock import Mock, MagicMock, call
 from sqlalchemy import event
 from sqlalchemy import select
 
@@ -340,6 +340,33 @@ class TestRegNewDBAPI(fixtures.TestBase):
         e = create_engine("mysql+my_mock_dialect://")
         assert isinstance(e.dialect, MockDialect)
 
+    @testing.requires.sqlite
+    def test_wrapper_hooks(self):
+        def get_dialect_cls(url):
+            url.drivername = "sqlite"
+            return url.get_dialect()
+
+        global WrapperFactory
+        WrapperFactory = Mock()
+        WrapperFactory.get_dialect_cls.side_effect = get_dialect_cls
+
+        from sqlalchemy.dialects import registry
+        registry.register("wrapperdialect", __name__, "WrapperFactory")
+
+        from sqlalchemy.dialects import sqlite
+        e = create_engine("wrapperdialect://")
+
+        eq_(e.dialect.name, "sqlite")
+        assert isinstance(e.dialect, sqlite.dialect)
+
+        eq_(
+            WrapperFactory.mock_calls,
+            [
+                call.get_dialect_cls(url.make_url("sqlite://")),
+                call.engine_created(e)
+            ]
+        )
+
 
 class MockDialect(DefaultDialect):
     @classmethod
index ff45b2d513c14df0bdff6c5327a14679032ff873..3d93cda899415fb56d580b6aaaf7de585a1d35e3 100644 (file)
@@ -4,11 +4,11 @@ from sqlalchemy import pool, select, event
 import sqlalchemy as tsa
 from sqlalchemy import testing
 from sqlalchemy.testing.util import gc_collect, lazy_gc
-from sqlalchemy.testing import eq_, assert_raises, is_not_
+from sqlalchemy.testing import eq_, assert_raises, is_not_, is_
 from sqlalchemy.testing.engines import testing_engine
 from sqlalchemy.testing import fixtures
 import random
-from sqlalchemy.testing.mock import Mock, call
+from sqlalchemy.testing.mock import Mock, call, patch
 import weakref
 
 join_timeout = 10
@@ -335,6 +335,13 @@ class PoolEventsTest(PoolTestBase):
 
         return p, canary
 
+    def _soft_invalidate_event_fixture(self):
+        p = self._queuepool_fixture()
+        canary = Mock()
+        event.listen(p, 'soft_invalidate', canary)
+
+        return p, canary
+
     def test_first_connect_event(self):
         p, canary = self._first_connect_event_fixture()
 
@@ -438,6 +445,31 @@ class PoolEventsTest(PoolTestBase):
         c1.close()
         eq_(canary, ['reset'])
 
+    def test_soft_invalidate_event_no_exception(self):
+        p, canary = self._soft_invalidate_event_fixture()
+
+        c1 = p.connect()
+        c1.close()
+        assert not canary.called
+        c1 = p.connect()
+        dbapi_con = c1.connection
+        c1.invalidate(soft=True)
+        assert canary.call_args_list[0][0][0] is dbapi_con
+        assert canary.call_args_list[0][0][2] is None
+
+    def test_soft_invalidate_event_exception(self):
+        p, canary = self._soft_invalidate_event_fixture()
+
+        c1 = p.connect()
+        c1.close()
+        assert not canary.called
+        c1 = p.connect()
+        dbapi_con = c1.connection
+        exc = Exception("hi")
+        c1.invalidate(exc, soft=True)
+        assert canary.call_args_list[0][0][0] is dbapi_con
+        assert canary.call_args_list[0][0][2] is exc
+
     def test_invalidate_event_no_exception(self):
         p, canary = self._invalidate_event_fixture()
 
@@ -1130,6 +1162,44 @@ class QueuePoolTest(PoolTestBase):
 
         eq_(len(success), 12, "successes: %s" % success)
 
+    def test_connrec_invalidated_within_checkout_no_race(self):
+        """Test that a concurrent ConnectionRecord.invalidate() which
+        occurs after the ConnectionFairy has called _ConnectionRecord.checkout()
+        but before the ConnectionFairy tests "fairy.connection is None"
+        will not result in an InvalidRequestError.
+
+        This use case assumes that a listener on the checkout() event
+        will be raising DisconnectionError so that a reconnect attempt
+        may occur.
+
+        """
+        dbapi = MockDBAPI()
+
+        def creator():
+            return dbapi.connect()
+
+        p = pool.QueuePool(creator=creator, pool_size=1, max_overflow=0)
+
+        conn = p.connect()
+        conn.close()
+
+        _existing_checkout = pool._ConnectionRecord.checkout
+
+        @classmethod
+        def _decorate_existing_checkout(cls, *arg, **kw):
+            fairy = _existing_checkout(*arg, **kw)
+            connrec = fairy._connection_record
+            connrec.invalidate()
+            return fairy
+
+        with patch(
+                "sqlalchemy.pool._ConnectionRecord.checkout",
+                _decorate_existing_checkout):
+            conn = p.connect()
+            is_(conn._connection_record.connection, None)
+        conn.close()
+
+
     @testing.requires.threading_with_mock
     @testing.requires.timing_intensive
     def test_notify_waiters(self):
@@ -1323,12 +1393,36 @@ class QueuePoolTest(PoolTestBase):
         c2 = p.connect()
         assert id(c2.connection) == c_id
 
+        c2_rec = c2._connection_record
         p._invalidate(c2)
+        assert c2_rec.connection is None
         c2.close()
         time.sleep(.5)
         c3 = p.connect()
         assert id(c3.connection) != c_id
 
+    @testing.requires.timing_intensive
+    def test_recycle_on_soft_invalidate(self):
+        p = self._queuepool_fixture(pool_size=1,
+                           max_overflow=0)
+        c1 = p.connect()
+        c_id = id(c1.connection)
+        c1.close()
+        c2 = p.connect()
+        assert id(c2.connection) == c_id
+
+        c2_rec = c2._connection_record
+        c2.invalidate(soft=True)
+        assert c2_rec.connection is c2.connection
+
+        c2.close()
+        time.sleep(.5)
+        c3 = p.connect()
+        assert id(c3.connection) != c_id
+        assert c3._connection_record is c2_rec
+        assert c2_rec.connection is c3.connection
+
+
     def _assert_cleanup_on_pooled_reconnect(self, dbapi, p):
         # p is QueuePool with size=1, max_overflow=2,
         # and one connection in the pool that will need to