]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- the execution sequence pulls all rowcount/last inserted ID
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Feb 2010 23:51:54 +0000 (23:51 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 28 Feb 2010 23:51:54 +0000 (23:51 +0000)
info from the cursor before commit() is called on the
DBAPI connection in an "autocommit" scenario.  This helps
mxodbc with rowcount and is probably a good idea overall.
- cx_oracle wants list(), not tuple(), for empty execute.
- cleaned up plain SQL param handling

CHANGES
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mssql/mxodbc.py
lib/sqlalchemy/dialects/oracle/cx_oracle.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
test/engine/test_execute.py
test/sql/test_rowcount.py

diff --git a/CHANGES b/CHANGES
index c4495552a30204ea23fbdddc16c585321b98bf88..d7695c40f16677fd76eac43b44b7c82c6d9e29cc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -190,6 +190,11 @@ CHANGES
     Note that it is *not* built/installed by default.
     See README for installation instructions.
 
+  - the execution sequence pulls all rowcount/last inserted ID
+    info from the cursor before commit() is called on the 
+    DBAPI connection in an "autocommit" scenario.  This helps 
+    mxodbc with rowcount and is probably a good idea overall.
+    
   - Opened up logging a bit such that isEnabledFor() is called 
     more often, so that changes to the log level for engine/pool
     will be reflected on next connect.   This adds a small 
index d1ccf44e2c7194edcfda121c7e1f5c0eba203e2e..254aa54fd3bea7ec0e92e8f019b0d66a94499720 100644 (file)
@@ -826,7 +826,11 @@ class MSExecutionContext(default.DefaultExecutionContext):
     def handle_dbapi_exception(self, e):
         if self._enable_identity_insert:
             try:
-                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % self.dialect.identifier_preparer.format_table(self.compiled.statement.table))
+                self.cursor.execute("SET IDENTITY_INSERT %s OFF" % 
+                                    self.dialect.\
+                                    identifier_preparer.\
+                                    format_table(self.compiled.statement.table)
+                                )
             except:
                 pass
 
index 73cf1346e02082e5f11e9946e85cfa8c094ade87..85c0ac6ac4d38e9041d0923329f9a670d478425e 100644 (file)
@@ -9,19 +9,7 @@ from sqlalchemy.dialects.mssql.pyodbc import MSExecutionContext_pyodbc
 # The pyodbc execution context seems to work for mxODBC; reuse it here
 
 class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc):
-    
-    def post_exec(self):
-        # snag rowcount before the cursor is closed
-        if not self.cursor.description:
-            self._rowcount = self.cursor.rowcount
-        super(MSExecutionContext_mxodbc, self).post_exec()
-        
-    @property
-    def rowcount(self):
-        if hasattr(self, '_rowcount'):
-            return self._rowcount
-        else:
-            return self.cursor.rowcount
+    pass
 
 class MSDialect_mxodbc(MxODBCConnector, MSDialect):
 
index 854eb875a8fa71257bf3d4217bef060cc13e030d..47909f8d177e0ca77f9aeb858c16506c9e649d2a 100644 (file)
@@ -307,6 +307,8 @@ class Oracle_cx_oracle(OracleDialect):
     driver = "cx_oracle"
     colspecs = colspecs
     
