From: Mike Bayer Date: Sat, 23 May 2015 13:07:36 +0000 (-0400) Subject: - fix some tests related to the URL change and try to make X-Git-Tag: rel_1_0_5~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e9921ad356fee4edb56007ae39793fb2211f13cf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git - fix some tests related to the URL change and try to make the URL design a little simpler --- diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index e2a086de42..a539ee9f71 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -48,7 +48,8 @@ class DefaultEngineStrategy(EngineStrategy): # create url.URL object u = url.make_url(name_or_url) - entrypoint, dialect_cls = u._get_dialect_plus_entrypoint() + entrypoint = u._get_entrypoint() + dialect_cls = entrypoint.get_dialect_cls(u) if kwargs.pop('_coerce_config', False): def pop_kwarg(key, default=None): diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index 07f6a5730a..32e3f8a6bd 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -117,7 +117,13 @@ class URL(object): else: return self.drivername.split('+')[1] - def _get_dialect_plus_entrypoint(self): + def _get_entrypoint(self): + """Return the "entry point" dialect class. + + This is normally the dialect itself except in the case when the + returned class implements the get_dialect_cls() method. + + """ if '+' not in self.drivername: name = self.drivername else: @@ -129,16 +135,16 @@ class URL(object): if hasattr(cls, 'dialect') and \ isinstance(cls.dialect, type) and \ issubclass(cls.dialect, Dialect): - return cls.dialect, cls.dialect + return cls.dialect else: - dialect_cls = cls.get_dialect_cls(self) - return cls, dialect_cls + return cls def get_dialect(self): """Return the SQLAlchemy database dialect class corresponding to this URL's driver name. """ - entrypoint, dialect_cls = self._get_dialect_plus_entrypoint() + entrypoint = self._get_entrypoint() + dialect_cls = entrypoint.get_dialect_cls(self) return dialect_cls def translate_connect_args(self, names=[], **kw): diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index 6193196936..39ebcc91ba 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -370,6 +370,9 @@ class MockReconnectTest(fixtures.TestBase): mock_dialect = Mock() class MyURL(URL): + def _get_entrypoint(self): + return Dialect + def get_dialect(self): return Dialect @@ -420,6 +423,7 @@ class CursorErrTest(fixtures.TestBase): from sqlalchemy.engine import default url = Mock( get_dialect=lambda: default.DefaultDialect, + _get_entrypoint=lambda: default.DefaultDialect, translate_connect_args=lambda: {}, query={},) eng = testing_engine( url, options=dict(module=dbapi, _initialize=initialize))