]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- fix some tests related to the URL change and try to make
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 May 2015 13:07:36 +0000 (09:07 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 23 May 2015 13:07:36 +0000 (09:07 -0400)
the URL design a little simpler

lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/url.py
test/engine/test_reconnect.py

index e2a086de42cee556593f1bef61eece9588374063..a539ee9f715782ea22e08a7e3b0f0b2ad1ef63f1 100644 (file)
@@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy):
         # create url.URL object
         u = url.make_url(name_or_url)
 
-        entrypoint, dialect_cls = u._get_dialect_plus_entrypoint()
+        entrypoint = u._get_entrypoint()
+        dialect_cls = entrypoint.get_dialect_cls(u)
 
         if kwargs.pop('_coerce_config', False):
             def pop_kwarg(key, default=None):
index 07f6a5730a708f7649de75a27372e3a781b0ca3d..32e3f8a6bd0faa8acd4636090744424e0095fa36 100644 (file)
@@ -117,7 +117,13 @@ class URL(object):
         else:
             return self.drivername.split('+')[1]
 
-    def _get_dialect_plus_entrypoint(self):
+    def _get_entrypoint(self):
+        """Return the "entry point" dialect class.
+
+        This is normally the dialect itself except in the case when the
+        returned class implements the get_dialect_cls() method.
+
+        """
         if '+' not in self.drivername:
             name = self.drivername
         else:
@@ -129,16 +135,16 @@ class URL(object):
         if hasattr(cls, 'dialect') and \
                 isinstance(cls.dialect, type) and \
                 issubclass(cls.dialect, Dialect):
-            return cls.dialect, cls.dialect
+            return cls.dialect
         else:
-            dialect_cls = cls.get_dialect_cls(self)
-            return cls, dialect_cls
+            return cls
 
     def get_dialect(self):
         """Return the SQLAlchemy database dialect class corresponding
         to this URL's driver name.
         """
-        entrypoint, dialect_cls = self._get_dialect_plus_entrypoint()
+        entrypoint = self._get_entrypoint()
+        dialect_cls = entrypoint.get_dialect_cls(self)
         return dialect_cls
 
     def translate_connect_args(self, names=[], **kw):
index 61931969368d07fbdaee658e5054b206a9204b81..39ebcc91ba24db0788575f169bbc1f9d1b4e7b95 100644 (file)
@@ -370,6 +370,9 @@ class MockReconnectTest(fixtures.TestBase):
         mock_dialect = Mock()
 
         class MyURL(URL):
+            def _get_entrypoint(self):
+                return Dialect
+
             def get_dialect(self):
                 return Dialect
 
@@ -420,6 +423,7 @@ class CursorErrTest(fixtures.TestBase):
         from sqlalchemy.engine import default
         url = Mock(
             get_dialect=lambda: default.DefaultDialect,
+            _get_entrypoint=lambda: default.DefaultDialect,
             translate_connect_args=lambda: {}, query={},)
         eng = testing_engine(
             url, options=dict(module=dbapi, _initialize=initialize))