]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Major refactoring of the MSSQL dialect. Thanks zzzeek.
authorMichael Trier <mtrier@gmail.com>
Mon, 22 Dec 2008 20:20:55 +0000 (20:20 +0000)
committerMichael Trier <mtrier@gmail.com>
Mon, 22 Dec 2008 20:20:55 +0000 (20:20 +0000)
Includes simplifying the IDENTITY handling and the exception handling. Also
includes a cleanup of the connection string handling for pyodbc to favor
the DSN syntax.

CHANGES
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
test/dialect/mssql.py
test/sql/query.py

diff --git a/CHANGES b/CHANGES
index c5571d5c391fb61e1addc8528ebb21710f5a073b..55521e4ef08c89d486397234509c865ee3e01c9f 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -196,6 +196,10 @@ CHANGES
       new doc section "Custom Comparators".
     
 - mssql
+    - Changes to the connection string parameters favor DSN as the
+      default specification for pyodbc. See the mssql.py docstring
+      for detailed usage instructions.
+
     - Added experimental support of savepoints. It
       currently does not work fully with sessions.
 
index bcfd975ab3dc0e7f69d7aaac9d9c9951ad9d71a0..8fb6bfa4ab9b8ae0ec5686628f0b793556d67492 100644 (file)
 # mssql.py
 
-"""MSSQL backend, thru either pymssq, adodbapi or pyodbc interfaces.
+"""Support for the Microsoft SQL Server database.
 
