]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Additions to support HAAlchemy plugin
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Sep 2016 19:10:32 +0000 (15:10 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 16 Sep 2016 20:20:18 +0000 (16:20 -0400)
- add a connect=True key to connection record to support
  pre-loading of _ConnectionRecord objects
- ensure _ConnectionRecord.close() leaves the record in a good
  state for reopening
- add _ConnectionRecord.record_info for persistent storage
- add "in_use" accessor based on fairy_ref being present or not
- allow for the exclusions system and SuiteRequirements to be
  usable without the full plugin_base setup.
- move some Python-env requirements to the importable
  requirements.py module.
- allow starttime to be queried
- add additional events for engine plugins
- have "dialect" be a first-class parameter to the pool,
  ensure the engine strategy supplies it up front

Change-Id: Ibf549f7a1766e49d335cd6f5e26bacfaef9a8229

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/interfaces.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/pool.py
lib/sqlalchemy/testing/config.py
lib/sqlalchemy/testing/exclusions.py
lib/sqlalchemy/testing/plugin/plugin_base.py
lib/sqlalchemy/testing/requirements.py
test/engine/test_parseconnect.py
test/engine/test_pool.py
test/requirements.py

index 83f0f0c83a9ba0ceb8719d489dff2668f547857a..b8acf298ffbb2833781d41eab70d7530399a992d 100644 (file)
@@ -1752,7 +1752,6 @@ class Engine(Connectable, log.Identified):
         self.pool = pool
         self.url = url
         self.dialect = dialect
-        self.pool._dialect = dialect
         if logging_name:
             self.logging_name = logging_name
         self.echo = echo
index 13e8bf1f48e3b11c254ad32c68c97037657e66bc..082661216d90844d9e5ac9bce27c9c8cb9465f7a 100644 (file)
@@ -900,6 +900,12 @@ class CreateEnginePlugin(object):
         """
         self.url = url
 
+    def handle_dialect_kwargs(self, dialect_cls, dialect_args):
+        """parse and modify dialect kwargs"""
+
+    def handle_pool_kwargs(self, pool_cls, pool_args):
+        """parse and modify pool kwargs"""
+
     def engine_created(self, engine):
         """Receive the :class:`.Engine` object when it is fully constructed.
 
index 82800a9189b4344da351048b22415a55bcc4e6cb..ccda14be421019a4e58ff82da9bc83db075210c8 100644 (file)
@@ -81,6 +81,9 @@ class DefaultEngineStrategy(EngineStrategy):
 
         dialect_args['dbapi'] = dbapi
 
+        for plugin in plugins:
+            plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
+
         # create dialect
         dialect = dialect_cls(**dialect_args)
 
@@ -106,7 +109,9 @@ class DefaultEngineStrategy(EngineStrategy):
             poolclass = pop_kwarg('poolclass', None)
             if poolclass is None:
                 poolclass = dialect_cls.get_pool_class(u)
-            pool_args = {}
+            pool_args = {
+                'dialect': dialect
+            }
 
             # consume pool arguments from kwargs, translating a few of
             # the arguments
@@ -121,6 +126,10 @@ class DefaultEngineStrategy(EngineStrategy):
                 tk = translate.get(k, k)
                 if tk in kwargs:
                     pool_args[k] = pop_kwarg(tk)
+
+            for plugin in plugins:
+                plugin.handle_pool_kwargs(poolclass, pool_args)
+
             pool = poolclass(creator, **pool_args)
         else:
             if isinstance(pool, poollib._DBProxy):
@@ -128,6 +137,8 @@ class DefaultEngineStrategy(EngineStrategy):
             else:
                 pool = pool
 
+            pool._dialect = dialect
+
         # create engine.
         engineclass = self.engine_cls
         engine_args = {}
index 4bd8f60ec9fd39c81995a1d463b4030e9debf102..1bdffc28b1204ba876e85f8491428671aa910280 100644 (file)
@@ -102,8 +102,8 @@ class Pool(log.Identified):
                  reset_on_return=True,
                  listeners=None,
                  events=None,
-                 _dispatch=None,
-                 _dialect=None):
+                 dialect=None,
+                 _dispatch=None):
         """
         Construct a Pool.
 
@@ -210,6 +210,15 @@ class Pool(log.Identified):
           pool.  This has been superseded by
           :func:`~sqlalchemy.event.listen`.
 
+        :param dialect: a :class:`.Dialect` that will handle the job
+         of calling rollback(), close(), or commit() on DBAPI connections.
+         If omitted, a built-in "stub" dialect is used.   Applications that
+         make use of :func:`~.create_engine` should not use this parameter
+         as it is handled by the engine creation strategy.
+
+         .. versionadded:: 1.1 - ``dialect`` is now a public parameter
+            to the :class:`.Pool`.
+
         """
         if logging_name:
             self.logging_name = self._orig_logging_name = logging_name
