If you need to load a specific driver pass ``module_name`` when
creating the engine::
- engine = create_engine('mssql://dsn', module_name='pymssql')
+ engine = create_engine('mssql+module_name://dsn')
``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and
``adodbapi``.
Examples of pyodbc connection string URLs:
-* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
+* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``.
The connection string that is created will appear like::
dsn=mydsn;TrustedConnection=Yes
-* *mssql://user:pass@mydsn* - connects using the DSN named
+* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named
``mydsn`` passing in the ``UID`` and ``PWD`` information. The
connection string that is created will appear like::
dsn=mydsn;UID=user;PWD=pass
-* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
+* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects
using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
information, plus the additional connection configuration option
``LANGUAGE``. The connection string that is created will appear
dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
-* *mssql://user:pass@host/db* - connects using a connection string
+* *mssql+pyodbc://user:pass@host/db* - connects using a connection string
dynamically created that would appear like::
DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
-* *mssql://user:pass@host:123/db* - connects using a connection
+* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection
string that is dynamically created, which also includes the port
information using the comma syntax. If your connection string
requires the port information to be passed as a ``port`` keyword
DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
-* *mssql://user:pass@host/db?port=123* - connects using a connection
+* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection
string that is dynamically created that includes the port
information as a separate ``port`` keyword. This will create the
following connection string::
For example::
- mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+ mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
would create the following connection string::
* *query_timeout* - allows you to override the default query timeout.
Defaults to ``None``. This is only supported on pymssql.
-* *text_as_varchar* - if enabled this will treat all TEXT column
- types as their equivalent VARCHAR(max) type. This is often used if
- you need to compare a VARCHAR to a TEXT field, which is not
- supported directly on MSSQL. Defaults to ``False``.
-
* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
should be used in place of the non-scoped version @@IDENTITY.
Defaults to ``False``. On pymssql this defaults to ``True``, and on
MSSQL_RESERVED_WORDS = set(['function'])
-class _StringType(object):
- """Base for MSSQL string types."""
-
- def __init__(self, collation=None, **kwargs):
- self.collation = kwargs.get('collate', collation)
-
- def _extend(self, spec):
- """Extend a string-type declaration with standard SQL
- COLLATE annotations.
- """
-
- if self.collation:
- collation = 'COLLATE %s' % self.collation
- else:
- collation = None
-
- return ' '.join([c for c in (spec, collation)
- if c is not None])
-
- def __repr__(self):
- attributes = inspect.getargspec(self.__init__)[0][1:]
- attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
-
- params = {}
- for attr in attributes:
- val = getattr(self, attr)
- if val is not None and val is not False:
- params[attr] = val
-
- return "%s(%s)" % (self.__class__.__name__,
- ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-
- def bind_processor(self, dialect):
- if self.convert_unicode or dialect.convert_unicode:
- if self.assert_unicode is None:
- assert_unicode = dialect.assert_unicode
- else:
- assert_unicode = self.assert_unicode
-
- if not assert_unicode:
- return None
-
- def process(value):
- if not isinstance(value, (unicode, sqltypes.NoneType)):
- if assert_unicode == 'warn':
- util.warn("Unicode type received non-unicode bind "
- "param value %r" % value)
- return value
- else:
- raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value)
- else:
- return value
- return process
- else:
- return None
-
- def result_processor(self, dialect):
- return None
-
-
class MSNumeric(sqltypes.Numeric):
def result_processor(self, dialect):
if self.asdecimal:
return process
- def get_col_spec(self):
- if self.precision is None:
- return "NUMERIC"
- else:
- return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale}
-
-
-class MSFloat(sqltypes.Float):
- def get_col_spec(self):
- if self.precision is None:
- return "FLOAT"
- else:
- return "FLOAT(%(precision)s)" % {'precision': self.precision}
-
-
-class MSReal(MSFloat):
+class MSReal(sqltypes.Float):
"""A type for ``real`` numbers."""
- def __init__(self):
- """
- Construct a Real.
+ __visit_name__ = 'REAL'
- """
+ def __init__(self):
super(MSReal, self).__init__(precision=24)
- def adapt(self, impltype):
- return impltype()
-
- def get_col_spec(self):
- return "REAL"
-
-
-class MSInteger(sqltypes.Integer):
- def get_col_spec(self):
- return "INTEGER"
-
-
-class MSBigInteger(MSInteger):
- def get_col_spec(self):
- return "BIGINT"
-
-
-class MSTinyInteger(MSInteger):
- def get_col_spec(self):
- return "TINYINT"
-
-
-class MSSmallInteger(MSInteger):
- def get_col_spec(self):
- return "SMALLINT"
+class MSTinyInteger(sqltypes.Integer):
+ __visit_name__ = 'TINYINT'
+class MSTime(sqltypes.Time):
+ def __init__(self, precision=None, **kwargs):
+ self.precision = precision
+ super(MSTime, self).__init__()
-class _DateTimeType(object):
- """Base for MSSQL datetime types."""
+class MSDateTime(sqltypes.DateTime):
def bind_processor(self, dialect):
- # if we receive just a date we can manipulate it
- # into a datetime since the db-api may not do this.
+ # most DBAPIs allow a datetime.date object
+ # as a datetime.
def process(value):
if type(value) is datetime.date:
return datetime.datetime(value.year, value.month, value.day)
return value
return process
+
+class MSSmallDateTime(MSDateTime):
+ __visit_name__ = 'SMALLDATETIME'
-
-class MSDateTime(_DateTimeType, sqltypes.DateTime):
- def get_col_spec(self):
- return "DATETIME"
-
-
-class MSDate(sqltypes.Date):
- def get_col_spec(self):
- return "DATE"
-
-
-class MSTime(sqltypes.Time):
- def __init__(self, precision=None, **kwargs):
- self.precision = precision
- super(MSTime, self).__init__()
-
- def get_col_spec(self):
- if self.precision:
- return "TIME(%s)" % self.precision
- else:
- return "TIME"
-
-
-class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine):
- def get_col_spec(self):
- return "SMALLDATETIME"
-
-
-class MSDateTime2(_DateTimeType, sqltypes.TypeEngine):
+class MSDateTime2(MSDateTime):
+ __visit_name__ = 'DATETIME2'
+
def __init__(self, precision=None, **kwargs):
self.precision = precision
- def get_col_spec(self):
- if self.precision:
- return "DATETIME2(%s)" % self.precision
- else:
- return "DATETIME2"
-
-
-class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine):
+class MSDateTimeOffset(sqltypes.TypeEngine):
+ __visit_name__ = 'DATETIMEOFFSET'
+
def __init__(self, precision=None, **kwargs):
self.precision = precision
- def get_col_spec(self):
- if self.precision:
- return "DATETIMEOFFSET(%s)" % self.precision
- else:
- return "DATETIMEOFFSET"
-
-
-class MSDateTimeAsDate(_DateTimeType, MSDate):
+class MSDateTimeAsDate(sqltypes.TypeDecorator):
""" This is an implementation of the Date type for versions of MSSQL that
do not support that specific type. In order to make it work a ``DATETIME``
column specification is used and the results get converted back to just
"""
- def get_col_spec(self):
- return "DATETIME"
+ impl = sqltypes.DateTime
- def result_processor(self, dialect):
- def process(value):
- # If the DBAPI returns the value as datetime.datetime(), truncate
- # it back to datetime.date()
- if type(value) is datetime.datetime:
- return value.date()
- return value
- return process
+ def process_bind_param(self, value, dialect):
+ if type(value) is datetime.date:
+ return datetime.datetime(value.year, value.month, value.day)
+ return value
+ def process_result_value(self, value, dialect):
+ if type(value) is datetime.datetime:
+ return value.date()
+ return value
-class MSDateTimeAsTime(MSTime):
+class MSDateTimeAsTime(sqltypes.TypeDecorator):
""" This is an implementation of the Time type for versions of MSSQL that
do not support that specific type. In order to make it work a ``DATETIME``
column specification is used and the results get converted back to just
__zero_date = datetime.date(1900, 1, 1)
- def get_col_spec(self):
- return "DATETIME"
+ impl = sqltypes.DateTime
- def bind_processor(self, dialect):
- def process(value):
- if type(value) is datetime.datetime:
- value = datetime.datetime.combine(self.__zero_date, value.time())
- elif type(value) is datetime.time:
- value = datetime.datetime.combine(self.__zero_date, value)
- return value
- return process
+ def process_bind_param(self, value, dialect):
+ if type(value) is datetime.datetime:
+ value = datetime.datetime.combine(self.__zero_date, value.time())
+ elif type(value) is datetime.time:
+ value = datetime.datetime.combine(self.__zero_date, value)
+ return value
- def result_processor(self, dialect):
- def process(value):
- if type(value) is datetime.datetime:
- return value.time()
- elif type(value) is datetime.date:
- return datetime.time(0, 0, 0)
- return value
- return process
+ def process_result_value(self, value, dialect):
+ if type(value) is datetime.datetime:
+ return value.time()
+ elif type(value) is datetime.date:
+ return datetime.time(0, 0, 0)
+ return value
-class MSDateTime_adodbapi(MSDateTime):
- def result_processor(self, dialect):
- def process(value):
- # adodbapi will return datetimes with empty time values as datetime.date() objects.
- # Promote them back to full datetime.datetime()
- if type(value) is datetime.date:
- return datetime.datetime(value.year, value.month, value.day)
- return value
- return process
+class _StringType(object):
+ """Base for MSSQL string types."""
+
+ def __init__(self, collation=None):
+ self.collation = collation
+
+ def __repr__(self):
+ attributes = inspect.getargspec(self.__init__)[0][1:]
+ attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:])
+
+ params = {}
+ for attr in attributes:
+ val = getattr(self, attr)
+ if val is not None and val is not False:
+ params[attr] = val
+
+ return "%s(%s)" % (self.__class__.__name__,
+ ', '.join(['%s=%r' % (k, params[k]) for k in params]))
-class MSText(_StringType, sqltypes.Text):
+class MSText(_StringType, sqltypes.TEXT):
"""MSSQL TEXT type, for variable-length text up to 2^31 characters."""
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kw):
"""Construct a TEXT.
:param collation: Optional, a column-level collation for this string
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.Text.__init__(self, None,
- convert_unicode=kwargs.get('convert_unicode', False),
- assert_unicode=kwargs.get('assert_unicode', None))
-
- def get_col_spec(self):
- if self.dialect.text_as_varchar:
- return self._extend("VARCHAR(max)")
- else:
- return self._extend("TEXT")
-
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.Text.__init__(self, *args, **kw)
class MSNText(_StringType, sqltypes.UnicodeText):
"""MSSQL NTEXT type, for variable-length unicode text up to 2^30
characters."""
+ __visit_name__ = 'NTEXT'
+
def __init__(self, *args, **kwargs):
"""Construct a NTEXT.
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.UnicodeText.__init__(self, None,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def get_col_spec(self):
- if self.dialect.text_as_varchar:
- return self._extend("NVARCHAR(max)")
- else:
- return self._extend("NTEXT")
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.UnicodeText.__init__(self, None, **kw)
-class MSString(_StringType, sqltypes.String):
+class MSString(_StringType, sqltypes.VARCHAR):
"""MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
+ def __init__(self, *args, **kw):
"""Construct a VARCHAR.
:param length: Optinal, maximum data length, in characters.
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.String.__init__(self, length=length,
- convert_unicode=convert_unicode,
- assert_unicode=assert_unicode)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("VARCHAR(%s)" % self.length)
- else:
- return self._extend("VARCHAR")
-
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.VARCHAR.__init__(self, *args, **kw)
-class MSNVarchar(_StringType, sqltypes.Unicode):
+class MSNVarchar(_StringType, sqltypes.NVARCHAR):
"""MSSQL NVARCHAR type.
For variable-length unicode character data up to 4,000 characters."""
- def __init__(self, length=None, **kwargs):
+ def __init__(self, *args, **kw):
"""Construct a NVARCHAR.
:param length: Optional, Maximum data length, in characters.
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.Unicode.__init__(self, length=length,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def adapt(self, impltype):
- return impltype(length=self.length,
- convert_unicode=self.convert_unicode,
- assert_unicode=self.assert_unicode,
- collation=self.collation)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length})
- else:
- return self._extend("NVARCHAR")
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.NVARCHAR.__init__(self, *args, **kw)
class MSChar(_StringType, sqltypes.CHAR):
"""MSSQL CHAR type, for fixed-length non-Unicode data with a maximum
of 8,000 characters."""
- def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs):
+ def __init__(self, *args, **kw):
"""Construct a CHAR.
:param length: Optinal, maximum data length, in characters.
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.CHAR.__init__(self, length=length,
- convert_unicode=convert_unicode,
- assert_unicode=assert_unicode)
-
- def get_col_spec(self):
- if self.length:
- return self._extend("CHAR(%s)" % self.length)
- else:
- return self._extend("CHAR")
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.CHAR.__init__(self, *args, **kw)
class MSNChar(_StringType, sqltypes.NCHAR):
For fixed-length unicode character data up to 4,000 characters."""
- def __init__(self, length=None, **kwargs):
+ def __init__(self, *args, **kw):
"""Construct an NCHAR.
:param length: Optional, Maximum data length, in characters.
value. Accepts a Windows Collation Name or a SQL Collation Name.
"""
- _StringType.__init__(self, **kwargs)
- sqltypes.NCHAR.__init__(self, length=length,
- convert_unicode=kwargs.get('convert_unicode', True),
- assert_unicode=kwargs.get('assert_unicode', 'warn'))
-
- def get_col_spec(self):
- if self.length:
- return self._extend("NCHAR(%(length)s)" % {'length' : self.length})
- else:
- return self._extend("NCHAR")
-
-
-class MSGenericBinary(sqltypes.Binary):
- """The Binary type assumes that a Binary specification without a length
- is an unbound Binary type whereas one with a length specification results
- in a fixed length Binary type.
-
- If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type.
-
- """
-
- def get_col_spec(self):
- if self.length:
- return "BINARY(%s)" % self.length
- else:
- return "IMAGE"
-
-
-class MSBinary(MSGenericBinary):
- def get_col_spec(self):
- if self.length:
- return "BINARY(%s)" % self.length
- else:
- return "BINARY"
+ collation = kw.pop('collation', None)
+ _StringType.__init__(self, collation)
+ sqltypes.NCHAR.__init__(self, *args, **kw)
+class MSBinary(sqltypes.Binary):
+ pass
-class MSVarBinary(MSGenericBinary):
- def get_col_spec(self):
- if self.length:
- return "VARBINARY(%s)" % self.length
- else:
- return "VARBINARY"
-
-
-class MSImage(MSGenericBinary):
- def get_col_spec(self):
- return "IMAGE"
+class MSVarBinary(sqltypes.Binary):
+ __visit_name__ = 'VARBINARY'
+class MSImage(sqltypes.Binary):
+ __visit_name__ = 'IMAGE'
+class MSBit(sqltypes.TypeEngine):
+ __visit_name__ = 'BIT'
+
class MSBoolean(sqltypes.Boolean):
- def get_col_spec(self):
- return "BIT"
-
def result_processor(self, dialect):
def process(value):
if value is None:
return value and True or False
return process
+class MSMoney(sqltypes.TypeEngine):
+ __visit_name__ = 'MONEY'
+
+class MSSmallMoney(MSMoney):
+ __visit_name__ = 'SMALLMONEY'
-class MSTimeStamp(sqltypes.TIMESTAMP):
- def get_col_spec(self):
- return "TIMESTAMP"
+class MSUniqueIdentifier(sqltypes.TypeEngine):
+ __visit_name__ = "UNIQUEIDENTIFIER"
-class MSMoney(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "MONEY"
+class MSVariant(sqltypes.TypeEngine):
+ __visit_name__ = 'SQL_VARIANT'
+class MSTypeCompiler(compiler.GenericTypeCompiler):
+ def _extend(self, spec, type_):
+ """Extend a string-type declaration with standard SQL
+ COLLATE annotations.
-class MSSmallMoney(MSMoney):
- def get_col_spec(self):
- return "SMALLMONEY"
+ """
+ if getattr(type_, 'collation', None):
+ collation = 'COLLATE %s' % type_.collation
+ else:
+ collation = None
-class MSUniqueIdentifier(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "UNIQUEIDENTIFIER"
+ if type_.length:
+ spec = spec + "(%d)" % type_.length
+
+ return ' '.join([c for c in (spec, collation)
+ if c is not None])
+ def visit_FLOAT(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision is None:
+ return "FLOAT"
+ else:
+ return "FLOAT(%(precision)s)" % {'precision': precision}
-class MSVariant(sqltypes.TypeEngine):
- def get_col_spec(self):
- return "SQL_VARIANT"
+ def visit_REAL(self, type_):
+ return "REAL"
+
+ def visit_TINYINT(self, type_):
+ return "TINYINT"
+
+ def visit_DATETIMEOFFSET(self, type_):
+ if type_.precision:
+ return "DATETIMEOFFSET(%s)" % type_.precision
+ else:
+ return "DATETIMEOFFSET"
+
+ def visit_TIME(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision:
+ return "TIME(%s)" % precision
+ else:
+ return "TIME"
+
+ def visit_DATETIME2(self, type_):
+ precision = getattr(type_, 'precision', None)
+ if precision:
+ return "DATETIME2(%s)" % precision
+ else:
+ return "DATETIME2"
+
+ def visit_SMALLDATETIME(self, type_):
+ return "SMALLDATETIME"
+
+ def visit_NTEXT(self, type_):
+ return self._extend("NTEXT", type_)
+
+ def visit_TEXT(self, type_):
+ return self._extend("TEXT", type_)
+
+ def visit_VARCHAR(self, type_):
+ return self._extend("VARCHAR", type_)
+
+ def visit_CHAR(self, type_):
+ return self._extend("CHAR", type_)
+
+ def visit_NCHAR(self, type_):
+ return self._extend("NCHAR", type_)
+
+ def visit_NVARCHAR(self, type_):
+ return self._extend("NVARCHAR", type_)
+
+ def visit_binary(self, type_):
+ if type_.length:
+ return self.visit_BINARY(type_)
+ else:
+ return self.visit_IMAGE(type_)
+
+ def visit_BINARY(self, type_):
+ if type_.length:
+ return "BINARY(%s)" % type_.length
+ else:
+ return "BINARY"
+
+ def visit_IMAGE(self, type_):
+ return "IMAGE"
+
+ def visit_VARBINARY(self, type_):
+ if type_.length:
+ return "VARBINARY(%s)" % type_.length
+ else:
+ return "VARBINARY"
+
+ def visit_boolean(self, type_):
+ return self.visit_BIT(type_)
+
+ def visit_BIT(self, type_):
+ return "BIT"
+
+ def visit_MONEY(self, type_):
+ return "MONEY"
+
+ def visit_SMALLMONEY(self, type_):
+ return 'SMALLMONEY'
+
+ def visit_UNIQUEIDENTIFIER(self, type_):
+ return "UNIQUEIDENTIFIER"
+ def visit_SQL_VARIANT(self, type_):
+ return 'SQL_VARIANT'
def _has_implicit_sequence(column):
return column.primary_key and \
break
return tbl._ms_has_sequence
-class MSSQLExecutionContext(default.DefaultExecutionContext):
+class MSExecutionContext(default.DefaultExecutionContext):
IINSERT = False
HASIDENT = False
if self.IINSERT:
self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+colspecs = {
+ sqltypes.Unicode : MSNVarchar,
+ sqltypes.Numeric : MSNumeric,
+ sqltypes.DateTime : MSDateTime,
+ sqltypes.Time : MSTime,
+ sqltypes.String : MSString,
+ sqltypes.Boolean : MSBoolean,
+ sqltypes.Text : MSText,
+ sqltypes.UnicodeText : MSNText,
+ sqltypes.CHAR: MSChar,
+ sqltypes.NCHAR: MSNChar,
+}
+
+ischema_names = {
+ 'int' : sqltypes.INTEGER,
+ 'bigint': sqltypes.BigInteger,
+ 'smallint' : sqltypes.SmallInteger,
+ 'tinyint' : MSTinyInteger,
+ 'varchar' : MSString,
+ 'nvarchar' : MSNVarchar,
+ 'char' : MSChar,
+ 'nchar' : MSNChar,
+ 'text' : MSText,
+ 'ntext' : MSNText,
+ 'decimal' : sqltypes.DECIMAL,
+ 'numeric' : sqltypes.NUMERIC,
+ 'float' : sqltypes.FLOAT,
+ 'datetime' : sqltypes.DateTime,
+ 'datetime2' : MSDateTime2,
+ 'datetimeoffset' : MSDateTimeOffset,
+ 'date': sqltypes.Date,
+ 'time': MSTime,
+ 'smalldatetime' : MSSmallDateTime,
+ 'binary' : MSBinary,
+ 'varbinary' : MSVarBinary,
+ 'bit': sqltypes.Boolean,
+ 'real' : MSReal,
+ 'image' : MSImage,
+ 'timestamp': sqltypes.TIMESTAMP,
+ 'money': MSMoney,
+ 'smallmoney': MSSmallMoney,
+ 'uniqueidentifier': MSUniqueIdentifier,
+ 'sql_variant': MSVariant,
+}
-class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
- def pre_exec(self):
- """where appropriate, issue "select scope_identity()" in the same statement"""
- super(MSSQLExecutionContext_pyodbc, self).pre_exec()
- if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
- and len(self.parameters) == 1 and self.dialect.use_scope_identity:
- self.statement += "; select scope_identity()"
-
- def post_exec(self):
- if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
- import pyodbc
- # Fetch the last inserted id from the manipulated statement
- # We may have to skip over a number of result sets with no data (due to triggers, etc.)
- while True:
- try:
- row = self.cursor.fetchone()
- break
- except pyodbc.Error, e:
- self.cursor.nextset()
- self._last_inserted_ids = [int(row[0])]
- else:
- super(MSSQLExecutionContext_pyodbc, self).post_exec()
+class MSSQLCompiler(compiler.SQLCompiler):
+ operators = compiler.OPERATORS.copy()
+ operators.update({
+ sql_operators.concat_op: '+',
+ sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
+ })
-class MSSQLDialect(default.DefaultDialect):
- name = 'mssql'
- supports_default_values = True
- supports_empty_insert = False
- auto_identity_insert = True
- execution_ctx_cls = MSSQLExecutionContext
- text_as_varchar = False
- use_scope_identity = False
- has_window_funcs = False
- max_identifier_length = 128
- schema_name = "dbo"
+ functions = compiler.SQLCompiler.functions.copy()
+ functions.update (
+ {
+ sql_functions.now: 'CURRENT_TIMESTAMP',
+ sql_functions.current_date: 'GETDATE()',
+ 'length': lambda x: "LEN(%s)" % x,
+ sql_functions.char_length: lambda x: "LEN(%s)" % x
+ }
+ )
- colspecs = {
- sqltypes.Unicode : MSNVarchar,
- sqltypes.Integer : MSInteger,
- sqltypes.SmallInteger: MSSmallInteger,
- sqltypes.Numeric : MSNumeric,
- sqltypes.Float : MSFloat,
- sqltypes.DateTime : MSDateTime,
- sqltypes.Date : MSDate,
- sqltypes.Time : MSTime,
- sqltypes.String : MSString,
- sqltypes.Binary : MSGenericBinary,
- sqltypes.Boolean : MSBoolean,
- sqltypes.Text : MSText,
- sqltypes.UnicodeText : MSNText,
- sqltypes.CHAR: MSChar,
- sqltypes.NCHAR: MSNChar,
- sqltypes.TIMESTAMP: MSTimeStamp,
- }
-
- ischema_names = {
- 'int' : MSInteger,
- 'bigint': MSBigInteger,
- 'smallint' : MSSmallInteger,
- 'tinyint' : MSTinyInteger,
- 'varchar' : MSString,
- 'nvarchar' : MSNVarchar,
- 'char' : MSChar,
- 'nchar' : MSNChar,
- 'text' : MSText,
- 'ntext' : MSNText,
- 'decimal' : MSNumeric,
- 'numeric' : MSNumeric,
- 'float' : MSFloat,
- 'datetime' : MSDateTime,
- 'datetime2' : MSDateTime2,
- 'datetimeoffset' : MSDateTimeOffset,
- 'date': MSDate,
- 'time': MSTime,
- 'smalldatetime' : MSSmallDateTime,
- 'binary' : MSBinary,
- 'varbinary' : MSVarBinary,
- 'bit': MSBoolean,
- 'real' : MSFloat,
- 'image' : MSImage,
- 'timestamp': MSTimeStamp,
- 'money': MSMoney,
- 'smallmoney': MSSmallMoney,
- 'uniqueidentifier': MSUniqueIdentifier,
- 'sql_variant': MSVariant,
- }
-
- def __new__(cls, *args, **kwargs):
- if cls is not MSSQLDialect:
- # this gets called with the dialect specific class
- return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
- dbapi = kwargs.get('dbapi', None)
- if dbapi:
- dialect = dialect_mapping.get(dbapi.__name__)
- return dialect(**kwargs)
- else:
- return object.__new__(cls, *args, **kwargs)
+ def __init__(self, *args, **kwargs):
+ super(MSSQLCompiler, self).__init__(*args, **kwargs)
+ self.tablealiases = {}
- def __init__(self,
- auto_identity_insert=True, query_timeout=None,
- text_as_varchar=False, use_scope_identity=False,
- has_window_funcs=False, max_identifier_length=None,
- schema_name="dbo", **opts):
- self.auto_identity_insert = bool(auto_identity_insert)
- self.query_timeout = int(query_timeout or 0)
- self.schema_name = schema_name
+ def get_select_precolumns(self, select):
+ """ MS-SQL puts TOP, it's version of LIMIT here """
+ if select._distinct or select._limit:
+ s = select._distinct and "DISTINCT " or ""
+
+ if select._limit:
+ if not select._offset:
+ s += "TOP %s " % (select._limit,)
+ else:
+ if not self.dialect.has_window_funcs:
+ raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
+ return s
+ return compiler.SQLCompiler.get_select_precolumns(self, select)
- # to-do: the options below should use server version introspection to set themselves on connection
- self.text_as_varchar = bool(text_as_varchar)
- self.use_scope_identity = bool(use_scope_identity)
- self.has_window_funcs = bool(has_window_funcs)
- self.max_identifier_length = int(max_identifier_length or 0) or 128
- super(MSSQLDialect, self).__init__(**opts)
+ def limit_clause(self, select):
+ # Limit in mssql is after the select keyword
+ return ""
- @classmethod
- def dbapi(cls, module_name=None):
- if module_name:
- try:
- dialect_cls = dialect_mapping[module_name]
- return dialect_cls.import_dbapi()
- except KeyError:
- raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name)
+ def visit_select(self, select, **kwargs):
+ """Look for ``LIMIT`` and OFFSET in a select statement, and if
+ so tries to wrap it in a subquery with ``row_number()`` criterion.
+
+ """
+ if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
+ # to use ROW_NUMBER(), an ORDER BY is required.
+ orderby = self.process(select._order_by_clause)
+ if not orderby:
+ raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
+
+ _offset = select._offset
+ _limit = select._limit
+ select._mssql_visit = True
+ select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
+
+ limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
+ limitselect.append_whereclause("mssql_rn>%d" % _offset)
+ if _limit is not None:
+ limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
+ return self.process(limitselect, iswrapper=True, **kwargs)
+ else:
+ return compiler.SQLCompiler.visit_select(self, select, **kwargs)
+
+ def _schema_aliased_table(self, table):
+ if getattr(table, 'schema', None) is not None:
+ if table not in self.tablealiases:
+ self.tablealiases[table] = table.alias()
+ return self.tablealiases[table]
+ else:
+ return None
+
+ def visit_table(self, table, mssql_aliased=False, **kwargs):
+ if mssql_aliased:
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ # alias schema-qualified tables
+ alias = self._schema_aliased_table(table)
+ if alias is not None:
+ return self.process(alias, mssql_aliased=True, **kwargs)
else:
- for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]:
- try:
- return dialect_cls.import_dbapi()
- except ImportError, e:
- pass
+ return super(MSSQLCompiler, self).visit_table(table, **kwargs)
+
+ def visit_alias(self, alias, **kwargs):
+ # translate for schema-qualified table aliases
+ self.tablealiases[alias.original] = alias
+ kwargs['mssql_aliased'] = True
+ return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
+
+ def visit_savepoint(self, savepoint_stmt):
+ util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
+ return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+ def visit_rollback_to_savepoint(self, savepoint_stmt):
+ return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
+
+ def visit_column(self, column, result_map=None, **kwargs):
+ if column.table is not None and \
+ (not self.isupdate and not self.isdelete) or self.is_subquery():
+ # translate for schema-qualified table aliases
+ t = self._schema_aliased_table(column.table)
+ if t is not None:
+ converted = expression._corresponding_column_or_error(t, column)
+
+ if result_map is not None:
+ result_map[column.name.lower()] = (column.name, (column, ), column.type)
+
+ return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
+
+ return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
+
+ def visit_binary(self, binary, **kwargs):
+ """Move bind parameters to the right-hand side of an operator, where
+ possible.
+
+ """
+ if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
+ and not isinstance(binary.right, expression._BindParamClause):
+ return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
+ else:
+ if (binary.operator is operator.eq or binary.operator is operator.ne) and (
+ (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
+ (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
+ isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
+ op = binary.operator == operator.eq and "IN" or "NOT IN"
+ return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
+ return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
+
+ def visit_insert(self, insert_stmt):
+ insert_select = False
+ if insert_stmt.parameters:
+ insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
+ if insert_select:
+ self.isinsert = True
+ colparams = self._get_colparams(insert_stmt)
+ preparer = self.preparer
+
+ insert = ' '.join(["INSERT"] +
+ [self.process(x) for x in insert_stmt._prefixes])
+
+ if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
+ raise exc.CompileError(
+ "The version of %s you are using does not support empty inserts." % self.dialect.name)
+ elif not colparams and self.dialect.supports_default_values:
+ return (insert + " INTO %s DEFAULT VALUES" % (
+ (preparer.format_table(insert_stmt.table),)))
else:
- raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
+ return (insert + " INTO %s (%s) SELECT %s" %
+ (preparer.format_table(insert_stmt.table),
+ ', '.join([preparer.format_column(c[0])
+ for c in colparams]),
+ ', '.join([c[1] for c in colparams])))
+ else:
+ return super(MSSQLCompiler, self).visit_insert(insert_stmt)
+
+ def label_select_column(self, select, column, asfrom):
+ if isinstance(column, expression.Function):
+ return column.label(None)
+ else:
+ return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
+
+ def for_update_clause(self, select):
+ # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
+ return ''
+
+ def order_by_clause(self, select):
+ order_by = self.process(select._order_by_clause)
+
+ # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
+ if order_by and (not self.is_subquery() or select._limit):
+ return " ORDER BY " + order_by
+ else:
+ return ""
+
+
+class MSDDLCompiler(compiler.DDLCompiler):
+ def get_column_specification(self, column, **kwargs):
+ colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type)
+
+ if column.nullable is not None:
+ if not column.nullable or column.primary_key:
+ colspec += " NOT NULL"
+ else:
+ colspec += " NULL"
+
+ if not column.table:
+ raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+
+ seq_col = _table_sequence_column(column.table)
+
+ # install a IDENTITY Sequence if we have an implicit IDENTITY column
+ if seq_col is column:
+ sequence = getattr(column, 'sequence', None)
+ if sequence:
+ start, increment = sequence.start or 1, sequence.increment or 1
+ else:
+ start, increment = 1, 1
+ colspec += " IDENTITY(%s,%s)" % (start, increment)
+ else:
+ default = self.get_column_default_string(column)
+ if default is not None:
+ colspec += " DEFAULT " + default
+
+ return colspec
+
+ def visit_drop_index(self, drop):
+ return "\nDROP INDEX %s.%s" % (
+ self.preparer.quote_identifier(drop.element.table.name),
+ self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote)
+ )
+
+
+class MSIdentifierPreparer(compiler.IdentifierPreparer):
+ reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
+
+ def __init__(self, dialect):
+ super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
+
+ def _escape_identifier(self, value):
+ #TODO: determine MSSQL's escaping rules
+ return value
+
+class MSDialect(default.DefaultDialect):
+ name = 'mssql'
+ supports_default_values = True
+ supports_empty_insert = False
+ auto_identity_insert = True
+ execution_ctx_cls = MSExecutionContext
+ text_as_varchar = False
+ use_scope_identity = False
+ has_window_funcs = False
+ max_identifier_length = 128
+ schema_name = "dbo"
+ colspecs = colspecs
+ ischema_names = ischema_names
+
+ supports_unicode_binds = True
+
+ statement_compiler = MSSQLCompiler
+ ddl_compiler = MSDDLCompiler
+ type_compiler = MSTypeCompiler
+ preparer = MSIdentifierPreparer
+
+ def __init__(self,
+ auto_identity_insert=True, query_timeout=None,
+ use_scope_identity=False,
+ has_window_funcs=False, max_identifier_length=None,
+ schema_name="dbo", **opts):
+ self.auto_identity_insert = bool(auto_identity_insert)
+ self.query_timeout = int(query_timeout or 0)
+ self.schema_name = schema_name
+
+ self.use_scope_identity = bool(use_scope_identity)
+ self.has_window_funcs = bool(has_window_funcs)
+ self.max_identifier_length = int(max_identifier_length or 0) or 128
+ super(MSDialect, self).__init__(**opts)
@base.connection_memoize(('mssql', 'server_version_info'))
def server_version_info(self, connection):
"""Return a tuple of the database's version number."""
raise NotImplementedError()
- def create_connect_args(self, url):
- opts = url.translate_connect_args(username='user')
- opts.update(url.query)
- if 'auto_identity_insert' in opts:
- self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
- if 'query_timeout' in opts:
- self.query_timeout = int(opts.pop('query_timeout'))
- if 'text_as_varchar' in opts:
- self.text_as_varchar = bool(int(opts.pop('text_as_varchar')))
- if 'use_scope_identity' in opts:
- self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
- if 'has_window_funcs' in opts:
- self.has_window_funcs = bool(int(opts.pop('has_window_funcs')))
- return self.make_connect_string(opts, url.query)
-
- def type_descriptor(self, typeobj):
- newobj = sqltypes.adapt_type(typeobj, self.colspecs)
- # Some types need to know about the dialect
- if isinstance(newobj, (MSText, MSNText)):
- newobj.dialect = self
- return newobj
-
def do_begin(self, connection):
cursor = connection.cursor()
cursor.execute("SET IMPLICIT_TRANSACTIONS OFF")
if fknm and scols:
table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True))
-
-class MSSQLDialect_pymssql(MSSQLDialect):
- supports_sane_rowcount = False
- max_identifier_length = 30
-
- @classmethod
- def import_dbapi(cls):
- import pymssql as module
- # pymmsql doesn't have a Binary method. we use string
- # TODO: monkeypatching here is less than ideal
- module.Binary = lambda st: str(st)
- return module
-
- def __init__(self, **params):
- super(MSSQLDialect_pymssql, self).__init__(**params)
- self.use_scope_identity = True
-
- # pymssql understands only ascii
- if self.convert_unicode:
- util.warn("pymssql does not support unicode")
- self.encoding = params.get('encoding', 'ascii')
-
- self.colspecs = MSSQLDialect.colspecs.copy()
- self.ischema_names = MSSQLDialect.ischema_names.copy()
- self.ischema_names['date'] = MSDateTimeAsDate
- self.colspecs[sqltypes.Date] = MSDateTimeAsDate
- self.ischema_names['time'] = MSDateTimeAsTime
- self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
- def create_connect_args(self, url):
- r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
- if hasattr(self, 'query_timeout'):
- self.dbapi._mssql.set_query_timeout(self.query_timeout)
- return r
-
- def make_connect_string(self, keys, query):
- if keys.get('port'):
- # pymssql expects port as host:port, not a separate arg
- keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
- del keys['port']
- return [[], keys]
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e)
-
- def do_begin(self, connection):
- pass
-
-
-class MSSQLDialect_pyodbc(MSSQLDialect):
- supports_sane_rowcount = False
- supports_sane_multi_rowcount = False
- # PyODBC unicode is broken on UCS-4 builds
- supports_unicode = sys.maxunicode == 65535
- supports_unicode_statements = supports_unicode
- execution_ctx_cls = MSSQLExecutionContext_pyodbc
-
- def __init__(self, description_encoding='latin-1', **params):
- super(MSSQLDialect_pyodbc, self).__init__(**params)
- self.description_encoding = description_encoding
-
- if self.server_version_info < (10,):
- self.colspecs = MSSQLDialect.colspecs.copy()
- self.ischema_names = MSSQLDialect.ischema_names.copy()
- self.ischema_names['date'] = MSDateTimeAsDate
- self.colspecs[sqltypes.Date] = MSDateTimeAsDate
- self.ischema_names['time'] = MSDateTimeAsTime
- self.colspecs[sqltypes.Time] = MSDateTimeAsTime
-
- # FIXME: scope_identity sniff should look at server version, not the ODBC driver
- # whether use_scope_identity will work depends on the version of pyodbc
- try:
- import pyodbc
- self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset')
- except:
- pass
-
- @classmethod
- def import_dbapi(cls):
- import pyodbc as module
- return module
-
- def make_connect_string(self, keys, query):
- if 'max_identifier_length' in keys:
- self.max_identifier_length = int(keys.pop('max_identifier_length'))
-
- if 'odbc_connect' in keys:
- connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
- else:
- dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
- if dsn_connection:
- connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
- else:
- port = ''
- if 'port' in keys and not 'port' in query:
- port = ',%d' % int(keys.pop('port'))
-
- connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
- 'Server=%s%s' % (keys.pop('host', ''), port),
- 'Database=%s' % keys.pop('database', '') ]
-
- user = keys.pop("user", None)
- if user:
- connectors.append("UID=%s" % user)
- connectors.append("PWD=%s" % keys.pop('password', ''))
- else:
- connectors.append("TrustedConnection=Yes")
-
- # if set to 'Yes', the ODBC layer will try to automagically convert
- # textual data from your database encoding to your client encoding
- # This should obviously be set to 'No' if you query a cp1253 encoded
- # database from a latin1 client...
- if 'odbc_autotranslate' in keys:
- connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
-
- connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
-
- return [[";".join (connectors)], {}]
-
- def is_disconnect(self, e):
- if isinstance(e, self.dbapi.ProgrammingError):
- return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e)
- elif isinstance(e, self.dbapi.Error):
- return '[08S01]' in str(e)
- else:
- return False
-
-
- def _server_version_info(self, dbapi_con):
- """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
- version = []
- r = re.compile('[.\-]')
- for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
- try:
- version.append(int(n))
- except ValueError:
- version.append(n)
- return tuple(version)
-
-class MSSQLDialect_adodbapi(MSSQLDialect):
- supports_sane_rowcount = True
- supports_sane_multi_rowcount = True
- supports_unicode = sys.maxunicode == 65535
- supports_unicode_statements = True
-
- @classmethod
- def import_dbapi(cls):
- import adodbapi as module
- return module
-
- colspecs = MSSQLDialect.colspecs.copy()
- colspecs[sqltypes.DateTime] = MSDateTime_adodbapi
-
- ischema_names = MSSQLDialect.ischema_names.copy()
- ischema_names['datetime'] = MSDateTime_adodbapi
-
- def make_connect_string(self, keys, query):
- connectors = ["Provider=SQLOLEDB"]
- if 'port' in keys:
- connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
- else:
- connectors.append ("Data Source=%s" % keys.get("host"))
- connectors.append ("Initial Catalog=%s" % keys.get("database"))
- user = keys.get("user")
- if user:
- connectors.append("User Id=%s" % user)
- connectors.append("Password=%s" % keys.get("password", ""))
- else:
- connectors.append("Integrated Security=SSPI")
- return [[";".join (connectors)], {}]
-
- def is_disconnect(self, e):
- return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e)
-
-
-dialect_mapping = {
- 'pymssql': MSSQLDialect_pymssql,
- 'pyodbc': MSSQLDialect_pyodbc,
- 'adodbapi': MSSQLDialect_adodbapi
- }
-
-
-class MSSQLCompiler(compiler.SQLCompiler):
- operators = compiler.OPERATORS.copy()
- operators.update({
- sql_operators.concat_op: '+',
- sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y)
- })
-
- functions = compiler.SQLCompiler.functions.copy()
- functions.update (
- {
- sql_functions.now: 'CURRENT_TIMESTAMP',
- sql_functions.current_date: 'GETDATE()',
- 'length': lambda x: "LEN(%s)" % x,
- sql_functions.char_length: lambda x: "LEN(%s)" % x
- }
- )
-
- def __init__(self, *args, **kwargs):
- super(MSSQLCompiler, self).__init__(*args, **kwargs)
- self.tablealiases = {}
-
- def get_select_precolumns(self, select):
- """ MS-SQL puts TOP, it's version of LIMIT here """
- if select._distinct or select._limit:
- s = select._distinct and "DISTINCT " or ""
-
- if select._limit:
- if not select._offset:
- s += "TOP %s " % (select._limit,)
- else:
- if not self.dialect.has_window_funcs:
- raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset')
- return s
- return compiler.SQLCompiler.get_select_precolumns(self, select)
-
- def limit_clause(self, select):
- # Limit in mssql is after the select keyword
- return ""
-
- def visit_select(self, select, **kwargs):
- """Look for ``LIMIT`` and OFFSET in a select statement, and if
- so tries to wrap it in a subquery with ``row_number()`` criterion.
-
- """
- if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
- # to use ROW_NUMBER(), an ORDER BY is required.
- orderby = self.process(select._order_by_clause)
- if not orderby:
- raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.')
-
- _offset = select._offset
- _limit = select._limit
- select._mssql_visit = True
- select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias()
-
- limitselect = sql.select([c for c in select.c if c.key!='mssql_rn'])
- limitselect.append_whereclause("mssql_rn>%d" % _offset)
- if _limit is not None:
- limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset))
- return self.process(limitselect, iswrapper=True, **kwargs)
- else:
- return compiler.SQLCompiler.visit_select(self, select, **kwargs)
-
- def _schema_aliased_table(self, table):
- if getattr(table, 'schema', None) is not None:
- if table not in self.tablealiases:
- self.tablealiases[table] = table.alias()
- return self.tablealiases[table]
- else:
- return None
-
- def visit_table(self, table, mssql_aliased=False, **kwargs):
- if mssql_aliased:
- return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
- # alias schema-qualified tables
- alias = self._schema_aliased_table(table)
- if alias is not None:
- return self.process(alias, mssql_aliased=True, **kwargs)
- else:
- return super(MSSQLCompiler, self).visit_table(table, **kwargs)
-
- def visit_alias(self, alias, **kwargs):
- # translate for schema-qualified table aliases
- self.tablealiases[alias.original] = alias
- kwargs['mssql_aliased'] = True
- return super(MSSQLCompiler, self).visit_alias(alias, **kwargs)
-
- def visit_savepoint(self, savepoint_stmt):
- util.warn("Savepoint support in mssql is experimental and may lead to data loss.")
- return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
- def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt)
-
- def visit_column(self, column, result_map=None, **kwargs):
- if column.table is not None and \
- (not self.isupdate and not self.isdelete) or self.is_subquery():
- # translate for schema-qualified table aliases
- t = self._schema_aliased_table(column.table)
- if t is not None:
- converted = expression._corresponding_column_or_error(t, column)
-
- if result_map is not None:
- result_map[column.name.lower()] = (column.name, (column, ), column.type)
-
- return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs)
-
- return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs)
-
- def visit_binary(self, binary, **kwargs):
- """Move bind parameters to the right-hand side of an operator, where
- possible.
-
- """
- if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \
- and not isinstance(binary.right, expression._BindParamClause):
- return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs)
- else:
- if (binary.operator is operator.eq or binary.operator is operator.ne) and (
- (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \
- (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \
- isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)):
- op = binary.operator == operator.eq and "IN" or "NOT IN"
- return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs)
- return super(MSSQLCompiler, self).visit_binary(binary, **kwargs)
-
- def visit_insert(self, insert_stmt):
- insert_select = False
- if insert_stmt.parameters:
- insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)]
- if insert_select:
- self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
- preparer = self.preparer
-
- insert = ' '.join(["INSERT"] +
- [self.process(x) for x in insert_stmt._prefixes])
-
- if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert:
- raise exc.CompileError(
- "The version of %s you are using does not support empty inserts." % self.dialect.name)
- elif not colparams and self.dialect.supports_default_values:
- return (insert + " INTO %s DEFAULT VALUES" % (
- (preparer.format_table(insert_stmt.table),)))
- else:
- return (insert + " INTO %s (%s) SELECT %s" %
- (preparer.format_table(insert_stmt.table),
- ', '.join([preparer.format_column(c[0])
- for c in colparams]),
- ', '.join([c[1] for c in colparams])))
- else:
- return super(MSSQLCompiler, self).visit_insert(insert_stmt)
-
- def label_select_column(self, select, column, asfrom):
- if isinstance(column, expression.Function):
- return column.label(None)
- else:
- return super(MSSQLCompiler, self).label_select_column(select, column, asfrom)
-
- def for_update_clause(self, select):
- # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use
- return ''
-
- def order_by_clause(self, select):
- order_by = self.process(select._order_by_clause)
-
- # MSSQL only allows ORDER BY in subqueries if there is a LIMIT
- if order_by and (not self.is_subquery() or select._limit):
- return " ORDER BY " + order_by
- else:
- return ""
-
-
-class MSSQLSchemaGenerator(compiler.SchemaGenerator):
- def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
-
- if column.nullable is not None:
- if not column.nullable or column.primary_key:
- colspec += " NOT NULL"
- else:
- colspec += " NULL"
-
- if not column.table:
- raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
-
- seq_col = _table_sequence_column(column.table)
-
- # install a IDENTITY Sequence if we have an implicit IDENTITY column
- if seq_col is column:
- sequence = getattr(column, 'sequence', None)
- if sequence:
- start, increment = sequence.start or 1, sequence.increment or 1
- else:
- start, increment = 1, 1
- colspec += " IDENTITY(%s,%s)" % (start, increment)
- else:
- default = self.get_column_default_string(column)
- if default is not None:
- colspec += " DEFAULT " + default
-
- return colspec
-
-class MSSQLSchemaDropper(compiler.SchemaDropper):
- def visit_index(self, index):
- self.append("\nDROP INDEX %s.%s" % (
- self.preparer.quote_identifier(index.table.name),
- self.preparer.quote(self._validate_identifier(index.name, False), index.quote)
- ))
- self.execute()
-
-
-class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
- reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
-
- def __init__(self, dialect):
- super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
-
- def _escape_identifier(self, value):
- #TODO: determine MSSQL's escaping rules
- return value
-
-dialect = MSSQLDialect
-dialect.statement_compiler = MSSQLCompiler
-dialect.schemagenerator = MSSQLSchemaGenerator
-dialect.schemadropper = MSSQLSchemaDropper
-dialect.preparer = MSSQLIdentifierPreparer
-