From aabb6e530b05bee9cc5c2382a308a987abd6168e Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 17 Apr 2007 20:49:35 +0000 Subject: [PATCH] - the dialects within sqlalchemy.databases become a setuptools entry points. loading the built-in database dialects works the same as always, but if none found will fall back to trying pkg_resources to load an external module [ticket:521] --- CHANGES | 4 ++++ lib/sqlalchemy/databases/firebird.py | 8 ++++--- lib/sqlalchemy/databases/mssql.py | 34 +++++++++++++++------------- lib/sqlalchemy/databases/mysql.py | 8 ++++--- lib/sqlalchemy/databases/oracle.py | 10 ++++---- lib/sqlalchemy/databases/postgres.py | 20 ++++++++-------- lib/sqlalchemy/databases/sqlite.py | 28 ++++++++++++----------- lib/sqlalchemy/engine/strategies.py | 20 ++++++++-------- lib/sqlalchemy/engine/url.py | 25 +++++++++++++------- setup.py | 4 ++++ test/engine/parseconnect.py | 8 ++++++- test/sql/testtypes.py | 6 ++--- 12 files changed, 104 insertions(+), 71 deletions(-) diff --git a/CHANGES b/CHANGES index d22f527401..d285509b0b 100644 --- a/CHANGES +++ b/CHANGES @@ -17,6 +17,10 @@ related error messages. Additionally, when a "connection no longer open" condition is detected, the entire connection pool is discarded and replaced with a new instance. #516 + - the dialects within sqlalchemy.databases become a setuptools + entry points. loading the built-in database dialects works the + same as always, but if none found will fall back to trying + pkg_resources to load an external module [ticket:521] - sql: - preliminary support for unicode table names, column names and SQL statements added, for databases which can support them. diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 2ab88101a9..4695426eb2 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -15,9 +15,6 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions -def dbapi(): - import kinterbasdb - return kinterbasdb _initialized_kb = False @@ -113,6 +110,11 @@ class FBDialect(ansisql.ANSIDialect): self.type_conv = type_conv self.concurrency_level= concurrency_level + def dbapi(cls): + import kinterbasdb + return kinterbasdb + dbapi = classmethod(dbapi) + def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.get('port'): diff --git a/lib/sqlalchemy/databases/mssql.py b/lib/sqlalchemy/databases/mssql.py index 41b51d12dd..22fafad814 100644 --- a/lib/sqlalchemy/databases/mssql.py +++ b/lib/sqlalchemy/databases/mssql.py @@ -52,21 +52,6 @@ import sqlalchemy.ansisql as ansisql import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions -def dbapi(module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) - else: - for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]: - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass - else: - raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') class MSNumeric(sqltypes.Numeric): def convert_result_value(self, value, dialect): @@ -331,7 +316,24 @@ class MSSQLDialect(ansisql.ANSIDialect): self.auto_identity_insert = auto_identity_insert self.text_as_varchar = False self.set_default_schema_name("dbo") - + + def dbapi(cls, module_name=None): + if module_name: + try: + dialect_cls = dialect_mapping[module_name] + return dialect_cls.import_dbapi() + except KeyError: + raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + else: + for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]: + try: + return dialect_cls.import_dbapi() + except ImportError, e: + pass + else: + raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc') + dbapi = classmethod(dbapi) + def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) opts.update(url.query) diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index d3a42ccdc4..21f8bb3984 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -12,9 +12,6 @@ import sqlalchemy.types as sqltypes import sqlalchemy.exceptions as exceptions from array import array -def dbapi(): - import MySQLdb as mysql - return mysql def kw_colspec(self, spec): if self.unsigned: @@ -280,6 +277,11 @@ class MySQLDialect(ansisql.ANSIDialect): def __init__(self, **kwargs): ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs) + def dbapi(cls): + import MySQLdb as mysql + return mysql + dbapi = classmethod(dbapi) + def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port']) opts.update(url.query) diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index fce59a0725..f49f1d4c06 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -11,9 +11,6 @@ from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging from sqlalchemy.engine import default, base import sqlalchemy.types as sqltypes -def dbapi(): - import cx_Oracle - return cx_Oracle class OracleNumeric(sqltypes.Numeric): @@ -172,7 +169,12 @@ class OracleDialect(ansisql.ANSIDialect): self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)] else: self.ORACLE_BINARY_TYPES = [] - + + def dbapi(cls): + import cx_Oracle + return cx_Oracle + dbapi = classmethod(dbapi) + def create_connect_args(self, url): if url.database: # if we have a database, then we have a remote host diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index a93ba200cf..0eca18be38 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -16,15 +16,6 @@ try: except: mxDateTime = None -def dbapi(): - try: - import psycopg2 as psycopg - except ImportError, e: - try: - import psycopg - except ImportError, e2: - raise e - return psycopg class PGInet(sqltypes.TypeEngine): def get_col_spec(self): @@ -258,6 +249,17 @@ class PGDialect(ansisql.ANSIDialect): self.use_information_schema = use_information_schema self.paramstyle = 'pyformat' + def dbapi(cls): + try: + import psycopg2 as psycopg + except ImportError, e: + try: + import psycopg + except ImportError, e2: + raise e + return psycopg + dbapi = classmethod(dbapi) + def create_connect_args(self, url): opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port']) if opts.has_key('port'): diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 2b7e28dfb5..0222496f83 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -12,18 +12,6 @@ import sqlalchemy.engine.default as default import sqlalchemy.types as sqltypes import datetime,time -def dbapi(): - try: - from pysqlite2 import dbapi2 as sqlite - except ImportError, e: - try: - from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. - except ImportError: - try: - sqlite = __import__('sqlite') # skip ourselves - except ImportError: - raise e - return sqlite class SLNumeric(sqltypes.Numeric): def get_col_spec(self): @@ -160,6 +148,20 @@ class SQLiteDialect(ansisql.ANSIDialect): return tuple([int(x) for x in num.split('.')]) self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3")) + def dbapi(cls): + try: + from pysqlite2 import dbapi2 as sqlite + except ImportError, e: + try: + from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name. + except ImportError: + try: + sqlite = __import__('sqlite') # skip ourselves + except ImportError: + raise e + return sqlite + dbapi = classmethod(dbapi) + def compiler(self, statement, bindparams, **kwargs): return SQLiteCompiler(self, statement, bindparams, **kwargs) @@ -347,4 +349,4 @@ class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer): super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True) dialect = SQLiteDialect -poolclass = pool.SingletonThreadPool +dialect.poolclass = pool.SingletonThreadPool diff --git a/lib/sqlalchemy/engine/strategies.py b/lib/sqlalchemy/engine/strategies.py index ba9b0968a5..ed31743d8e 100644 --- a/lib/sqlalchemy/engine/strategies.py +++ b/lib/sqlalchemy/engine/strategies.py @@ -41,27 +41,26 @@ class DefaultEngineStrategy(EngineStrategy): # create url.URL object u = url.make_url(name_or_url) - # get module from sqlalchemy.databases - module = u.get_module() + dialect_cls = u.get_dialect() dialect_args = {} # consume dialect arguments from kwargs - for k in util.get_cls_kwargs(module.dialect): + for k in util.get_cls_kwargs(dialect_cls): if k in kwargs: dialect_args[k] = kwargs.pop(k) dbapi = kwargs.pop('module', None) if dbapi is None: dbapi_args = {} - for k in util.get_func_kwargs(module.dbapi): + for k in util.get_func_kwargs(dialect_cls.dbapi): if k in kwargs: dbapi_args[k] = kwargs.pop(k) - dbapi = module.dbapi(**dbapi_args) + dbapi = dialect_cls.dbapi(**dbapi_args) dialect_args['dbapi'] = dbapi # create dialect - dialect = module.dialect(**dialect_args) + dialect = dialect_cls(**dialect_args) # assemble connection arguments (cargs, cparams) = dialect.create_connect_args(u) @@ -77,7 +76,7 @@ class DefaultEngineStrategy(EngineStrategy): raise exceptions.DBAPIError("Connection failed", e) creator = kwargs.pop('creator', connect) - poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool)) + poolclass = kwargs.pop('poolclass', getattr(dialect_cls, 'poolclass', poollib.QueuePool)) pool_args = {} # consume pool arguments from kwargs, translating a few of the arguments @@ -158,17 +157,16 @@ class MockEngineStrategy(EngineStrategy): # create url.URL object u = url.make_url(name_or_url) - # get module from sqlalchemy.databases - module = u.get_module() + dialect_cls = u.get_dialect() dialect_args = {} # consume dialect arguments from kwargs - for k in util.get_cls_kwargs(module.dialect): + for k in util.get_cls_kwargs(dialect_cls): if k in kwargs: dialect_args[k] = kwargs.pop(k) # create dialect - dialect = module.dialect(**dialect_args) + dialect = dialect_cls(**dialect_args) return MockEngineStrategy.MockConnection(dialect, executor) diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index faa0ffc11c..c5ad90ee9f 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -69,19 +69,28 @@ class URL(object): s += '?' + "&".join(["%s=%s" % (k, self.query[k]) for k in keys]) return s - def get_module(self): - """Return the SQLAlchemy database module corresponding to this URL's driver name.""" + def get_dialect(self): + """Return the SQLAlchemy database dialect class corresponding to this URL's driver name.""" + dialect=None if self.drivername == 'ansi': import sqlalchemy.ansisql - return sqlalchemy.ansisql - + return sqlalchemy.ansisql.dialect + try: - return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + module=getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername) + dialect=module.dialect except ImportError: if sys.exc_info()[2].tb_next is None: - raise exceptions.ArgumentError('unknown database %r' % self.drivername) - raise - + import pkg_resources + for res in pkg_resources.iter_entry_points('sqlalchemy.databases'): + if res.name==self.drivername: + dialect=res.load() + else: + raise + if dialect is not None: + return dialect + raise ImportError('unknown database %r' % self.drivername) + def translate_connect_args(self, names): """Translate this URL's attributes into a dictionary of connection arguments. diff --git a/setup.py b/setup.py index 48bbb9c9cb..552c9265c8 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,10 @@ setup(name = "SQLAlchemy", url = "http://www.sqlalchemy.org", packages = find_packages('lib'), package_dir = {'':'lib'}, + entry_points = { + 'sqlalchemy.databases': [ + '%s = sqlalchemy.databases.%s:dialect' % (f,f) for f in + ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird']]}, license = "MIT License", long_description = """\ SQLAlchemy is: diff --git a/test/engine/parseconnect.py b/test/engine/parseconnect.py index 01e4efbf13..49f71f8817 100644 --- a/test/engine/parseconnect.py +++ b/test/engine/parseconnect.py @@ -70,7 +70,13 @@ class CreateEngineTest(PersistTest): def testbadargs(self): # good arg, use MockDBAPI to prevent oracle import errors e = create_engine('oracle://', use_ansi=True, module=MockDBAPI()) - + + try: + e = create_engine("foobar://", module=MockDBAPI()) + assert False + except ImportError: + assert True + # bad arg try: e = create_engine('postgres://', use_ansi=True, module=MockDBAPI()) diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index d1256b31a5..b6d144d302 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -40,9 +40,9 @@ class MyUnicodeType(types.TypeDecorator): class AdaptTest(PersistTest): def testadapt(self): - e1 = url.URL('postgres').get_module().dialect() - e2 = url.URL('mysql').get_module().dialect() - e3 = url.URL('sqlite').get_module().dialect() + e1 = url.URL('postgres').get_dialect()() + e2 = url.URL('mysql').get_dialect()() + e3 = url.URL('sqlite').get_dialect()() type = String(40) -- 2.47.2