-* ``IDENTITY`` columns are supported by using SA ``schema.Sequence()``
-  objects. In other words::
+Driver
+------
+
+The MSSQL dialect will work with three different available drivers:
+
+* *pymssql* - http://pymssql.sourceforge.net/
+
+* *pyodbc* - http://pyodbc.sourceforge.net/. This is the recommeded
+  driver.
+
+* *adodbapi* - http://adodbapi.sourceforge.net/
+
+Drivers are loaded in the order listed above based on availability.
+Currently the pyodbc driver offers the greatest level of
+compatibility.
+
+Connecting
+----------
+
+Connecting with create_engine() uses the standard URL approach of
+``mssql://user:pass@host/dbname[?key=value&key=value...]``.
+
+If the database name is present, the tokens are converted to a
+connection string with the specified values. If the database is not
+present, then the host token is taken directly as the DSN name.
+
+Examples of pyodbc connection string URLs:
+
+* *mssql://mydsn* - connects using the specified DSN named ``mydsn``.
+  The connection string that is created will appear like::
+
+    dsn=mydsn;TrustedConnection=Yes
+
+* *mssql://user:pass@mydsn* - connects using the DSN named
+  ``mydsn`` passing in the ``UID`` and ``PWD`` information. The
+  connection string that is created will appear like::
+
+    dsn=mydsn;UID=user;PWD=pass
+
+* *mssql://user:pass@mydsn/?LANGUAGE=us_english* - connects
+  using the DSN named ``mydsn`` passing in the ``UID`` and ``PWD``
+  information, plus the additional connection configuration option
+  ``LANGUAGE``. The connection string that is created will appear
+  like::
+
+    dsn=mydsn;UID=user;PWD=pass;LANGUAGE=us_english
+
+* *mssql://user:pass@host/db* - connects using a connection string
+  dynamically created that would appear like::
+
+    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass
+
+* *mssql://user:pass@host:123/db* - connects using a connection
+  string that is dynamically created, which also includes the port
+  information using the comma syntax. If your connection string
+  requires the port information to be passed as a ``port`` keyword
+  see the next example. This will create the following connection
+  string::
+
+    DRIVER={SQL Server};Server=host,123;Database=db;UID=user;PWD=pass
+
+* *mssql://user:pass@host/db?port=123* - connects using a connection
+  string that is dynamically created that includes the port
+  information as a separate ``port`` keyword. This will create the
+  following connection string::
+
+    DRIVER={SQL Server};Server=host;Database=db;UID=user;PWD=pass;port=123
+
+If you require a connection string that is outside the options
+presented above, use the ``odbc_connect`` keyword to pass in a
+urlencoded connection string. What gets passed in will be urldecoded
+and passed directly.
+
+For example::
+
+    mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddb
+
+would create the following connection string::
+
+    dsn=mydsn;Database=db
+
+Encoding your connection string can be easily accomplished through
+the python shell. For example::
+
+    >>> import urllib
+    >>> urllib.quote_plus('dsn=mydsn;Database=db')
+    'dsn%3Dmydsn%3BDatabase%3Ddb'
+
+Additional arguments which may be specified either as query string
+arguments on the URL, or as keyword argument to
+:func:`~sqlalchemy.create_engine()` are:
+
+* *auto_identity_insert* - enables support for IDENTITY inserts by
+  automatically turning IDENTITY INSERT ON and OFF as required.
+  Defaults to ``True`.
+
+* *query_timeout* - allows you to override the default query timeout.
+  Defaults to ``None``. This is only supported on pymssql.
+
+* *text_as_varchar* - if enabled this will treat all TEXT column
+  types as their equivalent VARCHAR(max) type. This is often used if
+  you need to compare a VARCHAR to a TEXT field, which is not
+  supported directly on MSSQL. Defaults to ``False``.
+
+* *use_scope_identity* - allows you to specify that SCOPE_IDENTITY
+  should be used in place of the non-scoped version @@IDENTITY.
+  Defaults to ``False``. On pymssql this defaults to ``True``, and on
+  pyodbc this defaults to ``True`` if the version of pyodbc being
+  used supports it.
+
+* *has_window_funcs* - indicates whether or not window functions
+  (LIMIT and OFFSET) are supported on the version of MSSQL being
+  used. If you're running MSSQL 2005 or later turn this on to get
+  OFFSET support. Defaults to ``False``.
+
+* *max_identifier_length* - allows you to se the maximum length of
+  identfiers supported by the database. Defaults to 128. For pymssql
+  the default is 30.
+
+* *schema_name* - use to set the schema name. Defaults to ``dbo``.
+
+Auto Increment Behavior
+-----------------------
+
+``IDENTITY`` columns are supported by using SQLAlchemy
+``schema.Sequence()`` objects. In other words::
 
     Table('test', mss_engine,
-           Column('id',   Integer, Sequence('blah',100,10), primary_key=True),
+           Column('id', Integer,
+                  Sequence('blah',100,10), primary_key=True),
            Column('name', String(20))
          ).create()
 
-  would yield::
+would yield::
 
    CREATE TABLE test (
      id INTEGER NOT NULL IDENTITY(100,10) PRIMARY KEY,
      name VARCHAR(20) NULL,
      )
 
-  Note that the start & increment values for sequences are optional
-  and will default to 1,1.
+Note that the ``start`` and ``increment`` values for sequences are
+optional and will default to 1,1.
 
 * Support for ``SET IDENTITY_INSERT ON`` mode (automagic on / off for
   ``INSERT`` s)
 
-* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on ``INSERT``
+* Support for auto-fetching of ``@@IDENTITY/@@SCOPE_IDENTITY()`` on
+  ``INSERT``
+
+LIMIT/OFFSET Support
+--------------------
+
+MSSQL has no support for the LIMIT or OFFSET keysowrds. LIMIT is
+supported directly through the ``TOP`` Transact SQL keyword::
+
+    select.limit
+
+will yield::
 
-* ``select._limit`` implemented as ``SELECT TOP n``
+    SELECT TOP n
 
-* Experimental implemention of LIMIT / OFFSET with row_number()
+If the ``has_window_funcs`` flag is set then LIMIT with OFFSET
+support is available through the ``ROW_NUMBER OVER`` construct. This
+construct requires an ``ORDER BY`` to be specified as well and is
+only available on MSSQL 2005 and later.
 
-* Support for three levels of column nullability provided. The default
-  nullability allows nulls::
+Nullability
+-----------
+MSSQL has support for three levels of column nullability. The default
+nullability allows nulls and is explicit in the CREATE TABLE
+construct::
 
     name VARCHAR(20) NULL
 
-  If ``nullable=None`` is specified then no specification is made. In other
-  words the database's configured default is used. This will render::
+If ``nullable=None`` is specified then no specification is made. In
+other words the database's configured default is used. This will
+render::
 
     name VARCHAR(20)
 
-  If ``nullable`` is True or False then the column will be ``NULL` or
-  ``NOT NULL`` respectively.
+If ``nullable`` is ``True`` or ``False`` then the column will be
+``NULL` or ``NOT NULL`` respectively.
 
-Known issues / TODO:
+Known Issues
+------------
 
 * No support for more than one ``IDENTITY`` column per table
 
@@ -50,7 +194,7 @@ Known issues / TODO:
   does **not** work around
 
 """
-import datetime, operator, re, sys
+import datetime, operator, re, sys, urllib
 
 from sqlalchemy import sql, schema, exc, util
 from sqlalchemy.sql import compiler, expression, operators as sqlops, functions as sql_functions
@@ -299,77 +443,92 @@ class MSVariant(sqltypes.TypeEngine):
     def get_col_spec(self):
         return "SQL_VARIANT"
 
-class MSSQLExecutionContext(default.DefaultExecutionContext):
-    def __init__(self, *args, **kwargs):
-        self.IINSERT = self.HASIDENT = False
-        super(MSSQLExecutionContext, self).__init__(*args, **kwargs)
 
-    def _has_implicit_sequence(self, column):
-        if column.primary_key and column.autoincrement:
-            if isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
-                if column.default is None or (isinstance(column.default, schema.Sequence) and \
-                                              column.default.optional):
-                    return True
-        return False
+def _has_implicit_sequence(column):
+    return column.primary_key and  \
+        column.autoincrement and \
+        isinstance(column.type, sqltypes.Integer) and \
+        not column.foreign_keys and \
+        (
+            column.default is None or 
+            (
+                isinstance(column.default, schema.Sequence) and 
+                column.default.optional)
+            )
+
+def _table_sequence_column(tbl):
+    if not hasattr(tbl, '_ms_has_sequence'):
+        tbl._ms_has_sequence = None
+        for column in tbl.c:
+            if getattr(column, 'sequence', False) or _has_implicit_sequence(column):
+                tbl._ms_has_sequence = column
+                break
+    return tbl._ms_has_sequence
+
+class MSSQLExecutionContext(default.DefaultExecutionContext):
+    IINSERT = False
+    HASIDENT = False
 
     def pre_exec(self):
-        """MS-SQL has a special mode for inserting non-NULL values
-        into IDENTITY columns.
+        """Activate IDENTITY_INSERT if needed."""
 
-        Activate it if the feature is turned on and needed.
-        """
         if self.compiled.isinsert:
             tbl = self.compiled.statement.table
-            if not hasattr(tbl, 'has_sequence'):
-                tbl.has_sequence = None
-                for column in tbl.c:
-                    if getattr(column, 'sequence', False) or self._has_implicit_sequence(column):
-                        tbl.has_sequence = column
-                        break
-
-            self.HASIDENT = bool(tbl.has_sequence)
+            
+            seq_column = _table_sequence_column(tbl)
+            self.HASIDENT = bool(seq_column)
             if self.dialect.auto_identity_insert and self.HASIDENT:
-                if isinstance(self.compiled_parameters, list):
-                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters[0]
-                else:
-                    self.IINSERT = tbl.has_sequence.key in self.compiled_parameters
+                self.IINSERT = tbl._ms_has_sequence.key in self.compiled_parameters[0]
             else:
                 self.IINSERT = False
 
             if self.IINSERT:
-                self.cursor.execute("SET IDENTITY_INSERT %s ON" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+                self.cursor.execute("SET IDENTITY_INSERT %s ON" % 
+                    self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
 
-        super(MSSQLExecutionContext, self).pre_exec()
+    def handle_dbapi_exception(self, e):
+        if self.IINSERT:
+            try:
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+            except:
+                pass
 
     def post_exec(self):
-        """Turn off the INDENTITY_INSERT mode if it's been activated,
-        and fetch recently inserted IDENTIFY values (works only for
-        one column).
-        """
+        """Disable IDENTITY_INSERT if enabled."""
 
-        if self.compiled.isinsert and (not self.executemany) and self.HASIDENT and not self.IINSERT:
-            if not len(self._last_inserted_ids) or self._last_inserted_ids[0] is None:
+        if self.compiled.isinsert and not self.executemany and self.HASIDENT and not self.IINSERT:
+            if not self._last_inserted_ids or self._last_inserted_ids[0] is None:
                 if self.dialect.use_scope_identity:
                     self.cursor.execute("SELECT scope_identity() AS lastrowid")
                 else:
                     self.cursor.execute("SELECT @@identity AS lastrowid")
                 row = self.cursor.fetchone()
                 self._last_inserted_ids = [int(row[0])] + self._last_inserted_ids[1:]
-        super(MSSQLExecutionContext, self).post_exec()
+
+        if self.IINSERT:
+            self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
 
 
 class MSSQLExecutionContext_pyodbc (MSSQLExecutionContext):
     def pre_exec(self):
         """where appropriate, issue "select scope_identity()" in the same statement"""
         super(MSSQLExecutionContext_pyodbc, self).pre_exec()
-        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) \
+        if self.compiled.isinsert and self.HASIDENT and not self.IINSERT \
                 and len(self.parameters) == 1 and self.dialect.use_scope_identity:
             self.statement += "; select scope_identity()"
 
     def post_exec(self):
-        if self.compiled.isinsert and self.HASIDENT and (not self.IINSERT) and self.dialect.use_scope_identity:
-            # do nothing - id was fetched in dialect.do_execute()
-            pass
+        if self.HASIDENT and not self.IINSERT and self.dialect.use_scope_identity and not self.executemany:
+            import pyodbc
+            # Fetch the last inserted id from the manipulated statement
+            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
+            while True:
+                try:
+                    row = self.cursor.fetchone()
+                    break
+                except pyodbc.Error, e:
+                    self.cursor.nextset()
+            self._last_inserted_ids = [int(row[0])]
         else:
             super(MSSQLExecutionContext_pyodbc, self).post_exec()
 
@@ -377,7 +536,13 @@ class MSSQLDialect(default.DefaultDialect):
     name = 'mssql'
     supports_default_values = True
     supports_empty_insert = False
+    auto_identity_insert = True
     execution_ctx_cls = MSSQLExecutionContext
+    text_as_varchar = False
+    use_scope_identity = False
+    has_window_funcs = False
+    max_identifier_length = 128
+    schema_name = "dbo"
 
     colspecs = {
         sqltypes.Unicode : MSNVarchar,
@@ -426,23 +591,33 @@ class MSSQLDialect(default.DefaultDialect):
         'sql_variant': MSVariant,
     }
 
-    def __new__(cls, dbapi=None, *args, **kwargs):
-        if cls != MSSQLDialect:
+    def __new__(cls, *args, **kwargs):
+        if cls is not MSSQLDialect:
+            # this gets called with the dialect specific class
             return super(MSSQLDialect, cls).__new__(cls, *args, **kwargs)
+        dbapi = kwargs.get('dbapi', None)
         if dbapi:
             dialect = dialect_mapping.get(dbapi.__name__)
-            return dialect(*args, **kwargs)
+            return dialect(**kwargs)
         else:
             return object.__new__(cls, *args, **kwargs)
 
-    def __init__(self, auto_identity_insert=True, **params):
-        super(MSSQLDialect, self).__init__(**params)
-        self.auto_identity_insert = auto_identity_insert
-        self.text_as_varchar = False
-        self.use_scope_identity = False
-        self.has_window_funcs = False
-        self.set_default_schema_name("dbo")
+    def __init__(self,
+                 auto_identity_insert=True, query_timeout=None, text_as_varchar=False,
+                 use_scope_identity=False,  has_window_funcs=False, max_identifier_length=None,
+                 schema_name="dbo", **opts):
+        self.auto_identity_insert = bool(auto_identity_insert)
+        self.query_timeout = int(query_timeout or 0)
+        self.schema_name = schema_name
 
+        # to-do: the options below should use server version introspection to set themselves on connection
+        self.text_as_varchar = bool(text_as_varchar)
+        self.use_scope_identity = bool(use_scope_identity)
+        self.has_window_funcs =  bool(has_window_funcs)
+        self.max_identifier_length = int(max_identifier_length or 0) or 128
+        super(MSSQLDialect, self).__init__(**opts)
+
+    @classmethod
     def dbapi(cls, module_name=None):
         if module_name:
             try:
@@ -458,8 +633,8 @@ class MSSQLDialect(default.DefaultDialect):
                     pass
             else:
                 raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
-    dbapi = classmethod(dbapi)
 
+    @base.connection_memoize(('mssql', 'server_version_info'))
     def server_version_info(self, connection):
         """A tuple of the database server version.
 
@@ -472,14 +647,11 @@ class MSSQLDialect(default.DefaultDialect):
         cached per-Connection.
         """
         return connection.dialect._server_version_info(connection.connection)
-    server_version_info = base.connection_memoize(
-        ('mssql', 'server_version_info'))(server_version_info)
 
     def _server_version_info(self, dbapi_con):
         """Return a tuple of the database's version number."""
-
         raise NotImplementedError()
-    
+
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         opts.update(url.query)
@@ -493,7 +665,7 @@ class MSSQLDialect(default.DefaultDialect):
             self.use_scope_identity = bool(int(opts.pop('use_scope_identity')))
         if 'has_window_funcs' in opts:
             self.has_window_funcs =  bool(int(opts.pop('has_window_funcs')))
-        return self.make_connect_string(opts)
+        return self.make_connect_string(opts, url.query)
 
     def type_descriptor(self, typeobj):
         newobj = sqltypes.adapt_type(typeobj, self.colspecs)
@@ -505,51 +677,10 @@ class MSSQLDialect(default.DefaultDialect):
     def get_default_schema_name(self, connection):
         return self.schema_name
 
-    def set_default_schema_name(self, schema_name):
-        self.schema_name = schema_name
-
-    def last_inserted_ids(self):
-        return self.context.last_inserted_ids
-
-    def do_execute(self, cursor, statement, params, context=None, **kwargs):
-        if params == {}:
-            params = ()
-        try:
-            super(MSSQLDialect, self).do_execute(cursor, statement, params, context=context, **kwargs)
-        finally:
-            if context.IINSERT:
-                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
-
-    def do_executemany(self, cursor, statement, params, context=None, **kwargs):
-        try:
-            super(MSSQLDialect, self).do_executemany(cursor, statement, params, context=context, **kwargs)
-        finally:
-            if context.IINSERT:
-                cursor.execute("SET IDENTITY_INSERT %s OFF" % self.identifier_preparer.format_table(context.compiled.statement.table))
-
-    def _execute(self, c, statement, parameters):
-        try:
-            if parameters == {}:
-                parameters = ()
-            c.execute(statement, parameters)
-            self.context.rowcount = c.rowcount
-            c.DBPROP_COMMITPRESERVE = "Y"
-        except Exception, e:
-            raise exc.DBAPIError.instance(statement, parameters, e)
-
     def table_names(self, connection, schema):
         from sqlalchemy.databases import information_schema as ischema
         return ischema.table_names(connection, schema)
 
-    def raw_connection(self, connection):
-        """Pull the raw pymmsql connection out--sensative to "pool.ConnectionFairy" and pymssql.pymssqlCnx Classes"""
-        try:
-            # TODO: probably want to move this to individual dialect subclasses to
-            # save on the exception throw + simplify
-            return connection.connection.__dict__['_pymssqlCnx__cnx']
-        except:
-            return connection.connection.adoConn
-
     def uppercase_table(self, t):
         # convert all names to uppercase -- fixes refs to INFORMATION_SCHEMA for case-senstive DBs, and won't matter for case-insensitive
         t.name = t.name.upper()
@@ -559,6 +690,7 @@ class MSSQLDialect(default.DefaultDialect):
             c.name = c.name.upper()
         return t
 
+
     def has_table(self, connection, tablename, schema=None):
         import sqlalchemy.databases.information_schema as ischema
 
@@ -645,7 +777,7 @@ class MSSQLDialect(default.DefaultDialect):
                 ic = table.c[col_name]
                 ic.autoincrement = True
                 # setup a psuedo-sequence to represent the identity attribute - we interpret this at table.create() time as the identity attribute
-                ic.sequence = schema.Sequence(ic.name + '_identity')
+                ic.sequence = schema.Sequence(ic.name + '_identity', 1, 1)
                 # MSSQL: only one identity per table allowed
                 cursor.close()
                 break
@@ -722,16 +854,13 @@ class MSSQLDialect_pymssql(MSSQLDialect):
     supports_sane_rowcount = False
     max_identifier_length = 30
 
+    @classmethod
     def import_dbapi(cls):
         import pymssql as module
         # pymmsql doesn't have a Binary method.  we use string
         # TODO: monkeypatching here is less than ideal
         module.Binary = lambda st: str(st)
         return module
-    import_dbapi = classmethod(import_dbapi)
-
-    ischema_names = MSSQLDialect.ischema_names.copy()
-
 
     def __init__(self, **params):
         super(MSSQLDialect_pymssql, self).__init__(**params)
@@ -739,23 +868,16 @@ class MSSQLDialect_pymssql(MSSQLDialect):
 
         # pymssql understands only ascii
         if self.convert_unicode:
+            util.warn("pymssql does not support unicode")
             self.encoding = params.get('encoding', 'ascii')
 
-    def do_rollback(self, connection):
-        # pymssql throws an error on repeated rollbacks. Ignore it.
-        # TODO: this is normal behavior for most DBs.  are we sure we want to ignore it ?
-        try:
-            connection.rollback()
-        except:
-            pass
-
     def create_connect_args(self, url):
         r = super(MSSQLDialect_pymssql, self).create_connect_args(url)
         if hasattr(self, 'query_timeout'):
             self.dbapi._mssql.set_query_timeout(self.query_timeout)
         return r
 
-    def make_connect_string(self, keys):
+    def make_connect_string(self, keys, query):
         if keys.get('port'):
             # pymssql expects port as host:port, not a separate arg
             keys['host'] = ''.join([keys.get('host', ''), ':', str(keys['port'])])
@@ -776,6 +898,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
 
     def __init__(self, **params):
         super(MSSQLDialect_pyodbc, self).__init__(**params)
+        # FIXME: scope_identity sniff should look at server version, not the ODBC driver
         # whether use_scope_identity will work depends on the version of pyodbc
         try:
             import pyodbc
@@ -783,10 +906,10 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
         except:
             pass
 
+    @classmethod
     def import_dbapi(cls):
         import pyodbc as module
         return module
-    import_dbapi = classmethod(import_dbapi)
 
     colspecs = MSSQLDialect.colspecs.copy()
     if supports_unicode:
@@ -800,45 +923,41 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
     ischema_names['smalldatetime'] = MSDate_pyodbc
     ischema_names['datetime'] = MSDateTime_pyodbc
 
-    def make_connect_string(self, keys):
+    def make_connect_string(self, keys, query):
         if 'max_identifier_length' in keys:
             self.max_identifier_length = int(keys.pop('max_identifier_length'))
-        if 'dsn' in keys:
-            connectors = ['dsn=%s' % keys.pop('dsn')]
+
+        if 'odbc_connect' in keys:
+            connectors = [urllib.unquote_plus(keys.pop('odbc_connect'))]
         else:
-            port = ''
-            if 'port' in keys and (
-                keys.get('driver', 'SQL Server') == 'SQL Server'):
-                port = ',%d' % int(keys.pop('port'))
+            dsn_connection = 'dsn' in keys or ('host' in keys and 'database' not in keys)
+            if dsn_connection:
+                connectors= ['dsn=%s' % (keys.pop('host', '') or keys.pop('dsn', ''))]
+            else:
+                port = ''
+                if 'port' in keys and not 'port' in query:
+                    port = ',%d' % int(keys.pop('port'))
+
+                connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
+                              'Server=%s%s' % (keys.pop('host', ''), port),
+                              'Database=%s' % keys.pop('database', '') ]
+
+            user = keys.pop("user", None)
+            if user:
+                connectors.append("UID=%s" % user)
+                connectors.append("PWD=%s" % keys.pop('password', ''))
+            else:
+                connectors.append("TrustedConnection=Yes")
 
-            connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
-                          'Server=%s%s' % (keys.pop('host', ''), port),
-                          'Database=%s' % keys.pop('database', '') ]
+            # if set to 'Yes', the ODBC layer will try to automagically convert 
+            # textual data from your database encoding to your client encoding 
+            # This should obviously be set to 'No' if you query a cp1253 encoded 
+            # database from a latin1 client... 
+            if 'odbc_autotranslate' in keys:
+                connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
 
-            if 'port' in keys and not port:
-                connectors.append('Port=%d' % int(keys.pop('port')))
+            connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
 
-        user = keys.pop("user", None)
-        if user:
-            connectors.append("UID=%s" % user)
-            connectors.append("PWD=%s" % keys.pop('password', ''))
-        else:
-            connectors.append("TrustedConnection=Yes")
-
-        # if set to 'Yes', the ODBC layer will try to automagically convert 
-        # textual data from your database encoding to your client encoding 
-        # This should obviously be set to 'No' if you query a cp1253 encoded 
-        # database from a latin1 client... 
-        if 'odbc_autotranslate' in keys:
-            connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
-
-        # Allow specification of partial ODBC connect string
-        if 'odbc_options' in keys: 
-            odbc_options=keys.pop('odbc_options')
-            if odbc_options[0]=="'" and odbc_options[-1]=="'":
-                odbc_options=odbc_options[1:-1]
-            connectors.append(odbc_options)
-        connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
         return [[";".join (connectors)], {}]
 
     def is_disconnect(self, e):
@@ -850,23 +969,8 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             return False
 
 
-    def do_execute(self, cursor, statement, parameters, context=None, **kwargs):
-        super(MSSQLDialect_pyodbc, self).do_execute(cursor, statement, parameters, context=context, **kwargs)
-        if context and context.HASIDENT and (not context.IINSERT) and context.dialect.use_scope_identity:
-            import pyodbc
-            # Fetch the last inserted id from the manipulated statement
-            # We may have to skip over a number of result sets with no data (due to triggers, etc.)
-            while True:
-                try:
-                    row = cursor.fetchone()
-                    break
-                except pyodbc.Error, e:
-                    cursor.nextset()
-            context._last_inserted_ids = [int(row[0])]
-
     def _server_version_info(self, dbapi_con):
         """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
-
         version = []
         r = re.compile('[.\-]')
         for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
@@ -882,10 +986,10 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
     supports_unicode = sys.maxunicode == 65535
     supports_unicode_statements = True
 
+    @classmethod
     def import_dbapi(cls):
         import adodbapi as module
         return module
-    import_dbapi = classmethod(import_dbapi)
 
     colspecs = MSSQLDialect.colspecs.copy()
     colspecs[sqltypes.Unicode] = AdoMSNVarchar
@@ -895,7 +999,7 @@ class MSSQLDialect_adodbapi(MSSQLDialect):
     ischema_names['nvarchar'] = AdoMSNVarchar
     ischema_names['datetime'] = MSDateTime_adodbapi
 
-    def make_connect_string(self, keys):
+    def make_connect_string(self, keys, query):
         connectors = ["Provider=SQLOLEDB"]
         if 'port' in keys:
             connectors.append ("Data Source=%s, %s" % (keys.get("host"), keys.get("port")))
@@ -963,7 +1067,7 @@ class MSSQLCompiler(compiler.DefaultCompiler):
         so tries to wrap it in a subquery with ``row_number()`` criterion.
 
         """
-        if self.dialect.has_window_funcs and (not getattr(select, '_mssql_visit', None)) and select._offset:
+        if self.dialect.has_window_funcs and not getattr(select, '_mssql_visit', None) and select._offset:
             # to use ROW_NUMBER(), an ORDER BY is required.
             orderby = self.process(select._order_by_clause)
             if not orderby:
@@ -1073,21 +1177,25 @@ class MSSQLSchemaGenerator(compiler.SchemaGenerator):
     def get_column_specification(self, column, **kwargs):
         colspec = self.preparer.format_column(column) + " " + column.type.dialect_impl(self.dialect).get_col_spec()
 
-        # install a IDENTITY Sequence if we have an implicit IDENTITY column
-        if (not getattr(column.table, 'has_sequence', False)) and column.primary_key and \
-                column.autoincrement and isinstance(column.type, sqltypes.Integer) and not column.foreign_keys:
-            if column.default is None or (isinstance(column.default, schema.Sequence) and column.default.optional):
-                column.sequence = schema.Sequence(column.name + '_seq')
-
         if column.nullable is not None:
             if not column.nullable:
                 colspec += " NOT NULL"
             else:
                 colspec += " NULL"
+        
+        if not column.table:
+            raise exc.InvalidRequestError("mssql requires Table-bound columns in order to generate DDL")
+            
+        seq_col = _table_sequence_column(column.table)
 
-        if hasattr(column, 'sequence'):
-            column.table.has_sequence = column
-            colspec += " IDENTITY(%s,%s)" % (column.sequence.start or 1, column.sequence.increment or 1)
+        # install a IDENTITY Sequence if we have an implicit IDENTITY column
+        if seq_col is column:
+            sequence = getattr(column, 'sequence', None)
+            if sequence:
+                start, increment = sequence.start or 1, sequence.increment or 1
+            else:
+                start, increment = 1, 1
+            colspec += " IDENTITY(%s,%s)" % (start, increment)
         else:
             default = self.get_column_default_string(column)
             if default is not None:
@@ -1104,11 +1212,6 @@ class MSSQLSchemaDropper(compiler.SchemaDropper):
         self.execute()
 
 
-class MSSQLDefaultRunner(base.DefaultRunner):
-    # TODO: does ms-sql have standalone sequences ?
-    # A: No, only auto-incrementing IDENTITY property of a column
-    pass
-
 class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
     reserved_words = compiler.IdentifierPreparer.reserved_words.union(MSSQL_RESERVED_WORDS)
 
@@ -1116,7 +1219,7 @@ class MSSQLIdentifierPreparer(compiler.IdentifierPreparer):
         super(MSSQLIdentifierPreparer, self).__init__(dialect, initial_quote='[', final_quote=']')
 
     def _escape_identifier(self, value):
-        #TODO: determin MSSQL's escapeing rules
+        #TODO: determine MSSQL's escaping rules
         return value
 
 dialect = MSSQLDialect
@@ -1124,4 +1227,3 @@ dialect.statement_compiler = MSSQLCompiler
 dialect.schemagenerator = MSSQLSchemaGenerator
 dialect.schemadropper = MSSQLSchemaDropper
 dialect.preparer = MSSQLIdentifierPreparer
-dialect.defaultrunner = MSSQLDefaultRunner
index 9efd73a89fdb585f568b80f3d76df653a2c35793..7af74ec7bbb436f015469e41a9a3ad5111b20243 100644 (file)
@@ -350,6 +350,11 @@ class ExecutionContext(object):
 
         raise NotImplementedError()
 
+    def handle_dbapi_exception(self, e):
+        """Receive a DBAPI exception which occured upon execute, result fetch, etc."""
+        
+        raise NotImplementedError()
+        
     def should_autocommit_text(self, statement):
         """Parse the given textual statement and return True if it refers to a "committable" statement"""
 
@@ -714,7 +719,7 @@ class Connection(Connectable):
         try:
             self.engine.dialect.do_begin(self.connection)
         except Exception, e:
-            self._handle_dbapi_exception(e, None, None, None)
+            self._handle_dbapi_exception(e, None, None, None, None)
             raise
 
     def _rollback_impl(self):
@@ -725,7 +730,7 @@ class Connection(Connectable):
                 self.engine.dialect.do_rollback(self.connection)
                 self.__transaction = None
             except Exception, e:
-                self._handle_dbapi_exception(e, None, None, None)
+                self._handle_dbapi_exception(e, None, None, None, None)
                 raise
         else:
             self.__transaction = None
@@ -737,7 +742,7 @@ class Connection(Connectable):
             self.engine.dialect.do_commit(self.connection)
             self.__transaction = None
         except Exception, e:
-            self._handle_dbapi_exception(e, None, None, None)
+            self._handle_dbapi_exception(e, None, None, None, None)
             raise
 
     def _savepoint_impl(self, name=None):
@@ -897,13 +902,17 @@ class Connection(Connectable):
             schema_item = None
         return ddl(None, schema_item, self, *params, **multiparams)
 
-    def _handle_dbapi_exception(self, e, statement, parameters, cursor):
+    def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
             raise exc.DBAPIError.instance(None, None, e)
         self._reentrant_error = True
         try:
             if not isinstance(e, self.dialect.dbapi.Error):
                 return
+                
+            if context:
+                context.handle_dbapi_exception(e)
+                
             is_disconnect = self.dialect.is_disconnect(e)
             if is_disconnect:
                 self.invalidate(e)
@@ -923,7 +932,7 @@ class Connection(Connectable):
             dialect = self.engine.dialect
             return dialect.execution_ctx_cls(dialect, connection=self, **kwargs)
         except Exception, e:
-            self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None)
+            self._handle_dbapi_exception(e, kwargs.get('statement', None), kwargs.get('parameters', None), None, None)
             raise
 
     def _cursor_execute(self, cursor, statement, parameters, context=None):
@@ -933,7 +942,7 @@ class Connection(Connectable):
         try:
             self.dialect.do_execute(cursor, statement, parameters, context=context)
         except Exception, e:
-            self._handle_dbapi_exception(e, statement, parameters, cursor)
+            self._handle_dbapi_exception(e, statement, parameters, cursor, context)
             raise
 
     def _cursor_executemany(self, cursor, statement, parameters, context=None):
@@ -943,7 +952,7 @@ class Connection(Connectable):
         try:
             self.dialect.do_executemany(cursor, statement, parameters, context=context)
         except Exception, e:
-            self._handle_dbapi_exception(e, statement, parameters, cursor)
+            self._handle_dbapi_exception(e, statement, parameters, cursor, context)
             raise
 
     # poor man's multimethod/generic function thingy
@@ -1623,7 +1632,7 @@ class ResultProxy(object):
             self.close()
             return l
         except Exception, e:
-            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
 
     def fetchmany(self, size=None):
@@ -1636,7 +1645,7 @@ class ResultProxy(object):
                 self.close()
             return l
         except Exception, e:
-            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
 
     def fetchone(self):
@@ -1649,7 +1658,7 @@ class ResultProxy(object):
                 self.close()
                 return None
         except Exception, e:
-            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
 
     def scalar(self):
@@ -1657,7 +1666,7 @@ class ResultProxy(object):
         try:
             row = self._fetchone_impl()
         except Exception, e:
-            self.connection._handle_dbapi_exception(e, None, None, self.cursor)
+            self.connection._handle_dbapi_exception(e, None, None, self.cursor, self.context)
             raise
             
         try:
index 11fd43df561d3a9e5d856b37dda814013ce3980b..682ab526c6ebd10268436e38259f6242354ac64b 100644 (file)
@@ -259,6 +259,9 @@ class DefaultExecutionContext(base.ExecutionContext):
     def post_exec(self):
         pass
     
+    def handle_dbapi_exception(self, e):
+        pass
+
     def get_result_proxy(self):
         return base.ResultProxy(self)
 
@@ -306,7 +309,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(*inputsizes)
             except Exception, e:
-                self._connection._handle_dbapi_exception(e, None, None, None)
+                self._connection._handle_dbapi_exception(e, None, None, None, self)
                 raise
         else:
             inputsizes = {}
@@ -318,7 +321,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             try:
                 self.cursor.setinputsizes(**inputsizes)
             except Exception, e:
-                self._connection._handle_dbapi_exception(e, None, None, None)
+                self._connection._handle_dbapi_exception(e, None, None, None, self)
                 raise
 
     def __process_defaults(self):
index 5d97cf1484363b04ea07c75d8a161f9caf4d1aec..e38ee82b78f9db9db39278d00f4c8b6f86dc38e7 100755 (executable)
@@ -251,7 +251,10 @@ class GenerativeQueryTest(TestBase):
 class SchemaTest(TestBase):
 
     def setUp(self):
-        self.column = Column('test_column', Integer)
+        t = Table('sometable', MetaData(), 
+            Column('test_column', Integer)
+        )
+        self.column = t.c.test_column
 
     def test_that_mssql_default_nullability_emits_null(self):
         schemagenerator = \
@@ -399,18 +402,73 @@ class MatchTest(TestBase, AssertsCompiledSQL):
 class ParseConnectTest(TestBase, AssertsCompiledSQL):
     __only_on__ = 'mssql'
 
+    def test_pyodbc_connect_dsn_trusted(self):
+        u = url.make_url('mssql://mydsn')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+
+    def test_pyodbc_connect_old_style_dsn_trusted(self):
+        u = url.make_url('mssql:///?dsn=mydsn')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['dsn=mydsn;TrustedConnection=Yes'], {}], connection)
+
+    def test_pyodbc_connect_dsn_non_trusted(self):
+        u = url.make_url('mssql://username:password@mydsn')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['dsn=mydsn;UID=username;PWD=password'], {}], connection)
+
+    def test_pyodbc_connect_dsn_extra(self):
+        u = url.make_url('mssql://username:password@mydsn/?LANGUAGE=us_english&foo=bar')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['dsn=mydsn;UID=username;PWD=password;LANGUAGE=us_english;foo=bar'], {}], connection)
+
     def test_pyodbc_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
 
+    def test_pyodbc_connect_comma_port(self):
+        u = url.make_url('mssql://username:password@hostspec:12345/database')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec,12345;Database=database;UID=username;PWD=password'], {}], connection)
+
+    def test_pyodbc_connect_config_port(self):
+        u = url.make_url('mssql://username:password@hostspec/database?port=12345')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;port=12345'], {}], connection)
+
     def test_pyodbc_extra_connect(self):
         u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
         dialect = mssql.MSSQLDialect_pyodbc()
         connection = dialect.create_connect_args(u)
         self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
 
+    def test_pyodbc_odbc_connect(self):
+        u = url.make_url('mssql:///?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+
+    def test_pyodbc_odbc_connect_with_dsn(self):
+        u = url.make_url('mssql:///?odbc_connect=dsn%3Dmydsn%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['dsn=mydsn;Database=database;UID=username;PWD=password'], {}], connection)
+
+    def test_pyodbc_odbc_connect_ignores_other_values(self):
+        u = url.make_url('mssql://userdiff:passdiff@localhost/dbdiff?odbc_connect=DRIVER%3D%7BSQL+Server%7D%3BServer%3Dhostspec%3BDatabase%3Ddatabase%3BUID%3Dusername%3BPWD%3Dpassword')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+
+
 class TypesTest(TestBase):
     __only_on__ = 'mssql'
 
@@ -443,7 +501,7 @@ class TypesTest(TestBase):
             numeric_table.insert().execute(numericcol=Decimal('1E-7'))
             numeric_table.insert().execute(numericcol=Decimal('1E-8'))
         except:
-            assert False 
+            assert False
 
 if __name__ == "__main__":
     testenv.main()
index 9b3d4cec5eab6c2618b272a1f76e9abde3ac9d34..acfe4a4b0992679ee09698650989aadcbc3521c5 100644 (file)
@@ -59,7 +59,7 @@ class QueryTest(TestBase):
 
             result = table.insert().execute(**values)
             ret = values.copy()
-
+            
             for col, id in zip(table.primary_key, result.last_inserted_ids()):
                 ret[col.key] = id