]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
mssql: cleanup of module importing code; specifiable DB-API module; more explicit...
authorRick Morrison <rickmorrison@gmail.com>
Sun, 18 Mar 2007 17:14:10 +0000 (17:14 +0000)
committerRick Morrison <rickmorrison@gmail.com>
Sun, 18 Mar 2007 17:14:10 +0000 (17:14 +0000)
CHANGES
lib/sqlalchemy/databases/mssql.py

diff --git a/CHANGES b/CHANGES
index 3eac53cfcb73eea2fb4de94d43bd5b7c62bb48e0..6b39299e197728c55f2b2a9877e482021c4255b2 100644 (file)
--- a/CHANGES
+++ b/CHANGES
       for large unsized string fields. Use the new "text_as_varchar" to 
       turn it on. [ticket:509]
 
+    - cleanup of module importing code; specifiable DB-API module; more 
+      explicit ordering of module preferences. [ticket:480]
+
 
 
     
index f7d64761affff61c284e707bc14e2909ad1f4468..9ae65725071f3e4badaa5068245ca19c3e555c28 100644 (file)
@@ -30,7 +30,7 @@
 
 Known issues / TODO:
 
-* No support for more than one ``IDENTITY`` column per table no
+* No support for more than one ``IDENTITY`` column per table
 
 * No support for table reflection of ``IDENTITY`` columns with
  (seed,increment) values other than (1,1)
@@ -38,8 +38,7 @@ Known issues / TODO:
 * No support for ``GUID`` type columns (yet)
 
 * pymssql has problems with binary and unicode data that this module