@@ -237,8 +246,8 @@ class Pool(log.Identified):
 
         if _dispatch:
             self.dispatch._update(_dispatch, only_propagate=False)
-        if _dialect:
-            self._dialect = _dialect
+        if dialect:
+            self._dialect = dialect
         if events:
             for fn, target in events:
                 event.listen(self, target, fn)
@@ -445,11 +454,16 @@ class _ConnectionRecord(object):
 
     """
 
-    def __init__(self, pool):
+    def __init__(self, pool, connect=True):
         self.__pool = pool
-        self.__connect(first_connect_check=True)
+        if connect:
+            self.__connect(first_connect_check=True)
         self.finalize_callback = deque()
 
+    fairy_ref = None
+
+    starttime = None
+
     connection = None
     """A reference to the actual DBAPI connection being tracked.
 
@@ -468,6 +482,31 @@ class _ConnectionRecord(object):
         This dictionary is shared among the :attr:`._ConnectionFairy.info`
         and :attr:`.Connection.info` accessors.
 
+        .. note::
+
+            The lifespan of this dictionary is linked to the
+            DBAPI connection itself, meaning that it is **discarded** each time
+            the DBAPI connection is closed and/or invalidated.   The
+            :attr:`._ConnectionRecord.record_info` dictionary remains
+            persistent throughout the lifespan of the
+            :class:`._ConnectionRecord` container.
+
+        """
+        return {}
+
+    @util.memoized_property
+    def record_info(self):
+        """An "info' dictionary associated with the connection record
+        itself.
+
+        Unlike the :attr:`._ConnectionRecord.info` dictionary, which is linked
+        to the lifespan of the DBAPI connection, this dictionary is linked
+        to the lifespan of the :class:`._ConnectionRecord` container itself
+        and will remain persisent throughout the life of the
+        :class:`._ConnectionRecord`.
+
+        .. versionadded:: 1.1
+
         """
         return {}
 
@@ -505,6 +544,14 @@ class _ConnectionRecord(object):
             pool.dispatch.checkin(connection, self)
         pool._return_conn(self)
 
+    @property
+    def in_use(self):
+        return self.fairy_ref is not None
+
+    @property
+    def last_connect_time(self):
+        return self.starttime
+
     def close(self):
         if self.connection is not None:
             self.__close()
@@ -590,6 +637,7 @@ class _ConnectionRecord(object):
         if self.__pool.dispatch.close:
             self.__pool.dispatch.close(self.connection, self)
         self.__pool._close_connection(self.connection)
+        self.connection = None
 
     def __connect(self, first_connect_check=False):
         pool = self.__pool
@@ -812,9 +860,30 @@ class _ConnectionFairy(object):
         with the :attr:`._ConnectionRecord.info` and :attr:`.Connection.info`
         accessors.
 
+        The dictionary associated with a particular DBAPI connection is
+        discarded when the connection itself is discarded.
+
         """
         return self._connection_record.info
 
