From d9dfecaea93bab2f4008f594ae7ad2ae85ecb61a Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Tue, 20 Jan 2009 02:35:49 +0000 Subject: [PATCH] - mssql dialects are in place, not fully tested --- lib/sqlalchemy/dialects/mssql/__init__.py | 3 + lib/sqlalchemy/dialects/mssql/adodbapi.py | 50 + lib/sqlalchemy/dialects/mssql/base.py | 1404 ++++++++------------- lib/sqlalchemy/dialects/mssql/pymssql.py | 50 + lib/sqlalchemy/dialects/mssql/pyodbc.py | 59 + lib/sqlalchemy/dialects/mysql/base.py | 1 - lib/sqlalchemy/types.py | 5 +- test/dialect/mssql.py | 46 +- test/sql/testtypes.py | 15 +- 9 files changed, 709 insertions(+), 924 deletions(-) create mode 100644 lib/sqlalchemy/dialects/mssql/adodbapi.py create mode 100644 lib/sqlalchemy/dialects/mssql/pymssql.py create mode 100644 lib/sqlalchemy/dialects/mssql/pyodbc.py diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index e69de29bb2..a5fabbade1 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -0,0 +1,3 @@ +from sqlalchemy.dialects.mssql import base, pyodbc + +base.dialect = pyodbc.dialect \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py new file mode 100644 index 0000000000..9a6cc2779e --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -0,0 +1,50 @@ +from sqlalchemy.dialects.mssql.base import MSDateTime, MSDialect +import sys + +class MSDateTime_adodbapi(MSDateTime): + def result_processor(self, dialect): + def process(value): + # adodbapi will return datetimes with empty time values as datetime.date() objects. + # Promote them back to full datetime.datetime() + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + return process + + +class MSDialect_adodbapi(MSDialect): + supports_sane_rowcount = True + supports_sane_multi_rowcount = True + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = True + driver = 'adodbapi' + + @classmethod + def import_dbapi(cls): + import adodbapi as module + return module + + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.DateTime] = MSDateTime_adodbapi + + def create_connect_args(self, url): + keys = url.query + + connectors = ["Provider=SQLOLEDB"] + if 'port' in keys: + connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) + else: + connectors.append ("Data Source=%s" % keys.get("host")) + connectors.append ("Initial Catalog=%s" % keys.get("database")) + user = keys.get("user") + if user: + connectors.append("User Id=%s" % user) + connectors.append("Password=%s" % keys.get("password", "")) + else: + connectors.append("Integrated Security=SSPI") + return [[";".join (connectors)], {}] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) + +dialect = MSDialect_adodbapi \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index c972b6b0ca..0c11adb4ab 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -19,7 +19,7 @@ Drivers are loaded in the order listed above based on availability. If you need to load a specific driver pass ``module_name`` when creating the engine:: - engine = create_engine('mssql://dsn', module_name='pymssql') + engine = create_engine('mssql+module_name://dsn') ``module_name`` currently accepts: ``pyodbc``, ``pymssql``, and ``adodbapi``. @@ -39,18 +39,18 @@ present, then the host token is taken directly as the DSN name. Examples of pyodbc connection string URLs: -* *mssql://mydsn* - connects using the specified DSN named ``mydsn``. +* *mssql+pyodbc://mydsn* - connects using the specified DSN named ``mydsn``. The connection string that is created will appear like:: dsn=mydsn;TrustedConnection=Yes -* *mssql://user:pass@mydsn* - connects using the DSN named +* *mssql+pyodbc://user:pass@mydsn* - connects using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` information. The connection string that is created will appear like:: dsn=mydsn;UID=user;PWD=pass -* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects +* *mssql+pyodbc://user:pass@mydsn/?LANGUAGE=us_english* - connects using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD`` information, plus the additional connection configuration option ``LANGUAGE``. The connection string that is created will appear @@ -58,12 +58,12 @@ Examples of pyodbc connection string URLs: dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english -* *mssql://user:pass@host/db* - connects using a connection string +* *mssql+pyodbc://user:pass@host/db* - connects using a connection string dynamically created that would appear like:: DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass -* *mssql://user:pass@host:123/db* - connects using a connection +* *mssql+pyodbc://user:pass@host:123/db* - connects using a connection string that is dynamically created, which also includes the port information using the comma syntax. If your connection string requires the port information to be passed as a ``port`` keyword @@ -72,7 +72,7 @@ Examples of pyodbc connection string URLs: DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass -* *mssql://user:pass@host/db?port=123* - connects using a connection +* *mssql+pyodbc://user:pass@host/db?port=123* - connects using a connection string that is dynamically created that includes the port information as a separate ``port`` keyword. This will create the following connection string:: @@ -86,7 +86,7 @@ and passed directly. For example:: - mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb + mssql+pyodbc:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb would create the following connection string:: @@ -110,11 +110,6 @@ arguments on the URL, or as keyword argument to * *query_timeout* - allows you to override the default query timeout. Defaults to ``None``. This is only supported on pymssql. -* *text_as_varchar* - if enabled this will treat all TEXT column - types as their equivalent VARCHAR(max) type. This is often used if - you need to compare a VARCHAR to a TEXT field, which is not - supported directly on MSSQL. Defaults to ``False``. - * *use_scope_identity* - allows you to specify that SCOPE_IDENTITY should be used in place of the non-scoped version @@IDENTITY. Defaults to ``False``. On pymssql this defaults to ``True``, and on @@ -252,66 +247,6 @@ from decimal import Decimal as _python_Decimal MSSQL_RESERVED_WORDS = set(['function']) -class _StringType(object): - """Base for MSSQL string types.""" - - def __init__(self, collation=None, **kwargs): - self.collation = kwargs.get('collate', collation) - - def _extend(self, spec): - """Extend a string-type declaration with standard SQL - COLLATE annotations. - """ - - if self.collation: - collation = 'COLLATE %s' % self.collation - else: - collation = None - - return ' '.join([c for c in (spec, collation) - if c is not None]) - - def __repr__(self): - attributes = inspect.getargspec(self.__init__)[0][1:] - attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) - - params = {} - for attr in attributes: - val = getattr(self, attr) - if val is not None and val is not False: - params[attr] = val - - return "%s(%s)" % (self.__class__.__name__, - ', '.join(['%s=%r' % (k, params[k]) for k in params])) - - def bind_processor(self, dialect): - if self.convert_unicode or dialect.convert_unicode: - if self.assert_unicode is None: - assert_unicode = dialect.assert_unicode - else: - assert_unicode = self.assert_unicode - - if not assert_unicode: - return None - - def process(value): - if not isinstance(value, (unicode, sqltypes.NoneType)): - if assert_unicode == 'warn': - util.warn("Unicode type received non-unicode bind " - "param value %r" % value) - return value - else: - raise exc.InvalidRequestError("Unicode type received non-unicode bind param value %r" % value) - else: - return value - return process - else: - return None - - def result_processor(self, dialect): - return None - - class MSNumeric(sqltypes.Numeric): def result_processor(self, dialect): if self.asdecimal: @@ -345,121 +280,49 @@ class MSNumeric(sqltypes.Numeric): return process - def get_col_spec(self): - if self.precision is None: - return "NUMERIC" - else: - return "NUMERIC(%(precision)s, %(scale)s)" % {'precision': self.precision, 'scale' : self.scale} - - -class MSFloat(sqltypes.Float): - def get_col_spec(self): - if self.precision is None: - return "FLOAT" - else: - return "FLOAT(%(precision)s)" % {'precision': self.precision} - - -class MSReal(MSFloat): +class MSReal(sqltypes.Float): """A type for ``real`` numbers.""" - def __init__(self): - """ - Construct a Real. + __visit_name__ = 'REAL' - """ + def __init__(self): super(MSReal, self).__init__(precision=24) - def adapt(self, impltype): - return impltype() - - def get_col_spec(self): - return "REAL" - - -class MSInteger(sqltypes.Integer): - def get_col_spec(self): - return "INTEGER" - - -class MSBigInteger(MSInteger): - def get_col_spec(self): - return "BIGINT" - - -class MSTinyInteger(MSInteger): - def get_col_spec(self): - return "TINYINT" - - -class MSSmallInteger(MSInteger): - def get_col_spec(self): - return "SMALLINT" +class MSTinyInteger(sqltypes.Integer): + __visit_name__ = 'TINYINT' +class MSTime(sqltypes.Time): + def __init__(self, precision=None, **kwargs): + self.precision = precision + super(MSTime, self).__init__() -class _DateTimeType(object): - """Base for MSSQL datetime types.""" +class MSDateTime(sqltypes.DateTime): def bind_processor(self, dialect): - # if we receive just a date we can manipulate it - # into a datetime since the db-api may not do this. + # most DBAPIs allow a datetime.date object + # as a datetime. def process(value): if type(value) is datetime.date: return datetime.datetime(value.year, value.month, value.day) return value return process + +class MSSmallDateTime(MSDateTime): + __visit_name__ = 'SMALLDATETIME' - -class MSDateTime(_DateTimeType, sqltypes.DateTime): - def get_col_spec(self): - return "DATETIME" - - -class MSDate(sqltypes.Date): - def get_col_spec(self): - return "DATE" - - -class MSTime(sqltypes.Time): - def __init__(self, precision=None, **kwargs): - self.precision = precision - super(MSTime, self).__init__() - - def get_col_spec(self): - if self.precision: - return "TIME(%s)" % self.precision - else: - return "TIME" - - -class MSSmallDateTime(_DateTimeType, sqltypes.TypeEngine): - def get_col_spec(self): - return "SMALLDATETIME" - - -class MSDateTime2(_DateTimeType, sqltypes.TypeEngine): +class MSDateTime2(MSDateTime): + __visit_name__ = 'DATETIME2' + def __init__(self, precision=None, **kwargs): self.precision = precision - def get_col_spec(self): - if self.precision: - return "DATETIME2(%s)" % self.precision - else: - return "DATETIME2" - - -class MSDateTimeOffset(_DateTimeType, sqltypes.TypeEngine): +class MSDateTimeOffset(sqltypes.TypeEngine): + __visit_name__ = 'DATETIMEOFFSET' + def __init__(self, precision=None, **kwargs): self.precision = precision - def get_col_spec(self): - if self.precision: - return "DATETIMEOFFSET(%s)" % self.precision - else: - return "DATETIMEOFFSET" - - -class MSDateTimeAsDate(_DateTimeType, MSDate): +class MSDateTimeAsDate(sqltypes.TypeDecorator): """ This is an implementation of the Date type for versions of MSSQL that do not support that specific type. In order to make it work a ``DATETIME`` column specification is used and the results get converted back to just @@ -467,20 +330,19 @@ class MSDateTimeAsDate(_DateTimeType, MSDate): """ - def get_col_spec(self): - return "DATETIME" + impl = sqltypes.DateTime - def result_processor(self, dialect): - def process(value): - # If the DBAPI returns the value as datetime.datetime(), truncate - # it back to datetime.date() - if type(value) is datetime.datetime: - return value.date() - return value - return process + def process_bind_param(self, value, dialect): + if type(value) is datetime.date: + return datetime.datetime(value.year, value.month, value.day) + return value + def process_result_value(self, value, dialect): + if type(value) is datetime.datetime: + return value.date() + return value -class MSDateTimeAsTime(MSTime): +class MSDateTimeAsTime(sqltypes.TypeDecorator): """ This is an implementation of the Time type for versions of MSSQL that do not support that specific type. In order to make it work a ``DATETIME`` column specification is used and the results get converted back to just @@ -490,65 +352,63 @@ class MSDateTimeAsTime(MSTime): __zero_date = datetime.date(1900, 1, 1) - def get_col_spec(self): - return "DATETIME" + impl = sqltypes.DateTime - def bind_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - value = datetime.datetime.combine(self.__zero_date, value.time()) - elif type(value) is datetime.time: - value = datetime.datetime.combine(self.__zero_date, value) - return value - return process + def process_bind_param(self, value, dialect): + if type(value) is datetime.datetime: + value = datetime.datetime.combine(self.__zero_date, value.time()) + elif type(value) is datetime.time: + value = datetime.datetime.combine(self.__zero_date, value) + return value - def result_processor(self, dialect): - def process(value): - if type(value) is datetime.datetime: - return value.time() - elif type(value) is datetime.date: - return datetime.time(0, 0, 0) - return value - return process + def process_result_value(self, value, dialect): + if type(value) is datetime.datetime: + return value.time() + elif type(value) is datetime.date: + return datetime.time(0, 0, 0) + return value -class MSDateTime_adodbapi(MSDateTime): - def result_processor(self, dialect): - def process(value): - # adodbapi will return datetimes with empty time values as datetime.date() objects. - # Promote them back to full datetime.datetime() - if type(value) is datetime.date: - return datetime.datetime(value.year, value.month, value.day) - return value - return process +class _StringType(object): + """Base for MSSQL string types.""" + + def __init__(self, collation=None): + self.collation = collation + + def __repr__(self): + attributes = inspect.getargspec(self.__init__)[0][1:] + attributes.extend(inspect.getargspec(_StringType.__init__)[0][1:]) + + params = {} + for attr in attributes: + val = getattr(self, attr) + if val is not None and val is not False: + params[attr] = val + + return "%s(%s)" % (self.__class__.__name__, + ', '.join(['%s=%r' % (k, params[k]) for k in params])) -class MSText(_StringType, sqltypes.Text): +class MSText(_StringType, sqltypes.TEXT): """MSSQL TEXT type, for variable-length text up to 2^31 characters.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kw): """Construct a TEXT. :param collation: Optional, a column-level collation for this string value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.Text.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', False), - assert_unicode=kwargs.get('assert_unicode', None)) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("VARCHAR(max)") - else: - return self._extend("TEXT") - + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.Text.__init__(self, *args, **kw) class MSNText(_StringType, sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" + __visit_name__ = 'NTEXT' + def __init__(self, *args, **kwargs): """Construct a NTEXT. @@ -556,23 +416,16 @@ class MSNText(_StringType, sqltypes.UnicodeText): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.UnicodeText.__init__(self, None, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.dialect.text_as_varchar: - return self._extend("NVARCHAR(max)") - else: - return self._extend("NTEXT") + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.UnicodeText.__init__(self, None, **kw) -class MSString(_StringType, sqltypes.String): +class MSString(_StringType, sqltypes.VARCHAR): """MSSQL VARCHAR type, for variable-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): + def __init__(self, *args, **kw): """Construct a VARCHAR. :param length: Optinal, maximum data length, in characters. @@ -603,24 +456,16 @@ class MSString(_StringType, sqltypes.String): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.String.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("VARCHAR(%s)" % self.length) - else: - return self._extend("VARCHAR") - + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.VARCHAR.__init__(self, *args, **kw) -class MSNVarchar(_StringType, sqltypes.Unicode): +class MSNVarchar(_StringType, sqltypes.NVARCHAR): """MSSQL NVARCHAR type. For variable-length unicode character data up to 4,000 characters.""" - def __init__(self, length=None, **kwargs): + def __init__(self, *args, **kw): """Construct a NVARCHAR. :param length: Optional, Maximum data length, in characters. @@ -629,29 +474,16 @@ class MSNVarchar(_StringType, sqltypes.Unicode): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.Unicode.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def adapt(self, impltype): - return impltype(length=self.length, - convert_unicode=self.convert_unicode, - assert_unicode=self.assert_unicode, - collation=self.collation) - - def get_col_spec(self): - if self.length: - return self._extend("NVARCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NVARCHAR") + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NVARCHAR.__init__(self, *args, **kw) class MSChar(_StringType, sqltypes.CHAR): """MSSQL CHAR type, for fixed-length non-Unicode data with a maximum of 8,000 characters.""" - def __init__(self, length=None, convert_unicode=False, assert_unicode=None, **kwargs): + def __init__(self, *args, **kw): """Construct a CHAR. :param length: Optinal, maximum data length, in characters. @@ -682,16 +514,9 @@ class MSChar(_StringType, sqltypes.CHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.CHAR.__init__(self, length=length, - convert_unicode=convert_unicode, - assert_unicode=assert_unicode) - - def get_col_spec(self): - if self.length: - return self._extend("CHAR(%s)" % self.length) - else: - return self._extend("CHAR") + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.CHAR.__init__(self, *args, **kw) class MSNChar(_StringType, sqltypes.NCHAR): @@ -699,7 +524,7 @@ class MSNChar(_StringType, sqltypes.NCHAR): For fixed-length unicode character data up to 4,000 characters.""" - def __init__(self, length=None, **kwargs): + def __init__(self, *args, **kw): """Construct an NCHAR. :param length: Optional, Maximum data length, in characters. @@ -708,59 +533,23 @@ class MSNChar(_StringType, sqltypes.NCHAR): value. Accepts a Windows Collation Name or a SQL Collation Name. """ - _StringType.__init__(self, **kwargs) - sqltypes.NCHAR.__init__(self, length=length, - convert_unicode=kwargs.get('convert_unicode', True), - assert_unicode=kwargs.get('assert_unicode', 'warn')) - - def get_col_spec(self): - if self.length: - return self._extend("NCHAR(%(length)s)" % {'length' : self.length}) - else: - return self._extend("NCHAR") - - -class MSGenericBinary(sqltypes.Binary): - """The Binary type assumes that a Binary specification without a length - is an unbound Binary type whereas one with a length specification results - in a fixed length Binary type. - - If you want standard MSSQL ``BINARY`` behavior use the ``MSBinary`` type. - - """ - - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "IMAGE" - - -class MSBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "BINARY(%s)" % self.length - else: - return "BINARY" + collation = kw.pop('collation', None) + _StringType.__init__(self, collation) + sqltypes.NCHAR.__init__(self, *args, **kw) +class MSBinary(sqltypes.Binary): + pass -class MSVarBinary(MSGenericBinary): - def get_col_spec(self): - if self.length: - return "VARBINARY(%s)" % self.length - else: - return "VARBINARY" - - -class MSImage(MSGenericBinary): - def get_col_spec(self): - return "IMAGE" +class MSVarBinary(sqltypes.Binary): + __visit_name__ = 'VARBINARY' +class MSImage(sqltypes.Binary): + __visit_name__ = 'IMAGE' +class MSBit(sqltypes.TypeEngine): + __visit_name__ = 'BIT' + class MSBoolean(sqltypes.Boolean): - def get_col_spec(self): - return "BIT" - def result_processor(self, dialect): def process(value): if value is None: @@ -780,31 +569,129 @@ class MSBoolean(sqltypes.Boolean): return value and True or False return process +class MSMoney(sqltypes.TypeEngine): + __visit_name__ = 'MONEY' + +class MSSmallMoney(MSMoney): + __visit_name__ = 'SMALLMONEY' -class MSTimeStamp(sqltypes.TIMESTAMP): - def get_col_spec(self): - return "TIMESTAMP" +class MSUniqueIdentifier(sqltypes.TypeEngine): + __visit_name__ = "UNIQUEIDENTIFIER" -class MSMoney(sqltypes.TypeEngine): - def get_col_spec(self): - return "MONEY" +class MSVariant(sqltypes.TypeEngine): + __visit_name__ = 'SQL_VARIANT' +class MSTypeCompiler(compiler.GenericTypeCompiler): + def _extend(self, spec, type_): + """Extend a string-type declaration with standard SQL + COLLATE annotations. -class MSSmallMoney(MSMoney): - def get_col_spec(self): - return "SMALLMONEY" + """ + if getattr(type_, 'collation', None): + collation = 'COLLATE %s' % type_.collation + else: + collation = None -class MSUniqueIdentifier(sqltypes.TypeEngine): - def get_col_spec(self): - return "UNIQUEIDENTIFIER" + if type_.length: + spec = spec + "(%d)" % type_.length + + return ' '.join([c for c in (spec, collation) + if c is not None]) + def visit_FLOAT(self, type_): + precision = getattr(type_, 'precision', None) + if precision is None: + return "FLOAT" + else: + return "FLOAT(%(precision)s)" % {'precision': precision} -class MSVariant(sqltypes.TypeEngine): - def get_col_spec(self): - return "SQL_VARIANT" + def visit_REAL(self, type_): + return "REAL" + + def visit_TINYINT(self, type_): + return "TINYINT" + + def visit_DATETIMEOFFSET(self, type_): + if type_.precision: + return "DATETIMEOFFSET(%s)" % type_.precision + else: + return "DATETIMEOFFSET" + + def visit_TIME(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "TIME(%s)" % precision + else: + return "TIME" + + def visit_DATETIME2(self, type_): + precision = getattr(type_, 'precision', None) + if precision: + return "DATETIME2(%s)" % precision + else: + return "DATETIME2" + + def visit_SMALLDATETIME(self, type_): + return "SMALLDATETIME" + + def visit_NTEXT(self, type_): + return self._extend("NTEXT", type_) + + def visit_TEXT(self, type_): + return self._extend("TEXT", type_) + + def visit_VARCHAR(self, type_): + return self._extend("VARCHAR", type_) + + def visit_CHAR(self, type_): + return self._extend("CHAR", type_) + + def visit_NCHAR(self, type_): + return self._extend("NCHAR", type_) + + def visit_NVARCHAR(self, type_): + return self._extend("NVARCHAR", type_) + + def visit_binary(self, type_): + if type_.length: + return self.visit_BINARY(type_) + else: + return self.visit_IMAGE(type_) + + def visit_BINARY(self, type_): + if type_.length: + return "BINARY(%s)" % type_.length + else: + return "BINARY" + + def visit_IMAGE(self, type_): + return "IMAGE" + + def visit_VARBINARY(self, type_): + if type_.length: + return "VARBINARY(%s)" % type_.length + else: + return "VARBINARY" + + def visit_boolean(self, type_): + return self.visit_BIT(type_) + + def visit_BIT(self, type_): + return "BIT" + + def visit_MONEY(self, type_): + return "MONEY" + + def visit_SMALLMONEY(self, type_): + return 'SMALLMONEY' + + def visit_UNIQUEIDENTIFIER(self, type_): + return "UNIQUEIDENTIFIER" + def visit_SQL_VARIANT(self, type_): + return 'SQL_VARIANT' def _has_implicit_sequence(column): return column.primary_key and \ @@ -827,7 +714,7 @@ def _table_sequence_column(tbl): break return tbl._ms_has_sequence -class MSSQLExecutionContext(default.DefaultExecutionContext): +class MSExecutionContext(default.DefaultExecutionContext): IINSERT = False HASIDENT = False @@ -869,136 +756,306 @@ class MSSQLExecutionContext(default.DefaultExecutionContext): if self.IINSERT: self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table)) +colspecs = { + sqltypes.Unicode : MSNVarchar, + sqltypes.Numeric : MSNumeric, + sqltypes.DateTime : MSDateTime, + sqltypes.Time : MSTime, + sqltypes.String : MSString, + sqltypes.Boolean : MSBoolean, + sqltypes.Text : MSText, + sqltypes.UnicodeText : MSNText, + sqltypes.CHAR: MSChar, + sqltypes.NCHAR: MSNChar, +} + +ischema_names = { + 'int' : sqltypes.INTEGER, + 'bigint': sqltypes.BigInteger, + 'smallint' : sqltypes.SmallInteger, + 'tinyint' : MSTinyInteger, + 'varchar' : MSString, + 'nvarchar' : MSNVarchar, + 'char' : MSChar, + 'nchar' : MSNChar, + 'text' : MSText, + 'ntext' : MSNText, + 'decimal' : sqltypes.DECIMAL, + 'numeric' : sqltypes.NUMERIC, + 'float' : sqltypes.FLOAT, + 'datetime' : sqltypes.DateTime, + 'datetime2' : MSDateTime2, + 'datetimeoffset' : MSDateTimeOffset, + 'date': sqltypes.Date, + 'time': MSTime, + 'smalldatetime' : MSSmallDateTime, + 'binary' : MSBinary, + 'varbinary' : MSVarBinary, + 'bit': sqltypes.Boolean, + 'real' : MSReal, + 'image' : MSImage, + 'timestamp': sqltypes.TIMESTAMP, + 'money': MSMoney, + 'smallmoney': MSSmallMoney, + 'uniqueidentifier': MSUniqueIdentifier, + 'sql_variant': MSVariant, +} -class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext): - def pre_exec(self): - """where appropriate, issue "select scope_identity()" in the same statement""" - super(MSSQLExecutionContext_pyodbc, self).pre_exec() - if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \ - and len(self.parameters) == 1 and self.dialect.use_scope_identity: - self.statement += "; select scope_identity()" - - def post_exec(self): - if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany: - import pyodbc - # Fetch the last inserted id from the manipulated statement - # We may have to skip over a number of result sets with no data (due to triggers, etc.) - while True: - try: - row = self.cursor.fetchone() - break - except pyodbc.Error, e: - self.cursor.nextset() - self._last_inserted_ids = [int(row[0])] - else: - super(MSSQLExecutionContext_pyodbc, self).post_exec() +class MSSQLCompiler(compiler.SQLCompiler): + operators = compiler.OPERATORS.copy() + operators.update({ + sql_operators.concat_op: '+', + sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) + }) -class MSSQLDialect(default.DefaultDialect): - name = 'mssql' - supports_default_values = True - supports_empty_insert = False - auto_identity_insert = True - execution_ctx_cls = MSSQLExecutionContext - text_as_varchar = False - use_scope_identity = False - has_window_funcs = False - max_identifier_length = 128 - schema_name = "dbo" + functions = compiler.SQLCompiler.functions.copy() + functions.update ( + { + sql_functions.now: 'CURRENT_TIMESTAMP', + sql_functions.current_date: 'GETDATE()', + 'length': lambda x: "LEN(%s)" % x, + sql_functions.char_length: lambda x: "LEN(%s)" % x + } + ) - colspecs = { - sqltypes.Unicode : MSNVarchar, - sqltypes.Integer : MSInteger, - sqltypes.SmallInteger: MSSmallInteger, - sqltypes.Numeric : MSNumeric, - sqltypes.Float : MSFloat, - sqltypes.DateTime : MSDateTime, - sqltypes.Date : MSDate, - sqltypes.Time : MSTime, - sqltypes.String : MSString, - sqltypes.Binary : MSGenericBinary, - sqltypes.Boolean : MSBoolean, - sqltypes.Text : MSText, - sqltypes.UnicodeText : MSNText, - sqltypes.CHAR: MSChar, - sqltypes.NCHAR: MSNChar, - sqltypes.TIMESTAMP: MSTimeStamp, - } - - ischema_names = { - 'int' : MSInteger, - 'bigint': MSBigInteger, - 'smallint' : MSSmallInteger, - 'tinyint' : MSTinyInteger, - 'varchar' : MSString, - 'nvarchar' : MSNVarchar, - 'char' : MSChar, - 'nchar' : MSNChar, - 'text' : MSText, - 'ntext' : MSNText, - 'decimal' : MSNumeric, - 'numeric' : MSNumeric, - 'float' : MSFloat, - 'datetime' : MSDateTime, - 'datetime2' : MSDateTime2, - 'datetimeoffset' : MSDateTimeOffset, - 'date': MSDate, - 'time': MSTime, - 'smalldatetime' : MSSmallDateTime, - 'binary' : MSBinary, - 'varbinary' : MSVarBinary, - 'bit': MSBoolean, - 'real' : MSFloat, - 'image' : MSImage, - 'timestamp': MSTimeStamp, - 'money': MSMoney, - 'smallmoney': MSSmallMoney, - 'uniqueidentifier': MSUniqueIdentifier, - 'sql_variant': MSVariant, - } - - def __new__(cls, *args, **kwargs): - if cls is not MSSQLDialect: - # this gets called with the dialect specific class - return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs) - dbapi = kwargs.get('dbapi', None) - if dbapi: - dialect = dialect_mapping.get(dbapi.__name__) - return dialect(**kwargs) - else: - return object.__new__(cls, *args, **kwargs) + def __init__(self, *args, **kwargs): + super(MSSQLCompiler, self).__init__(*args, **kwargs) + self.tablealiases = {} - def __init__(self, - auto_identity_insert=True, query_timeout=None, - text_as_varchar=False, use_scope_identity=False, - has_window_funcs=False, max_identifier_length=None, - schema_name="dbo", **opts): - self.auto_identity_insert = bool(auto_identity_insert) - self.query_timeout = int(query_timeout or 0) - self.schema_name = schema_name + def get_select_precolumns(self, select): + """ MS-SQL puts TOP, it's version of LIMIT here """ + if select._distinct or select._limit: + s = select._distinct and "DISTINCT " or "" + + if select._limit: + if not select._offset: + s += "TOP %s " % (select._limit,) + else: + if not self.dialect.has_window_funcs: + raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') + return s + return compiler.SQLCompiler.get_select_precolumns(self, select) - # to-do: the options below should use server version introspection to set themselves on connection - self.text_as_varchar = bool(text_as_varchar) - self.use_scope_identity = bool(use_scope_identity) - self.has_window_funcs = bool(has_window_funcs) - self.max_identifier_length = int(max_identifier_length or 0) or 128 - super(MSSQLDialect, self).__init__(**opts) + def limit_clause(self, select): + # Limit in mssql is after the select keyword + return "" - @classmethod - def dbapi(cls, module_name=None): - if module_name: - try: - dialect_cls = dialect_mapping[module_name] - return dialect_cls.import_dbapi() - except KeyError: - raise exc.InvalidRequestError("Unsupported MSSQL module '%s' requested (must be adodbpi, pymssql or pyodbc)" % module_name) + def visit_select(self, select, **kwargs): + """Look for ``LIMIT`` and OFFSET in a select statement, and if + so tries to wrap it in a subquery with ``row_number()`` criterion. + + """ + if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset: + # to use ROW_NUMBER(), an ORDER BY is required. + orderby = self.process(select._order_by_clause) + if not orderby: + raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') + + _offset = select._offset + _limit = select._limit + select._mssql_visit = True + select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() + + limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) + limitselect.append_whereclause("mssql_rn>%d" % _offset) + if _limit is not None: + limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) + return self.process(limitselect, iswrapper=True, **kwargs) + else: + return compiler.SQLCompiler.visit_select(self, select, **kwargs) + + def _schema_aliased_table(self, table): + if getattr(table, 'schema', None) is not None: + if table not in self.tablealiases: + self.tablealiases[table] = table.alias() + return self.tablealiases[table] + else: + return None + + def visit_table(self, table, mssql_aliased=False, **kwargs): + if mssql_aliased: + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + # alias schema-qualified tables + alias = self._schema_aliased_table(table) + if alias is not None: + return self.process(alias, mssql_aliased=True, **kwargs) else: - for dialect_cls in [MSSQLDialect_pyodbc, MSSQLDialect_pymssql, MSSQLDialect_adodbapi]: - try: - return dialect_cls.import_dbapi() - except ImportError, e: - pass + return super(MSSQLCompiler, self).visit_table(table, **kwargs) + + def visit_alias(self, alias, **kwargs): + # translate for schema-qualified table aliases + self.tablealiases[alias.original] = alias + kwargs['mssql_aliased'] = True + return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) + + def visit_savepoint(self, savepoint_stmt): + util.warn("Savepoint support in mssql is experimental and may lead to data loss.") + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_rollback_to_savepoint(self, savepoint_stmt): + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) + + def visit_column(self, column, result_map=None, **kwargs): + if column.table is not None and \ + (not self.isupdate and not self.isdelete) or self.is_subquery(): + # translate for schema-qualified table aliases + t = self._schema_aliased_table(column.table) + if t is not None: + converted = expression._corresponding_column_or_error(t, column) + + if result_map is not None: + result_map[column.name.lower()] = (column.name, (column, ), column.type) + + return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) + + return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) + + def visit_binary(self, binary, **kwargs): + """Move bind parameters to the right-hand side of an operator, where + possible. + + """ + if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ + and not isinstance(binary.right, expression._BindParamClause): + return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) + else: + if (binary.operator is operator.eq or binary.operator is operator.ne) and ( + (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ + (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ + isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): + op = binary.operator == operator.eq and "IN" or "NOT IN" + return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) + return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) + + def visit_insert(self, insert_stmt): + insert_select = False + if insert_stmt.parameters: + insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)] + if insert_select: + self.isinsert = True + colparams = self._get_colparams(insert_stmt) + preparer = self.preparer + + insert = ' '.join(["INSERT"] + + [self.process(x) for x in insert_stmt._prefixes]) + + if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: + raise exc.CompileError( + "The version of %s you are using does not support empty inserts." % self.dialect.name) + elif not colparams and self.dialect.supports_default_values: + return (insert + " INTO %s DEFAULT VALUES" % ( + (preparer.format_table(insert_stmt.table),))) else: - raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi') + return (insert + " INTO %s (%s) SELECT %s" % + (preparer.format_table(insert_stmt.table), + ', '.join([preparer.format_column(c[0]) + for c in colparams]), + ', '.join([c[1] for c in colparams]))) + else: + return super(MSSQLCompiler, self).visit_insert(insert_stmt) + + def label_select_column(self, select, column, asfrom): + if isinstance(column, expression.Function): + return column.label(None) + else: + return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) + + def for_update_clause(self, select): + # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use + return '' + + def order_by_clause(self, select): + order_by = self.process(select._order_by_clause) + + # MSSQL only allows ORDER BY in subqueries if there is a LIMIT + if order_by and (not self.is_subquery() or select._limit): + return " ORDER BY " + order_by + else: + return "" + + +class MSDDLCompiler(compiler.DDLCompiler): + def get_column_specification(self, column, **kwargs): + colspec = self.preparer.format_column(column) + " " + self.dialect.type_compiler.process(column.type) + + if column.nullable is not None: + if not column.nullable or column.primary_key: + colspec += " NOT NULL" + else: + colspec += " NULL" + + if not column.table: + raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") + + seq_col = _table_sequence_column(column.table) + + # install a IDENTITY Sequence if we have an implicit IDENTITY column + if seq_col is column: + sequence = getattr(column, 'sequence', None) + if sequence: + start, increment = sequence.start or 1, sequence.increment or 1 + else: + start, increment = 1, 1 + colspec += " IDENTITY(%s,%s)" % (start, increment) + else: + default = self.get_column_default_string(column) + if default is not None: + colspec += " DEFAULT " + default + + return colspec + + def visit_drop_index(self, drop): + return "\nDROP INDEX %s.%s" % ( + self.preparer.quote_identifier(drop.element.table.name), + self.preparer.quote(self._validate_identifier(drop.element.name, False), drop.element.quote) + ) + + +class MSIdentifierPreparer(compiler.IdentifierPreparer): + reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS) + + def __init__(self, dialect): + super(MSIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') + + def _escape_identifier(self, value): + #TODO: determine MSSQL's escaping rules + return value + +class MSDialect(default.DefaultDialect): + name = 'mssql' + supports_default_values = True + supports_empty_insert = False + auto_identity_insert = True + execution_ctx_cls = MSExecutionContext + text_as_varchar = False + use_scope_identity = False + has_window_funcs = False + max_identifier_length = 128 + schema_name = "dbo" + colspecs = colspecs + ischema_names = ischema_names + + supports_unicode_binds = True + + statement_compiler = MSSQLCompiler + ddl_compiler = MSDDLCompiler + type_compiler = MSTypeCompiler + preparer = MSIdentifierPreparer + + def __init__(self, + auto_identity_insert=True, query_timeout=None, + use_scope_identity=False, + has_window_funcs=False, max_identifier_length=None, + schema_name="dbo", **opts): + self.auto_identity_insert = bool(auto_identity_insert) + self.query_timeout = int(query_timeout or 0) + self.schema_name = schema_name + + self.use_scope_identity = bool(use_scope_identity) + self.has_window_funcs = bool(has_window_funcs) + self.max_identifier_length = int(max_identifier_length or 0) or 128 + super(MSDialect, self).__init__(**opts) @base.connection_memoize(('mssql', 'server_version_info')) def server_version_info(self, connection): @@ -1018,28 +1075,6 @@ class MSSQLDialect(default.DefaultDialect): """Return a tuple of the database's version number.""" raise NotImplementedError() - def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - opts.update(url.query) - if 'auto_identity_insert' in opts: - self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert'))) - if 'query_timeout' in opts: - self.query_timeout = int(opts.pop('query_timeout')) - if 'text_as_varchar' in opts: - self.text_as_varchar = bool(int(opts.pop('text_as_varchar'))) - if 'use_scope_identity' in opts: - self.use_scope_identity = bool(int(opts.pop('use_scope_identity'))) - if 'has_window_funcs' in opts: - self.has_window_funcs = bool(int(opts.pop('has_window_funcs'))) - return self.make_connect_string(opts, url.query) - - def type_descriptor(self, typeobj): - newobj = sqltypes.adapt_type(typeobj, self.colspecs) - # Some types need to know about the dialect - if isinstance(newobj, (MSText, MSNText)): - newobj.dialect = self - return newobj - def do_begin(self, connection): cursor = connection.cursor() cursor.execute("SET IMPLICIT_TRANSACTIONS OFF") @@ -1248,414 +1283,3 @@ class MSSQLDialect(default.DefaultDialect): if fknm and scols: table.append_constraint(schema.ForeignKeyConstraint(scols, [_gen_fkref(table, s, t, c) for s, t, c in rcols], fknm, link_to_name=True)) - -class MSSQLDialect_pymssql(MSSQLDialect): - supports_sane_rowcount = False - max_identifier_length = 30 - - @classmethod - def import_dbapi(cls): - import pymssql as module - # pymmsql doesn't have a Binary method. we use string - # TODO: monkeypatching here is less than ideal - module.Binary = lambda st: str(st) - return module - - def __init__(self, **params): - super(MSSQLDialect_pymssql, self).__init__(**params) - self.use_scope_identity = True - - # pymssql understands only ascii - if self.convert_unicode: - util.warn("pymssql does not support unicode") - self.encoding = params.get('encoding', 'ascii') - - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - def create_connect_args(self, url): - r = super(MSSQLDialect_pymssql, self).create_connect_args(url) - if hasattr(self, 'query_timeout'): - self.dbapi._mssql.set_query_timeout(self.query_timeout) - return r - - def make_connect_string(self, keys, query): - if keys.get('port'): - # pymssql expects port as host:port, not a separate arg - keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) - del keys['port'] - return [[], keys] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) - - def do_begin(self, connection): - pass - - -class MSSQLDialect_pyodbc(MSSQLDialect): - supports_sane_rowcount = False - supports_sane_multi_rowcount = False - # PyODBC unicode is broken on UCS-4 builds - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = supports_unicode - execution_ctx_cls = MSSQLExecutionContext_pyodbc - - def __init__(self, description_encoding='latin-1', **params): - super(MSSQLDialect_pyodbc, self).__init__(**params) - self.description_encoding = description_encoding - - if self.server_version_info < (10,): - self.colspecs = MSSQLDialect.colspecs.copy() - self.ischema_names = MSSQLDialect.ischema_names.copy() - self.ischema_names['date'] = MSDateTimeAsDate - self.colspecs[sqltypes.Date] = MSDateTimeAsDate - self.ischema_names['time'] = MSDateTimeAsTime - self.colspecs[sqltypes.Time] = MSDateTimeAsTime - - # FIXME: scope_identity sniff should look at server version, not the ODBC driver - # whether use_scope_identity will work depends on the version of pyodbc - try: - import pyodbc - self.use_scope_identity = hasattr(pyodbc.Cursor, 'nextset') - except: - pass - - @classmethod - def import_dbapi(cls): - import pyodbc as module - return module - - def make_connect_string(self, keys, query): - if 'max_identifier_length' in keys: - self.max_identifier_length = int(keys.pop('max_identifier_length')) - - if 'odbc_connect' in keys: - connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))] - else: - dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys) - if dsn_connection: - connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))] - else: - port = '' - if 'port' in keys and not 'port' in query: - port = ',%d' % int(keys.pop('port')) - - connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'), - 'Server=%s%s' % (keys.pop('host', ''), port), - 'Database=%s' % keys.pop('database', '') ] - - user = keys.pop("user", None) - if user: - connectors.append("UID=%s" % user) - connectors.append("PWD=%s" % keys.pop('password', '')) - else: - connectors.append("TrustedConnection=Yes") - - # if set to 'Yes', the ODBC layer will try to automagically convert - # textual data from your database encoding to your client encoding - # This should obviously be set to 'No' if you query a cp1253 encoded - # database from a latin1 client... - if 'odbc_autotranslate' in keys: - connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate")) - - connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()]) - - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - if isinstance(e, self.dbapi.ProgrammingError): - return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) - elif isinstance(e, self.dbapi.Error): - return '[08S01]' in str(e) - else: - return False - - - def _server_version_info(self, dbapi_con): - """Convert a pyodbc SQL_DBMS_VER string into a tuple.""" - version = [] - r = re.compile('[.\-]') - for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)): - try: - version.append(int(n)) - except ValueError: - version.append(n) - return tuple(version) - -class MSSQLDialect_adodbapi(MSSQLDialect): - supports_sane_rowcount = True - supports_sane_multi_rowcount = True - supports_unicode = sys.maxunicode == 65535 - supports_unicode_statements = True - - @classmethod - def import_dbapi(cls): - import adodbapi as module - return module - - colspecs = MSSQLDialect.colspecs.copy() - colspecs[sqltypes.DateTime] = MSDateTime_adodbapi - - ischema_names = MSSQLDialect.ischema_names.copy() - ischema_names['datetime'] = MSDateTime_adodbapi - - def make_connect_string(self, keys, query): - connectors = ["Provider=SQLOLEDB"] - if 'port' in keys: - connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port"))) - else: - connectors.append ("Data Source=%s" % keys.get("host")) - connectors.append ("Initial Catalog=%s" % keys.get("database")) - user = keys.get("user") - if user: - connectors.append("User Id=%s" % user) - connectors.append("Password=%s" % keys.get("password", "")) - else: - connectors.append("Integrated Security=SSPI") - return [[";".join (connectors)], {}] - - def is_disconnect(self, e): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and "'connection failure'" in str(e) - - -dialect_mapping = { - 'pymssql': MSSQLDialect_pymssql, - 'pyodbc': MSSQLDialect_pyodbc, - 'adodbapi': MSSQLDialect_adodbapi - } - - -class MSSQLCompiler(compiler.SQLCompiler): - operators = compiler.OPERATORS.copy() - operators.update({ - sql_operators.concat_op: '+', - sql_operators.match_op: lambda x, y: "CONTAINS (%s, %s)" % (x, y) - }) - - functions = compiler.SQLCompiler.functions.copy() - functions.update ( - { - sql_functions.now: 'CURRENT_TIMESTAMP', - sql_functions.current_date: 'GETDATE()', - 'length': lambda x: "LEN(%s)" % x, - sql_functions.char_length: lambda x: "LEN(%s)" % x - } - ) - - def __init__(self, *args, **kwargs): - super(MSSQLCompiler, self).__init__(*args, **kwargs) - self.tablealiases = {} - - def get_select_precolumns(self, select): - """ MS-SQL puts TOP, it's version of LIMIT here """ - if select._distinct or select._limit: - s = select._distinct and "DISTINCT " or "" - - if select._limit: - if not select._offset: - s += "TOP %s " % (select._limit,) - else: - if not self.dialect.has_window_funcs: - raise exc.InvalidRequestError('MSSQL does not support LIMIT with an offset') - return s - return compiler.SQLCompiler.get_select_precolumns(self, select) - - def limit_clause(self, select): - # Limit in mssql is after the select keyword - return "" - - def visit_select(self, select, **kwargs): - """Look for ``LIMIT`` and OFFSET in a select statement, and if - so tries to wrap it in a subquery with ``row_number()`` criterion. - - """ - if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset: - # to use ROW_NUMBER(), an ORDER BY is required. - orderby = self.process(select._order_by_clause) - if not orderby: - raise exc.InvalidRequestError('MSSQL requires an order_by when using an offset.') - - _offset = select._offset - _limit = select._limit - select._mssql_visit = True - select = select.column(sql.literal_column("ROW_NUMBER() OVER (ORDER BY %s)" % orderby).label("mssql_rn")).order_by(None).alias() - - limitselect = sql.select([c for c in select.c if c.key!='mssql_rn']) - limitselect.append_whereclause("mssql_rn>%d" % _offset) - if _limit is not None: - limitselect.append_whereclause("mssql_rn<=%d" % (_limit + _offset)) - return self.process(limitselect, iswrapper=True, **kwargs) - else: - return compiler.SQLCompiler.visit_select(self, select, **kwargs) - - def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: - if table not in self.tablealiases: - self.tablealiases[table] = table.alias() - return self.tablealiases[table] - else: - return None - - def visit_table(self, table, mssql_aliased=False, **kwargs): - if mssql_aliased: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - # alias schema-qualified tables - alias = self._schema_aliased_table(table) - if alias is not None: - return self.process(alias, mssql_aliased=True, **kwargs) - else: - return super(MSSQLCompiler, self).visit_table(table, **kwargs) - - def visit_alias(self, alias, **kwargs): - # translate for schema-qualified table aliases - self.tablealiases[alias.original] = alias - kwargs['mssql_aliased'] = True - return super(MSSQLCompiler, self).visit_alias(alias, **kwargs) - - def visit_savepoint(self, savepoint_stmt): - util.warn("Savepoint support in mssql is experimental and may lead to data loss.") - return "SAVE TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) - - def visit_rollback_to_savepoint(self, savepoint_stmt): - return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint(savepoint_stmt) - - def visit_column(self, column, result_map=None, **kwargs): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or self.is_subquery(): - # translate for schema-qualified table aliases - t = self._schema_aliased_table(column.table) - if t is not None: - converted = expression._corresponding_column_or_error(t, column) - - if result_map is not None: - result_map[column.name.lower()] = (column.name, (column, ), column.type) - - return super(MSSQLCompiler, self).visit_column(converted, result_map=None, **kwargs) - - return super(MSSQLCompiler, self).visit_column(column, result_map=result_map, **kwargs) - - def visit_binary(self, binary, **kwargs): - """Move bind parameters to the right-hand side of an operator, where - possible. - - """ - if isinstance(binary.left, expression._BindParamClause) and binary.operator == operator.eq \ - and not isinstance(binary.right, expression._BindParamClause): - return self.process(expression._BinaryExpression(binary.right, binary.left, binary.operator), **kwargs) - else: - if (binary.operator is operator.eq or binary.operator is operator.ne) and ( - (isinstance(binary.left, expression._FromGrouping) and isinstance(binary.left.element, expression._ScalarSelect)) or \ - (isinstance(binary.right, expression._FromGrouping) and isinstance(binary.right.element, expression._ScalarSelect)) or \ - isinstance(binary.left, expression._ScalarSelect) or isinstance(binary.right, expression._ScalarSelect)): - op = binary.operator == operator.eq and "IN" or "NOT IN" - return self.process(expression._BinaryExpression(binary.left, binary.right, op), **kwargs) - return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) - - def visit_insert(self, insert_stmt): - insert_select = False - if insert_stmt.parameters: - insert_select = [p for p in insert_stmt.parameters.values() if isinstance(p, sql.Select)] - if insert_select: - self.isinsert = True - colparams = self._get_colparams(insert_stmt) - preparer = self.preparer - - insert = ' '.join(["INSERT"] + - [self.process(x) for x in insert_stmt._prefixes]) - - if not colparams and not self.dialect.supports_default_values and not self.dialect.supports_empty_insert: - raise exc.CompileError( - "The version of %s you are using does not support empty inserts." % self.dialect.name) - elif not colparams and self.dialect.supports_default_values: - return (insert + " INTO %s DEFAULT VALUES" % ( - (preparer.format_table(insert_stmt.table),))) - else: - return (insert + " INTO %s (%s) SELECT %s" % - (preparer.format_table(insert_stmt.table), - ', '.join([preparer.format_column(c[0]) - for c in colparams]), - ', '.join([c[1] for c in colparams]))) - else: - return super(MSSQLCompiler, self).visit_insert(insert_stmt) - - def label_select_column(self, select, column, asfrom): - if isinstance(column, expression.Function): - return column.label(None) - else: - return super(MSSQLCompiler, self).label_select_column(select, column, asfrom) - - def for_update_clause(self, select): - # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which SQLAlchemy doesn't use - return '' - - def order_by_clause(self, select): - order_by = self.process(select._order_by_clause) - - # MSSQL only allows ORDER BY in subqueries if there is a LIMIT - if order_by and (not self.is_subquery() or select._limit): - return " ORDER BY " + order_by - else: - return "" - - -class MSSQLSchemaGenerator(compiler.SchemaGenerator): - def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec() - - if column.nullable is not None: - if not column.nullable or column.primary_key: - colspec += " NOT NULL" - else: - colspec += " NULL" - - if not column.table: - raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL") - - seq_col = _table_sequence_column(column.table) - - # install a IDENTITY Sequence if we have an implicit IDENTITY column - if seq_col is column: - sequence = getattr(column, 'sequence', None) - if sequence: - start, increment = sequence.start or 1, sequence.increment or 1 - else: - start, increment = 1, 1 - colspec += " IDENTITY(%s,%s)" % (start, increment) - else: - default = self.get_column_default_string(column) - if default is not None: - colspec += " DEFAULT " + default - - return colspec - -class MSSQLSchemaDropper(compiler.SchemaDropper): - def visit_index(self, index): - self.append("\nDROP INDEX %s.%s" % ( - self.preparer.quote_identifier(index.table.name), - self.preparer.quote(self._validate_identifier(index.name, False), index.quote) - )) - self.execute() - - -class MSSQLIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS) - - def __init__(self, dialect): - super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']') - - def _escape_identifier(self, value): - #TODO: determine MSSQL's escaping rules - return value - -dialect = MSSQLDialect -dialect.statement_compiler = MSSQLCompiler -dialect.schemagenerator = MSSQLSchemaGenerator -dialect.schemadropper = MSSQLSchemaDropper -dialect.preparer = MSSQLIdentifierPreparer - diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py new file mode 100644 index 0000000000..1b5858c53c --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -0,0 +1,50 @@ +from sqlalchemy.dialects.mssql.base import MSDialect, MSDateTimeAsDate, MSDateTimeAsTime +from sqlalchemy import types as sqltypes + +class MSDialect_pymssql(MSDialect): + supports_sane_rowcount = False + max_identifier_length = 30 + driver = 'pymssql' + + # TODO: shouldnt this be based on server version <10 like pyodbc does ? + colspecs = MSSQLDialect.colspecs.copy() + colspecs[sqltypes.Date] = MSDateTimeAsDate + colspecs[sqltypes.Time] = MSDateTimeAsTime + + @classmethod + def import_dbapi(cls): + import pymssql as module + # pymmsql doesn't have a Binary method. we use string + # TODO: monkeypatching here is less than ideal + module.Binary = lambda st: str(st) + return module + + def __init__(self, **params): + super(MSSQLDialect_pymssql, self).__init__(**params) + self.use_scope_identity = True + + # pymssql understands only ascii + if self.convert_unicode: + util.warn("pymssql does not support unicode") + self.encoding = params.get('encoding', 'ascii') + + + def create_connect_args(self, url): + if hasattr(self, 'query_timeout'): + # ick, globals ? we might want to move this.... + self.dbapi._mssql.set_query_timeout(self.query_timeout) + + keys = url.query + if keys.get('port'): + # pymssql expects port as host:port, not a separate arg + keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])]) + del keys['port'] + return [[], keys] + + def is_disconnect(self, e): + return isinstance(e, self.dbapi.DatabaseError) and "Error 10054" in str(e) + + def do_begin(self, connection): + pass + +dialect = MSDialect_pymssql \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py new file mode 100644 index 0000000000..3c18f60e75 --- /dev/null +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -0,0 +1,59 @@ +from sqlalchemy.dialects.mssql.base import MSExecutionContext, MSDialect, MSDateTimeAsDate, MSDateTimeAsTime +from sqlalchemy.connectors.pyodbc import PyODBCConnector +from sqlalchemy import types as sqltypes + +import sys + +class MSExecutionContext_pyodbc(MSExecutionContext): + def pre_exec(self): + """where appropriate, issue "select scope_identity()" in the same statement""" + super(MSSQLExecutionContext_pyodbc, self).pre_exec() + if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \ + and len(self.parameters) == 1 and self.dialect.use_scope_identity: + self.statement += "; select scope_identity()" + + def post_exec(self): + if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany: + import pyodbc + # Fetch the last inserted id from the manipulated statement + # We may have to skip over a number of result sets with no data (due to triggers, etc.) + while True: + try: + row = self.cursor.fetchone() + break + except pyodbc.Error, e: + self.cursor.nextset() + self._last_inserted_ids = [int(row[0])] + else: + super(MSSQLExecutionContext_pyodbc, self).post_exec() + + +class MSDialect_pyodbc(PyODBCConnector, MSDialect): + supports_sane_rowcount = False + supports_sane_multi_rowcount = False + # PyODBC unicode is broken on UCS-4 builds + supports_unicode = sys.maxunicode == 65535 + supports_unicode_statements = supports_unicode + execution_ctx_cls = MSExecutionContext_pyodbc + + pyodbc_driver_name = 'SQL Server' + + def __init__(self, description_encoding='latin-1', **params): + super(MSDialect_pyodbc, self).__init__(**params) + self.description_encoding = description_encoding + self.use_scope_identity = self.dbapi and hasattr(self.dbapi.Cursor, 'nextset') + + if self.server_version_info < (10,): + self.colspecs = MSDialect.colspecs.copy() + self.colspecs[sqltypes.Date] = MSDateTimeAsDate + self.colspecs[sqltypes.Time] = MSDateTimeAsTime + + def is_disconnect(self, e): + if isinstance(e, self.dbapi.ProgrammingError): + return "The cursor's connection has been closed." in str(e) or 'Attempt to use a closed connection.' in str(e) + elif isinstance(e, self.dbapi.Error): + return '[08S01]' in str(e) + else: + return False + +dialect = MSDialect_pyodbc \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index ad675839e0..bb6b7ab75f 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1414,7 +1414,6 @@ class MySQLDDLCompiler(compiler.DDLCompiler): """Builds column DDL.""" colspec = [self.preparer.format_column(column), - #self.dialect.type_compiler.process(column.type.dialect_impl(self.dialect)) self.dialect.type_compiler.process(column.type) ] diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py index 2f5548236f..63bd8bfbab 100644 --- a/lib/sqlalchemy/types.py +++ b/lib/sqlalchemy/types.py @@ -465,7 +465,10 @@ class String(Concatenable, TypeEngine): self.assert_unicode = assert_unicode def adapt(self, impltype): - return impltype(length=self.length, convert_unicode=self.convert_unicode, assert_unicode=self.assert_unicode) + return impltype( + length=self.length, + convert_unicode=self.convert_unicode, + assert_unicode=self.assert_unicode) def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: diff --git a/test/dialect/mssql.py b/test/dialect/mssql.py index f0b0bec76f..165d6908f5 100755 --- a/test/dialect/mssql.py +++ b/test/dialect/mssql.py @@ -1,7 +1,7 @@ import testenv; testenv.configure_for_tests() import datetime, os, pickleable, re from sqlalchemy import * -from sqlalchemy import types, exc +from sqlalchemy import types, exc, schema from sqlalchemy.orm import * from sqlalchemy.sql import table, column from sqlalchemy.databases import mssql @@ -11,7 +11,7 @@ from testlib.testing import eq_ class CompileTest(TestBase, AssertsCompiledSQL): - __dialect__ = mssql.MSSQLDialect() + __dialect__ = mssql.dialect() def test_insert(self): t = table('sometable', column('somecolumn')) @@ -258,36 +258,26 @@ class SchemaTest(TestBase): ) self.column = t.c.test_column + dialect = mssql.dialect() + self.ddl_compiler = dialect.ddl_compiler(dialect, schema.CreateTable(t)) + + def _column_spec(self): + return self.ddl_compiler.get_column_specification(self.column) + def test_that_mssql_default_nullability_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_none_nullability_does_not_emit_nullability(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = None - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR", column_specification) + eq_("test_column VARCHAR", self._column_spec()) def test_that_mssql_specified_nullable_emits_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = True - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NULL", column_specification) + eq_("test_column VARCHAR NULL", self._column_spec()) def test_that_mssql_specified_not_nullable_emits_not_null(self): - schemagenerator = \ - mssql.MSSQLDialect().schemagenerator(mssql.MSSQLDialect(), None) self.column.nullable = False - column_specification = \ - schemagenerator.get_column_specification(self.column) - eq_("test_column VARCHAR NOT NULL", column_specification) + eq_("test_column VARCHAR NOT NULL", self._column_spec()) def full_text_search_missing(): @@ -683,7 +673,8 @@ class TypesTest2(TestBase, AssertsExecutionResults): table_args.append(Column('c%s' % index, type_(*args, **kw), nullable=None)) binary_table = Table(*table_args) - gen = testing.db.dialect.schemagenerator(testing.db.dialect, testing.db, None, None) + dialect = mssql.dialect() + gen = dialect.ddl_compiler(dialect, schema.CreateTable(binary_table)) for col in binary_table.c: index = int(col.name[1:]) @@ -691,11 +682,7 @@ class TypesTest2(TestBase, AssertsExecutionResults): "%s %s" % (col.name, columns[index][3])) self.assert_(repr(col)) - try: - binary_table.create(checkfirst=True) - assert True - except: - raise + binary_table.create(checkfirst=True) reflected_binary = Table('test_mssql_binary', MetaData(testing.db), autoload=True) for col in reflected_binary.c: @@ -957,6 +944,9 @@ def colspec(c): class BinaryTest(TestBase, AssertsExecutionResults): """Test the Binary and VarBinary types""" + + __only_on__ = 'mssql' + def setUpAll(self): global binary_table, MyPickleType diff --git a/test/sql/testtypes.py b/test/sql/testtypes.py index 34acba4c74..29ed49d073 100644 --- a/test/sql/testtypes.py +++ b/test/sql/testtypes.py @@ -34,14 +34,13 @@ class AdaptTest(TestBase): assert ta != tb def testmsnvarchar(self): - dialect = mssql.MSSQLDialect() + dialect = mssql.dialect() # run the test twice to ensure the caching step works too for x in range(0, 1): col = Column('', Unicode(length=10)) dialect_type = col.type.dialect_impl(dialect) assert isinstance(dialect_type, mssql.MSNVarchar) - assert dialect_type.get_col_spec() == 'NVARCHAR(10)' - + eq_(dialect.type_compiler.process(dialect_type), 'NVARCHAR(10)') def testoracletimestamp(self): dialect = oracle.OracleDialect() @@ -105,7 +104,15 @@ class AdaptTest(TestBase): """ - for dialect in [oracle.dialect(), mysql.dialect(), postgres.dialect(), sqlite.dialect(), sybase.dialect(), informix.dialect(), maxdb.dialect()]: #engines.all_dialects(): + for dialect in [ + oracle.dialect(), + mysql.dialect(), + postgres.dialect(), + sqlite.dialect(), + sybase.dialect(), + informix.dialect(), + maxdb.dialect(), + mssql.dialect()]: # TODO when dialects are complete: engines.all_dialects(): for type_, expected in ( (FLOAT, "FLOAT"), (NUMERIC, "NUMERIC"), -- 2.47.3