]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- Fixed bug in :meth:`.Update.return_defaults` which would cause all
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Dec 2015 22:24:47 +0000 (17:24 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 14 Dec 2015 22:30:21 +0000 (17:30 -0500)
insert-default holding columns not otherwise included in the SET
clause (such as primary key cols) to get rendered into the RETURNING
even though this is an UPDATE.

- Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this
flag would not be honored correctly in the case that multiple
UPDATE statements were to be emitted, either as part of a flush
or a bulk update operation.  Additionally, RETURNING
would be emitted unnecessarily within update statements.

fixes #3609

doc/build/changelog/changelog_10.rst
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/crud.py
lib/sqlalchemy/testing/assertsql.py
test/orm/test_unitofworkv2.py
test/orm/test_versioning.py
test/sql/test_returning.py

index 950046cd0eb6d0fd6d1d701bf07cc95bd1f2c6a7..974aa5f1a1f3d53af5ab8f7b5ba0cd0bf3530d95 100644 (file)
 .. changelog::
     :version: 1.0.11
 
+    .. change::
+        :tags: bug, sql
+        :tickets: 3609
+        :versions: 1.1.0b1
+
+        Fixed bug in :meth:`.Update.return_defaults` which would cause all
+        insert-default holding columns not otherwise included in the SET
+        clause (such as primary key cols) to get rendered into the RETURNING
+        even though this is an UPDATE.
+
+    .. change::
+        :tags: bug, orm
+        :tickets: 3609
+        :versions: 1.1.0b1
+
+        Major fixes to the :paramref:`.Mapper.eager_defaults` flag, this
+        flag would not be honored correctly in the case that multiple
+        UPDATE statements were to be emitted, either as part of a flush
+        or a bulk update operation.  Additionally, RETURNING
+        would be emitted unnecessarily within update statements.
+
     .. change::
         :tags: bug, orm
         :tickets: 3606
index 5ade4b966518913da0bad0eed4b2a7cbc85e607a..95aa14a26f803ff530cf509f78c70ea1cca5cc78 100644 (file)
@@ -1970,12 +1970,24 @@ class Mapper(InspectionAttr):
             (
                 table,
                 frozenset([
-                    col for col in columns
+                    col.key for col in columns
                     if col.server_default is not None])
             )
             for table, columns in self._cols_by_table.items()
         )
 
