]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- mssql dialects are in place, not fully tested
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Jan 2009 02:35:49 +0000 (02:35 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 20 Jan 2009 02:35:49 +0000 (02:35 +0000)
lib/sqlalchemy/dialects/mssql/__init__.py
lib/sqlalchemy/dialects/mssql/adodbapi.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/pymssql.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/pyodbc.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/types.py
test/dialect/mssql.py
test/sql/testtypes.py

index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..a5fabbade17a29be8b5ec6bff38061ea9aa07c87 100644 (file)
@@ -0,0 +1,3 @@
+from sqlalchemy.dialects.mssql import base, pyodbc
+
+base.dialect = pyodbc.dialect
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py
new file mode 100644 (file)
index 0000000..9a6cc27
--- /dev/null
@@ -0,0 +1,50 @@
+from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect
+import sys
+
+class MSDateTime_adodbapi(MSDateTime):
+    def result_processor(self, dialect):
+        def process(value):
+            # adodbapi will return datetimes with empty time values as datetime.date() objects.
+            # Promote them back to full datetime.datetime()
+            if type(value) is datetime.date:
+                return datetime.datetime(value.year, value.month, value.day)
+            return value
+        return process
+
+
+class MSDialect_adodbapi(MSDialect):
+    supports_sane_rowcount = True
+    supports_sane_multi_rowcount = True
+    supports_unicode = sys.maxunicode == 65535
+    supports_unicode_statements = True
+    driver = 'adodbapi'
+    
+    @classmethod
+    def import_dbapi(cls):
+        import adodbapi as module
+        return module
+
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
+
+    def create_connect_args(self, url):
+        keys = url.query
+
+        connectors = ["Provider=SQLOLEDB"]
+        if 'port' in keys:
+            connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
+        else:
+            connectors.append ("Data Source=%s" % keys.get("host"))
+        connectors.append ("Initial Catalog=%s" % keys.get("database"))
+        user = keys.get("user")
+        if user:
+            connectors.append("User Id=%s" % user)
+            connectors.append("Password=%s" % keys.get("password", ""))
+        else:
+            connectors.append("Integrated Security=SSPI")
+        return [[";".join (connectors)], {}]
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
+
+dialect = MSDialect_adodbapi
\ No newline at end of file
index c972b6b0cab6aa6159d3d9f2bcfe27a30d0078f6..0c11adb4ab44dd4f499e7d47c28d5bb3497a12ef 100644 (file)
@@ -19,7 +19,7 @@ Drivers are loaded in the order listed above based on availability.
 If you need to load a specific driver pass ``module_name`` when
 creating the engine::
 
-    engine = create_engine('mssql://dsn', module_name='pymssql')
+    engine = create_engine('mssql+module_name://dsn')
 
 ``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
 ``adodbapi``.
@@ -39,18 +39,18 @@ present, then the host token is taken directly as the DSN name.
 
 Examples of pyodbc connection string URLs:
 
-* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
+* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``.
   The connection string that is created will appear like::
 
     dsn=mydsn;TrustedConnection=Yes
 
-* *mssql://user:pass@mydsn* - connects using the DSN named
+* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named
   ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
   connection string that is created will appear like::
 
     dsn=mydsn;UID=user;PWD=pass
 
-* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
+* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects
   using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
   information, plus the additional connection configuration option
   ``LANGUAGE``. The connection string that is created will appear
@@ -58,12 +58,12 @@ Examples of pyodbc connection string URLs:
 
     dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
 
-* *mssql://user:pass@host/db* - connects using a connection string
+* *mssql+pyodbc://user:pass@host/db* - connects using a connection string
   dynamically created that would appear like::
 
     DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
 
-* *mssql://user:pass@host:123/db* - connects using a connection
+* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection
   string that is dynamically created, which also includes the port
   information using the comma syntax. If your connection string
   requires the port information to be passed as a ``port`` keyword
@@ -72,7 +72,7 @@ Examples of pyodbc connection string URLs:
 
     DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
 
-* *mssql://user:pass@host/db?port=123* - connects using a connection
+* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection
   string that is dynamically created that includes the port
   information as a separate ``port`` keyword. This will create the
   following connection string::
@@ -86,7 +86,7 @@ and passed directly.
 
 For example::
 
-    mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+    mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
 
 would create the following connection string::
 
@@ -110,11 +110,6 @@ arguments on the URL, or as keyword argument to
 * *query_timeout* - allows you to override the default query timeout.
   Defaults to ``None``. This is only supported on pymssql.
 
-* *text_as_varchar* - if enabled this will treat all TEXT column
-  types as their equivalent VARCHAR(max) type. This is often used if
-  you need to compare a VARCHAR to a TEXT field, which is not
-  supported directly on MSSQL. Defaults to ``False``.
-
 * *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
   should be used in place of the non-scoped version @@IDENTITY.
   Defaults to ``False``. On pymssql this defaults to ``True``, and on
@@ -252,66 +247,6 @@ from decimal import Decimal as _python_Decimal
 MSSQL_RESERVED_WORDS = set(['function'])
 
 
-class _StringType(object):
-    """Base for MSSQL string types."""
-
-    def __init__(self, collation=None, **kwargs):
-        self.collation = kwargs.get('collate', collation)
-
-    def _extend(self, spec):
-        """Extend a string-type declaration with standard SQL
-        COLLATE annotations.
-        """
-
-        if self.collation:
-            collation = 'COLLATE %s' % self.collation
-        else:
-            collation = None
-
-        return ' '.join([c for c in (spec, collation)
-                         if c is not None])
-
-    def __repr__(self):
-        attributes = inspect.getargspec(self.__init__)[0][1:]
-        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
-        params = {}
-        for attr in attributes:
-            val = getattr(self, attr)
-            if val is not None and val is not False:
-                params[attr] = val
-
-        return "%s(%s)" % (self.__class__.__name__,
-                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
-    def bind_processor(self, dialect):
-        if self.convert_unicode or dialect.convert_unicode:
-            if self.assert_unicode is None:
-                assert_unicode = dialect.assert_unicode
-            else:
-                assert_unicode = self.assert_unicode
-
-            if not assert_unicode:
-                return None
-
-            def process(value):
-                if not isinstance(value, (unicode, sqltypes.NoneType)):
-                    if assert_unicode == 'warn':
-                        util.warn("Unicode type received non-unicode bind "
-                                  "param value %r" % value)
-                        return value
-                    else:
-                        raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
-                else:
-                    return value
-            return process
-        else:
-            return None
-
-    def result_processor(self, dialect):
-        return None
-
-
 class MSNumeric(sqltypes.Numeric):
     def result_processor(self, dialect):
         if self.asdecimal:
@@ -345,121 +280,49 @@ class MSNumeric(sqltypes.Numeric):
 
         return process
 
-    def get_col_spec(self):
-        if self.precision is None:
-            return "NUMERIC"
-        else:
-            return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-
-class MSFloat(sqltypes.Float):
-    def get_col_spec(self):
-        if self.precision is None:
-            return "FLOAT"
-        else:
-            return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class MSReal(MSFloat):
+class MSReal(sqltypes.Float):
     """A type for ``real`` numbers."""
 
-    def __init__(self):
-        """
-        Construct a Real.
+    __visit_name__ = 'REAL'
 
-        """
+    def __init__(self):
         super(MSReal, self).__init__(precision=24)
 
-    def adapt(self, impltype):
-        return impltype()
-
-    def get_col_spec(self):
-        return "REAL"
-
-
-class MSInteger(sqltypes.Integer):
-    def get_col_spec(self):
-        return "INTEGER"
-
-
-class MSBigInteger(MSInteger):
-    def get_col_spec(self):
-        return "BIGINT"
-
-
-class MSTinyInteger(MSInteger):
-    def get_col_spec(self):
-        return "TINYINT"
-
-
-class MSSmallInteger(MSInteger):
-    def get_col_spec(self):
-        return "SMALLINT"
+class MSTinyInteger(sqltypes.Integer):
+    __visit_name__ = 'TINYINT'
 
+class MSTime(sqltypes.Time):
+    def __init__(self, precision=None, **kwargs):
+        self.precision = precision
+        super(MSTime, self).__init__()
 
-class _DateTimeType(object):
-    """Base for MSSQL datetime types."""
 
+class MSDateTime(sqltypes.DateTime):
     def bind_processor(self, dialect):
-        # if we receive just a date we can manipulate it
-        # into a datetime since the db-api may not do this.
+        # most DBAPIs allow a datetime.date object
+        # as a datetime.
         def process(value):
             if type(value) is datetime.date:
                 return datetime.datetime(value.year, value.month, value.day)
             return value
         return process
+    
+class MSSmallDateTime(MSDateTime):
+    __visit_name__ = 'SMALLDATETIME'
 
-
-class MSDateTime(_DateTimeType, sqltypes.DateTime):
-    def get_col_spec(self):
-        return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
-    def get_col_spec(self):
-        return "DATE"
-
-
-class MSTime(sqltypes.Time):
-    def __init__(self, precision=None, **kwargs):
-        self.precision = precision
-        super(MSTime, self).__init__()
-
-    def get_col_spec(self):
-        if self.precision:
-            return "TIME(%s)" % self.precision
-        else:
-            return "TIME"
-
-
-class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "SMALLDATETIME"
-
-
-class MSDateTime2(_DateTimeType, sqltypes.TypeEngine):
+class MSDateTime2(MSDateTime):
+    __visit_name__ = 'DATETIME2'
+    
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
 
-    def get_col_spec(self):
-        if self.precision:
-            return "DATETIME2(%s)" % self.precision
-        else:
-            return "DATETIME2"
-
-
-class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine):
+class MSDateTimeOffset(sqltypes.TypeEngine):
+    __visit_name__ = 'DATETIMEOFFSET'
+    
     def __init__(self, precision=None, **kwargs):
         self.precision = precision
 
-    def get_col_spec(self):
-        if self.precision:
-            return "DATETIMEOFFSET(%s)" % self.precision
-        else:
-            return "DATETIMEOFFSET"
-
-
-class MSDateTimeAsDate(_DateTimeType, MSDate):
+class MSDateTimeAsDate(sqltypes.TypeDecorator):
     """ This is an implementation of the Date type for versions of MSSQL that
     do not support that specific type. In order to make it work a ``DATETIME``
     column specification is used and the results get converted back to just
@@ -467,20 +330,19 @@ class MSDateTimeAsDate(_DateTimeType, MSDate):
 
     """
 
-    def get_col_spec(self):
-        return "DATETIME"
+    impl = sqltypes.DateTime
 
-    def result_processor(self, dialect):
-        def process(value):
-            # If the DBAPI returns the value as datetime.datetime(), truncate
-            # it back to datetime.date()
-            if type(value) is datetime.datetime:
-                return value.date()
-            return value
-        return process
+    def process_bind_param(self, value, dialect):
+        if type(value) is datetime.date:
+            return datetime.datetime(value.year, value.month, value.day)
+        return value
 
+    def process_result_value(self, value, dialect):
+        if type(value) is datetime.datetime:
+            return value.date()
+        return value
 
-class MSDateTimeAsTime(MSTime):
+class MSDateTimeAsTime(sqltypes.TypeDecorator):
     """ This is an implementation of the Time type for versions of MSSQL that
     do not support that specific type. In order to make it work a ``DATETIME``
     column specification is used and the results get converted back to just
@@ -490,65 +352,63 @@ class MSDateTimeAsTime(MSTime):
 
     __zero_date = datetime.date(1900, 1, 1)
 
-    def get_col_spec(self):
-        return "DATETIME"
+    impl = sqltypes.DateTime
 
-    def bind_processor(self, dialect):
-        def process(value):
-            if type(value) is datetime.datetime:
-                value = datetime.datetime.combine(self.__zero_date, value.time())
-            elif type(value) is datetime.time:
-                value = datetime.datetime.combine(self.__zero_date, value)
-            return value
-        return process
+    def process_bind_param(self, value, dialect):
+        if type(value) is datetime.datetime:
+            value = datetime.datetime.combine(self.__zero_date, value.time())
+        elif type(value) is datetime.time:
+            value = datetime.datetime.combine(self.__zero_date, value)
+        return value
 
-    def result_processor(self, dialect):
-        def process(value):
-            if type(value) is datetime.datetime:
-                return value.time()
-            elif type(value) is datetime.date:
-                return datetime.time(0, 0, 0)
-            return value
-        return process
+    def process_result_value(self, value, dialect):
+        if type(value) is datetime.datetime:
+            return value.time()
+        elif type(value) is datetime.date:
+            return datetime.time(0, 0, 0)
+        return value
 
 
-class MSDateTime_adodbapi(MSDateTime):
-    def result_processor(self, dialect):
-        def process(value):
-            # adodbapi will return datetimes with empty time values as datetime.date() objects.
-            # Promote them back to full datetime.datetime()
-            if type(value) is datetime.date:
-                return datetime.datetime(value.year, value.month, value.day)
-            return value
-        return process
+class _StringType(object):
+    """Base for MSSQL string types."""
+
+    def __init__(self, collation=None):
+        self.collation = collation
+
+    def __repr__(self):
+        attributes = inspect.getargspec(self.__init__)[0][1:]
+        attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
+
+        params = {}
+        for attr in attributes:
+            val = getattr(self, attr)
+            if val is not None and val is not False:
+                params[attr] = val
+
+        return "%s(%s)" % (self.__class__.__name__,
+                           ', '.join(['%s=%r' % (k, params[k]) for k in params]))
 
 
-class MSText(_StringType, sqltypes.Text):
+class MSText(_StringType, sqltypes.TEXT):
     """MSSQL TEXT type, for variable-length text up to 2^31 characters."""
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, *args, **kw):
         """Construct a TEXT.
 
         :param collation: Optional, a column-level collation for this string
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.Text.__init__(self, None,
-                convert_unicode=kwargs.get('convert_unicode', False),
-                assert_unicode=kwargs.get('assert_unicode', None))
-
-    def get_col_spec(self):
-        if self.dialect.text_as_varchar:
-            return self._extend("VARCHAR(max)")
-        else:
-            return self._extend("TEXT")
-
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.Text.__init__(self, *args, **kw)
 
 class MSNText(_StringType, sqltypes.UnicodeText):
     """MSSQL NTEXT type, for variable-length unicode text up to 2^30
     characters."""
 
+    __visit_name__ = 'NTEXT'
+    
     def __init__(self, *args, **kwargs):
         """Construct a NTEXT.
 
@@ -556,23 +416,16 @@ class MSNText(_StringType, sqltypes.UnicodeText):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.UnicodeText.__init__(self, None,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def get_col_spec(self):
-        if self.dialect.text_as_varchar:
-            return self._extend("NVARCHAR(max)")
-        else:
-            return self._extend("NTEXT")
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.UnicodeText.__init__(self, None, **kw)
 
 
-class MSString(_StringType, sqltypes.String):
+class MSString(_StringType, sqltypes.VARCHAR):
     """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
     of 8,000 characters."""
 
-    def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
+    def __init__(self, *args, **kw):
         """Construct a VARCHAR.
 
         :param length: Optinal, maximum data length, in characters.
@@ -603,24 +456,16 @@ class MSString(_StringType, sqltypes.String):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.String.__init__(self, length=length,
-                convert_unicode=convert_unicode,
-                assert_unicode=assert_unicode)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("VARCHAR(%s)" % self.length)
-        else:
-            return self._extend("VARCHAR")
-
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.VARCHAR.__init__(self, *args, **kw)
 
-class MSNVarchar(_StringType, sqltypes.Unicode):
+class MSNVarchar(_StringType, sqltypes.NVARCHAR):
     """MSSQL NVARCHAR type.
 
     For variable-length unicode character data up to 4,000 characters."""
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, *args, **kw):
         """Construct a NVARCHAR.
 
         :param length: Optional, Maximum data length, in characters.
@@ -629,29 +474,16 @@ class MSNVarchar(_StringType, sqltypes.Unicode):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.Unicode.__init__(self, length=length,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def adapt(self, impltype):
-        return impltype(length=self.length,
-                        convert_unicode=self.convert_unicode,
-                        assert_unicode=self.assert_unicode,
-                        collation=self.collation)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length})
-        else:
-            return self._extend("NVARCHAR")
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.NVARCHAR.__init__(self, *args, **kw)
 
 
 class MSChar(_StringType, sqltypes.CHAR):
     """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
     of 8,000 characters."""
 
-    def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
+    def __init__(self, *args, **kw):
         """Construct a CHAR.
 
         :param length: Optinal, maximum data length, in characters.
