set(propkey_to_col).intersection(state_dict).difference(
mapper._pk_keys_by_table[table])
)
+ has_all_defaults = True
else:
params = {}
for propkey in set(propkey_to_col).intersection(
value, state.committed_state[propkey]) is not True:
params[col.key] = value
+ if mapper.base_mapper.eager_defaults:
+ has_all_defaults = mapper._server_onupdate_default_cols[table].\
+ issubset(params)
+ else:
+ has_all_defaults = True
+
if update_version_id is not None and \
mapper.version_id_col in mapper._cols_by_table[table]:
params.update(pk_params)
yield (
state, state_dict, params, mapper,
- connection, value_params)
+ connection, value_params, has_all_defaults)
def _collect_post_update_commands(base_mapper, uowtransaction, table,
type_=mapper.version_id_col.type))
stmt = table.update(clause)
- if mapper.base_mapper.eager_defaults:
- stmt = stmt.return_defaults()
- elif mapper.version_id_col is not None:
- stmt = stmt.return_defaults(mapper.version_id_col)
-
return stmt
statement = base_mapper._memo(('update', table), update_stmt)
- for (connection, paramkeys, hasvalue), \
+ for (connection, paramkeys, hasvalue, has_all_defaults), \
records in groupby(
update,
lambda rec: (
rec[4], # connection
set(rec[2]), # set of parameter keys
- bool(rec[5]))): # whether or not we have "value" parameters
-
+ bool(rec[5]), # whether or not we have "value" parameters
+ rec[6] # has_all_defaults
+ )
+ ):
rows = 0
records = list(records)
assert_singlerow = connection.dialect.supports_sane_rowcount
assert_multirow = assert_singlerow and \
connection.dialect.supports_sane_multi_rowcount
- allow_multirow = not needs_version_id
+ allow_multirow = has_all_defaults and not needs_version_id
+
+ if bookkeeping and mapper.base_mapper.eager_defaults:
+ statement = statement.return_defaults()
+ elif mapper.version_id_col is not None:
+ statement = statement.return_defaults(mapper.version_id_col)
if hasvalue:
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = connection.execute(
statement.values(value_params),
params)
if not allow_multirow:
check_rowcount = assert_singlerow
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
c = cached_connections[connection].\
execute(statement, params)
rows += c.rowcount
for state, state_dict, params, mapper, \
- connection, value_params in records:
+ connection, value_params, has_all_defaults in records:
if bookkeeping:
_postfetch(
mapper,
bool(rec[5]), # whether we have "value" parameters
rec[6],
rec[7])):
+
if not bookkeeping or \
(
has_all_defaults
from test.orm import _fixtures
from sqlalchemy import exc, util
from sqlalchemy.testing import fixtures, config
-from sqlalchemy import Integer, String, ForeignKey, func, literal
+from sqlalchemy import Integer, String, ForeignKey, func, \
+ literal, FetchedValue, text
from sqlalchemy.orm import mapper, relationship, backref, \
create_session, unitofwork, attributes,\
Session, exc as orm_exc
eq_(t1.returning_val, 5)
+class EagerDefaultsTest(fixtures.MappedTest):
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ 'test', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer, server_default="3")
+ )
+
+ Table(
+ 'test2', metadata,
+ Column('id', Integer, primary_key=True),
+ Column('foo', Integer),
+ Column('bar', Integer, server_onupdate=FetchedValue())
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class Thing(cls.Basic):
+ pass
+
+ class Thing2(cls.Basic):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ Thing = cls.classes.Thing
+
+ mapper(Thing, cls.tables.test, eager_defaults=True)
+
+ Thing2 = cls.classes.Thing2
+
+ mapper(Thing2, cls.tables.test2, eager_defaults=True)
+
+ def test_insert_defaults_present(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1, foo=5),
+ Thing(id=2, foo=10)
+ )
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, :foo)",
+ [{'foo': 5, 'id': 1}, {'foo': 10, 'id': 2}]
+ ),
+ )
+
+ def go():
+ eq_(t1.foo, 5)
+ eq_(t2.foo, 10)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_present_as_expr(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1, foo=text("2 + 5")),
+ Thing(id=2, foo=text("5 + 5"))
+ )
+
+ s.add_all([t1, t2])
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+ "RETURNING test.foo",
+ [{'id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+ "RETURNING test.foo",
+ [{'id': 2}],
+ dialect='postgresql'
+ )
+ )
+
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, 2 + 5)",
+ [{'id': 1}]
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id, foo) VALUES (:id, 5 + 5)",
+ [{'id': 2}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 2}]
+ ),
+ )
+
+ def go():
+ eq_(t1.foo, 7)
+ eq_(t2.foo, 10)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_nonpresent(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ t1, t2 = (
+ Thing(id=1),
+ Thing(id=2)
+ )
+
+ s.add_all([t1, t2])
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (%(id)s) RETURNING test.foo",
+ [{'id': 2}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (:id)",
+ [{'id': 1}, {'id': 2}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test.foo AS test_foo FROM test "
+ "WHERE test.id = :param_1",
+ [{'param_1': 2}]
+ )
+ )
+
+ def test_update_defaults_nonpresent(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = 12
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 5, 'test2_id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 7, 'test2_id': 3}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 8, 'bar': 12, 'test2_id': 4}],
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 3}]
+ )
+ )
+
+ def go():
+ eq_(t1.bar, 2)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 4)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_update_defaults_present_as_expr(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t1.bar = text("1 + 1")
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = text("5 + 7")
+
+ if testing.db.dialect.implicit_returning:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=1 + 1 "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 5, 'test2_id': 1}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test2.id = %(test2_id)s",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s "
+ "WHERE test2.id = %(test2_id)s "
+ "RETURNING test2.bar",
+ [{'foo': 7, 'test2_id': 3}],
+ dialect='postgresql'
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=%(foo)s, bar=5 + 7 "
+ "WHERE test2.id = %(test2_id)s RETURNING test2.bar",
+ [{'foo': 8, 'test2_id': 4}],
+ dialect='postgresql'
+ ),
+ )
+ else:
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=1 + 1 "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=5 + 7 "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 8, 'test2_id': 4}],
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 1}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 3}]
+ ),
+ CompiledSQL(
+ "SELECT test2.bar AS test2_bar FROM test2 "
+ "WHERE test2.id = :param_1",
+ [{'param_1': 4}]
+ )
+ )
+
+ def go():
+ eq_(t1.bar, 2)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 4)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
+ def test_insert_defaults_bulk_insert(self):
+ Thing = self.classes.Thing
+ s = Session()
+
+ mappings = [
+ {"id": 1},
+ {"id": 2}
+ ]
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: s.bulk_insert_mappings(Thing, mappings),
+ CompiledSQL(
+ "INSERT INTO test (id) VALUES (:id)",
+ [{'id': 1}, {'id': 2}]
+ )
+ )
+
+ def test_update_defaults_bulk_update(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2, t3, t4 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3),
+ Thing2(id=3, foo=3, bar=4),
+ Thing2(id=4, foo=4, bar=5)
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ mappings = [
+ {"id": 1, "foo": 5},
+ {"id": 2, "foo": 6, "bar": 10},
+ {"id": 3, "foo": 7},
+ {"id": 4, "foo": 8}
+ ]
+
+ self.assert_sql_execution(
+ testing.db,
+ lambda: s.bulk_update_mappings(Thing2, mappings),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 5, 'test2_id': 1}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo, bar=:bar "
+ "WHERE test2.id = :test2_id",
+ [{'foo': 6, 'bar': 10, 'test2_id': 2}]
+ ),
+ CompiledSQL(
+ "UPDATE test2 SET foo=:foo WHERE test2.id = :test2_id",
+ [{'foo': 7, 'test2_id': 3}, {'foo': 8, 'test2_id': 4}]
+ )
+ )
+
+ def test_update_defaults_present(self):
+ Thing2 = self.classes.Thing2
+ s = Session()
+
+ t1, t2 = (
+ Thing2(id=1, foo=1, bar=2),
+ Thing2(id=2, foo=2, bar=3)
+ )
+
+ s.add_all([t1, t2])
+ s.flush()
+
+ t1.bar = 5
+ t2.bar = 10
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ CompiledSQL(
+ "UPDATE test2 SET bar=%(bar)s WHERE test2.id = %(test2_id)s",
+ [{'bar': 5, 'test2_id': 1}, {'bar': 10, 'test2_id': 2}],
+ dialect='postgresql'
+ )
+ )
+
class TypeWoBoolTest(fixtures.MappedTest, testing.AssertsExecutionResults):
"""test support for custom datatypes that return a non-__bool__ value
when compared via __eq__(), eg. ticket 3469"""