]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- ensure rowcount is returned for an UPDATE with no implicit returning
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Aug 2013 21:37:59 +0000 (17:37 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sun, 25 Aug 2013 21:37:59 +0000 (17:37 -0400)
- modernize test for that
- use py3k compatible next() in test_returning/test_versioning

lib/sqlalchemy/engine/base.py
lib/sqlalchemy/testing/mock.py
test/engine/test_execute.py
test/orm/test_versioning.py
test/sql/test_returning.py

index 257eaa18a21c594f81d3e4d7d0bec266e1fcf303..735113a267f88455b1573ada0566c393448a73c8 100644 (file)
@@ -898,11 +898,10 @@ class Connection(Connectable):
             elif not context._is_explicit_returning:
                 result.close(_autoclose_connection=False)
                 result._metadata = None
-        elif context.isupdate:
-            if context._is_implicit_returning:
-                context._fetch_implicit_update_returning(result)
-                result.close(_autoclose_connection=False)
-                result._metadata = None
+        elif context.isupdate and context._is_implicit_returning:
+            context._fetch_implicit_update_returning(result)
+            result.close(_autoclose_connection=False)
+            result._metadata = None
 
         elif result._metadata is None:
             # no results, get rowcount
index 650962384a099dbabad1963f95e9039c57bf8283..fa2d477a7000bace14012014bf34ba0f66e34e7d 100644 (file)
@@ -4,10 +4,10 @@ from __future__ import absolute_import
 from ..util import py33
 
 if py33:
-    from unittest.mock import MagicMock, Mock, call
+    from unittest.mock import MagicMock, Mock, call, patch
 else:
     try:
-        from mock import MagicMock, Mock, call
+        from mock import MagicMock, Mock, call, patch
     except ImportError:
         raise ImportError(
                 "SQLAlchemy's test suite requires the "
index 1d2aebf972203a9b0d16f6e6f556d20f2437ace2..9623c080a6491b21c4904ebcc4c2496041f202f6 100644 (file)
@@ -17,9 +17,9 @@ from sqlalchemy.testing.engines import testing_engine
 import logging.handlers
 from sqlalchemy.dialects.oracle.zxjdbc import ReturningParam
 from sqlalchemy.engine import result as _result, default
-from sqlalchemy.engine.base import Connection, Engine
+from sqlalchemy.engine.base import Engine
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing.mock import Mock, call
+from sqlalchemy.testing.mock import Mock, call, patch
 
 
 users, metadata, users_autoinc = None, None, None
@@ -29,11 +29,11 @@ class ExecuteTest(fixtures.TestBase):
         global users, users_autoinc, metadata
         metadata = MetaData(testing.db)
         users = Table('users', metadata,
-            Column('user_id', INT, primary_key = True, autoincrement=False),
+            Column('user_id', INT, primary_key=True, autoincrement=False),
             Column('user_name', VARCHAR(20)),
         )
         users_autoinc = Table('users_autoinc', metadata,
-            Column('user_id', INT, primary_key = True,
+            Column('user_id', INT, primary_key=True,
                                     test_needs_autoincrement=True),
             Column('user_name', VARCHAR(20)),
         )
@@ -892,42 +892,42 @@ class ResultProxyTest(fixtures.TestBase):
     def test_no_rowcount_on_selects_inserts(self):
         """assert that rowcount is only called on deletes and updates.
 
-        This because cursor.rowcount can be expensive on some dialects
-        such as Firebird.
+        This because cursor.rowcount may can be expensive on some dialects
+        such as Firebird, however many dialects require it be called
+        before the cursor is closed.
 
         """
 
         metadata = self.metadata
 
         engine = engines.testing_engine()
-        metadata.bind = engine
 
         t = Table('t1', metadata,
             Column('data', String(10))
         )
-        metadata.create_all()
+        metadata.create_all(engine)
 
-        class BreakRowcountMixin(object):
-            @property
-            def rowcount(self):
-                assert False
+        with patch.object(engine.dialect.execution_ctx_cls, "rowcount") as mock_rowcount:
+            mock_rowcount.__get__ = Mock()
+            engine.execute(t.insert(),
+                                {'data': 'd1'},
+                                {'data': 'd2'},
+                                {'data': 'd3'})
 
-        execution_ctx_cls = engine.dialect.execution_ctx_cls
-        engine.dialect.execution_ctx_cls = type("FakeCtx",
-                                            (BreakRowcountMixin,
-                                            execution_ctx_cls),
-                                            {})
+            eq_(len(mock_rowcount.__get__.mock_calls), 0)
 
-        try:
-            r = t.insert().execute({'data': 'd1'}, {'data': 'd2'},
-                                   {'data': 'd3'})
-            eq_(t.select().execute().fetchall(), [('d1', ), ('d2', ),
-                ('d3', )])
-            assert_raises(AssertionError, t.update().execute, {'data'
-                          : 'd4'})
-            assert_raises(AssertionError, t.delete().execute)
-        finally:
-            engine.dialect.execution_ctx_cls = execution_ctx_cls
+            eq_(
+                    engine.execute(t.select()).fetchall(),
+                    [('d1', ), ('d2', ), ('d3', )]
+            )
+            eq_(len(mock_rowcount.__get__.mock_calls), 0)
+
+            engine.execute(t.update(), {'data': 'd4'})
+
+            eq_(len(mock_rowcount.__get__.mock_calls), 1)
+
+            engine.execute(t.delete())
+            eq_(len(mock_rowcount.__get__.mock_calls), 2)
 
 
     @testing.requires.python26
index d8d92830f8814d39cf9a6ceab7ed903f000b1ce2..026793c971c92264148bec6b3b7907a87837bd1e 100644 (file)
@@ -668,7 +668,7 @@ class ServerVersioningTest(fixtures.MappedTest):
             if hasattr(stmt, "_counter"):
                 return stmt._counter
             else:
-                stmt._counter = str(counter.next())
+                stmt._counter = str(next(counter))
                 return stmt._counter
 
         Table('version_table', metadata,
index 179d2d26107992d64ced46892f694a23500bd952..19f5d26c053554e85822a47dbc29421349a95e84 100644 (file)
@@ -201,7 +201,7 @@ class ReturnDefaultsTest(fixtures.TablesTest):
 
         @compiles(IncDefault)
         def compile(element, compiler, **kw):
-            return str(counter.next())
+            return str(next(counter))
 
         Table("t1", metadata,
                 Column("id", Integer, primary_key=True, test_needs_autoincrement=True),