+    @property
+    def record_info(self):
+        """Info dictionary associated with the :class:`._ConnectionRecord
+        container referred to by this :class:`.ConnectionFairy`.
+
+        Unlike the :attr:`._ConnectionFairy.info` dictionary, the lifespan
+        of this dictionary is persistent across connections that are
+        disconnected and/or invalidated within the lifespan of a
+        :class:`._ConnectionRecord`.
+
+        .. versionadded:: 1.1
+
+        """
+        if self._connection_record:
+            return self._connection_record.record_info
+        else:
+            return None
+
     def invalidate(self, e=None, soft=False):
         """Mark this connection as invalidated.
 
@@ -938,7 +1007,7 @@ class SingletonThreadPool(Pool):
                               use_threadlocal=self._use_threadlocal,
                               reset_on_return=self._reset_on_return,
                               _dispatch=self.dispatch,
-                              _dialect=self._dialect)
+                              dialect=self._dialect)
 
     def dispose(self):
         """Dispose of this pool."""
@@ -1098,7 +1167,7 @@ class QueuePool(Pool):
                               use_threadlocal=self._use_threadlocal,
                               reset_on_return=self._reset_on_return,
                               _dispatch=self.dispatch,
-                              _dialect=self._dialect)
+                              dialect=self._dialect)
 
     def dispose(self):
         while True:
@@ -1168,7 +1237,7 @@ class NullPool(Pool):
                               use_threadlocal=self._use_threadlocal,
                               reset_on_return=self._reset_on_return,
                               _dispatch=self.dispatch,
-                              _dialect=self._dialect)
+                              dialect=self._dialect)
 
     def dispose(self):
         pass
@@ -1210,7 +1279,7 @@ class StaticPool(Pool):
                               echo=self.echo,
                               logging_name=self._orig_logging_name,
                               _dispatch=self.dispatch,
-                              _dialect=self._dialect)
+                              dialect=self._dialect)
 
     def _create_connection(self):
         return self._conn
@@ -1264,7 +1333,7 @@ class AssertionPool(Pool):
         return self.__class__(self._creator, echo=self.echo,
                               logging_name=self._orig_logging_name,
                               _dispatch=self.dispatch,
-                              _dialect=self._dialect)
+                              dialect=self._dialect)
 
     def _do_get(self):
         if self._checked_out:
index da5997661ece7a25e18eb1d3dadcd5cfdd422db6..6648f9130faebddb454bcf797c1ce35c862647b7 100644 (file)
@@ -15,7 +15,11 @@ file_config = None
 test_schema = None
 test_schema_2 = None
 _current = None
-_skip_test_exception = None
+
+try:
+    from unittest import SkipTest as _skip_test_exception
+except ImportError:
+    _skip_test_exception = None
 
 
 class Config(object):
@@ -90,3 +94,4 @@ class Config(object):
 
 def skip_test(msg):
     raise _skip_test_exception(msg)
+
index b672656a0b898db85e329afd9fdd6abe1411bc9a..fb1041db38eae0471cbd2b4e7372b26d38ac7f89 100644 (file)
@@ -109,21 +109,21 @@ class compound(object):
         else:
             all_fails._expect_success(config._current)
 
-    def _do(self, config, fn, *args, **kw):
+    def _do(self, cfg, fn, *args, **kw):
         for skip in self.skips:
-            if skip(config):
+            if skip(cfg):
                 msg = "'%s' : %s" % (
                     fn.__name__,
-                    skip._as_string(config)
+                    skip._as_string(cfg)
                 )
                 config.skip_test(msg)
 
         try:
             return_value = fn(*args, **kw)
         except Exception as ex:
-            self._expect_failure(config, ex, name=fn.__name__)
+            self._expect_failure(cfg, ex, name=fn.__name__)
         else:
-            self._expect_success(config, name=fn.__name__)
+            self._expect_success(cfg, name=fn.__name__)
             return return_value
 
     def _expect_failure(self, config, ex, name='block'):
@@ -208,8 +208,10 @@ class Predicate(object):
         if negate:
             bool_ = not negate
         return self.description % {
-            "driver": config.db.url.get_driver_name(),
-            "database": config.db.url.get_backend_name(),
+            "driver": config.db.url.get_driver_name()
+            if config else "<no driver>",
+            "database": config.db.url.get_backend_name()
+            if config else "<no database>",
             "doesnt_support": "doesn't support" if bool_ else "does support",
             "does_support": "does support" if bool_ else "doesn't support"
         }
index fc9d71165b928c3c494acf42ad28cb68dcf08c6b..6581195dff252ee2cd8047c1abbf30e020fe88ea 100644 (file)
@@ -267,6 +267,7 @@ def _engine_uri(options, file_config):
     if not db_urls:
         db_urls.append(file_config.get('db', 'default'))
 
+    config._current = None
     for db_url in db_urls:
         cfg = provision.setup_config(
             db_url, options, file_config, provision.FOLLOWER_IDENT)
index a9370a30e27beff5424e8b1e18b31d8d4c16b62e..b0f466892c3be0818218abc2b91f4f94bc75ce20 100644 (file)
@@ -15,6 +15,8 @@ to provide specific inclusion/exclusions.
 
 """
 
+import sys
+
 from . import exclusions
 from .. import util
 
@@ -707,6 +709,44 @@ class SuiteRequirements(Requirements):
             "Stability issues with coverage + py3k"
         )
 