@@ -682,16 +514,9 @@ class MSChar(_StringType, sqltypes.CHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.CHAR.__init__(self, length=length,
-                convert_unicode=convert_unicode,
-                assert_unicode=assert_unicode)
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("CHAR(%s)" % self.length)
-        else:
-            return self._extend("CHAR")
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.CHAR.__init__(self, *args, **kw)
 
 
 class MSNChar(_StringType, sqltypes.NCHAR):
@@ -699,7 +524,7 @@ class MSNChar(_StringType, sqltypes.NCHAR):
 
     For fixed-length unicode character data up to 4,000 characters."""
 
-    def __init__(self, length=None, **kwargs):
+    def __init__(self, *args, **kw):
         """Construct an NCHAR.
 
         :param length: Optional, Maximum data length, in characters.
@@ -708,59 +533,23 @@ class MSNChar(_StringType, sqltypes.NCHAR):
           value. Accepts a Windows Collation Name or a SQL Collation Name.
 
         """
-        _StringType.__init__(self, **kwargs)
-        sqltypes.NCHAR.__init__(self, length=length,
-                convert_unicode=kwargs.get('convert_unicode', True),
-                assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
-    def get_col_spec(self):
-        if self.length:
-            return self._extend("NCHAR(%(length)s)" % {'length' : self.length})
-        else:
-            return self._extend("NCHAR")
-
-
-class MSGenericBinary(sqltypes.Binary):
-    """The Binary type assumes that a Binary specification without a length
-    is an unbound Binary type whereas one with a length specification results
-    in a fixed length Binary type.
-
-    If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type.
-
-    """
-
-    def get_col_spec(self):
-        if self.length:
-            return "BINARY(%s)" % self.length
-        else:
-            return "IMAGE"
-
-
-class MSBinary(MSGenericBinary):
-    def get_col_spec(self):
-        if self.length:
-            return "BINARY(%s)" % self.length
-        else:
-            return "BINARY"
+        collation = kw.pop('collation', None)
+        _StringType.__init__(self, collation)
+        sqltypes.NCHAR.__init__(self, *args, **kw)
 
+class MSBinary(sqltypes.Binary):
+    pass
 
-class MSVarBinary(MSGenericBinary):
-    def get_col_spec(self):
-        if self.length:
-            return "VARBINARY(%s)" % self.length
-        else:
-            return "VARBINARY"
-
-
-class MSImage(MSGenericBinary):
-    def get_col_spec(self):
-        return "IMAGE"
+class MSVarBinary(sqltypes.Binary):
+    __visit_name__ = 'VARBINARY'
 
+class MSImage(sqltypes.Binary):
+    __visit_name__ = 'IMAGE'
 
+class MSBit(sqltypes.TypeEngine):
+    __visit_name__ = 'BIT'
+    
 class MSBoolean(sqltypes.Boolean):
-    def get_col_spec(self):
-        return "BIT"
-
     def result_processor(self, dialect):
         def process(value):
             if value is None:
@@ -780,31 +569,129 @@ class MSBoolean(sqltypes.Boolean):
                 return value and True or False
         return process
 
+class MSMoney(sqltypes.TypeEngine):
+    __visit_name__ = 'MONEY'
+
+class MSSmallMoney(MSMoney):
+    __visit_name__ = 'SMALLMONEY'
 
-class MSTimeStamp(sqltypes.TIMESTAMP):
-    def get_col_spec(self):
-        return "TIMESTAMP"
 
+class MSUniqueIdentifier(sqltypes.TypeEngine):
+    __visit_name__ = "UNIQUEIDENTIFIER"
 
-class MSMoney(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "MONEY"
+class MSVariant(sqltypes.TypeEngine):
+    __visit_name__ = 'SQL_VARIANT'
 
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+    def _extend(self, spec, type_):
+        """Extend a string-type declaration with standard SQL
+        COLLATE annotations.
 
-class MSSmallMoney(MSMoney):
-    def get_col_spec(self):
-        return "SMALLMONEY"
+        """
 
+        if getattr(type_, 'collation', None):
+            collation = 'COLLATE %s' % type_.collation
+        else:
+            collation = None
 
-class MSUniqueIdentifier(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "UNIQUEIDENTIFIER"
+        if type_.length:
+            spec = spec + "(%d)" % type_.length
+        
+        return ' '.join([c for c in (spec, collation)
+            if c is not None])
 
+    def visit_FLOAT(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision is None:
+            return "FLOAT"
+        else:
+            return "FLOAT(%(precision)s)" % {'precision': precision}
 
-class MSVariant(sqltypes.TypeEngine):
-    def get_col_spec(self):
-        return "SQL_VARIANT"
+    def visit_REAL(self, type_):
+        return "REAL"
+
+    def visit_TINYINT(self, type_):
+        return "TINYINT"
+
+    def visit_DATETIMEOFFSET(self, type_):
+        if type_.precision:
+            return "DATETIMEOFFSET(%s)" % type_.precision
+        else:
+            return "DATETIMEOFFSET"
+
+    def visit_TIME(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision:
+            return "TIME(%s)" % precision
+        else:
+            return "TIME"
+
+    def visit_DATETIME2(self, type_):
+        precision = getattr(type_, 'precision', None)
+        if precision:
+            return "DATETIME2(%s)" % precision
+        else:
+            return "DATETIME2"
+
+    def visit_SMALLDATETIME(self, type_):
+        return "SMALLDATETIME"
+
+    def visit_NTEXT(self, type_):
+        return self._extend("NTEXT", type_)
+
+    def visit_TEXT(self, type_):
+        return self._extend("TEXT", type_)
+
+    def visit_VARCHAR(self, type_):
+        return self._extend("VARCHAR", type_)
+
+    def visit_CHAR(self, type_):
+        return self._extend("CHAR", type_)
+
+    def visit_NCHAR(self, type_):
+        return self._extend("NCHAR", type_)
+
+    def visit_NVARCHAR(self, type_):
+        return self._extend("NVARCHAR", type_)
+
+    def visit_binary(self, type_):
+        if type_.length:
+            return self.visit_BINARY(type_)
+        else:
+            return self.visit_IMAGE(type_)
+
+    def visit_BINARY(self, type_):
+        if type_.length:
+            return "BINARY(%s)" % type_.length
+        else:
+            return "BINARY"
+
+    def visit_IMAGE(self, type_):
+        return "IMAGE"
+
+    def visit_VARBINARY(self, type_):
+        if type_.length:
+            return "VARBINARY(%s)" % type_.length
+        else:
+            return "VARBINARY"
+
+    def visit_boolean(self, type_):
+        return self.visit_BIT(type_)
+
+    def visit_BIT(self, type_):
+        return "BIT"
+
+    def visit_MONEY(self, type_):
+        return "MONEY"
+
+    def visit_SMALLMONEY(self, type_):
+        return 'SMALLMONEY'
+
+    def visit_UNIQUEIDENTIFIER(self, type_):
+        return "UNIQUEIDENTIFIER"
 
+    def visit_SQL_VARIANT(self, type_):
+        return 'SQL_VARIANT'
 
 def _has_implicit_sequence(column):
     return column.primary_key and  \
@@ -827,7 +714,7 @@ def _table_sequence_column(tbl):
                 break
     return tbl._ms_has_sequence
 
-class MSSQLExecutionContext(default.DefaultExecutionContext):
+class MSExecutionContext(default.DefaultExecutionContext):
     IINSERT = False
     HASIDENT = False
 
@@ -869,136 +756,306 @@ class MSSQLExecutionContext(default.DefaultExecutionContext):
         if self.IINSERT:
             self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
 
+colspecs = {
+    sqltypes.Unicode : MSNVarchar,
+    sqltypes.Numeric : MSNumeric,
+    sqltypes.DateTime : MSDateTime,
+    sqltypes.Time : MSTime,
+    sqltypes.String : MSString,
+    sqltypes.Boolean : MSBoolean,
+    sqltypes.Text : MSText,
+    sqltypes.UnicodeText : MSNText,
+    sqltypes.CHAR: MSChar,
+    sqltypes.NCHAR: MSNChar,
+}
+
+ischema_names = {
+    'int' : sqltypes.INTEGER,
+    'bigint': sqltypes.BigInteger,
+    'smallint' : sqltypes.SmallInteger,
+    'tinyint' : MSTinyInteger,
+    'varchar' : MSString,
+    'nvarchar' : MSNVarchar,
+    'char' : MSChar,
+    'nchar' : MSNChar,
+    'text' : MSText,
+    'ntext' : MSNText,
+    'decimal' : sqltypes.DECIMAL,
+    'numeric' : sqltypes.NUMERIC,
+    'float' : sqltypes.FLOAT,
+    'datetime' : sqltypes.DateTime,
+    'datetime2' : MSDateTime2,
+    'datetimeoffset' : MSDateTimeOffset,
+    'date': sqltypes.Date,
+    'time': MSTime,
+    'smalldatetime' : MSSmallDateTime,
+    'binary' : MSBinary,
+    'varbinary' : MSVarBinary,
+    'bit': sqltypes.Boolean,
+    'real' : MSReal,
+    'image' : MSImage,
+    'timestamp': sqltypes.TIMESTAMP,
+    'money': MSMoney,
+    'smallmoney': MSSmallMoney,
+    'uniqueidentifier': MSUniqueIdentifier,
+    'sql_variant': MSVariant,
+}
 
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
-    def pre_exec(self):
-        """where appropriate, issue "select scope_identity()" in the same statement"""
-        super(MSSQLExecutionContext_pyodbc, self).pre_exec()
-        if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
-                and len(self.parameters) == 1 and self.dialect.use_scope_identity:
-            self.statement += "; select scope_identity()"
-
-    def post_exec(self):
-        if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
-            import pyodbc
-            # Fetch the last inserted id from the manipulated statement
-            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
-            while True:
-                try:
-                    row = self.cursor.fetchone()
-                    break
-                except pyodbc.Error, e:
-                    self.cursor.nextset()
-            self._last_inserted_ids = [int(row[0])]
-        else:
-            super(MSSQLExecutionContext_pyodbc, self).post_exec()
+class MSSQLCompiler(compiler.SQLCompiler):
+    operators = compiler.OPERATORS.copy()
+    operators.update({
+        sql_operators.concat_op: '+',
+        sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
+    })
 
-class MSSQLDialect(default.DefaultDialect):
-    name = 'mssql'
-    supports_default_values = True
-    supports_empty_insert = False
-    auto_identity_insert = True
-    execution_ctx_cls = MSSQLExecutionContext
-    text_as_varchar = False
-    use_scope_identity = False
-    has_window_funcs = False
-    max_identifier_length = 128
-    schema_name = "dbo"
+    functions = compiler.SQLCompiler.functions.copy()
+    functions.update (
+        {
+            sql_functions.now: 'CURRENT_TIMESTAMP',
+            sql_functions.current_date: 'GETDATE()',
+            'length': lambda x: "LEN(%s)" % x,
+            sql_functions.char_length: lambda x: "LEN(%s)" % x
+        }
+    )
 
-    colspecs = {
-        sqltypes.Unicode : MSNVarchar,
-        sqltypes.Integer : MSInteger,
-        sqltypes.SmallInteger: MSSmallInteger,
-        sqltypes.Numeric : MSNumeric,
-        sqltypes.Float : MSFloat,
-        sqltypes.DateTime : MSDateTime,
-        sqltypes.Date : MSDate,
-        sqltypes.Time : MSTime,
-        sqltypes.String : MSString,
-        sqltypes.Binary : MSGenericBinary,
-        sqltypes.Boolean : MSBoolean,
-        sqltypes.Text : MSText,
-        sqltypes.UnicodeText : MSNText,
-        sqltypes.CHAR: MSChar,
-        sqltypes.NCHAR: MSNChar,
-        sqltypes.TIMESTAMP: MSTimeStamp,
-    }
-
-    ischema_names = {
-        'int' : MSInteger,
-        'bigint': MSBigInteger,
-        'smallint' : MSSmallInteger,
-        'tinyint' : MSTinyInteger,
-        'varchar' : MSString,
-        'nvarchar' : MSNVarchar,
-        'char' : MSChar,
-        'nchar' : MSNChar,
-        'text' : MSText,
-        'ntext' : MSNText,
-        'decimal' : MSNumeric,
-        'numeric' : MSNumeric,
-        'float' : MSFloat,
-        'datetime' : MSDateTime,
-        'datetime2' : MSDateTime2,
-        'datetimeoffset' : MSDateTimeOffset,
-        'date': MSDate,
-        'time': MSTime,
-        'smalldatetime' : MSSmallDateTime,
-        'binary' : MSBinary,
-        'varbinary' : MSVarBinary,
-        'bit': MSBoolean,
-        'real' : MSFloat,
-        'image' : MSImage,
-        'timestamp': MSTimeStamp,
-        'money': MSMoney,
-        'smallmoney': MSSmallMoney,
-        'uniqueidentifier': MSUniqueIdentifier,
-        'sql_variant': MSVariant,
-    }
-
-    def __new__(cls, *args, **kwargs):
-        if cls is not MSSQLDialect:
-            # this gets called with the dialect specific class
-            return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
-        dbapi = kwargs.get('dbapi', None)
-        if dbapi:
-            dialect = dialect_mapping.get(dbapi.__name__)
-            return dialect(**kwargs)
-        else:
-            return object.__new__(cls, *args, **kwargs)
+    def __init__(self, *args, **kwargs):
+        super(MSSQLCompiler, self).__init__(*args, **kwargs)
+        self.tablealiases = {}
 
-    def __init__(self,
-                 auto_identity_insert=True, query_timeout=None,
-                 text_as_varchar=False, use_scope_identity=False,
-                 has_window_funcs=False, max_identifier_length=None,
-                 schema_name="dbo", **opts):
-        self.auto_identity_insert = bool(auto_identity_insert)
-        self.query_timeout = int(query_timeout or 0)
-        self.schema_name = schema_name
+    def get_select_precolumns(self, select):
+        """ MS-SQL puts TOP, it's version of LIMIT here """
+        if select._distinct or select._limit:
+            s = select._distinct and "DISTINCT " or ""
+            
+            if select._limit:
+                if not select._offset:
+                    s += "TOP %s " % (select._limit,)
+                else:
+                    if not self.dialect.has_window_funcs:
+                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+            return s
+        return compiler.SQLCompiler.get_select_precolumns(self, select)
 
-        # to-do: the options below should use server version introspection to set themselves on connection
-        self.text_as_varchar = bool(text_as_varchar)
-        self.use_scope_identity = bool(use_scope_identity)
-        self.has_window_funcs =  bool(has_window_funcs)
-        self.max_identifier_length = int(max_identifier_length or 0) or 128
-        super(MSSQLDialect, self).__init__(**opts)
+    def limit_clause(self, select):
+        # Limit in mssql is after the select keyword
+        return ""
 
-    @classmethod
-    def dbapi(cls, module_name=None):
-        if module_name:
-            try:
-                dialect_cls = dialect_mapping[module_name]
-                return dialect_cls.import_dbapi()
-            except KeyError:
-                raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+    def visit_select(self, select, **kwargs):
+        """Look for ``LIMIT`` and OFFSET in a select statement, and if
+        so tries to wrap it in a subquery with ``row_number()`` criterion.
+
+        """
+        if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
+            # to use ROW_NUMBER(), an ORDER BY is required.
+            orderby = self.process(select._order_by_clause)
+            if not orderby:
+                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
+
+            _offset = select._offset
+            _limit = select._limit
+            select._mssql_visit = True
+            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
+
+            limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
+            limitselect.append_whereclause("mssql_rn>%d" % _offset)
+            if _limit is not None:
+                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
+            return self.process(limitselect, iswrapper=True, **kwargs)
+        else:
+            return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+    def _schema_aliased_table(self, table):
+        if getattr(table, 'schema', None) is not None:
+            if table not in self.tablealiases:
+                self.tablealiases[table] = table.alias()
+            return self.tablealiases[table]
+        else:
+            return None
+
+    def visit_table(self, table, mssql_aliased=False, **kwargs):
+        if mssql_aliased:
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+        # alias schema-qualified tables
+        alias = self._schema_aliased_table(table)
+        if alias is not None:
+            return self.process(alias, mssql_aliased=True, **kwargs)
         else:
-            for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
-                try:
-                    return dialect_cls.import_dbapi()
-                except ImportError, e:
-                    pass
+            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+    def visit_alias(self, alias, **kwargs):
+        # translate for schema-qualified table aliases
+        self.tablealiases[alias.original] = alias
+        kwargs['mssql_aliased'] = True
+        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
+
+    def visit_savepoint(self, savepoint_stmt):
+        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+        return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+    def visit_rollback_to_savepoint(self, savepoint_stmt):
+        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+    def visit_column(self, column, result_map=None, **kwargs):
+        if column.table is not None and \
+            (not self.isupdate and not self.isdelete) or self.is_subquery():
+            # translate for schema-qualified table aliases
+            t = self._schema_aliased_table(column.table)
+            if t is not None:
+                converted = expression._corresponding_column_or_error(t, column)
+
+                if result_map is not None:
+                    result_map[column.name.lower()] = (column.name, (column, ), column.type)
+
+                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
+
+        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
+
+    def visit_binary(self, binary, **kwargs):
+        """Move bind parameters to the right-hand side of an operator, where
+        possible.
+
+        """
+        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
+            and not isinstance(binary.right, expression._BindParamClause):
+            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
+        else:
+            if (binary.operator is operator.eq or binary.operator is operator.ne) and (
+                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
+                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
+                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
+                op = binary.operator == operator.eq and "IN" or "NOT IN"
+                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
+            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+    def visit_insert(self, insert_stmt):
+        insert_select = False
+        if insert_stmt.parameters:
+            insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
+        if insert_select:
+            self.isinsert = True
+            colparams = self._get_colparams(insert_stmt)
+            preparer = self.preparer
+
+            insert = ' '.join(["INSERT"] +
+                              [self.process(x) for x in insert_stmt._prefixes])
+
+            if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
+                raise exc.CompileError(
+                    "The version of %s you are using does not support empty inserts." % self.dialect.name)
+            elif not colparams and self.dialect.supports_default_values:
+                return (insert + " INTO %s DEFAULT VALUES" % (
+                    (preparer.format_table(insert_stmt.table),)))
             else:
-                raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
+                return (insert + " INTO %s (%s) SELECT %s" %
+                    (preparer.format_table(insert_stmt.table),
+                     ', '.join([preparer.format_column(c[0])
+                               for c in colparams]),
+                     ', '.join([c[1] for c in colparams])))
+        else:
+            return super(MSSQLCompiler, self).visit_insert(insert_stmt)
+
+    def label_select_column(self, select, column, asfrom):
+        if isinstance(column, expression.Function):
+            return column.label(None)
+        else:
+            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
+
+    def for_update_clause(self, select):
+        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+        return ''
+
+    def order_by_clause(self, select):
+        order_by = self.process(select._order_by_clause)
+
+        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+        if order_by and (not self.is_subquery() or select._limit):
+            return " ORDER BY " + order_by
+        else:
+            return ""
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+    def get_column_specification(self, column, **kwargs):
+        colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+
+        if column.nullable is not None:
+            if not column.nullable or column.primary_key:
+                colspec += " NOT NULL"
+            else:
+                colspec += " NULL"
+        
+        if not column.table:
+            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+            
+        seq_col = _table_sequence_column(column.table)
+
+        # install a IDENTITY Sequence if we have an implicit IDENTITY column
+        if seq_col is column:
+            sequence = getattr(column, 'sequence', None)
+            if sequence:
+                start, increment = sequence.start or 1, sequence.increment or 1
+            else:
+                start, increment = 1, 1
+            colspec += " IDENTITY(%s,%s)" % (start, increment)
+        else:
+            default = self.get_column_default_string(column)
+            if default is not None:
+                colspec += " DEFAULT " + default
+
+        return colspec
+
+    def visit_drop_index(self, drop):
+        return "\nDROP INDEX %s.%s" % (
+            self.preparer.quote_identifier(drop.element.table.name),
+            self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote)
+            )
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+    reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
+
+    def __init__(self, dialect):
+        super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+
+    def _escape_identifier(self, value):
+        #TODO: determine MSSQL's escaping rules
+        return value
+
+class MSDialect(default.DefaultDialect):
+    name = 'mssql'
+    supports_default_values = True
+    supports_empty_insert = False
+    auto_identity_insert = True
+    execution_ctx_cls = MSExecutionContext
+    text_as_varchar = False
+    use_scope_identity = False
+    has_window_funcs = False
+    max_identifier_length = 128
+    schema_name = "dbo"
+    colspecs = colspecs
+    ischema_names = ischema_names
+
+    supports_unicode_binds = True
+
+    statement_compiler = MSSQLCompiler
+    ddl_compiler = MSDDLCompiler
+    type_compiler = MSTypeCompiler
+    preparer = MSIdentifierPreparer
+
+    def __init__(self,
+                 auto_identity_insert=True, query_timeout=None,
+                 use_scope_identity=False,
+                 has_window_funcs=False, max_identifier_length=None,
+                 schema_name="dbo", **opts):
+        self.auto_identity_insert = bool(auto_identity_insert)
+        self.query_timeout = int(query_timeout or 0)
+        self.schema_name = schema_name
+
+        self.use_scope_identity = bool(use_scope_identity)
+        self.has_window_funcs =  bool(has_window_funcs)
+        self.max_identifier_length = int(max_identifier_length or 0) or 128
+        super(MSDialect, self).__init__(**opts)
 
     @base.connection_memoize(('mssql', 'server_version_info'))
     def server_version_info(self, connection):
@@ -1018,28 +1075,6 @@ class MSSQLDialect(default.DefaultDialect):
         """Return a tuple of the database's version number."""
         raise NotImplementedError()
 
-    def create_connect_args(self, url):
-        opts = url.translate_connect_args(username='user')
-        opts.update(url.query)
-        if 'auto_identity_insert' in opts:
-            self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
-        if 'query_timeout' in opts:
-            self.query_timeout = int(opts.pop('query_timeout'))
-        if 'text_as_varchar' in opts:
-            self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
-        if 'use_scope_identity' in opts:
-            self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
-        if 'has_window_funcs' in opts:
-            self.has_window_funcs =  bool(int(opts.pop('has_window_funcs')))
-        return self.make_connect_string(opts, url.query)
-
-    def type_descriptor(self, typeobj):
-        newobj = sqltypes.adapt_type(typeobj, self.colspecs)
-        # Some types need to know about the dialect
-        if isinstance(newobj, (MSText, MSNText)):
-            newobj.dialect = self
-        return newobj
-
     def do_begin(self, connection):
         cursor = connection.cursor()
         cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
@@ -1248,414 +1283,3 @@ class MSSQLDialect(default.DefaultDialect):
         if fknm and scols:
             table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
 
-
-class MSSQLDialect_pymssql(MSSQLDialect):
-    supports_sane_rowcount = False
-    max_identifier_length = 30
-
-    @classmethod
-    def import_dbapi(cls):
-        import pymssql as module
-        # pymmsql doesn't have a Binary method.  we use string
-        # TODO: monkeypatching here is less than ideal
-        module.Binary = lambda st: str(st)
-        return module
-
-    def __init__(self, **params):
-        super(MSSQLDialect_pymssql, self).__init__(**params)
-        self.use_scope_identity = True
-
-        # pymssql understands only ascii
-        if self.convert_unicode:
-            util.warn("pymssql does not support unicode")
-            self.encoding = params.get('encoding', 'ascii')
-
-        self.colspecs = MSSQLDialect.colspecs.copy()
-        self.ischema_names = MSSQLDialect.ischema_names.copy()
-        self.ischema_names['date'] = MSDateTimeAsDate
-        self.colspecs[sqltypes.Date] = MSDateTimeAsDate
-        self.ischema_names['time'] = MSDateTimeAsTime
-        self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
-    def create_connect_args(self, url):
-        r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
-        if hasattr(self, 'query_timeout'):
-            self.dbapi._mssql.set_query_timeout(self.query_timeout)
-        return r
-
-    def make_connect_string(self, keys, query):
-        if keys.get('port'):
-            # pymssql expects port as host:port, not a separate arg
-            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
-            del keys['port']
-        return [[], keys]
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
-
-    def do_begin(self, connection):
-        pass
-
-
-class MSSQLDialect_pyodbc(MSSQLDialect):
-    supports_sane_rowcount = False
-    supports_sane_multi_rowcount = False
-    # PyODBC unicode is broken on UCS-4 builds
-    supports_unicode = sys.maxunicode == 65535
-    supports_unicode_statements = supports_unicode
-    execution_ctx_cls = MSSQLExecutionContext_pyodbc
-
-    def __init__(self, description_encoding='latin-1', **params):
-        super(MSSQLDialect_pyodbc, self).__init__(**params)
-        self.description_encoding = description_encoding
-
-        if self.server_version_info < (10,):
-            self.colspecs = MSSQLDialect.colspecs.copy()
-            self.ischema_names = MSSQLDialect.ischema_names.copy()
-            self.ischema_names['date'] = MSDateTimeAsDate
-            self.colspecs[sqltypes.Date] = MSDateTimeAsDate
-            self.ischema_names['time'] = MSDateTimeAsTime
-            self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
-        # FIXME: scope_identity sniff should look at server version, not the ODBC driver
-        # whether use_scope_identity will work depends on the version of pyodbc
-        try:
-            import pyodbc
-            self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
-        except:
-            pass
-
-    @classmethod
-    def import_dbapi(cls):
-        import pyodbc as module
-        return module
-
-    def make_connect_string(self, keys, query):
-        if 'max_identifier_length' in keys:
-            self.max_identifier_length = int(keys.pop('max_identifier_length'))
-
-        if 'odbc_connect' in keys:
-            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
-        else:
-            dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
-            if dsn_connection:
-                connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
-            else:
-                port = ''
-                if 'port' in keys and not 'port' in query:
-                    port = ',%d' % int(keys.pop('port'))
-
-                connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
-                              'Server=%s%s' % (keys.pop('host', ''), port),
-                              'Database=%s' % keys.pop('database', '') ]
-
-            user = keys.pop("user", None)
-            if user:
-                connectors.append("UID=%s" % user)
-                connectors.append("PWD=%s" % keys.pop('password', ''))
-            else:
-                connectors.append("TrustedConnection=Yes")
-
-            # if set to 'Yes', the ODBC layer will try to automagically convert 
-            # textual data from your database encoding to your client encoding 
-            # This should obviously be set to 'No' if you query a cp1253 encoded 
-            # database from a latin1 client... 
-            if 'odbc_autotranslate' in keys:
-                connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
-
-            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
-
-        return [[";".join (connectors)], {}]
-
-    def is_disconnect(self, e):
-        if isinstance(e, self.dbapi.ProgrammingError):
-            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
-        elif isinstance(e, self.dbapi.Error):
-            return '[08S01]' in str(e)
-        else:
-            return False
-
-
-    def _server_version_info(self, dbapi_con):
-        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
-        version = []
-        r = re.compile('[.\-]')
-        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
-            try:
-                version.append(int(n))
-            except ValueError:
-                version.append(n)
-        return tuple(version)
-
-class MSSQLDialect_adodbapi(MSSQLDialect):
-    supports_sane_rowcount = True
-    supports_sane_multi_rowcount = True
-    supports_unicode = sys.maxunicode == 65535
-    supports_unicode_statements = True
-
-    @classmethod
-    def import_dbapi(cls):
-        import adodbapi as module
-        return module
-
-    colspecs = MSSQLDialect.colspecs.copy()
-    colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
-
-    ischema_names = MSSQLDialect.ischema_names.copy()
-    ischema_names['datetime'] = MSDateTime_adodbapi
-
-    def make_connect_string(self, keys, query):
-        connectors = ["Provider=SQLOLEDB"]
-        if 'port' in keys:
-            connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
-        else:
-            connectors.append ("Data Source=%s" % keys.get("host"))
-        connectors.append ("Initial Catalog=%s" % keys.get("database"))
-        user = keys.get("user")
-        if user:
-            connectors.append("User Id=%s" % user)
-            connectors.append("Password=%s" % keys.get("password", ""))
-        else:
-            connectors.append("Integrated Security=SSPI")
-        return [[";".join (connectors)], {}]
-
-    def is_disconnect(self, e):
-        return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
-
-
-dialect_mapping = {
-    'pymssql':  MSSQLDialect_pymssql,
-    'pyodbc':   MSSQLDialect_pyodbc,
-    'adodbapi': MSSQLDialect_adodbapi
-    }
-
-
-class MSSQLCompiler(compiler.SQLCompiler):
-    operators = compiler.OPERATORS.copy()
-    operators.update({
-        sql_operators.concat_op: '+',
-        sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
-    })
-
-    functions = compiler.SQLCompiler.functions.copy()
-    functions.update (
-        {
-            sql_functions.now: 'CURRENT_TIMESTAMP',
-            sql_functions.current_date: 'GETDATE()',
-            'length': lambda x: "LEN(%s)" % x,
-            sql_functions.char_length: lambda x: "LEN(%s)" % x
-        }
-    )
-
-    def __init__(self, *args, **kwargs):
-        super(MSSQLCompiler, self).__init__(*args, **kwargs)
-        self.tablealiases = {}
-
-    def get_select_precolumns(self, select):
-        """ MS-SQL puts TOP, it's version of LIMIT here """
-        if select._distinct or select._limit:
-            s = select._distinct and "DISTINCT " or ""
-            
-            if select._limit:
-                if not select._offset:
-                    s += "TOP %s " % (select._limit,)
-                else:
-                    if not self.dialect.has_window_funcs:
-                        raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
-            return s
-        return compiler.SQLCompiler.get_select_precolumns(self, select)
-
-    def limit_clause(self, select):
-        # Limit in mssql is after the select keyword
-        return ""
-
-    def visit_select(self, select, **kwargs):
-        """Look for ``LIMIT`` and OFFSET in a select statement, and if
-        so tries to wrap it in a subquery with ``row_number()`` criterion.
-
-        """
-        if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
-            # to use ROW_NUMBER(), an ORDER BY is required.
-            orderby = self.process(select._order_by_clause)
-            if not orderby:
-                raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
-
-            _offset = select._offset
-            _limit = select._limit
-            select._mssql_visit = True
-            select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
-
-            limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
-            limitselect.append_whereclause("mssql_rn>%d" % _offset)
-            if _limit is not None:
-                limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
-            return self.process(limitselect, iswrapper=True, **kwargs)
-        else:
-            return compiler.SQLCompiler.visit_select(self, select, **kwargs)
-
-    def _schema_aliased_table(self, table):
-        if getattr(table, 'schema', None) is not None:
-            if table not in self.tablealiases:
-                self.tablealiases[table] = table.alias()
-            return self.tablealiases[table]
-        else:
-            return None
-
-    def visit_table(self, table, mssql_aliased=False, **kwargs):
-        if mssql_aliased:
-            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
-        # alias schema-qualified tables
-        alias = self._schema_aliased_table(table)
-        if alias is not None:
-            return self.process(alias, mssql_aliased=True, **kwargs)
-        else:
-            return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
-    def visit_alias(self, alias, **kwargs):
-        # translate for schema-qualified table aliases
-        self.tablealiases[alias.original] = alias
-        kwargs['mssql_aliased'] = True
-        return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
-
-    def visit_savepoint(self, savepoint_stmt):
-        util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
-        return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
-    def visit_rollback_to_savepoint(self, savepoint_stmt):
-        return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
-    def visit_column(self, column, result_map=None, **kwargs):
-        if column.table is not None and \
-            (not self.isupdate and not self.isdelete) or self.is_subquery():
-            # translate for schema-qualified table aliases
-            t = self._schema_aliased_table(column.table)
-            if t is not None:
-                converted = expression._corresponding_column_or_error(t, column)
-
-                if result_map is not None:
-                    result_map[column.name.lower()] = (column.name, (column, ), column.type)
-
-                return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
-
-        return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
-
-    def visit_binary(self, binary, **kwargs):
-        """Move bind parameters to the right-hand side of an operator, where
-        possible.
-
-        """
-        if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
-            and not isinstance(binary.right, expression._BindParamClause):
-            return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
-        else:
-            if (binary.operator is operator.eq or binary.operator is operator.ne) and (
-                (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
-                (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
-                 isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
-                op = binary.operator == operator.eq and "IN" or "NOT IN"
-                return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
-            return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
-
-    def visit_insert(self, insert_stmt):
-        insert_select = False
-        if insert_stmt.parameters:
-            insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
-        if insert_select:
-            self.isinsert = True
-            colparams = self._get_colparams(insert_stmt)
-            preparer = self.preparer
-
-            insert = ' '.join(["INSERT"] +
-                              [self.process(x) for x in insert_stmt._prefixes])
-
-            if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
-                raise exc.CompileError(
-                    "The version of %s you are using does not support empty inserts." % self.dialect.name)
-            elif not colparams and self.dialect.supports_default_values:
-                return (insert + " INTO %s DEFAULT VALUES" % (
-                    (preparer.format_table(insert_stmt.table),)))
-            else:
-                return (insert + " INTO %s (%s) SELECT %s" %
-                    (preparer.format_table(insert_stmt.table),
-                     ', '.join([preparer.format_column(c[0])
-                               for c in colparams]),
-                     ', '.join([c[1] for c in colparams])))
-        else:
-            return super(MSSQLCompiler, self).visit_insert(insert_stmt)
-
-    def label_select_column(self, select, column, asfrom):
-        if isinstance(column, expression.Function):
-            return column.label(None)
-        else:
-            return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
-
-    def for_update_clause(self, select):
-        # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
-        return ''
-
-    def order_by_clause(self, select):
-        order_by = self.process(select._order_by_clause)
-
-        # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
-        if order_by and (not self.is_subquery() or select._limit):
-            return " ORDER BY " + order_by
-        else:
-            return ""
-
-
-class MSSQLSchemaGenerator(compiler.SchemaGenerator):
-    def get_column_specification(self, column, **kwargs):
-        colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
-        if column.nullable is not None:
-            if not column.nullable or column.primary_key:
-                colspec += " NOT NULL"
-            else:
-                colspec += " NULL"
-        
-        if not column.table:
-            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
-            
-        seq_col = _table_sequence_column(column.table)
-
-        # install a IDENTITY Sequence if we have an implicit IDENTITY column
-        if seq_col is column:
-            sequence = getattr(column, 'sequence', None)
-            if sequence:
-                start, increment = sequence.start or 1, sequence.increment or 1
-            else:
-                start, increment = 1, 1
-            colspec += " IDENTITY(%s,%s)" % (start, increment)
-        else:
-            default = self.get_column_default_string(column)
-            if default is not None:
-                colspec += " DEFAULT " + default
-
-        return colspec
-
-class MSSQLSchemaDropper(compiler.SchemaDropper):
-    def visit_index(self, index):
-        self.append("\nDROP INDEX %s.%s" % (
-            self.preparer.quote_identifier(index.table.name),
-            self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
-            ))
-        self.execute()
-
-
-class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
-    reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
-
-    def __init__(self, dialect):
-        super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
-
-    def _escape_identifier(self, value):
-        #TODO: determine MSSQL's escaping rules
-        return value
-
-dialect = MSSQLDialect
-dialect.statement_compiler = MSSQLCompiler
-dialect.schemagenerator = MSSQLSchemaGenerator
-dialect.schemadropper = MSSQLSchemaDropper
-dialect.preparer = MSSQLIdentifierPreparer
-
diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py
new file mode 100644 (file)
index 0000000..1b5858c
--- /dev/null
@@ -0,0 +1,50 @@
+from sqlalchemy.dialects.mssql.base import MSDialect, MSDateTimeAsDate, MSDateTimeAsTime
+from sqlalchemy import types as sqltypes
+
+class MSDialect_pymssql(MSDialect):
+    supports_sane_rowcount = False
+    max_identifier_length = 30
+    driver = 'pymssql'
+
+    # TODO: shouldnt this be based on server version <10 like pyodbc does ?
+    colspecs = MSSQLDialect.colspecs.copy()
+    colspecs[sqltypes.Date] = MSDateTimeAsDate
+    colspecs[sqltypes.Time] = MSDateTimeAsTime
+    
+    @classmethod
+    def import_dbapi(cls):
+        import pymssql as module
+        # pymmsql doesn't have a Binary method.  we use string
+        # TODO: monkeypatching here is less than ideal
+        module.Binary = lambda st: str(st)
+        return module
+
+    def __init__(self, **params):
+        super(MSSQLDialect_pymssql, self).__init__(**params)
+        self.use_scope_identity = True
+
+        # pymssql understands only ascii
+        if self.convert_unicode:
+            util.warn("pymssql does not support unicode")
+            self.encoding = params.get('encoding', 'ascii')
+
+
+    def create_connect_args(self, url):
+        if hasattr(self, 'query_timeout'):
+            # ick, globals ?   we might want to move this....
+            self.dbapi._mssql.set_query_timeout(self.query_timeout)
+
+        keys = url.query
+        if keys.get('port'):
+            # pymssql expects port as host:port, not a separate arg
+            keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
+            del keys['port']
+        return [[], keys]
+
+    def is_disconnect(self, e):
+        return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
+
+    def do_begin(self, connection):
+        pass
+
+dialect = MSDialect_pymssql
\ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py
new file mode 100644 (file)
index 0000000..3c18f60
--- /dev/null
@@ -0,0 +1,59 @@
+from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSDateTimeAsDate, MSDateTimeAsTime
+from sqlalchemy.connectors.pyodbc import PyODBCConnector
+from sqlalchemy import types as sqltypes
+
+import sys
+
+class MSExecutionContext_pyodbc(MSExecutionContext):
+    def pre_exec(self):
+        """where appropriate, issue "select scope_identity()" in the same statement"""
+        super(MSSQLExecutionContext_pyodbc, self).pre_exec()
+        if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
+                and len(self.parameters) == 1 and self.dialect.use_scope_identity:
+            self.statement += "; select scope_identity()"
+
+    def post_exec(self):
+        if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
+            import pyodbc
+            # Fetch the last inserted id from the manipulated statement
+            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
+            while True:
+                try:
+                    row = self.cursor.fetchone()
+                    break
+                except pyodbc.Error, e:
+                    self.cursor.nextset()
+            self._last_inserted_ids = [int(row[0])]
+        else:
+            super(MSSQLExecutionContext_pyodbc, self).post_exec()
+
+
+class MSDialect_pyodbc(PyODBCConnector, MSDialect):
+    supports_sane_rowcount = False
+    supports_sane_multi_rowcount = False
+    # PyODBC unicode is broken on UCS-4 builds
+    supports_unicode = sys.maxunicode == 65535
+    supports_unicode_statements = supports_unicode
+    execution_ctx_cls = MSExecutionContext_pyodbc
+
+    pyodbc_driver_name = 'SQL Server'
+
+    def __init__(self, description_encoding='latin-1', **params):
+        super(MSDialect_pyodbc, self).__init__(**params)
+        self.description_encoding = description_encoding
+        self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset')
+        
+        if self.server_version_info < (10,):
+            self.colspecs = MSDialect.colspecs.copy()
+            self.colspecs[sqltypes.Date] = MSDateTimeAsDate
+            self.colspecs[sqltypes.Time] = MSDateTimeAsTime
+
+    def is_disconnect(self, e):
+        if isinstance(e, self.dbapi.ProgrammingError):
+            return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
+        elif isinstance(e, self.dbapi.Error):
+            return '[08S01]' in str(e)
+        else:
+            return False
+
+dialect = MSDialect_pyodbc
\ No newline at end of file
index ad675839e08cbdb092a9564fd2e67cea43be36af..bb6b7ab75f8f86c16ed43a49f601254266252aee 100644 (file)
@@ -1414,7 +1414,6 @@ class MySQLDDLCompiler(compiler.DDLCompiler):
         """Builds column DDL."""
 
         colspec = [self.preparer.format_column(column),
-                    #self.dialect.type_compiler.process(column.type.dialect_impl(self.dialect))
                     self.dialect.type_compiler.process(column.type)
                    ]
 
index 2f5548236f22fd71b413334f4c48105087d7785e..63bd8bfbab786d82add8dda6e6dbc190661d7633 100644 (file)
@@ -465,7 +465,10 @@ class String(Concatenable, TypeEngine):
         self.assert_unicode = assert_unicode
 
     def adapt(self, impltype):
-        return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode)
+        return impltype(
+                    length=self.length, 
+                    convert_unicode=self.convert_unicode, 
+                    assert_unicode=self.assert_unicode)
 
     def bind_processor(self, dialect):
         if self.convert_unicode or dialect.convert_unicode:
index f0b0bec76f6c77f9d1f60fbcacaf575362a25569..165d6908f585527e6606376cc6f64152275c5ba9 100755 (executable)
@@ -1,7 +1,7 @@
 import testenv; testenv.configure_for_tests()
 import datetime, os, pickleable, re
 from sqlalchemy import *
-from sqlalchemy import types, exc
+from sqlalchemy import types, exc, schema
 from sqlalchemy.orm import *
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
@@ -11,7 +11,7 @@ from testlib.testing import eq_
 
 
 class CompileTest(TestBase, AssertsCompiledSQL):
-    __dialect__ = mssql.MSSQLDialect()
+    __dialect__ = mssql.dialect()
 
     def test_insert(self):
         t = table('sometable', column('somecolumn'))
@@ -258,36 +258,26 @@ class SchemaTest(TestBase):
         )
         self.column = t.c.test_column
 
+        dialect = mssql.dialect()
+        self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t))
+    
+    def _column_spec(self):
+        return self.ddl_compiler.get_column_specification(self.column)
+        
     def test_that_mssql_default_nullability_emits_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NULL", column_specification)