+    @_memoized_configured_property
+    def _server_onupdate_default_cols(self):
+        return dict(
+            (
+                table,
+                frozenset([
+                    col.key for col in columns
+                    if col.server_onupdate is not None])
+            )
+            for table, columns in self._cols_by_table.items()
+        )
+
     @property
     def selectable(self):
         """The :func:`.select` construct this :class:`.Mapper` selects from
index 768c1146a3ec7070d67a42f7a0aad24bad8e687d..88c96e94ce169387f5561ea2a2f5e0b4614b9fa8 100644 (file)
@@ -448,6 +448,7 @@ def _collect_update_commands(
                 set(propkey_to_col).intersection(state_dict).difference(
                     mapper._pk_keys_by_table[table])
             )
+            has_all_defaults = True
         else:
             params = {}
             for propkey in set(propkey_to_col).intersection(
@@ -463,6 +464,12 @@ def _collect_update_commands(
                         value, state.committed_state[propkey]) is not True:
                     params[col.key] = value
 
+            if mapper.base_mapper.eager_defaults:
+                has_all_defaults = mapper._server_onupdate_default_cols[table].\
+                    issubset(params)
+            else:
+                has_all_defaults = True
+
         if update_version_id is not None and \
                 mapper.version_id_col in mapper._cols_by_table[table]:
 
@@ -529,7 +536,7 @@ def _collect_update_commands(
             params.update(pk_params)
             yield (
                 state, state_dict, params, mapper,
-                connection, value_params)
+                connection, value_params, has_all_defaults)
 
 
 def _collect_post_update_commands(base_mapper, uowtransaction, table,
@@ -619,23 +626,20 @@ def _emit_update_statements(base_mapper, uowtransaction,
                     type_=mapper.version_id_col.type))
 
         stmt = table.update(clause)
-        if mapper.base_mapper.eager_defaults:
-            stmt = stmt.return_defaults()
-        elif mapper.version_id_col is not None:
-            stmt = stmt.return_defaults(mapper.version_id_col)
-
         return stmt
 
     statement = base_mapper._memo(('update', table), update_stmt)
 
-    for (connection, paramkeys, hasvalue), \
+    for (connection, paramkeys, hasvalue, has_all_defaults), \
         records in groupby(
             update,
             lambda rec: (
                 rec[4],  # connection
                 set(rec[2]),  # set of parameter keys
-                bool(rec[5]))):  # whether or not we have "value" parameters
-
+                bool(rec[5]),  # whether or not we have "value" parameters
+                rec[6]  # has_all_defaults
+            )
+    ):
         rows = 0
         records = list(records)
 
@@ -645,11 +649,16 @@ def _emit_update_statements(base_mapper, uowtransaction,
         assert_singlerow = connection.dialect.supports_sane_rowcount
         assert_multirow = assert_singlerow and \
             connection.dialect.supports_sane_multi_rowcount
-        allow_multirow = not needs_version_id
+        allow_multirow = has_all_defaults and not needs_version_id
+
+        if bookkeeping and mapper.base_mapper.eager_defaults:
+            statement = statement.return_defaults()
+        elif mapper.version_id_col is not None:
+            statement = statement.return_defaults(mapper.version_id_col)
 
         if hasvalue:
             for state, state_dict, params, mapper, \
-                    connection, value_params in records:
+                    connection, value_params, has_all_defaults in records:
                 c = connection.execute(
                     statement.values(value_params),
                     params)
@@ -669,7 +678,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
             if not allow_multirow:
                 check_rowcount = assert_singlerow
                 for state, state_dict, params, mapper, \
-                        connection, value_params in records:
+                        connection, value_params, has_all_defaults in records:
                     c = cached_connections[connection].\
                         execute(statement, params)
 
@@ -699,7 +708,7 @@ def _emit_update_statements(base_mapper, uowtransaction,
                 rows += c.rowcount
 
                 for state, state_dict, params, mapper, \
-                        connection, value_params in records:
+                        connection, value_params, has_all_defaults in records:
                     if bookkeeping:
                         _postfetch(
                             mapper,
@@ -741,6 +750,7 @@ def _emit_insert_statements(base_mapper, uowtransaction,
                 bool(rec[5]),  # whether we have "value" parameters
                 rec[6],
                 rec[7])):
+
         if not bookkeeping or \
                 (
                     has_all_defaults
index 18b96018d468fed2d4ae6f208a9b6f1a2aa3bdfd..c5495ccde5cf71ec5648d3a86ef9769001299f6e 100644 (file)
@@ -493,6 +493,7 @@ def _append_param_update(
         else:
             compiler.postfetch.append(c)
     elif implicit_return_defaults and \
+            stmt._return_defaults is not True and \
             c in implicit_return_defaults:
         compiler.returning.append(c)
 
index 24349360710b01e07dabc8dd3310727901c28803..39d0789855705f314c08e3a7d4f79742afc6ea16 100644 (file)
@@ -13,6 +13,7 @@ import contextlib
 from .. import event
 from sqlalchemy.schema import _DDLCompiles
 from sqlalchemy.engine.util import _distill_params
+from sqlalchemy.engine import url
 
 
 class AssertRule(object):
@@ -58,16 +59,25 @@ class CursorSQL(SQLMatchRule):
 
 class CompiledSQL(SQLMatchRule):
 
-    def __init__(self, statement, params=None):
+    def __init__(self, statement, params=None, dialect='default'):
         self.statement = statement
         self.params = params
+        self.dialect = dialect
 
     def _compare_sql(self, execute_observed, received_statement):
         stmt = re.sub(r'[\n\t]', '', self.statement)
         return received_statement == stmt
 
     def _compile_dialect(self, execute_observed):
-        return DefaultDialect()
+        if self.dialect == 'default':
+            return DefaultDialect()
+        else:
+            # ugh
+            if self.dialect == 'postgresql':
+                params = {'implicit_returning': True}
+            else:
+                params = {}
+            return url.URL(self.dialect).get_dialect()(**params)
 
     def _received_statement(self, execute_observed):
         """reconstruct the statement and params in terms