+    @property
+    def python2(self):
+        return exclusions.skip_if(
+            lambda: sys.version_info >= (3,),
+            "Python version 2.xx is required."
+        )
+
+    @property
+    def python3(self):
+        return exclusions.skip_if(
+            lambda: sys.version_info < (3,),
+            "Python version 3.xx is required."
+        )
+
+    @property
+    def cpython(self):
+        return exclusions.only_if(
+            lambda: util.cpython,
+            "cPython interpreter needed"
+        )
+
+    @property
+    def non_broken_pickle(self):
+        from sqlalchemy.util import pickle
+        return exclusions.only_if(
+            lambda: not util.pypy and pickle.__name__ == 'cPickle'
+                or sys.version_info >= (3, 2),
+            "Needs cPickle+cPython or newer Python 3 pickle"
+        )
+
+    @property
+    def predictable_gc(self):
+        """target platform must remove all cycles unconditionally when
+        gc.collect() is called, as well as clean out unreferenced subclasses.
+
+        """
+        return self.cpython
+
     @property
     def no_coverage(self):
         """Test should be skipped if coverage is enabled.
index 0e1f6c3d2e8f7d6f4c3474e8369f2f88f384194e..894fff2808f1f02c394deef60b6b2b29e50ea300 100644 (file)
@@ -6,6 +6,7 @@ import sqlalchemy as tsa
 from sqlalchemy.testing import fixtures
 from sqlalchemy import testing
 from sqlalchemy.testing.mock import Mock, MagicMock, call
+from sqlalchemy.testing import mock
 from sqlalchemy.dialects import registry
 from sqlalchemy.dialects import plugins
 
@@ -403,6 +404,8 @@ class TestRegNewDBAPI(fixtures.TestBase):
             MyEnginePlugin.mock_calls,
             [
                 call(e.url, {}),
+                call.handle_dialect_kwargs(sqlite.dialect, mock.ANY),
+                call.handle_pool_kwargs(mock.ANY, {"dialect": e.dialect}),
                 call.engine_created(e)
             ]
         )
index 057289199f3ad9580ff0888fbe2d4f38b6108819..5b87c90b89daec6eaf29f8c1ebdf31561d74fff6 100644 (file)
@@ -232,6 +232,81 @@ class PoolTest(PoolTestBase):
         assert not c2.info
         assert 'foo2' in c.info
 
+    def test_rec_info(self):
+        p = self._queuepool_fixture(pool_size=1, max_overflow=0)
+
+        c = p.connect()
+        self.assert_(not c.record_info)
+        self.assert_(c.record_info is c._connection_record.record_info)
+
+        c.record_info['foo'] = 'bar'
+        c.close()
+        del c
+
+        c = p.connect()
+        self.assert_('foo' in c.record_info)
+
+        c.invalidate()
+        c = p.connect()
+        self.assert_('foo' in c.record_info)
+
+        c.record_info['foo2'] = 'bar2'
+        c.detach()
+        is_(c.record_info, None)
+        is_(c._connection_record, None)
+
+        c2 = p.connect()
+
+        assert c2.record_info
+        assert 'foo2' in c2.record_info
+
+    def test_rec_unconnected(self):
+        # test production of a _ConnectionRecord with an
+        # initally unconnected state.
+
+        dbapi = MockDBAPI()
+        p1 = pool.Pool(
+            creator=lambda: dbapi.connect('foo.db')
+        )
+
+        r1 = pool._ConnectionRecord(p1, connect=False)
+
+        assert not r1.connection
+        c1 = r1.get_connection()
+        is_(c1, r1.connection)
+
+    def test_rec_close_reopen(self):
+        # test that _ConnectionRecord.close() allows
+        # the record to be reusable
+        dbapi = MockDBAPI()
+        p1 = pool.Pool(
+            creator=lambda: dbapi.connect('foo.db')
+        )
+
+        r1 = pool._ConnectionRecord(p1)
+
+        c1 = r1.connection
+        c2 = r1.get_connection()
+        is_(c1, c2)
+
+        r1.close()
+
+        assert not r1.connection
+        eq_(
+            c1.mock_calls,
+            [call.close()]
+        )
+
+        c2 = r1.get_connection()
+
+        is_not_(c1, c2)
+        is_(c2, r1.connection)
+
+        eq_(
+            c2.mock_calls,
+            []
+        )
+
 
 class PoolDialectTest(PoolTestBase):
     def _dialect(self):
index 87e3bb726448de88ff0897130fce0c85220b913b..3a2fcf03b757162343ff9973e02c1362a0c2dd48 100644 (file)
@@ -747,45 +747,6 @@ class DefaultRequirements(SuiteRequirements):
     def duplicate_key_raises_integrity_error(self):
         return fails_on("postgresql+pg8000")
 
-    @property
-    def python2(self):
-        return skip_if(
-                lambda: sys.version_info >= (3,),
-                "Python version 2.xx is required."
-                )
-
-    @property
-    def python3(self):
-        return skip_if(
-                lambda: sys.version_info < (3,),
-                "Python version 3.xx is required."
-                )
-
-    @property
-    def cpython(self):
-        return only_if(lambda: util.cpython,
-               "cPython interpreter needed"
-             )
-
-
-    @property
-    def non_broken_pickle(self):
-        from sqlalchemy.util import pickle
-        return only_if(
-            lambda: not util.pypy and pickle.__name__ == 'cPickle'
-                or sys.version_info >= (3, 2),
-            "Needs cPickle+cPython or newer Python 3 pickle"
-        )
-
-
-    @property
-    def predictable_gc(self):
-        """target platform must remove all cycles unconditionally when
-        gc.collect() is called, as well as clean out unreferenced subclasses.
-
-        """
-        return self.cpython
-
     @property
     def hstore(self):
         def check_hstore(config):