+        eq_("test_column VARCHAR NULL", self._column_spec())
 
     def test_that_mssql_none_nullability_does_not_emit_nullability(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = None
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR", column_specification)
+        eq_("test_column VARCHAR", self._column_spec())
 
     def test_that_mssql_specified_nullable_emits_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = True
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NULL", column_specification)
+        eq_("test_column VARCHAR NULL", self._column_spec())
 
     def test_that_mssql_specified_not_nullable_emits_not_null(self):
-        schemagenerator = \
-            mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None)
         self.column.nullable = False
-        column_specification = \
-            schemagenerator.get_column_specification(self.column)
-        eq_("test_column VARCHAR NOT NULL", column_specification)
+        eq_("test_column VARCHAR NOT NULL", self._column_spec())
 
 
 def full_text_search_missing():
@@ -683,7 +673,8 @@ class TypesTest2(TestBase, AssertsExecutionResults):
             table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None))
 
         binary_table = Table(*table_args)
-        gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None)
+        dialect = mssql.dialect()
+        gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table))
 
         for col in binary_table.c:
             index = int(col.name[1:])
@@ -691,11 +682,7 @@ class TypesTest2(TestBase, AssertsExecutionResults):
                            "%s %s" % (col.name, columns[index][3]))
             self.assert_(repr(col))
 
