]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the dialects within sqlalchemy.databases become a setuptools
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Apr 2007 20:49:35 +0000 (20:49 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 17 Apr 2007 20:49:35 +0000 (20:49 +0000)
entry points. loading the built-in database dialects works the
same as always, but if none found will fall back to trying
pkg_resources to load an external module [ticket:521]

12 files changed:
CHANGES
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/url.py
setup.py
test/engine/parseconnect.py
test/sql/testtypes.py

diff --git a/CHANGES b/CHANGES
index d22f527401bb4e01d537bb175f1e7e4d462174a5..d285509b0bfbb8df038c8b78cca2b177ed35ca94 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       related error messages.  Additionally, when a "connection no 
       longer open" condition is detected, the entire connection pool 
       is discarded and replaced with a new instance.  #516
+    - the dialects within sqlalchemy.databases become a setuptools
+      entry points. loading the built-in database dialects works the
+      same as always, but if none found will fall back to trying
+      pkg_resources to load an external module [ticket:521]
 - sql:
     - preliminary support for unicode table names, column names and 
       SQL statements added, for databases which can support them.
index 2ab88101a99240cfa591c501f8059409e12d6ec9..4695426eb28dfc48ba31ea66957547b4d7843c09 100644 (file)
@@ -15,9 +15,6 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 
-def dbapi():
-    import kinterbasdb
-    return kinterbasdb
 
 _initialized_kb = False
 
@@ -113,6 +110,11 @@ class FBDialect(ansisql.ANSIDialect):
         self.type_conv = type_conv
         self.concurrency_level= concurrency_level
 
+    def dbapi(cls):
+        import kinterbasdb
+        return kinterbasdb
+    dbapi = classmethod(dbapi)
+    
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
         if opts.get('port'):
index 41b51d12ddff7f711567d69eaa27600208730d23..22fafad814eb271e9c96b7f0db950cf911d78a43 100644 (file)
@@ -52,21 +52,6 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 
-def dbapi(module_name=None):
-    if module_name:
-        try:
-            dialect_cls = dialect_mapping[module_name]
-            return dialect_cls.import_dbapi()
-        except KeyError:
-            raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
-    else:
-        for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]:
-            try:
-                return dialect_cls.import_dbapi()
-            except ImportError, e:
-                pass
-        else:
-            raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
     
 class MSNumeric(sqltypes.Numeric):
     def convert_result_value(self, value, dialect):
@@ -331,7 +316,24 @@ class MSSQLDialect(ansisql.ANSIDialect):
         self.auto_identity_insert = auto_identity_insert
         self.text_as_varchar = False
         self.set_default_schema_name("dbo")
-            
+
+    def dbapi(cls, module_name=None):
+        if module_name:
+            try:
+                dialect_cls = dialect_mapping[module_name]
+                return dialect_cls.import_dbapi()
+            except KeyError:
+                raise exceptions.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+        else:
+            for dialect_cls in [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]:
+                try:
+                    return dialect_cls.import_dbapi()
+                except ImportError, e:
+                    pass
+            else:
+                raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
+    dbapi = classmethod(dbapi)
+    
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
         opts.update(url.query)
index d3a42ccdc47b545a9a43309ac852bcf52cafd50e..21f8bb3984fa97873048181c81382ac7913b2357 100644 (file)
@@ -12,9 +12,6 @@ import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 from array import array
 
-def dbapi():
-    import MySQLdb as mysql
-    return mysql
 
 def kw_colspec(self, spec):
     if self.unsigned:
@@ -280,6 +277,11 @@ class MySQLDialect(ansisql.ANSIDialect):
     def __init__(self, **kwargs):
         ansisql.ANSIDialect.__init__(self, default_paramstyle='format', **kwargs)
 
+    def dbapi(cls):
+        import MySQLdb as mysql
+        return mysql
+    dbapi = classmethod(dbapi)
+    
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port'])
         opts.update(url.query)
index fce59a0725a36e4bdb667b04db5758198f03897e..f49f1d4c0673616a552277d031a6087784573b5e 100644 (file)
@@ -11,9 +11,6 @@ from sqlalchemy import util, sql, engine, schema, ansisql, exceptions, logging
 from sqlalchemy.engine import default, base
 import sqlalchemy.types as sqltypes
 
-def dbapi():
-    import cx_Oracle
-    return cx_Oracle
 
 
 class OracleNumeric(sqltypes.Numeric):
@@ -172,7 +169,12 @@ class OracleDialect(ansisql.ANSIDialect):
             self.ORACLE_BINARY_TYPES = [getattr(self.dbapi, k) for k in ["BFILE", "CLOB", "NCLOB", "BLOB", "LONG_BINARY", "LONG_STRING"] if hasattr(self.dbapi, k)]
         else:
             self.ORACLE_BINARY_TYPES = []