@@ -159,7 +169,7 @@ class CompiledSQL(SQLMatchRule):
             'Testing for compiled statement %r partial params %r, '
             'received %%(received_statement)r with params '
             '%%(received_parameters)r' % (
-                self.statement, expected_params
+                self.statement.replace('%', '%%'), expected_params
             )
         )
 
@@ -170,6 +180,7 @@ class RegexSQL(CompiledSQL):
         self.regex = re.compile(regex)
         self.orig_regex = regex
         self.params = params
+        self.dialect = 'default'
 
     def _failure_message(self, expected_params):
         return (
index 09240dfdb6dcb76a1c6b3f870bae63fae44b1b53..c8ce13c913ca43d355e170abb008d14f49a16f5b 100644 (file)
@@ -5,7 +5,8 @@ from sqlalchemy.testing.schema import Table, Column
 from test.orm import _fixtures
 from sqlalchemy import exc, util
 from sqlalchemy.testing import fixtures, config
-from sqlalchemy import Integer, String, ForeignKey, func, literal
+from sqlalchemy import Integer, String, ForeignKey, func, \
+    literal, FetchedValue, text
 from sqlalchemy.orm import mapper, relationship, backref, \
     create_session, unitofwork, attributes,\
     Session, exc as orm_exc
@@ -1848,6 +1849,450 @@ class NoAttrEventInFlushTest(fixtures.MappedTest):
         eq_(t1.returning_val, 5)
 
 
+class EagerDefaultsTest(fixtures.MappedTest):
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            'test', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('foo', Integer, server_default="3")
+        )
+
+        Table(
+            'test2', metadata,
+            Column('id', Integer, primary_key=True),
+            Column('foo', Integer),
+            Column('bar', Integer, server_onupdate=FetchedValue())
+        )
+
+    @classmethod
+    def setup_classes(cls):
+        class Thing(cls.Basic):
+            pass
+
+        class Thing2(cls.Basic):
+            pass
+
+    @classmethod
+    def setup_mappers(cls):
+        Thing = cls.classes.Thing
+
+        mapper(Thing, cls.tables.test, eager_defaults=True)
+
+        Thing2 = cls.classes.Thing2
+
+        mapper(Thing2, cls.tables.test2, eager_defaults=True)
+
+    def test_insert_defaults_present(self):
+        Thing = self.classes.Thing
+        s = Session()
+
+        t1, t2 = (
+            Thing(id=1, foo=5),
+            Thing(id=2, foo=10)
+        )
+
+        s.add_all([t1, t2])
+
+        self.assert_sql_execution(
+            testing.db,
+            s.flush,
+            CompiledSQL(
+                "INSERT INTO test (id, foo) VALUES (:id, :foo)",
+                [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}]
+            ),
+        )
+
+        def go():
+            eq_(t1.foo, 5)
+            eq_(t2.foo, 10)
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_insert_defaults_present_as_expr(self):
+        Thing = self.classes.Thing
+        s = Session()
+
+        t1, t2 = (
+            Thing(id=1, foo=text("2 + 5")),
+            Thing(id=2, foo=text("5 + 5"))
+        )
+
+        s.add_all([t1, t2])
+
+        if testing.db.dialect.implicit_returning:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+                    "RETURNING test.foo",
+                    [{'id': 1}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+                    "RETURNING test.foo",
+                    [{'id': 2}],
+                    dialect='postgresql'
+                )
+            )
+
+        else:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
+                    [{'id': 1}]
+                ),
+                CompiledSQL(
+                    "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
+                    [{'id': 2}]
+                ),
+                CompiledSQL(
+                    "SELECT test.foo AS test_foo FROM test "
+                    "WHERE test.id = :param_1",
+                    [{'param_1': 1}]
+                ),
+                CompiledSQL(
+                    "SELECT test.foo AS test_foo FROM test "
+                    "WHERE test.id = :param_1",
+                    [{'param_1': 2}]
+                ),
+            )
+
+        def go():
+            eq_(t1.foo, 7)
+            eq_(t2.foo, 10)
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_insert_defaults_nonpresent(self):
+        Thing = self.classes.Thing
+        s = Session()
+
+        t1, t2 = (
+            Thing(id=1),
+            Thing(id=2)
+        )
+
+        s.add_all([t1, t2])
+
+        if testing.db.dialect.implicit_returning:
+            self.assert_sql_execution(
+                testing.db,
+                s.commit,
+                CompiledSQL(
+                    "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+                    [{'id': 1}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+                    [{'id': 2}],
+                    dialect='postgresql'
+                ),
+            )
+        else:
+            self.assert_sql_execution(
+                testing.db,
+                s.commit,
+                CompiledSQL(
+                    "INSERT INTO test (id) VALUES (:id)",
+                    [{'id': 1}, {'id': 2}]
+                ),
+                CompiledSQL(
+                    "SELECT test.foo AS test_foo FROM test "
+                    "WHERE test.id = :param_1",
+                    [{'param_1': 1}]
+                ),
+                CompiledSQL(
+                    "SELECT test.foo AS test_foo FROM test "
+                    "WHERE test.id = :param_1",
+                    [{'param_1': 2}]
+                )
+            )
+
+    def test_update_defaults_nonpresent(self):
+        Thing2 = self.classes.Thing2
+        s = Session()
+
+        t1, t2, t3, t4 = (
+            Thing2(id=1, foo=1, bar=2),
+            Thing2(id=2, foo=2, bar=3),
+            Thing2(id=3, foo=3, bar=4),
+            Thing2(id=4, foo=4, bar=5)
+        )
+
+        s.add_all([t1, t2, t3, t4])
+        s.flush()
+
+        t1.foo = 5
+        t2.foo = 6
+        t2.bar = 10
+        t3.foo = 7
+        t4.foo = 8
+        t4.bar = 12
+
+        if testing.db.dialect.implicit_returning:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s "
+                    "WHERE test2.id = %(test2_id)s "
+                    "RETURNING test2.bar",
+                    [{'foo': 5, 'test2_id': 1}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+                    "WHERE test2.id = %(test2_id)s",
+                    [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s "
+                    "WHERE test2.id = %(test2_id)s "
+                    "RETURNING test2.bar",
+                    [{'foo': 7, 'test2_id': 3}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+                    "WHERE test2.id = %(test2_id)s",
+                    [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+                    dialect='postgresql'
+                ),
+            )
+        else:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+                    [{'foo': 5, 'test2_id': 1}]
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo, bar=:bar "
+                    "WHERE test2.id = :test2_id",
+                    [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+                    [{'foo': 7, 'test2_id': 3}]
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo, bar=:bar "
+                    "WHERE test2.id = :test2_id",
+                    [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+                ),
+                CompiledSQL(
+                    "SELECT test2.bar AS test2_bar FROM test2 "
+                    "WHERE test2.id = :param_1",
+                    [{'param_1': 1}]
+                ),
+                CompiledSQL(
+                    "SELECT test2.bar AS test2_bar FROM test2 "
+                    "WHERE test2.id = :param_1",
+                    [{'param_1': 3}]
+                )
+            )
+
+        def go():
+            eq_(t1.bar, 2)
+            eq_(t2.bar, 10)
+            eq_(t3.bar, 4)
+            eq_(t4.bar, 12)
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_update_defaults_present_as_expr(self):
+        Thing2 = self.classes.Thing2
+        s = Session()
+
+        t1, t2, t3, t4 = (
+            Thing2(id=1, foo=1, bar=2),
+            Thing2(id=2, foo=2, bar=3),
+            Thing2(id=3, foo=3, bar=4),
+            Thing2(id=4, foo=4, bar=5)
+        )
+
+        s.add_all([t1, t2, t3, t4])
+        s.flush()
+
+        t1.foo = 5
+        t1.bar = text("1 + 1")
+        t2.foo = 6
+        t2.bar = 10
+        t3.foo = 7
+        t4.foo = 8
+        t4.bar = text("5 + 7")
+
+        if testing.db.dialect.implicit_returning:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 "
+                    "WHERE test2.id = %(test2_id)s "
+                    "RETURNING test2.bar",
+                    [{'foo': 5, 'test2_id': 1}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+                    "WHERE test2.id = %(test2_id)s",
+                    [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s "
+                    "WHERE test2.id = %(test2_id)s "
+                    "RETURNING test2.bar",
+                    [{'foo': 7, 'test2_id': 3}],
+                    dialect='postgresql'
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 "
+                    "WHERE test2.id = %(test2_id)s RETURNING test2.bar",
+                    [{'foo': 8, 'test2_id': 4}],
+                    dialect='postgresql'
+                ),
+            )
+        else:
+            self.assert_sql_execution(
+                testing.db,
+                s.flush,
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo, bar=1 + 1 "
+                    "WHERE test2.id = :test2_id",
+                    [{'foo': 5, 'test2_id': 1}]
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo, bar=:bar "
+                    "WHERE test2.id = :test2_id",
+                    [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+                    [{'foo': 7, 'test2_id': 3}]
+                ),
+                CompiledSQL(
+                    "UPDATE test2 SET foo=:foo, bar=5 + 7 "
+                    "WHERE test2.id = :test2_id",
+                    [{'foo': 8, 'test2_id': 4}],
+                ),
+                CompiledSQL(
+                    "SELECT test2.bar AS test2_bar FROM test2 "
+                    "WHERE test2.id = :param_1",
+                    [{'param_1': 1}]
+                ),
+                CompiledSQL(
+                    "SELECT test2.bar AS test2_bar FROM test2 "
+                    "WHERE test2.id = :param_1",
+                    [{'param_1': 3}]
+                ),
+                CompiledSQL(
+                    "SELECT test2.bar AS test2_bar FROM test2 "
+                    "WHERE test2.id = :param_1",
+                    [{'param_1': 4}]
+                )
+            )
+
+        def go():
+            eq_(t1.bar, 2)
+            eq_(t2.bar, 10)
+            eq_(t3.bar, 4)
+            eq_(t4.bar, 12)
+
+        self.assert_sql_count(testing.db, go, 0)
+
+    def test_insert_defaults_bulk_insert(self):
+        Thing = self.classes.Thing
+        s = Session()
+
+        mappings = [
+            {"id": 1},
+            {"id": 2}
+        ]
+
+        self.assert_sql_execution(
+            testing.db,
+            lambda: s.bulk_insert_mappings(Thing, mappings),
+            CompiledSQL(
+                "INSERT INTO test (id) VALUES (:id)",
+                [{'id': 1}, {'id': 2}]
+            )
+        )
+
+    def test_update_defaults_bulk_update(self):
+        Thing2 = self.classes.Thing2
+        s = Session()
+
+        t1, t2, t3, t4 = (
+            Thing2(id=1, foo=1, bar=2),
+            Thing2(id=2, foo=2, bar=3),
+            Thing2(id=3, foo=3, bar=4),
+            Thing2(id=4, foo=4, bar=5)
+        )
+
+        s.add_all([t1, t2, t3, t4])
+        s.flush()
+
+        mappings = [
+            {"id": 1, "foo": 5},
+            {"id": 2, "foo": 6, "bar": 10},
+            {"id": 3, "foo": 7},
+            {"id": 4, "foo": 8}
+        ]
+
+        self.assert_sql_execution(
+            testing.db,
+            lambda: s.bulk_update_mappings(Thing2, mappings),
+            CompiledSQL(
+                "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+                [{'foo': 5, 'test2_id': 1}]
+            ),
+            CompiledSQL(
+                "UPDATE test2 SET foo=:foo, bar=:bar "
+                "WHERE test2.id = :test2_id",
+                [{'foo': 6, 'bar': 10, 'test2_id': 2}]
+            ),
+            CompiledSQL(
+                "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+                [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}]
+            )
+        )
+
+    def test_update_defaults_present(self):
+        Thing2 = self.classes.Thing2
+        s = Session()
+
+        t1, t2 = (
+            Thing2(id=1, foo=1, bar=2),
+            Thing2(id=2, foo=2, bar=3)
+        )
+
+        s.add_all([t1, t2])
+        s.flush()
+
+        t1.bar = 5
+        t2.bar = 10
+
+        self.assert_sql_execution(
+            testing.db,
+            s.commit,
+            CompiledSQL(
+                "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s",
+                [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}],
+                dialect='postgresql'
+            )
+        )
+
 class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
     """test support for custom datatypes that return a non-__bool__ value
     when compared via __eq__(), eg. ticket 3469"""
index f42069230f4b641ebca01ad89e025342deb51f80..124053d4709a27b8f248636962738775a5e71daf 100644 (file)
@@ -894,19 +894,26 @@ class ServerVersioningTest(fixtures.MappedTest):
         class Bar(cls.Basic):
             pass
 
-    def _fixture(self, expire_on_commit=True):
+    def _fixture(self, expire_on_commit=True, eager_defaults=False):
         Foo, version_table = self.classes.Foo, self.tables.version_table
 
         mapper(
             Foo, version_table, version_id_col=version_table.c.version_id,
             version_id_generator=False,
+            eager_defaults=eager_defaults
         )
 
         s1 = Session(expire_on_commit=expire_on_commit)
         return s1
 
     def test_insert_col(self):
-        sess = self._fixture()
+        self._test_insert_col()
+
+    def test_insert_col_eager_defaults(self):
+        self._test_insert_col(eager_defaults=True)
+
+    def _test_insert_col(self, **kw):
+        sess = self._fixture(**kw)
 
         f1 = self.classes.Foo(value='f1')
         sess.add(f1)
@@ -935,7 +942,13 @@ class ServerVersioningTest(fixtures.MappedTest):
         self.assert_sql_execution(testing.db, sess.flush, *statements)
 
     def test_update_col(self):
-        sess = self._fixture()
+        self._test_update_col()
+
+    def test_update_col_eager_defaults(self):
+        self._test_update_col(eager_defaults=True)
+
+    def _test_update_col(self, **kw):
+        sess = self._fixture(**kw)
 
         f1 = self.classes.Foo(value='f1')
         sess.add(f1)
index cd9f632b9be198dc794f3ade03634273d5e96bb4..77a0c60075fba56ed3a6133a5b160623602cd8cb 100644 (file)
@@ -387,6 +387,31 @@ class ReturnDefaultsTest(fixtures.TablesTest):
             {"data": None, 'upddef': 1}
         )
 
+    def test_insert_all(self):
+        t1 = self.tables.t1
+        result = testing.db.execute(
+            t1.insert().values(upddef=1).return_defaults()
+        )
+        eq_(
+            dict(result.returned_defaults),
+            {"id": 1, "data": None, "insdef": 0}
+        )
+
+    def test_update_all(self):
+        t1 = self.tables.t1
+        testing.db.execute(
+            t1.insert().values(upddef=1)
+        )
+        result = testing.db.execute(
+            t1.update().
+            values(insdef=2).return_defaults()
+        )
+        eq_(
+            dict(result.returned_defaults),
+            {'upddef': 1}
+        )
+
+
 
 class ImplicitReturningFlag(fixtures.TestBase):
     __backend__ = True