-  does **not** work around adodbapi fails testtypes.py unit test on
-  unicode data too -- issue with the test?
+  does **not** work around
 """
 
 import sys, StringIO, string, types, re, datetime
@@ -52,96 +51,6 @@ import sqlalchemy.ansisql as ansisql
 import sqlalchemy.types as sqltypes
 import sqlalchemy.exceptions as exceptions
 
-dbmodule = None
-dialect = None
-
-def use_adodbapi():
-    global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names
-    import adodbapi as dbmodule
-    # ADODBAPI has a non-standard Connection method
-    connect = dbmodule.Connection
-    def make_connect_string(keys):
-        connectors = ["Provider=SQLOLEDB"]
-        connectors.append ("Data Source=%s" % keys.get("host"))
-        connectors.append ("Initial Catalog=%s" % keys.get("database"))
-        user = keys.get("user")
-        if user:
-            connectors.append("User Id=%s" % user)
-            connectors.append("Password=%s" % keys.get("password", ""))
-        else:
-            connectors.append("Integrated Security=SSPI")
-        return [[";".join (connectors)], {}]
-    sane_rowcount = True
-    dialect = MSSQLDialect
-    colspecs[sqltypes.Unicode] = AdoMSUnicode
-    ischema_names['nvarchar'] = AdoMSUnicode
-    
-def use_pymssql():
-    global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names
-    import pymssql as dbmodule
-    connect = dbmodule.connect
-    # pymmsql doesn't have a Binary method.  we use string
-    dbmodule.Binary = lambda st: str(st)
-    def make_connect_string(keys):
-        if keys.get('port'):
-            # pymssql expects port as host:port, not a separate arg
-            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
-            del keys['port'] 
-        return [[], keys]
-    do_commit = True
-    sane_rowcount = False
-    dialect = PyMSSQLDialect
-    colspecs[sqltypes.Unicode] = MSUnicode
-    ischema_names['nvarchar'] = MSUnicode
-    
-def use_pyodbc():
-    global dbmodule, connect, make_connect_string, do_commit, sane_rowcount, dialect, colspecs, ischema_names
-    import pyodbc as dbmodule
-    connect = dbmodule.connect
-    def make_connect_string(keys):
-        connectors = ["Driver={SQL Server}"]
-        connectors.append("Server=%s" % keys.get("host"))
-        connectors.append("Database=%s" % keys.get("database"))
-        user = keys.get("user")
-        if user:
-            connectors.append("UID=%s" % user)
-            connectors.append("PWD=%s" % keys.get("password", ""))
-        else:
-            connectors.append ("TrustedConnection=Yes")
-        return [[";".join (connectors)], {}]
-    do_commit = True
-    sane_rowcount = False
-    dialect = MSSQLDialect
-    import warnings
-    warnings.warn('pyodbc support in sqlalchemy.databases.mssql is experimental - use at your own risk.')
-    colspecs[sqltypes.Unicode] = AdoMSUnicode
-    ischema_names['nvarchar'] = AdoMSUnicode
-
-def use_default():
-    import_errors = []
-    def try_use(f):
-        try:
-            f()
-        except ImportError, e:
-            import_errors.append(e)
-            return False
-        else:
-            return True
-    for f in [
-            # XXX - is this the best default ordering? For now, it retains the current (2007-Jan-11) 
-            # default - that is, adodbapi first, pymssql second - and adds pyodbc as a third option.
-            # However, my tests suggest that the exact opposite order may be the best!
-            use_adodbapi,
-            use_pymssql,
-            use_pyodbc,
-            ]:
-        if try_use(f):
-            return dbmodule # informational return, so the user knows what he's using.
-    else:
-        return None
-        # cant raise this right now since all dialects need to be importable/loadable
-        #raise ImportError(import_errors)
-        
 
 class MSNumeric(sqltypes.Numeric):
     def convert_result_value(self, value, dialect):
@@ -232,19 +141,16 @@ class MSString(sqltypes.String):
         return "VARCHAR(%(length)s)" % {'length' : self.length}
 
 class MSNVarchar(MSString):
-    """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True.
-    """
-
+    """NVARCHAR string, does Unicode conversion if `dialect.convert_encoding` is True.  """
     impl = sqltypes.Unicode
 
     def get_col_spec(self):
         if self.length:
             return "NVARCHAR(%(length)s)" % {'length' : self.length}
+        elif self.dialect.text_as_varchar:
+            return "NVARCHAR(max)"
         else:
-            if self.dialect.text_as_varchar:
-                return "NVARCHAR(max)"
-            else:
-                return "NTEXT"
+            return "NTEXT"
 
 class AdoMSNVarchar(MSNVarchar):
     def convert_bind_param(self, value, dialect):
@@ -255,7 +161,6 @@ class AdoMSNVarchar(MSNVarchar):
 
 class MSUnicode(sqltypes.Unicode):
     """Unicode subclass, does Unicode conversion in all cases, uses NVARCHAR impl."""
-
     impl = MSNVarchar
 
 class AdoMSUnicode(MSUnicode):
@@ -298,43 +203,6 @@ class MSBoolean(sqltypes.Boolean):
         else:
             return value and True or False
         
-colspecs = {
-    sqltypes.Integer : MSInteger,
-    sqltypes.Smallinteger: MSSmallInteger,
-    sqltypes.Numeric : MSNumeric,
-    sqltypes.Float : MSFloat,
-    sqltypes.DateTime : MSDateTime,
-    sqltypes.Date : MSDate,
-    sqltypes.String : MSString,
-    sqltypes.Unicode : MSUnicode,
-    sqltypes.Binary : MSBinary,
-    sqltypes.Boolean : MSBoolean,
-    sqltypes.TEXT : MSText,
-    sqltypes.CHAR: MSChar,
-    sqltypes.NCHAR: MSNChar,
-}
-
-ischema_names = {
-    'int' : MSInteger,
-    'smallint' : MSSmallInteger,
-    'tinyint' : MSTinyInteger,
-    'varchar' : MSString,
-    'nvarchar' : MSUnicode,
-    'char' : MSChar,
-    'nchar' : MSNChar,
-    'text' : MSText,
-    'ntext' : MSText, 
-    'decimal' : MSNumeric,
-    'numeric' : MSNumeric,
-    'float' : MSFloat,
-    'datetime' : MSDateTime,
-    'smalldatetime' : MSDate,
-    'binary' : MSBinary,
-    'bit': MSBoolean,
-    'real' : MSFloat,
-    'image' : MSBinary
-}
-
 def descriptor():
     return {'name':'mssql',
     'description':'MSSQL',
@@ -406,11 +274,63 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
 
 
 class MSSQLDialect(ansisql.ANSIDialect):
-    def __init__(self, module=None, auto_identity_insert=True, **params):
-        self.module = module or dbmodule or use_default()
+    colspecs = {
+        sqltypes.Integer : MSInteger,
+        sqltypes.Smallinteger: MSSmallInteger,
+        sqltypes.Numeric : MSNumeric,
+        sqltypes.Float : MSFloat,
+        sqltypes.DateTime : MSDateTime,
+        sqltypes.Date : MSDate,
+        sqltypes.String : MSString,
+        sqltypes.Unicode : MSUnicode,
+        sqltypes.Binary : MSBinary,
+        sqltypes.Boolean : MSBoolean,
+        sqltypes.TEXT : MSText,
+        sqltypes.CHAR: MSChar,
+        sqltypes.NCHAR: MSNChar,
+    }
+
+    ischema_names = {
+        'int' : MSInteger,
+        'smallint' : MSSmallInteger,
+        'tinyint' : MSTinyInteger,
+        'varchar' : MSString,
+        'nvarchar' : MSUnicode,
+        'char' : MSChar,
+        'nchar' : MSNChar,
+        'text' : MSText,
+        'ntext' : MSText,
+        'decimal' : MSNumeric,
+        'numeric' : MSNumeric,
+        'float' : MSFloat,
+        'datetime' : MSDateTime,
+        'smalldatetime' : MSDate,
+        'binary' : MSBinary,
+        'bit': MSBoolean,
+        'real' : MSFloat,
+        'image' : MSBinary
+    }
+
+    def __new__(cls, module_name=None, *args, **kwargs):
+        if cls != MSSQLDialect:
+            return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
+        if module_name:
+            dialect = dialect_mapping.get(module_name)
+            if not dialect:
+                raise exceptions.InvalidRequestError('Unsupported MSSQL module requested (must be adodbpi, pymssql or pyodbc): ' + module_name)
+            if not hasattr(dialect, 'module'):
+                raise dialect.saved_import_error
+            return dialect(*args, **kwargs)
+        else:
+            for dialect in dialect_preference:
+                if hasattr(dialect, 'module'):
+                    return dialect(*args, **kwargs)
+            raise ImportError('No DBAPI module detected for MSSQL - please install adodbapi, pymssql or pyodbc')
+
+    def __init__(self, module_name=None, auto_identity_insert=True, **params):
+        super(MSSQLDialect, self).__init__(**params)
         self.auto_identity_insert = auto_identity_insert
         self.text_as_varchar = False
-        ansisql.ANSIDialect.__init__(self, **params)
         self.set_default_schema_name("dbo")
         
     def create_connect_args(self, url):
@@ -422,13 +342,13 @@ class MSSQLDialect(ansisql.ANSIDialect):
             self.query_timeout = int(opts.pop('query_timeout'))
         if opts.has_key('text_as_varchar'):
             self.text_as_varchar = bool(opts.pop('text_as_varchar'))
-        return make_connect_string(opts)
+        return self.make_connect_string(opts)
 
     def create_execution_context(self):
         return MSSQLExecutionContext(self)
 
     def type_descriptor(self, typeobj):
-        newobj = sqltypes.adapt_type(typeobj, colspecs)
+        newobj = sqltypes.adapt_type(typeobj, self.colspecs)
         # Some types need to know about the dialect
         if isinstance(newobj, (MSText, MSNVarchar)):
             newobj.dialect = self
@@ -437,8 +357,9 @@ class MSSQLDialect(ansisql.ANSIDialect):
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
+    # this is only implemented in the dbapi-specific subclasses
     def supports_sane_rowcount(self):
-        return sane_rowcount
+        raise NotImplementedError()
 
     def compiler(self, statement, bindparams, **kwargs):
         return MSSQLCompiler(self, statement, bindparams, **kwargs)
@@ -556,7 +477,7 @@ class MSSQLDialect(ansisql.ANSIDialect):
             for a in (charlen, numericprec, numericscale):
                 if a is not None:
                     args.append(a)
-            coltype = ischema_names[type]
+            coltype = self.ischema_names[type]
             if coltype == MSString and charlen == -1:
                 coltype = MSText()                
             else:
@@ -628,7 +549,17 @@ class MSSQLDialect(ansisql.ANSIDialect):
         if fknm and scols:
             table.append_constraint(schema.ForeignKeyConstraint(scols, ['%s.%s' % (t,c) for (s,t,c) in rcols], fknm))
 
-class PyMSSQLDialect(MSSQLDialect):
+class MSSQLDialect_pymssql(MSSQLDialect):
+    try:
+        import pymssql as module
+        # pymmsql doesn't have a Binary method.  we use string
+        module.Binary = lambda st: str(st)
+    except ImportError, e:
+        saved_import_error = e
+
+    def supports_sane_rowcount(self):
+        return True
+
     def do_rollback(self, connection):
         # pymssql throws an error on repeated rollbacks. Ignore it.
         try:
@@ -637,11 +568,19 @@ class PyMSSQLDialect(MSSQLDialect):
             pass
 
     def create_connect_args(self, url):
-        r = super(PyMSSQLDialect, self).create_connect_args(url)
+        r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
         if hasattr(self, 'query_timeout'):
-            dbmodule._mssql.set_query_timeout(self.query_timeout)
+            self.module._mssql.set_query_timeout(self.query_timeout)
         return r
 
+    def make_connect_string(self, keys):
+        if keys.get('port'):
+            # pymssql expects port as host:port, not a separate arg
+            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
+            del keys['port']
+        return [[], keys]
+
+
 ##    This code is leftover from the initial implementation, for reference
 ##    def do_begin(self, connection):
 ##        """implementations might want to put logic here for turning autocommit on/off, etc."""
@@ -673,6 +612,68 @@ class PyMSSQLDialect(MSSQLDialect):
 ##        r.query("begin tran")
 ##        r.fetch_array()
 
+class MSSQLDialect_pyodbc(MSSQLDialect):
+    try:
+        import pyodbc as module
+    except ImportError, e:
+        saved_import_error = e
+
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.Unicode] = AdoMSUnicode
+    ischema_names = MSSQLDialect.ischema_names.copy()
+    ischema_names['nvarchar'] = AdoMSUnicode
+
+    def supports_sane_rowcount(self):
+        return False
+
+    def make_connect_string(self, keys):
+        connectors = ["Driver={SQL Server}"]
+        connectors.append("Server=%s" % keys.get("host"))
+        connectors.append("Database=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("UID=%s" % user)
+            connectors.append("PWD=%s" % keys.get("password", ""))
+        else:
+            connectors.append ("TrustedConnection=Yes")
+        return [[";".join (connectors)], {}]
+
+
+class MSSQLDialect_adodbapi(MSSQLDialect):
+    try:
+        import adodbapi as module
+    except ImportError, e:
+        saved_import_error = e
+
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.Unicode] = AdoMSUnicode
+    ischema_names = MSSQLDialect.ischema_names.copy()
+    ischema_names['nvarchar'] = AdoMSUnicode
+
+    def supports_sane_rowcount(self):
+        return True
+
+    def make_connect_string(self, keys):
+        connectors = ["Provider=SQLOLEDB"]
+        connectors.append ("Data Source=%s" % keys.get("host"))
+        connectors.append ("Initial Catalog=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("User Id=%s" % user)
+            connectors.append("Password=%s" % keys.get("password", ""))
+        else:
+            connectors.append("Integrated Security=SSPI")
+        return [[";".join (connectors)], {}]
+
+
+dialect_mapping = {
+    'pymssql':  MSSQLDialect_pymssql,
+    'pyodbc':   MSSQLDialect_pyodbc,
+    'adodbapi': MSSQLDialect_adodbapi
+    }
+dialect_preference = [MSSQLDialect_adodbapi, MSSQLDialect_pymssql, MSSQLDialect_pyodbc]
+
+
 class MSSQLCompiler(ansisql.ANSICompiler):
     def __init__(self, dialect, statement, parameters, **kwargs):
         super(MSSQLCompiler, self).__init__(dialect, statement, parameters, **kwargs)
@@ -781,7 +782,5 @@ class MSSQLIdentifierPreparer(ansisql.ANSIIdentifierPreparer):
         #TODO: determin MSSQL's case folding rules
         return value
 
-use_default()
-
-
+dialect = MSSQLDialect