-            
+
+    def dbapi(cls):
+        import cx_Oracle
+        return cx_Oracle
+    dbapi = classmethod(dbapi)
+    
     def create_connect_args(self, url):
         if url.database:
             # if we have a database, then we have a remote host
index a93ba200cf289a0d96346c86beec19c942084518..0eca18be3800acd3782a72c8ac99b275553d114c 100644 (file)
@@ -16,15 +16,6 @@ try:
 except:
     mxDateTime = None
 
-def dbapi():
-    try:
-        import psycopg2 as psycopg
-    except ImportError, e:
-        try:
-            import psycopg
-        except ImportError, e2:
-            raise e
-    return psycopg
     
 class PGInet(sqltypes.TypeEngine):
     def get_col_spec(self):
@@ -258,6 +249,17 @@ class PGDialect(ansisql.ANSIDialect):
         self.use_information_schema = use_information_schema
         self.paramstyle = 'pyformat'
 
+    def dbapi(cls):
+        try:
+            import psycopg2 as psycopg
+        except ImportError, e:
+            try:
+                import psycopg
+            except ImportError, e2:
+                raise e
+        return psycopg
+    dbapi = classmethod(dbapi)
+    
     def create_connect_args(self, url):
         opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
         if opts.has_key('port'):
index 2b7e28dfb5c73a837553fe02d599198996fe121b..0222496f83719b9c0a8aa1aa016d8c4decc41963 100644 (file)
@@ -12,18 +12,6 @@ import sqlalchemy.engine.default as default
 import sqlalchemy.types as sqltypes
 import datetime,time
 
-def dbapi():
-    try:
-        from pysqlite2 import dbapi2 as sqlite
-    except ImportError, e:
-        try:
-            from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
-        except ImportError:
-            try:
-                sqlite = __import__('sqlite') # skip ourselves
-            except ImportError:
-                raise e
-    return sqlite
     
 class SLNumeric(sqltypes.Numeric):
     def get_col_spec(self):
@@ -160,6 +148,20 @@ class SQLiteDialect(ansisql.ANSIDialect):
             return tuple([int(x) for x in num.split('.')])
         self.supports_cast = (self.dbapi is None or vers(self.dbapi.sqlite_version) >= vers("3.2.3"))
 
+    def dbapi(cls):
+        try:
+            from pysqlite2 import dbapi2 as sqlite
+        except ImportError, e:
+            try:
+                from sqlite3 import dbapi2 as sqlite #try the 2.5+ stdlib name.
+            except ImportError:
+                try:
+                    sqlite = __import__('sqlite') # skip ourselves
+                except ImportError:
+                    raise e
+        return sqlite
+    dbapi = classmethod(dbapi)
+
     def compiler(self, statement, bindparams, **kwargs):
         return SQLiteCompiler(self, statement, bindparams, **kwargs)
 
@@ -347,4 +349,4 @@ class SQLiteIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         super(SQLiteIdentifierPreparer, self).__init__(dialect, omit_schema=True)
 
 dialect = SQLiteDialect
-poolclass = pool.SingletonThreadPool
+dialect.poolclass = pool.SingletonThreadPool
index ba9b0968a5a48a3fc51450a641c135b6f906992b..ed31743d8e1c75b287d03aee7a1307e1a5566c5c 100644 (file)
@@ -41,27 +41,26 @@ class DefaultEngineStrategy(EngineStrategy):
         # create url.URL object
         u = url.make_url(name_or_url)
 
-        # get module from sqlalchemy.databases
-        module = u.get_module()
+        dialect_cls = u.get_dialect()
 
         dialect_args = {}
         # consume dialect arguments from kwargs
-        for k in util.get_cls_kwargs(module.dialect):
+        for k in util.get_cls_kwargs(dialect_cls):
             if k in kwargs:
                 dialect_args[k] = kwargs.pop(k)
 
         dbapi = kwargs.pop('module', None)
         if dbapi is None:
             dbapi_args = {}
-            for k in util.get_func_kwargs(module.dbapi):
+            for k in util.get_func_kwargs(dialect_cls.dbapi):
                 if k in kwargs:
                     dbapi_args[k] = kwargs.pop(k)
-            dbapi = module.dbapi(**dbapi_args)
+            dbapi = dialect_cls.dbapi(**dbapi_args)
         
         dialect_args['dbapi'] = dbapi
         
         # create dialect
-        dialect = module.dialect(**dialect_args)
+        dialect = dialect_cls(**dialect_args)
 
         # assemble connection arguments
         (cargs, cparams) = dialect.create_connect_args(u)
