from typing import Collection
from typing import Deque
from typing import Dict
+from typing import FrozenSet
from typing import Generic
from typing import Iterable
from typing import Iterator
)
@HasMemoized.memoized_attribute
- def _server_default_cols(self):
+ def _server_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_default is not None
+ or (
+ col.default is not None
+ and col.default.is_clause_element
+ )
]
),
)
)
@HasMemoized.memoized_attribute
- def _server_default_plus_onupdate_propkeys(self):
- result = set()
-
- for table, columns in self._cols_by_table.items():
- for col in columns:
- if (
- col.server_default is not None
- or col.server_onupdate is not None
- ) and col in self._columntoproperty:
- result.add(self._columntoproperty[col].key)
-
- return result
-
- @HasMemoized.memoized_attribute
- def _server_onupdate_default_cols(self):
+ def _server_onupdate_default_cols(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[Column[Any]]]:
return dict(
(
table,
frozenset(
[
- col.key
- for col in columns
+ col
+ for col in cast("Iterable[Column[Any]]", columns)
if col.server_onupdate is not None
+ or (
+ col.onupdate is not None
+ and col.onupdate.is_clause_element
+ )
]
),
)
for table, columns in self._cols_by_table.items()
)
+ @HasMemoized.memoized_attribute
+ def _server_default_col_keys(self) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_onupdate_default_col_keys(
+ self,
+ ) -> Mapping[FromClause, FrozenSet[str]]:
+ return {
+ table: frozenset(col.key for col in cols if col.key is not None)
+ for table, cols in self._server_onupdate_default_cols.items()
+ }
+
+ @HasMemoized.memoized_attribute
+ def _server_default_plus_onupdate_propkeys(self) -> Set[str]:
+ result: Set[str] = set()
+
+ col_to_property = self._columntoproperty
+ for table, columns in self._server_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ for table, columns in self._server_onupdate_default_cols.items():
+ result.update(
+ col_to_property[col].key
+ for col in columns.intersection(col_to_property)
+ )
+ return result
+
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
Column("bar", Integer, server_onupdate=FetchedValue()),
)
+ Table(
+ "test3",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo", String(50), default=func.lower("HI")),
+ )
+
+ Table(
+ "test4",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("foo", Integer),
+ Column("bar", Integer, onupdate=text("5 + 3")),
+ )
+
@classmethod
def setup_classes(cls):
class Thing(cls.Basic):
class Thing2(cls.Basic):
pass
+ class Thing3(cls.Basic):
+ pass
+
+ class Thing4(cls.Basic):
+ pass
+
@classmethod
def setup_mappers(cls):
Thing = cls.classes.Thing
Thing2, cls.tables.test2, eager_defaults=True
)
- def test_insert_defaults_present(self):
+ Thing3 = cls.classes.Thing3
+
+ cls.mapper_registry.map_imperatively(
+ Thing3, cls.tables.test3, eager_defaults=True
+ )
+
+ Thing4 = cls.classes.Thing4
+
+ cls.mapper_registry.map_imperatively(
+ Thing4, cls.tables.test4, eager_defaults=True
+ )
+
+ def test_server_insert_defaults_present(self):
Thing = self.classes.Thing
s = fixture_session()
s.flush,
CompiledSQL(
"INSERT INTO test (id, foo) VALUES (:id, :foo)",
- [{"foo": 5, "id": 1}, {"foo": 10, "id": 2}],
+ [
+ {"foo": 5, "id": 1},
+ {"foo": 10, "id": 2},
+ ],
),
)
self.assert_sql_count(testing.db, go, 0)
- def test_insert_defaults_present_as_expr(self):
+ def test_server_insert_defaults_present_as_expr(self):
Thing = self.classes.Thing
s = fixture_session()
testing.db,
s.flush,
CompiledSQL(
- "INSERT INTO test (id, foo) VALUES (%(id)s, 2 + 5) "
+ "INSERT INTO test (id, foo) "
+ "VALUES (%(id)s, 2 + 5) "
"RETURNING test.foo",
[{"id": 1}],
dialect="postgresql",
),
CompiledSQL(
- "INSERT INTO test (id, foo) VALUES (%(id)s, 5 + 5) "
+ "INSERT INTO test (id, foo) "
+ "VALUES (%(id)s, 5 + 5) "
"RETURNING test.foo",
[{"id": 2}],
dialect="postgresql",
self.assert_sql_count(testing.db, go, 0)
- def test_insert_defaults_nonpresent(self):
+ def test_server_insert_defaults_nonpresent(self):
Thing = self.classes.Thing
s = fixture_session()
),
)
- def test_update_defaults_nonpresent(self):
+ def test_clientsql_insert_defaults_nonpresent(self):
+ Thing3 = self.classes.Thing3
+ s = fixture_session()
+
+ t1, t2 = (Thing3(id=1), Thing3(id=2))
+
+ s.add_all([t1, t2])
+
+ self.assert_sql_execution(
+ testing.db,
+ s.commit,
+ Conditional(
+ testing.db.dialect.insert_returning,
+ [
+ Conditional(
+ testing.db.dialect.insert_executemany_returning,
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 1}, {"id": 2}],
+ dialect="postgresql",
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 1}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (%(id)s, lower(%(lower_1)s)) "
+ "RETURNING test3.foo",
+ [{"id": 2}],
+ dialect="postgresql",
+ ),
+ ],
+ ),
+ ],
+ [
+ CompiledSQL(
+ "INSERT INTO test3 (id, foo) "
+ "VALUES (:id, lower(:lower_1))",
+ [
+ {"id": 1, "lower_1": "HI"},
+ {"id": 2, "lower_1": "HI"},
+ ],
+ ),
+ CompiledSQL(
+ "SELECT test3.foo AS test3_foo "
+ "FROM test3 WHERE test3.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test3.foo AS test3_foo "
+ "FROM test3 WHERE test3.id = :pk_1",
+ [{"pk_1": 2}],
+ ),
+ ],
+ ),
+ )
+
+ def test_server_update_defaults_nonpresent(self):
Thing2 = self.classes.Thing2
s = fixture_session()
self.assert_sql_count(testing.db, go, 0)
+ def test_clientsql_update_defaults_nonpresent(self):
+ Thing4 = self.classes.Thing4
+ s = fixture_session()
+
+ t1, t2, t3, t4 = (
+ Thing4(id=1, foo=1),
+ Thing4(id=2, foo=2),
+ Thing4(id=3, foo=3),
+ Thing4(id=4, foo=4),
+ )
+
+ s.add_all([t1, t2, t3, t4])
+ s.flush()
+
+ t1.foo = 5
+ t2.foo = 6
+ t2.bar = 10
+ t3.foo = 7
+ t4.foo = 8
+ t4.bar = 12
+
+ self.assert_sql_execution(
+ testing.db,
+ s.flush,
+ Conditional(
+ testing.db.dialect.update_returning,
+ [
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 "
+ "WHERE test4.id = %(test4_id)s RETURNING test4.bar",
+ [{"foo": 5, "test4_id": 1}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s "
+ "WHERE test4.id = %(test4_id)s",
+ [{"foo": 6, "bar": 10, "test4_id": 2}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=5 + 3 WHERE "
+ "test4.id = %(test4_id)s RETURNING test4.bar",
+ [{"foo": 7, "test4_id": 3}],
+ dialect="postgresql",
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=%(foo)s, bar=%(bar)s WHERE "
+ "test4.id = %(test4_id)s",
+ [{"foo": 8, "bar": 12, "test4_id": 4}],
+ dialect="postgresql",
+ ),
+ ],
+ [
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 5, "test4_id": 1}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=:bar "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 6, "bar": 10, "test4_id": 2}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=5 + 3 "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 7, "test4_id": 3}],
+ ),
+ CompiledSQL(
+ "UPDATE test4 SET foo=:foo, bar=:bar "
+ "WHERE test4.id = :test4_id",
+ [{"foo": 8, "bar": 12, "test4_id": 4}],
+ ),
+ CompiledSQL(
+ "SELECT test4.bar AS test4_bar FROM test4 "
+ "WHERE test4.id = :pk_1",
+ [{"pk_1": 1}],
+ ),
+ CompiledSQL(
+ "SELECT test4.bar AS test4_bar FROM test4 "
+ "WHERE test4.id = :pk_1",
+ [{"pk_1": 3}],
+ ),
+ ],
+ ),
+ )
+
+ def go():
+ eq_(t1.bar, 8)
+ eq_(t2.bar, 10)
+ eq_(t3.bar, 8)
+ eq_(t4.bar, 12)
+
+ self.assert_sql_count(testing.db, go, 0)
+
def test_update_defaults_present_as_expr(self):
Thing2 = self.classes.Thing2
s = fixture_session()
#! coding:utf-8
+from __future__ import annotations
+
+from typing import Tuple
from sqlalchemy import bindparam
from sqlalchemy import Column
class InsertTest(_InsertTestBase, fixtures.TablesTest, AssertsCompiledSQL):
__dialect__ = "default"
+ @testing.combinations(
+ ((), ("z",), ()),
+ (("x",), (), ()),
+ (("x",), ("y",), ("x", "y")),
+ (("x", "y"), ("y",), ("x", "y")),
+ )
+ def test_return_defaults_generative(
+ self,
+ initial_keys: Tuple[str, ...],
+ second_keys: Tuple[str, ...],
+ expected_keys: Tuple[str, ...],
+ ):
+ t = table("foo", column("x"), column("y"), column("z"))
+
+ initial_cols = tuple(t.c[initial_keys])
+ second_cols = tuple(t.c[second_keys])
+ expected = set(t.c[expected_keys])
+
+ stmt = t.insert().return_defaults(*initial_cols)
+ eq_(stmt._return_defaults_columns, initial_cols)
+ stmt = stmt.return_defaults(*second_cols)
+ assert isinstance(stmt._return_defaults_columns, tuple)
+ eq_(set(stmt._return_defaults_columns), expected)
+
def test_binds_that_match_columns(self):
"""test bind params named after column names
replace the normal SET/VALUES generation."""