-        try:
-            binary_table.create(checkfirst=True)
-            assert True
-        except:
-            raise
+        binary_table.create(checkfirst=True)
 
         reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True)
         for col in reflected_binary.c:
@@ -957,6 +944,9 @@ def colspec(c):
 
 class BinaryTest(TestBase, AssertsExecutionResults):
     """Test the Binary and VarBinary types"""
+    
+    __only_on__ = 'mssql'
+    
     def setUpAll(self):
         global binary_table, MyPickleType
 
index 34acba4c740bda598a3bbe130aa4a64a0f45524a..29ed49d073b33c8750a03b13b94d6d1e96e15daa 100644 (file)
@@ -34,14 +34,13 @@ class AdaptTest(TestBase):
                     assert ta != tb
 
     def testmsnvarchar(self):
-        dialect = mssql.MSSQLDialect()
+        dialect = mssql.dialect()
         # run the test twice to ensure the caching step works too
         for x in range(0, 1):
             col = Column('', Unicode(length=10))
             dialect_type = col.type.dialect_impl(dialect)
             assert isinstance(dialect_type, mssql.MSNVarchar)
-            assert dialect_type.get_col_spec() == 'NVARCHAR(10)'
-
+            eq_(dialect.type_compiler.process(dialect_type), 'NVARCHAR(10)')
 
     def testoracletimestamp(self):
         dialect = oracle.OracleDialect()
@@ -105,7 +104,15 @@ class AdaptTest(TestBase):
         
         """
         
-        for dialect in [oracle.dialect(), mysql.dialect(), postgres.dialect(), sqlite.dialect(), sybase.dialect(), informix.dialect(), maxdb.dialect()]: #engines.all_dialects():
+        for dialect in [
+                oracle.dialect(), 
+                mysql.dialect(), 
+                postgres.dialect(), 
+                sqlite.dialect(), 
+                sybase.dialect(), 
+                informix.dialect(), 
+                maxdb.dialect(), 
+                mssql.dialect()]: # TODO when dialects are complete:  engines.all_dialects():
             for type_, expected in (
                 (FLOAT, "FLOAT"),
                 (NUMERIC, "NUMERIC"),