@@ -77,7 +76,7 @@ class DefaultEngineStrategy(EngineStrategy):
                     raise exceptions.DBAPIError("Connection failed", e)
             creator = kwargs.pop('creator', connect)
 
-            poolclass = kwargs.pop('poolclass', getattr(module, 'poolclass', poollib.QueuePool))
+            poolclass = kwargs.pop('poolclass', getattr(dialect_cls, 'poolclass', poollib.QueuePool))
             pool_args = {}
 
             # consume pool arguments from kwargs, translating a few of the arguments
@@ -158,17 +157,16 @@ class MockEngineStrategy(EngineStrategy):
         # create url.URL object
         u = url.make_url(name_or_url)
 
-        # get module from sqlalchemy.databases
-        module = u.get_module()
+        dialect_cls = u.get_dialect()
 
         dialect_args = {}
         # consume dialect arguments from kwargs
-        for k in util.get_cls_kwargs(module.dialect):
+        for k in util.get_cls_kwargs(dialect_cls):
             if k in kwargs:
                 dialect_args[k] = kwargs.pop(k)
 
         # create dialect
-        dialect = module.dialect(**dialect_args)
+        dialect = dialect_cls(**dialect_args)
 
         return MockEngineStrategy.MockConnection(dialect, executor)
 
index faa0ffc11cd94e53099adc2e2f0a8b24c682e6b9..c5ad90ee9f83d3aa639df13642d84e83cfd017de 100644 (file)
@@ -69,19 +69,28 @@ class URL(object):
             s += '?' + "&".join(["%s=%s" % (k, self.query[k]) for k in keys])
         return s
 
-    def get_module(self):
-        """Return the SQLAlchemy database module corresponding to this URL's driver name."""
+    def get_dialect(self):
+        """Return the SQLAlchemy database dialect class corresponding to this URL's driver name."""
+        dialect=None
         if self.drivername == 'ansi':
             import sqlalchemy.ansisql
-            return sqlalchemy.ansisql
-            
+            return sqlalchemy.ansisql.dialect
+
         try:
-            return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+            module=getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
+            dialect=module.dialect
         except ImportError:
             if sys.exc_info()[2].tb_next is None:
-                raise exceptions.ArgumentError('unknown database %r' % self.drivername)
-            raise
-
+                import pkg_resources
+                for res in pkg_resources.iter_entry_points('sqlalchemy.databases'):
+                    if res.name==self.drivername:
+                        dialect=res.load()
+            else:
+               raise
+        if dialect is not None:
+            return dialect
+        raise ImportError('unknown database %r' % self.drivername) 
+  
     def translate_connect_args(self, names):
         """Translate this URL's attributes into a dictionary of connection arguments.
 
index 48bbb9c9cb7fbb56d09755a0d8d0fbfa82b5f9d9..552c9265c8dd30848935e18d136efaa98ba20c4c 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -10,6 +10,10 @@ setup(name = "SQLAlchemy",
     url = "http://www.sqlalchemy.org",
     packages = find_packages('lib'),
     package_dir = {'':'lib'},
+    entry_points = { 
+      'sqlalchemy.databases': [
+        '%s = sqlalchemy.databases.%s:dialect' % (f,f) for f in 
+          ['sqlite', 'postgres', 'mysql', 'oracle', 'mssql', 'firebird']]},
     license = "MIT License",
     long_description = """\
 SQLAlchemy is:
index 01e4efbf13a6458ae5dcfecbcde93675b81e973b..49f71f881709099b1a082a0ce591a92d0754556d 100644 (file)
@@ -70,7 +70,13 @@ class CreateEngineTest(PersistTest):
     def testbadargs(self):
         # good arg, use MockDBAPI to prevent oracle import errors
         e = create_engine('oracle://', use_ansi=True, module=MockDBAPI())
-
+        
+        try:
+            e = create_engine("foobar://", module=MockDBAPI())
+            assert False
+        except ImportError:
+            assert True 
+            
         # bad arg
         try:
             e = create_engine('postgres://', use_ansi=True, module=MockDBAPI())
index d1256b31a564bbad018bbdb5a422526107fe13c4..b6d144d302f40e38541aa5f9702697f8721dc8ea 100644 (file)
@@ -40,9 +40,9 @@ class MyUnicodeType(types.TypeDecorator):
 
 class AdaptTest(PersistTest):
     def testadapt(self):
-        e1 = url.URL('postgres').get_module().dialect()
-        e2 = url.URL('mysql').get_module().dialect()
-        e3 = url.URL('sqlite').get_module().dialect()
+        e1 = url.URL('postgres').get_dialect()()
+        e2 = url.URL('mysql').get_dialect()()
+        e3 = url.URL('sqlite').get_dialect()()
         
         type = String(40)