+    execute_sequence_format = list
+    
     def __init__(self, 
                 auto_setinputsizes=True, 
                 auto_convert_lobs=True, 
index b4f2524d6e13ba22042df9045f3ae917109ec7a4..46907dfcf97ab89b3690a22c6911a0772539a649 100644 (file)
@@ -77,6 +77,10 @@ class Dialect(object):
     execution_ctx_cls
       a :class:`ExecutionContext` class used to handle statement execution
 
+    execute_sequence_format
+      either the 'tuple' or 'list' type, depending on what cursor.execute()
+      accepts for the second argument (they vary).
+      
     preparer
       a :class:`~sqlalchemy.sql.compiler.IdentifierPreparer` class used to
       quote identifiers.
@@ -1055,6 +1059,7 @@ class Connection(Connectable):
 
         In the case of 'raw' execution which accepts positional parameters,
         it may be a list of tuples or lists.
+        
         """
 
         if not multiparams:
@@ -1104,7 +1109,9 @@ class Connection(Connectable):
             keys = []
 
         context = self.__create_execution_context(
-                        compiled_sql=elem.compile(dialect=self.dialect, column_keys=keys, inline=len(params) > 1),
+                        compiled_sql=elem.compile(
+                                        dialect=self.dialect, column_keys=keys, 
+                                        inline=len(params) > 1),
                         parameters=params
                     )
         return self.__execute_context(context)
@@ -1128,9 +1135,15 @@ class Connection(Connectable):
             context.pre_exec()
             
         if context.executemany:
-            self._cursor_executemany(context.cursor, context.statement, context.parameters, context=context)
+            self._cursor_executemany(
+                            context.cursor, 
+                            context.statement, 
+                            context.parameters, context=context)
         else:
-            self._cursor_execute(context.cursor, context.statement, context.parameters[0], context=context)
+            self._cursor_execute(
+                            context.cursor, 
+                            context.statement, 
+                            context.parameters[0], context=context)
             
         if context.compiled:
             context.post_exec()
@@ -1138,10 +1151,17 @@ class Connection(Connectable):
             if context.isinsert and not context.executemany:
                 context.post_insert()
         
+        # create a resultproxy, get rowcount/implicit RETURNING
+        # rows, close cursor if no further results pending
+        r = context.get_result_proxy()._autoclose()
+
         if self.__transaction is None and context.should_autocommit:
             self._commit_impl()
-            
-        return context.get_result_proxy()._autoclose()
+        
+        if r.closed and self.should_close_with_result:
+            self.close()
+        
+        return r
         
     def _handle_dbapi_exception(self, e, statement, parameters, cursor, context):
         if getattr(self, '_reentrant_error', False):
@@ -1893,6 +1913,7 @@ class ResultProxy(object):
 
     _process_row = RowProxy
     out_parameters = None
+    _can_close_connection = False
     
     def __init__(self, context):
         self.context = context
@@ -1904,7 +1925,6 @@ class ResultProxy(object):
                         context.engine._should_log_debug()
         self._init_metadata()
 
-    
     def _init_metadata(self):
         metadata = self._cursor_description()
         if metadata is None:
@@ -1962,21 +1982,26 @@ class ResultProxy(object):
         return self.cursor.description
             
     def _autoclose(self):
+        """called by the Connection to autoclose cursors that have no pending results
+        beyond those used by an INSERT/UPDATE/DELETE with no explicit RETURNING clause.
+        
+        """
         if self.context.isinsert:
             if self.context._is_implicit_returning:
                 self.context._fetch_implicit_returning(self)
-                self.close()
+                self.close(_autoclose_connection=False)
             elif not self.context._is_explicit_returning:
-                self.close()
+                self.close(_autoclose_connection=False)
         elif self._metadata is None:
             # no results, get rowcount 
-            # (which requires open cursor on some DB's such as firebird),
+            # (which requires open cursor on some drivers
+            # such as kintersbasdb, mxodbc),
             self.rowcount
-            self.close() # autoclose
-            
+            self.close(_autoclose_connection=False)
+        
         return self
-            
-    def close(self):
+    
+    def close(self, _autoclose_connection=True):
         """Close this ResultProxy.
 
         Closes the underlying DBAPI cursor corresponding to the execution.
@@ -1992,12 +2017,14 @@ class ResultProxy(object):
 
         * all result rows are exhausted using the fetchXXX() methods.
         * cursor.description is None.
+        
         """
 
         if not self.closed:
             self.closed = True
             self.cursor.close()
-            if self.connection.should_close_with_result:
+            if _autoclose_connection and \
+                self.connection.should_close_with_result:
                 self.connection.close()
 
     def __iter__(self):
index 4d4fd7c71969657f979bba3c6fdea480f7d896f2..cd2c103938bcd50aae90e2ae0e9f72d39fce5e63 100644 (file)
@@ -30,6 +30,10 @@ class DefaultDialect(base.Dialect):
     preparer = compiler.IdentifierPreparer
     supports_alter = True
 
+    # most DBAPIs happy with this for execute().
+    # not cx_oracle.  
+    execute_sequence_format = tuple
+    
     supports_sequences = False
     sequences_optional = False
     preexecute_autoincrement_sequences = False
@@ -365,7 +369,7 @@ class DefaultExecutionContext(base.ExecutionContext):
     @util.memoized_property
     def _default_params(self):
         if self.dialect.positional:
-            return ()
+            return self.dialect.execute_sequence_format()
         else:
             return {}
         
@@ -392,21 +396,23 @@ class DefaultExecutionContext(base.ExecutionContext):
         """Apply string encoding to the keys of dictionary-based bind parameters.
 
         This is only used executing textual, non-compiled SQL expressions.
+        
         """
-
-        if self.dialect.positional or self.dialect.supports_unicode_statements:
-            if params:
+        
+        if not params:
+            return [self._default_params]
+        elif isinstance(params[0], self.dialect.execute_sequence_format):
+            return params
+        elif isinstance(params[0], dict):
+            if self.dialect.supports_unicode_statements:
                 return params
             else:
-                return [self._default_params]
+                def proc(d):
+                    return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
+                return [proc(d) for d in params] or [{}]
         else:
-            def proc(d):
-                # sigh, sometimes we get positional arguments with a dialect
-                # that doesnt specify positional (because of execute_text())
-                if not isinstance(d, dict):
-                    return d
-                return dict((k.encode(self.dialect.encoding), d[k]) for k in d)
-            return [proc(d) for d in params] or [{}]
+            return [self.dialect.execute_sequence_format(p) for p in params]
+        
 
     def __convert_compiled_params(self, compiled_parameters):
         """Convert the dictionary of bind parameter values into a dict or list
@@ -423,7 +429,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                         param.append(processors[key](compiled_params[key]))
                     else:
                         param.append(compiled_params[key])
-                parameters.append(tuple(param))
+                parameters.append(self.dialect.execute_sequence_format(param))
         else:
             encode = not self.dialect.supports_unicode_statements
             for compiled_params in compiled_parameters:
@@ -442,7 +448,7 @@ class DefaultExecutionContext(base.ExecutionContext):
                         else:
                             param[key] = compiled_params[key]
                 parameters.append(param)
-        return tuple(parameters)
+        return self.dialect.execute_sequence_format(parameters)
 
     def should_autocommit_text(self, statement):
         return AUTOCOMMIT_REGEXP.match(statement)
@@ -514,7 +520,7 @@ class DefaultExecutionContext(base.ExecutionContext):
             
     def _fetch_implicit_returning(self, resultproxy):
         table = self.compiled.statement.table
-        row = resultproxy.first()
+        row = resultproxy.fetchone()
 
         self._inserted_primary_key = [v is not None and v or row[c] 
             for c, v in zip(table.primary_key, self._inserted_primary_key)
index 4a45fceb311ca82bec6c70b9693dd71d32e99f56..1752fda0ddc3d099431e36582b3590c177db10e2 100644 (file)
@@ -12,9 +12,13 @@ users, metadata = None, None
 class ExecuteTest(TestBase):
     @classmethod
     def setup_class(cls):
-        global users, metadata
+        global users, users_autoinc, metadata
         metadata = MetaData(testing.db)
         users = Table('users', metadata,
+            Column('user_id', INT, primary_key = True, autoincrement=False),
+            Column('user_name', VARCHAR(20)),
+        )
+        users_autoinc = Table('users_autoinc', metadata,
             Column('user_id', INT, primary_key = True, test_needs_autoincrement=True),
             Column('user_name', VARCHAR(20)),
         )
@@ -28,16 +32,22 @@ class ExecuteTest(TestBase):
     def teardown_class(cls):
         metadata.drop_all()
 
-    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', 'mysql+pyodbc', '+zxjdbc', 'mysql+oursql')
+    @testing.fails_on_everything_except('firebird', 'maxdb', 'sqlite', '+pyodbc', '+mxodbc', '+zxjdbc', 'mysql+oursql')
     def test_raw_qmark(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (?, ?)", (1,"jack"))
             conn.execute("insert into users (user_id, user_name) values (?, ?)", [2,"fred"])
-            conn.execute("insert into users (user_id, user_name) values (?, ?)", [3,"ed"], [4,"horse"])
-            conn.execute("insert into users (user_id, user_name) values (?, ?)", (5,"barney"), (6,"donkey"))
+            conn.execute("insert into users (user_id, user_name) values (?, ?)", 
+                                                                                [3,"ed"],
+                                                                                [4,"horse"])
+            conn.execute("insert into users (user_id, user_name) values (?, ?)", 
+                                                                (5,"barney"), (6,"donkey"))
             conn.execute("insert into users (user_id, user_name) values (?, ?)", 7, 'sally')
             res = conn.execute("select * from users order by user_id")
-            assert res.fetchall() == [(1, "jack"), (2, "fred"), (3, "ed"), (4, "horse"), (5, "barney"), (6, "donkey"), (7, 'sally')]
+            assert res.fetchall() == [(1, "jack"), (2, "fred"), 
+                                        (3, "ed"), (4, "horse"), 
+                                        (5, "barney"), (6, "donkey"), 
+                                        (7, 'sally')]
             conn.execute("delete from users")
 
     @testing.fails_on_everything_except('mysql+mysqldb', 'postgresql')
@@ -46,11 +56,15 @@ class ExecuteTest(TestBase):
     def test_raw_sprintf(self):
         for conn in (testing.db, testing.db.connect()):
             conn.execute("insert into users (user_id, user_name) values (%s, %s)", [1,"jack"])
-            conn.execute("insert into users (user_id, user_name) values (%s, %s)", [2,"ed"], [3,"horse"])
+            conn.execute("insert into users (user_id, user_name) values (%s, %s)", 
+                                                                            [2,"ed"], 
+                                                                            [3,"horse"])
             conn.execute("insert into users (user_id, user_name) values (%s, %s)", 4, 'sally')
             conn.execute("insert into users (user_id) values (%s)", 5)
             res = conn.execute("select * from users order by user_id")
-            assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally'), (5, None)]
+            assert res.fetchall() == [(1, "jack"), (2, "ed"), 
+                                        (3, "horse"), (4, 'sally'), 
+                                        (5, None)]
             conn.execute("delete from users")
 
     # pyformat is supported for mysql, but skipping because a few driver
@@ -59,9 +73,12 @@ class ExecuteTest(TestBase):
     @testing.fails_on_everything_except('postgresql+psycopg2', 'postgresql+pypostgresql')
     def test_raw_python(self):
         for conn in (testing.db, testing.db.connect()):
-            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':1, 'name':'jack'})
-            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'})
-            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", id=4, name='sally')
+            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)",
+                                    {'id':1, 'name':'jack'})
+            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)",
+                                {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'})
+            conn.execute("insert into users (user_id, user_name) values (%(id)s, %(name)s)", 
+                                id=4, name='sally')
             res = conn.execute("select * from users order by user_id")
             assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
             conn.execute("delete from users")
@@ -69,9 +86,12 @@ class ExecuteTest(TestBase):
     @testing.fails_on_everything_except('sqlite', 'oracle+cx_oracle')
     def test_raw_named(self):
         for conn in (testing.db, testing.db.connect()):
-            conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':1, 'name':'jack'})
-            conn.execute("insert into users (user_id, user_name) values (:id, :name)", {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'})
-            conn.execute("insert into users (user_id, user_name) values (:id, :name)", id=4, name='sally')
+            conn.execute("insert into users (user_id, user_name) values (:id, :name)", 
+                                            {'id':1, 'name':'jack'})
+            conn.execute("insert into users (user_id, user_name) values (:id, :name)", 
+                                            {'id':2, 'name':'ed'}, {'id':3, 'name':'horse'})
+            conn.execute("insert into users (user_id, user_name) values (:id, :name)", 
+                                            id=4, name='sally')
             res = conn.execute("select * from users order by user_id")
             assert res.fetchall() == [(1, "jack"), (2, "ed"), (3, "horse"), (4, 'sally')]
             conn.execute("delete from users")
@@ -86,8 +106,8 @@ class ExecuteTest(TestBase):
 
     def test_empty_insert(self):
         """test that execute() interprets [] as a list with no params"""
-        result = testing.db.execute(users.insert().values(user_name=bindparam('name')), [])
-        eq_(testing.db.execute(users.select()).fetchall(), [
+        result = testing.db.execute(users_autoinc.insert().values(user_name=bindparam('name')), [])
+        eq_(testing.db.execute(users_autoinc.select()).fetchall(), [
             (1, None)
         ])
 
@@ -124,17 +144,25 @@ class ProxyConnectionTest(TestBase):
 
         for engine in (
             engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy())),
-            engines.testing_engine(options=dict(implicit_returning=False, proxy=MyProxy(), strategy='threadlocal'))
+            engines.testing_engine(options=dict(
+                                                    implicit_returning=False, 
+                                                    proxy=MyProxy(), 
+                                                    strategy='threadlocal'))
         ):
             m = MetaData(engine)
 
-            t1 = Table('t1', m, Column('c1', Integer, primary_key=True), Column('c2', String(50), default=func.lower('Foo'), primary_key=True))
+            t1 = Table('t1', m, 
+                    Column('c1', Integer, primary_key=True), 
+                    Column('c2', String(50), default=func.lower('Foo'), primary_key=True)
+            )
 
             m.create_all()
             try:
                 t1.insert().execute(c1=5, c2='some data')
                 t1.insert().execute(c1=6)
-                assert engine.execute("select * from t1").fetchall() == [(5, 'some data'), (6, 'foo')]
+                eq_(engine.execute("select * from t1").fetchall(),
+                    [(5, 'some data'), (6, 'foo')]
+                )
             finally:
                 m.drop_all()
             
@@ -165,7 +193,8 @@ class ProxyConnectionTest(TestBase):
                 cursor = [
                     ("CREATE TABLE t1", {}, ()),
                     ("INSERT INTO t1 (c1, c2)", {'c2': 'some data', 'c1': 5}, (5, 'some data')),
-                    ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params),  # bind param name 'lower_2' might be incorrect
+                    # bind param name 'lower_2' might be incorrect
+                    ("INSERT INTO t1 (c1, c2)", {'c1': 6, "lower_2":"Foo"}, insert2_params),  
                     ("select * from t1", {}, ()),
                     ("DROP TABLE t1", {}, ())
                 ]
index 6da25b9145cfb36b8a6a408da4e2adde6b548c10..9577b104a67b8aa099ca44b913096e14490436e8 100644 (file)
@@ -54,22 +54,19 @@ class FoundRowsTest(TestBase, AssertsExecutionResults):
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='Z')
         print "expecting 3, dialect reports %s" % r.rowcount
-        if testing.db.dialect.supports_sane_rowcount:
-            assert r.rowcount == 3
+        assert r.rowcount == 3
 
     def test_update_rowcount2(self):
         # WHERE matches 3, 0 rows changed
         department = employees_table.c.department
         r = employees_table.update(department=='C').execute(department='C')
         print "expecting 3, dialect reports %s" % r.rowcount
-        if testing.db.dialect.supports_sane_rowcount:
-            assert r.rowcount == 3
+        assert r.rowcount == 3
 
     def test_delete_rowcount(self):
         # WHERE matches 3, 3 rows deleted
         department = employees_table.c.department
         r = employees_table.delete(department=='C').execute()
         print "expecting 3, dialect reports %s" % r.rowcount
-        if testing.db.dialect.supports_sane_rowcount:
-            assert r.rowcount == 3
+        assert r.rowcount == 3