# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/python/black
- rev: 22.8.0
+ rev: 23.3.0
hooks:
- id: black
def search_for_user(session, username, email=None):
-
baked_query = bakery(lambda session: session.query(User))
baked_query += lambda q: q.filter(User.name == bindparam("username"))
Base = declarative_base(cls=Base)
+
# existing mapping proceeds, Declarative will ignore any annotations
# which don't include ``Mapped[]``
class Foo(Base):
::
with engine.connect() as conn:
-
# (variable) stmt: Select[Tuple[str, int]]
stmt = select(str_col, int_col)
all the way from statement to result set::
with Session(engine) as session:
-
# (variable) stmt: Select[Tuple[int, str]]
stmt_1 = select(User.id, User.name)
as a SELECT against two mapped classes::
with Session(engine) as session:
-
# (variable) stmt: Select[Tuple[User, Address]]
stmt_2 = select(User, Address).join_from(User, Address)
class as well as the return type expected from a statement::
with Session(engine) as session:
-
# this is in fact an Annotated type, but typing tools don't
# generally display this
str50 = Annotated[str, 50]
+
# declarative base with a type-level override, using a type that is
# expected to be used in multiple places
class Base(DeclarativeBase):
by mixin columns. The following mapping::
class Foo:
-
col1 = mapped_column(Integer)
col3 = mapped_column(Integer)
class Bar:
-
col2 = mapped_column(Integer)
col4 = mapped_column(Integer)
class Model(Base, Foo, Bar):
-
id = mapped_column(Integer, primary_key=True)
__tablename__ = "model"
around, as::
class Foo:
-
id = mapped_column(Integer, primary_key=True)
col1 = mapped_column(Integer)
col3 = mapped_column(Integer)
class Model(Foo, Base):
-
col2 = mapped_column(Integer)
col4 = mapped_column(Integer)
__tablename__ = "model"
class Model(Foo, Base):
-
col2 = mapped_column(Integer)
col4 = mapped_column(Integer)
__tablename__ = "model"
# which... we usually don't.
with engine.connect() as connection:
-
connection.execution_options(isolation_level="AUTOCOMMIT")
# run statement(s) in autocommit mode
# use an autocommit block
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
-
# run statement in autocommit mode
connection.execute("<statement>")
with conn.execution_options(stream_results=True, max_row_buffer=100).execute(
text("select * from table")
) as result:
-
for row in result:
print(f"{row}")
class JSONEncodedDict(TypeDecorator):
-
impl = VARCHAR
cache_ok = True
class RoomBooking(Base):
-
__tablename__ = "room_booking"
id: Mapped[int] = mapped_column(primary_key=True)
class EventCalendar(Base):
-
__tablename__ = "event_calendar"
id: Mapped[int] = mapped_column(primary_key=True)
class RoomBooking(Base):
-
__tablename__ = "room_booking"
room = Column(Integer(), primary_key=True)
from sqlalchemy import select
if __name__ == "__main__":
-
engine = create_engine("mysql+mysqldb://scott:tiger@localhost/test", echo_pool=True)
def do_a_thing(engine):
# also correct !
foo = relationship(Dest, foreign_keys=[Dest.foo_id, Dest.bar_id])
+
# if you're using columns from the class that you're inside of, just use the column objects !
class MyClass(Base):
foo_id = Column(...)
class Mixin(MappedAsDataclass):
-
create_user: Mapped[int] = mapped_column()
update_user: Mapped[Optional[int]] = mapped_column(default=None, init=False)
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
.. _mixin_inheritance_columns:
from sqlalchemy.orm import DeclarativeBase
+
# declarative base class
class Base(DeclarativeBase):
pass
class SomeClass(Base):
-
# ...
# pep-484 type will be Optional, but column will be
async def insert_objects(async_session: async_sessionmaker[AsyncSession]) -> None:
-
async with async_session() as session:
async with session.begin():
session.add_all(
async def select_and_update_objects(
async_session: async_sessionmaker[AsyncSession],
) -> None:
-
async with async_session() as session:
stmt = select(A).options(selectinload(A.bs))
as an awaitable by indicating the :attr:`_asyncio.AsyncAttrs.awaitable_attrs`
prefix::
- a1 = await (session.scalars(select(A))).one()
+ a1 = (await session.scalars(select(A))).one()
for b1 in await a1.awaitable_attrs.bs:
print(b1)
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import mapped_column
+
# declarative base class
class Base(DeclarativeBase):
pass
>>> from sqlalchemy.orm import Session
>>> with Session(engine) as session:
- ...
... spongebob = User(
... name="spongebob",
... fullname="Spongebob Squarepants",
name = mapped_column(String)
parent_folder = relationship(
- "Folder", back_populates="child_folders",
- remote_side=[account_id, folder_id]
+ "Folder", back_populates="child_folders", remote_side=[account_id, folder_id]
)
child_folders = relationship("Folder", back_populates="parent_folder")
@event.listens_for(Session, "do_orm_execute")
def _do_orm_execute(orm_execute_state):
-
if (
orm_execute_state.is_select
and not orm_execute_state.is_column_load
self.session.commit()
def test_something_with_rollbacks(self):
-
self.session.add(Bar())
self.session.flush()
self.session.rollback()
# conn is an instance of AsyncConnection
async with engine.begin() as conn:
-
# to support SQLAlchemy DDL methods as well as legacy functions, the
# AsyncConnection.run_sync() awaitable method will pass a "sync"
# version of the AsyncConnection object to any synchronous method,
)
async with engine.connect() as conn:
-
# the default result object is the
# sqlalchemy.engine.Result object
result = await conn.execute(t1.select())
"""
async with async_sessionmaker() as oob_session:
-
# use AUTOCOMMIT for each connection to reduce transaction
# overhead / contention
await oob_session.connection(
async def async_main():
-
engine = create_async_engine(
"postgresql+asyncpg://scott:tiger@localhost/test",
echo=True,
# iterate through ColumnProperty objects
for col_attr in mapper.column_attrs:
-
# look at the Column mapped by the ColumnProperty
# (we look at the first column in the less common case
# of a property mapped to multiple columns at once)
@event.listens_for(col_attr, "init_scalar", retval=True, propagate=True)
def init_scalar(target, value, dict_):
-
if default.is_callable:
# the callable of ColumnDefault always accepts a context
# argument; we can pass it as None here.
if __name__ == "__main__":
-
Base = declarative_base()
event.listen(Base, "mapper_configured", configure_listener, propagate=True)
event.listen(session_factory, "do_orm_execute", self._do_orm_execute)
def _do_orm_execute(self, orm_context):
-
for opt in orm_context.user_defined_options:
if isinstance(opt, RelationshipCache):
opt = opt._process_orm_context(orm_context)
if __name__ == "__main__":
-
# set up a region based on the ScopedSessionBackend,
# pointing to the scoped_session declared in the example
# environment.
if __name__ == "__main__":
-
Base = declarative_base()
class User(HasPrivate, Base):
if __name__ == "__main__":
-
Base = declarative_base()
class Parent(HasTemporal, Base):
Base.metadata.create_all(engine)
with Session(engine) as session:
-
c = Company(
name="company1",
employees=[
Base.metadata.create_all(engine)
with Session(engine) as session:
-
c = Company(
name="company1",
employees=[
class Engineer(Person):
-
# illustrate a single-inh "conflicting" mapped_column declaration,
# where both subclasses want to share the same column that is nonetheless
# not "local" to the base class
Base.metadata.create_all(engine)
with Session(engine) as session:
-
c = Company(
name="company1",
employees=[
@classmethod
def main(cls):
-
parser = argparse.ArgumentParser("python -m examples.performance")
if cls.name is None:
session = Session(bind=engine)
for id_ in random.sample(ids, n):
-
stmt = lambdas.lambda_stmt(lambda: future_select(Customer))
stmt += lambda s: s.where(Customer.id == id_)
session.execute(stmt).scalar_one()
stmt = select(Customer.__table__).where(Customer.id == bindparam("id"))
with engine.connect() as conn:
for id_ in random.sample(ids, n):
-
row = conn.execute(stmt, {"id": id_}).first()
tuple(row)
quito.reports.append(Report(85))
async with Session() as sess:
-
sess.add_all(
[tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
quito.reports.append(Report(85))
with Session() as sess:
-
sess.add_all(
[tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
quito.reports.append(Report(85))
with Session() as sess:
-
sess.add_all(
[tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
quito.reports.append(Report(85))
with Session() as sess:
-
sess.add_all(
[tokyo, newyork, toronto, london, dublin, brasilia, quito]
)
for color, char in [
(data[i], data[i + 1]) for i in range(0, len(data), 2)
]:
-
x = self.x + col
y = self.y + row
if 0 <= x <= MAX_X and 0 <= y <= MAX_Y:
("enemy2", 25),
("enemy1", 10),
)
- for (ship_vert, (etype, score)) in zip(
+ for ship_vert, (etype, score) in zip(
range(5, 30, ENEMY_VERT_SPACING), arrangement
):
for ship_horiz in range(0, 50, 10):
def _history_mapper(local_mapper):
-
cls = local_mapper.class_
if cls.__dict__.get("_history_mapper_configured", False):
__dialect__ = "default"
def setUp(self):
-
self.engine = engine = create_engine("sqlite://")
self.session = Session(engine)
self.make_base()
}
class SubClass(BaseClass):
-
subname = Column(String(50), unique=True)
__mapper_args__ = {"polymorphic_identity": "sub"}
"""
for instance in session.dirty:
if hasattr(instance, "new_version") and session.is_modified(instance):
-
# make it transient
instance.new_version(session)
super().__init__(**kw)
def new_version(self, session):
-
# our current identity key, which will be used on the "old"
# version of us to emit an UPDATE. this is just for assertion purposes
old_identity_key = inspect(self).key
data = Column(String)
def new_version(self, session):
-
# expire parent's reference to us
session.expire(self.parent, ["child"])
if __name__ == "__main__":
-
Base = declarative_base()
class AnimalFact(PolymorphicVerticalProperty, Base):
if __name__ == "__main__":
-
Base = declarative_base()
class AnimalFact(Base):
class _UnicodeLiteral:
def literal_processor(self, dialect):
def process(value):
-
value = value.replace("'", "''")
if dialect.identifier_preparer._double_percents:
dialect: MSDialect
def _opt_encode(self, statement):
-
if self.compiled and self.compiled.schema_translate_map:
-
rst = self.compiled.preparer._render_schema_translates
statement = rst(statement, self.compiled.schema_translate_map)
@generate_driver_url.for_db("mssql")
def generate_driver_url(url, driver, query_str):
-
backend = url.get_backend_name()
new_url = url.set(drivername="%s+%s" % (backend, driver))
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
-
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
"""
def bind_processor(self, dialect):
-
super_process = super().bind_processor(dialect)
if not dialect._need_decimal_fix:
@classmethod
def get_pool_class(cls, url):
-
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@classmethod
def get_pool_class(cls, url):
-
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
class MySQLCompiler(compiler.SQLCompiler):
-
render_table_with_column_in_update_from = True
"""Overridden from base SQLCompiler value"""
tmp = " FOR UPDATE"
if select._for_update_arg.of and self.dialect.supports_for_update_of:
-
tables = util.OrderedSet()
for c in select._for_update_arg.of:
tables.update(sql_util.surface_selectables_only(c))
):
arg = opts[opt]
if opt in _reflection._options_of_type_string:
-
arg = self.sql_compiler.render_literal_value(
arg, sqltypes.String()
)
length = index.dialect_options[self.dialect.name]["length"]
if length is not None:
-
if isinstance(length, dict):
# length value can be a (column_name --> integer value)
# mapping specifying the prefix length for each column of the
@reflection.cache
def get_table_options(self, connection, table_name, schema=None, **kw):
-
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
-
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
]
if col_tuples:
-
correct_for_wrong_fk_case = connection.execute(
sql.text(
"""
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
-
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
-
charset = self._connection_charset
full_name = ".".join(
self.identifier_preparer._quote_free_identifiers(schema, view_name)
**kw,
):
if precision is None:
-
precision = getattr(type_, "precision", None)
if _requires_binary_precision:
# add expressions to accommodate FOR UPDATE OF
if for_update is not None and for_update.of:
-
adapter = sql_util.ClauseAdapter(inner_subquery)
for_update.of = [
adapter.traverse(elem) for elem in for_update.of
class OracleIdentifierPreparer(compiler.IdentifierPreparer):
-
reserved_words = {x.lower() for x in RESERVED_WORDS}
illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union(
["_", "$"]
# Oracle parameters and use the custom escaping here
escaped_from = kw.get("escaped_from", None)
if not escaped_from:
-
if self._bind_translate_re.search(name):
# not quite the translate use case as we want to
# also get a quick boolean if we even found
# check for has_out_parameters or RETURNING, create cx_Oracle.var
# objects if so
if self.compiled.has_out_parameters or self.compiled._oracle_returning:
-
out_parameters = self.out_parameters
assert out_parameters is not None
def _generate_cursor_outputtype_handler(self):
output_handlers = {}
- for (keyname, name, objects, type_) in self.compiled._result_columns:
+ for keyname, name, objects, type_ in self.compiled._result_columns:
handler = type_._cached_custom_processor(
self.dialect,
"cx_oracle_outputtypehandler",
and is_sql_compiler(self.compiled)
and self.compiled._oracle_returning
):
-
initial_buffer = self.fetchall_for_returning(
self.cursor, _internal=True
)
threaded=None,
**kwargs,
):
-
OracleDialect.__init__(self, **kwargs)
self.arraysize = arraysize
self.encoding_errors = encoding_errors
def output_type_handler(
cursor, name, default_type, size, precision, scale
):
-
if (
default_type == cx_Oracle.NUMBER
and default_type is not cx_Oracle.NATIVE_FLOAT
return output_type_handler
def on_connect(self):
-
output_type_handler = self._generate_connection_outputtype_handler()
def on_connect(conn):
def do_commit_twophase(
self, connection, xid, is_prepared=True, recover=False
):
-
if not is_prepared:
self.do_commit(connection.connection)
else:
thick_mode=None,
**kwargs,
):
-
super().__init__(
auto_convert_lobs,
coerce_to_decimal,
@drop_all_schema_objects_post_tables.for_db("oracle")
def _ora_drop_all_schema_objects_post_tables(cfg, eng):
-
with eng.begin() as conn:
for syn in conn.dialect._get_synonyms(conn, None, None, None):
conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
@stop_test_class_outside_fixtures.for_db("oracle")
def _ora_stop_test_class_outside_fixtures(config, db, cls):
-
try:
_purge_recyclebin(db)
except exc.DatabaseError as err:
log.info("db reaper connecting to %r", url)
eng = create_engine(url)
with eng.begin() as conn:
-
log.info("identifiers in file: %s", ", ".join(idents))
to_reap = conn.exec_driver_sql(
inherit_cache = True
def __init__(self, clauses, **kw):
-
type_arg = kw.pop("type_", None)
super().__init__(operators.comma_op, *clauses, **kw)
def _split_enum_values(array_string):
-
if '"' not in array_string:
# no escape char is present so it can just split on the comma
return array_string.split(",") if array_string else []
@classmethod
def adapt_emulated_to_native(cls, interval, **kw):
-
return AsyncPgInterval(precision=interval.second_precision)
adapt_connection = self._adapt_connection
async with adapt_connection._execute_mutex:
-
if not adapt_connection._started:
await adapt_connection._start_transaction()
class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
-
server_side = True
__slots__ = ("_rowbuffer",)
@classmethod
def get_pool_class(cls, url):
-
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
stringify_dialect = "postgresql"
def __init__(self, constraint=None, index_elements=None, index_where=None):
-
if constraint is not None:
if not isinstance(constraint, str) and isinstance(
constraint,
def __init__(self, *args, **kwargs):
args = list(args)
if len(args) > 1:
-
initial_arg = coercions.expect(
roles.ExpressionElementRole,
args.pop(0),
@classmethod
def get_pool_class(cls, url):
-
async_fallback = url.query.get("async_fallback", False)
if util.asbool(async_fallback):
@util.memoized_instancemethod
def _hstore_oids(self, dbapi_connection):
-
extras = self._psycopg2_extras
oids = extras.HstoreAdapter.get_oids(dbapi_connection)
if oids is not None and oids[0]:
@isolation_level.setter
def isolation_level(self, value):
-
# aiosqlite's isolation_level setter works outside the Thread
# that it's supposed to, necessitating setting check_same_thread=False.
# for improved stability, we instead invent our own awaitable version
return target_text
def visit_on_conflict_do_nothing(self, on_conflict, **kw):
-
target_text = self._on_conflict_target(on_conflict, **kw)
if target_text:
class SQLiteDDLCompiler(compiler.DDLCompiler):
def get_column_specification(self, column, **kwargs):
-
coltype = self.dialect.type_compiler_instance.process(
column.type, type_expression=column
)
return text
def visit_foreign_key_constraint(self, constraint, **kw):
-
local_table = constraint.elements[0].parent.table
remote_table = constraint.elements[0].column.table
persisted,
tablesql,
):
-
if generated:
# the type of a column "cc INTEGER GENERATED ALWAYS AS (1 + 42)"
# somehow is "INTEGER GENERATED ALWAYS"
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
-
auto_index_by_sig = {}
for idx in self.get_indexes(
connection,
stringify_dialect = "sqlite"
def __init__(self, index_elements=None, index_where=None):
-
if index_elements is not None:
self.constraint_target = None
self.inferred_target_elements = index_elements
)
def set_isolation_level(self, dbapi_connection, level):
-
if level == "AUTOCOMMIT":
dbapi_connection.isolation_level = None
else:
self._handle_dbapi_exception(e, None, None, None, None)
def _commit_impl(self) -> None:
-
if self._has_events or self.engine._has_events:
self.dispatch.commit(self)
_CoreMultiExecuteParams,
_CoreSingleExecuteParams,
]:
-
event_multiparams: _CoreMultiExecuteParams
event_params: _CoreSingleExecuteParams
)
if self._echo:
-
self._log_info(str_statement)
stats = context._get_cache_stats()
generic_setinputsizes,
context,
):
-
if imv_batch.processed_setinputsizes:
try:
dialect.do_set_input_sizes(
)
if self._echo:
-
self._log_info(sql_util._long_statement(sub_stmt))
imv_stats = f""" {
engine = engineclass(pool, dialect, u, **engine_args)
if _initialize:
-
do_on_connect = dialect.on_connect_url(u)
if do_on_connect:
def _splice_horizontally(
self, other: CursorResultMetaData
) -> CursorResultMetaData:
-
assert not self._tuplefilter
keymap = dict(self._keymap)
)
else:
-
# no dupes - copy secondary elements from compiled
# columns into self._keymap. this is the most common
# codepath for Core / ORM statement executions before the
for idx, rmap_entry in enumerate(result_columns)
]
else:
-
# name-based or text-positional cases, where we need
# to read cursor.description names
def _key_fallback(
self, key: Any, err: Optional[Exception], raiseerr: bool = True
) -> Optional[NoReturn]:
-
if raiseerr:
if self._unpickled and isinstance(key, elements.ColumnElement):
raise exc.NoSuchColumnError(
return index
def _indexes_for_keys(self, keys):
-
try:
return [self._keymap[key][0] for key in keys]
except KeyError as ke:
self._metadata = self._no_result_metadata
def _init_metadata(self, context, cursor_description):
-
if context.compiled:
compiled = context.compiled
return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
def _ensure_has_table_connection(self, arg):
-
if not isinstance(arg, Connection):
raise exc.ArgumentError(
"The argument passed to Dialect.has_table() should be a "
self._set_connection_characteristics(connection, characteristics)
def _set_connection_characteristics(self, connection, characteristics):
-
characteristic_values = [
(name, self.connection_characteristics[name], value)
for name, value in characteristics.items()
@util.memoized_instancemethod
def _gen_allowed_isolation_levels(self, dbapi_conn):
-
try:
raw_levels = list(self.get_isolation_level_values(dbapi_conn))
except NotImplementedError:
scope,
**kw,
):
-
names_fns = []
temp_names_fns = []
if ObjectKind.TABLE in kind:
class StrCompileDialect(DefaultDialect):
-
statement_compiler = compiler.StrSQLCompiler
ddl_compiler = compiler.DDLCompiler
type_compiler_cls = compiler.StrSQLTypeCompiler
[name for param, name in out_bindparams]
),
):
-
type_ = bindparam.type
impl_type = type_.dialect_impl(self.dialect)
dbapi_type = impl_type.get_dbapi_type(self.dialect.loaded_dbapi)
return [getter(None, param) for param in self.compiled_parameters]
def _setup_ins_pk_from_implicit_returning(self, result, rows):
-
if not rows:
return []
target: Union[Engine, Type[Engine], Dialect, Type[Dialect]],
identifier: str,
) -> Optional[Union[Dialect, Type[Dialect]]]:
-
if isinstance(target, type):
if issubclass(target, Engine):
return Dialect
def _construct(
cls, init: Callable[..., Any], bind: Union[Engine, Connection]
) -> Inspector:
-
if hasattr(bind.dialect, "inspector"):
cls = bind.dialect.inspector # type: ignore[attr-defined]
exclude_columns: Collection[str],
cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
) -> None:
-
orig_name = col_d["name"]
table.metadata.dispatch.column_reflect(self, table, col_d)
def _getter(
self, key: Any, raiseerr: bool = True
) -> Optional[Callable[[Row[Any]], Any]]:
-
index = self._index_for_key(key, raiseerr)
if index is not None:
@HasMemoized_ro_memoized_attribute
def _iterator_getter(self) -> Callable[..., Iterator[_R]]:
-
make_row = self._row_getter
post_creational_filter = self._post_creational_filter
return [make_row(row) for row in rows]
def _allrows(self) -> List[_R]:
-
post_creational_filter = self._post_creational_filter
make_row = self._row_getter
return all(isinstance(target.dispatch, t) for t in types)
def dispatch_parent_is(t: Type[Any]) -> bool:
-
return isinstance(
cast("_JoinedDispatcher[_ET]", target.dispatch).parent, t
)
if len(argnames) == len(argspec.args) and has_kw is bool(
argspec.varkw
):
-
formatted_def = "def %s(%s%s)" % (
dispatch_collection.name,
", ".join(dispatch_collection.arg_names),
retval: Optional[bool] = None,
asyncio: bool = False,
) -> None:
-
target, identifier = self.dispatch_target, self.identifier
dispatch_collection = getattr(target.dispatch, identifier)
"contains() doesn't apply to a scalar object endpoint; use =="
)
else:
-
return self._comparator._criterion_exists(
**{self.value_attr: other}
)
if self.sync_connection:
raise exc.InvalidRequestError("connection is already started")
self.sync_connection = self._assign_proxied(
- await (greenlet_spawn(self.sync_engine.connect))
+ await greenlet_spawn(self.sync_engine.connect)
)
return self
async def invalidate(
self, exception: Optional[BaseException] = None
) -> None:
-
"""Invalidate the underlying DBAPI connection associated with
this :class:`_engine.Connection`.
return AsyncEngine(self.sync_engine.execution_options(**opt))
async def dispose(self, close: bool = True) -> None:
-
"""Dispose of the connection pool used by this
:class:`_asyncio.AsyncEngine`.
identity_token: Optional[Any] = None,
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
) -> Optional[_O]:
-
"""Return an instance based on the given primary key identifier,
or ``None`` if not found.
bind_arguments: Optional[_BindArguments] = None,
**kw: Any,
) -> AsyncResult[Any]:
-
"""Execute a statement and return a streaming
:class:`_asyncio.AsyncResult` object.
by_module_properties: ByModuleProperties = cls.by_module
for token in map_config.cls.__module__.split("."):
-
if token not in by_module_properties:
by_module_properties[token] = util.Properties({})
name_for_collection_relationship: NameForCollectionRelationshipType,
generate_relationship: GenerateRelationshipType,
) -> None:
-
map_config = table_to_map_config.get(lcl_m2m, None)
referred_cfg = table_to_map_config.get(rem_m2m, None)
if map_config is None or referred_cfg is None:
for sup_ in scls.__mro__[1:]:
sup_sm = _mapper_or_none(sup_)
if sup_sm:
-
sm._set_concrete_base(sup_sm)
break
# first collect the primary __table__ for each class into a
# collection of metadata/schemaname -> table names
for thingy in to_map:
-
if thingy.local_table is not None:
metadata_to_table[
(thingy.local_table.metadata, thingy.local_table.schema)
metadata = mapper.class_.metadata
for rel in mapper._props.values():
-
if (
isinstance(rel, relationships.RelationshipProperty)
and rel._init_args.secondary._is_populated()
):
-
secondary_arg = rel._init_args.secondary
if isinstance(secondary_arg.argument, Table):
def iter_for_shard(
shard_id: ShardIdentifier,
) -> Union[Result[_T], IteratorResult[_TP]]:
-
bind_arguments = dict(orm_context.bind_arguments)
bind_arguments["shard_id"] = shard_id
class HybridExtensionType(InspectionAttrExtensionType):
-
HYBRID_METHOD = "HYBRID_METHOD"
"""Symbol indicating an :class:`InspectionAttr` that's
of type :class:`.hybrid_method`.
def _get_comparator(
self, comparator: Any
) -> Callable[[Any], _HybridClassLevelAccessor[_T]]:
-
proxy_attr = attributes.create_proxied_attribute(self)
def expr_comparator(
and stmt.lvalues[0].name in mapped_attr_lookup
and isinstance(stmt.lvalues[0].node, Var)
):
-
left_node = stmt.lvalues[0].node
python_type_for_type = mapped_attr_lookup[
and isinstance(stmt.rvalue.args[0].callee, RefExpr)
)
):
-
new_python_type_for_type = (
infer.infer_type_from_right_hand_nameexpr(
api,
api: SemanticAnalyzerPluginInterface,
is_mixin_scan: bool = False,
) -> Optional[List[util.SQLAlchemyAttribute]]:
-
info = util.info_for_cls(cls, api)
if info is None:
sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
if sym is not None and isinstance(sym.node, TypeInfo):
if names.has_base_type_id(sym.node, names.TYPEENGINE):
-
left_hand_explicit_type = UnionType(
[
infer.extract_python_type_from_typeengine(
elif isinstance(stmt.rvalue, CallExpr) and isinstance(
stmt.rvalue.callee, RefExpr
):
-
python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
)
)
if left_hand_explicit_type is not None:
-
return _infer_type_from_left_and_inferred_right(
api, node, left_hand_explicit_type, python_type_for_type
)
def get_class_decorator_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
-
sym = self.lookup_fully_qualified(fullname)
if sym is not None and sym.node is not None:
selected_level = self.logger.getEffectiveLevel()
if level >= selected_level:
-
if STACKLEVEL:
kwargs["stacklevel"] = (
kwargs.get("stacklevel", 1) + STACKLEVEL_OFFSET
def __go(lcls: Any) -> None:
-
_sa_util.preloaded.import_prefix("sqlalchemy.orm")
_sa_util.preloaded.import_prefix("sqlalchemy.ext")
last_parent is not False
and last_parent.key != parent_state.key
):
-
if last_parent.obj() is None:
raise orm_exc.StaleDataError(
"Removing state %s from parent "
else:
original = state.committed_state.get(self.key, _NO_HISTORY)
if original is PASSIVE_NO_RESULT:
-
loader_passive = passive | (
PASSIVE_ONLY_PERSISTENT
| NO_AUTOFLUSH
and original is not NO_VALUE
and original is not current
):
-
ret.append((instance_state(original), original))
return ret
def _initialize_collection(
self, state: InstanceState[Any]
) -> Tuple[CollectionAdapter, _AdaptedCollectionProtocol]:
-
adapter, collection = state.manager.initialize_collection(
self.key, state, self.collection_factory
)
initiator is not check_remove_token
and initiator is not check_replace_token
):
-
if not check_for_dupes_on_remove or not util.has_dupes(
# when this event is called, the item is usually
# present in the list, except for a pop() operation.
elif original is _NO_HISTORY:
return cls((), list(current), ())
else:
-
current_states = [
((c is not None) and instance_state(c) or None, c)
for c in current
backref: Optional[str] = None,
**kw: Any,
) -> QueryableAttribute[Any]:
-
manager = manager_of_class(class_)
if uselist:
factory = kw.pop("typecallable", None)
bookkeeping = False
for table, super_mapper in mappers_to_run:
-
# find bindparams in the statement. For bulk, we don't really know if
# a key in the params applies to a different table since we are
# potentially inserting for multiple tables here; looking at the
def _get_orm_crud_kv_pairs(
cls, mapper, statement, kv_iterator, needs_to_be_cacheable
):
-
core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
for k, v in kv_iterator:
"""
if orm_level_statement._returning:
-
fs = FromStatement(
orm_level_statement._returning,
dml_level_statement,
bind_arguments,
result,
):
-
execution_context = result.context
compile_state = execution_context.compiled.compile_state
bind_arguments,
is_pre_event,
):
-
(
update_options,
execution_options,
session._autoflush()
if update_options._dml_strategy == "orm":
-
if update_options._synchronize_session == "auto":
update_options = cls._do_pre_synchronize_auto(
session,
bind_arguments,
result,
):
-
# this stage of the execution is called after the
# do_orm_execute event hook. meaning for an extension like
# horizontal sharding, this step happens *within* the horizontal
bind_arguments,
update_options,
):
-
try:
eval_condition = cls._eval_condition_from_statement(
update_options, statement
bind_arguments,
is_pre_event,
):
-
(
insert_options,
execution_options,
bind_arguments: _BindArguments,
conn: Connection,
) -> _result.Result:
-
insert_options = execution_options.get(
"_sa_orm_insert_options", cls.default_insert_options
)
@classmethod
def create_for_statement(cls, statement, compiler, **kw) -> BulkORMInsert:
-
self = cast(
BulkORMInsert,
super().create_for_statement(statement, compiler, **kw),
class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
@classmethod
def create_for_statement(cls, statement, compiler, **kw):
-
self = cls.__new__(cls)
dml_strategy = statement._annotations.get(
bind_arguments: _BindArguments,
conn: Connection,
) -> _result.Result:
-
update_options = execution_options.get(
"_sa_orm_update_options", cls.default_update_options
)
is_delete_using: bool = False,
is_executemany: bool = False,
) -> bool:
-
# normal answer for "should we use RETURNING" at all.
normal_answer = (
dialect.update_returning and mapper.local_table.implicit_returning
def _do_post_synchronize_evaluate(
cls, session, statement, result, update_options
):
-
matched_objects = cls._get_matched_objects_on_criteria(
update_options,
session.identity_map.all_states(),
states = set()
for obj, state, dict_ in matched_objects:
-
to_evaluate = state.unmodified.intersection(evaluated_keys)
for key in to_evaluate:
bind_arguments: _BindArguments,
conn: Connection,
) -> _result.Result:
-
update_options = execution_options.get(
"_sa_orm_update_options", cls.default_update_options
)
is_delete_using: bool = False,
is_executemany: bool = False,
) -> bool:
-
# normal answer for "should we use RETURNING" at all.
normal_answer = (
dialect.delete_returning and mapper.local_table.implicit_returning
Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]],
Callable[[str, bool], _class_resolver],
]:
-
global _fallback_dict
if _fallback_dict is None:
# Did factory callable return a builtin?
if cls in __canned_instrumentation:
-
# if so, just convert.
# in previous major releases, this codepath wasn't working and was
# not covered by tests. prior to that it supplied a "wrapper"
def __go(lcls):
-
global keyfunc_mapping, mapped_collection
global column_keyed_dict, column_mapped_collection
global MappedCollection, KeyFuncDict
bind_arguments,
is_pre_event,
):
-
# consume result-level load_options. These may have been set up
# in an ORMExecuteState hook
(
def _column_naming_convention(
cls, label_style: SelectLabelStyle, legacy: bool
) -> _LabelConventionCallable:
-
if legacy:
def name(col, col_name=None):
bind_arguments,
is_pre_event,
):
-
# consume result-level load_options. These may have been set up
# in an ORMExecuteState hook
(
compiler: Optional[SQLCompiler],
**kw: Any,
) -> ORMFromStatementCompileState:
-
assert isinstance(statement_container, FromStatement)
if compiler is not None and compiler.stack:
self._adapt_on_names = _adapt_on_names
def _compiler_dispatch(self, compiler, **kw):
-
"""provide a fixed _compiler_dispatch method.
This is roughly similar to using the sqlalchemy.ext.compiler
select_statement._with_options
or select_statement._memoized_select_entities
):
-
for (
memoized_entities
) in select_statement._memoized_select_entities:
@classmethod
def from_statement(cls, statement, from_statement):
-
from_statement = coercions.expect(
roles.ReturnsRowsRole,
from_statement,
return statement
def _simple_statement(self):
-
statement = self._select_statement(
self.primary_columns + self.secondary_columns,
tuple(self.from_clauses) + tuple(self.eager_joins.values()),
suffixes,
group_by,
):
-
statement = Select._create_raw_select(
_raw_columns=raw_columns,
_from_obj=from_obj,
return cols
def _get_current_adapter(self):
-
adapters = []
if self._from_obj_alias:
return _adapt_clause
def _join(self, args, entities_collection):
- for (right, onclause, from_, flags) in args:
+ for right, onclause, from_, flags in args:
isouter = flags["isouter"]
full = flags["full"]
# test for joining to an unmapped selectable as the target
if r_info.is_clause_element:
-
if prop:
right_mapper = prop.mapper
)
and ext_info not in self.extra_criteria_entities
):
-
self.extra_criteria_entities[ext_info] = (
ext_info,
ext_info._adapter if ext_info.is_aliased_class else None,
search = set(self.extra_criteria_entities.values())
- for (ext_info, adapter) in search:
+ for ext_info, adapter in search:
if ext_info in self._join_entities:
continue
def to_compile_state(
cls, compile_state, entities, entities_collection, is_current_entities
):
-
for idx, entity in enumerate(entities):
if entity._is_lambda_element:
if entity._is_sequence:
return _entity_corresponds_to(self.entity_zero, entity)
def _get_entity_clauses(self, compile_state):
-
adapter = None
if not self.is_aliased_class:
class _BundleEntity(_QueryEntity):
-
_extra_entities = ()
__slots__ = (
or ("additional_entity_criteria", self.mapper)
in compile_state.global_attributes
):
-
compile_state.extra_criteria_entities[ezero] = (
ezero,
ezero._adapter if ezero.is_aliased_class else None,
"""
if typing.TYPE_CHECKING:
-
__tablename__: Any
"""String name to assign to the generated
:class:`_schema.Table` object, if not specified directly via
def _resolve_type(
self, python_type: _MatchedOnType
) -> Optional[sqltypes.TypeEngine[Any]]:
-
search: Iterable[Tuple[_MatchedOnType, Type[Any]]]
python_type_type: Type[Any]
def _as_declarative(
registry: _RegistryType, cls: Type[Any], dict_: _ClassDict
) -> Optional[_MapperConfig]:
-
# declarative scans the class for attributes. no table or mapper
# args passed separately.
return _MapperConfig.setup_mapping(registry, cls, dict_, None, {})
)
def set_cls_attribute(self, attrname: str, value: _T) -> _T:
-
manager = instrumentation.manager_of_class(self.cls)
manager.install_member(attrname, value)
return value
table: Optional[FromClause],
mapper_kw: _MapperKwArgs,
):
-
# grab class dict before the instrumentation manager has been added.
# reduces cycles
self.clsdict_view = (
return getattr(cls, key, obj) is not obj
else:
-
all_datacls_fields = {
f.name: f.metadata[sa_dataclass_metadata_key]
for f in util.dataclass_fields(cls)
if not sa_dataclass_metadata_key:
- def local_attributes_for_class() -> Iterable[
- Tuple[str, Any, Any, bool]
- ]:
+ def local_attributes_for_class() -> (
+ Iterable[Tuple[str, Any, Any, bool]]
+ ):
return (
(
name,
fixed_sa_dataclass_metadata_key = sa_dataclass_metadata_key
- def local_attributes_for_class() -> Iterable[
- Tuple[str, Any, Any, bool]
- ]:
+ def local_attributes_for_class() -> (
+ Iterable[Tuple[str, Any, Any, bool]]
+ ):
for name in names:
field = dataclass_fields.get(name, None)
if field and sa_dataclass_metadata_key in field.metadata:
local_attributes_for_class,
locally_collected_columns,
) in bases:
-
# this transfer can also take place as we scan each name
# for finer-grained control of how collected_attributes is
# populated, as this is what impacts column ordering.
self.mapper_args_fn = mapper_args_fn
def _setup_dataclasses_transforms(self) -> None:
-
dataclass_setup_arguments = self.dataclass_setup_arguments
if not dataclass_setup_arguments:
return
new_anno = {}
for name, annotation in cls_annotations.items():
if _is_mapped_annotation(annotation, klass, klass):
-
extracted = _extract_mapped_subtype(
annotation,
klass,
expect_mapped: Optional[bool],
attr_value: Any,
) -> Optional[_CollectedAnnotation]:
-
if name in self.collected_annotations:
return self.collected_annotations[name]
# copy mixin columns to the mapped class
for name, obj, annotation, is_dataclass in attributes_for_class():
-
if (
not fixed_table
and obj is None
setattr(cls, name, obj)
elif isinstance(obj, (Column, MappedColumn)):
-
if attribute_is_overridden(name, obj):
# if column has been overridden
# (like by the InstrumentedAttribute of the
look_for_dataclass_things = bool(self.dataclass_setup_arguments)
for k in list(collected_attributes):
-
if k in _include_dunders:
continue
assert expect_annotations_wo_mapped
if isinstance(value, _DCAttributeOptions):
-
if (
value._has_dataclass_arguments
and not look_for_dataclass_things
for key, c in list(our_stuff.items()):
if isinstance(c, _MapsColumns):
-
mp_to_assign = c.mapper_property_to_assign
if mp_to_assign:
our_stuff[key] = mp_to_assign
table_cls = Table
if tablename is not None:
-
args: Tuple[Any, ...] = ()
table_kw: Dict[str, Any] = {}
and self.inherits is None
and not _get_immediate_cls_attr(cls, "__no_table__")
):
-
raise exc.InvalidRequestError(
"Class %r does not have a __table__ or __tablename__ "
"specified and does not inherit from an existing "
)
if table is None:
-
# single table inheritance.
# ensure no table args
if table_args:
isdelete,
childisdelete,
):
-
if self.post_update:
-
child_post_updates = unitofwork.PostUpdateAll(
uow, self.mapper.primary_base_mapper, False
)
after_save,
before_delete,
):
-
if self.post_update:
parent_post_updates = unitofwork.PostUpdateAll(
uow, self.parent.primary_base_mapper, False
isdelete,
childisdelete,
):
-
if self.post_update:
-
if not isdelete:
parent_post_updates = unitofwork.PostUpdateAll(
uow, self.parent.primary_base_mapper, False
and not self.cascade.delete_orphan
and not self.passive_deletes == "all"
):
-
# post_update means we have to update our
# row to not reference the child object
# before we can DELETE the row
after_save,
before_delete,
):
-
uow.dependencies.update(
[
(parent_saves, after_save),
tmp.update((c, state) for c in history.added + history.deleted)
if need_cascade_pks:
-
for child in history.unchanged:
associationrow = {}
sync.update(
def _synchronize(
self, state, child, associationrow, clearkeys, uowcommit, operation
):
-
# this checks for None if uselist=True
self._verify_canload(child)
def _populate_composite_bulk_save_mappings_fn(
self,
) -> Callable[[Dict[str, Any]], None]:
-
if self._generated_composite_accessor:
get_values = self._generated_composite_accessor
else:
def _comparator_factory(
self, mapper: Mapper[Any]
) -> Type[PropComparator[_T]]:
-
comparator_callable = None
for m in self.parent.iterate_to_root():
cls, target: Any, identifier: str
) -> Union[Session, type]:
if isinstance(target, scoped_session):
-
target = target.session_factory
if not isinstance(target, sessionmaker) and (
not isinstance(target, type) or not issubclass(target, Session)
if is_instance_event:
if not raw or restore_load_context:
-
fn = event_key._listen_fn
def wrap(
propagate: bool = False,
include_key: bool = False,
) -> None:
-
target, fn = event_key.dispatch_target, event_key._listen_fn
if active_history:
] = None,
init_method: Optional[Callable[..., None]] = None,
) -> None:
-
if mapper:
self.mapper = mapper # type: ignore[assignment]
if registry:
def _get_context_loader(
self, context: ORMCompileState, path: AbstractEntityRegistry
) -> Optional[_LoadElement]:
-
load: Optional[_LoadElement] = None
search_path = path[self]
"""
instance = session.identity_map.get(key)
if instance is not None:
-
state = attributes.instance_state(instance)
if mapper.inherits and not state.mapper.isa(mapper):
require_pk_cols: bool = False,
is_user_refresh: bool = False,
):
-
"""Load the given primary key identity from the database."""
query = statement
identity_token=None,
is_user_refresh=None,
):
-
compile_options = {}
load_options = {}
if version_check:
polymorphic_discriminator=None,
**kw,
):
-
if with_polymorphic:
poly_properties = mapper._iterate_polymorphic_properties(
with_polymorphic
polymorphic_discriminator is not None
and polymorphic_discriminator is not mapper.polymorphic_on
):
-
if adapter:
pd = adapter.columns[polymorphic_discriminator]
else:
is_not_primary_key = _none_set.intersection
def _instance(row):
-
# determine the state that we'll be populating
if refresh_identity_key:
# fixed state that we're refreshing
def _populate_partial(
context, row, state, dict_, isnew, load_path, unloaded, populators
):
-
if not isnew:
if unloaded:
# extra pass, see #8166
def _validate_version_id(mapper, state, dict_, row, getter):
-
if mapper._get_state_attr_by_column(
state, dict_, mapper.version_id_col
) != getter(row):
# currently use state.key
statement = mapper._optimized_get_statement(state, attribute_names)
if statement is not None:
-
# undefer() isn't needed here because statement has the
# columns needed already, this implicitly undefers that column
stmt = FromStatement(mapper, statement)
from ..util.typing import Literal
if TYPE_CHECKING:
-
from . import AttributeEventToken
from . import Mapper
from ..sql.elements import ColumnElement
def _cols(self, mapper: Mapper[_KT]) -> Sequence[ColumnElement[_KT]]:
cols: List[ColumnElement[_KT]] = []
metadata = getattr(mapper.local_table, "metadata", None)
- for (ckey, tkey) in self.colkeys:
+ for ckey, tkey in self.colkeys:
if tkey is None or metadata is None or tkey not in metadata:
cols.append(mapper.local_table.c[ckey]) # type: ignore
else:
# while a configure_mappers() is occurring (and defer a
# configure_mappers() until construction succeeds)
with _CONFIGURE_MUTEX:
-
cast("MapperEvents", self.dispatch._events)._new_mapper_instance(
class_, self
)
)
if self._primary_key_argument:
-
coerced_pk_arg = [
self._str_arg_to_mapped_col("primary_key", c)
if isinstance(c, str)
}
def _configure_properties(self) -> None:
-
self.columns = self.c = sql_base.ColumnCollection() # type: ignore
# object attribute names mapped to MapperProperty objects
incoming_prop = explicit_col_props_by_key.get(key)
if incoming_prop:
-
new_prop = self._reconcile_prop_with_incoming_columns(
key,
inherited_prop,
Sequence[KeyedColumnElement[Any]], KeyedColumnElement[Any]
],
) -> ColumnProperty[Any]:
-
columns = util.to_list(column)
mapped_column = []
for c in columns:
incoming_prop: Optional[ColumnProperty[Any]] = None,
single_column: Optional[KeyedColumnElement[Any]] = None,
) -> ColumnProperty[Any]:
-
if incoming_prop and (
self.concrete
or not isinstance(existing_prop, properties.ColumnProperty)
def _is_orphan(self, state: InstanceState[_O]) -> bool:
orphan_possible = False
for mapper in self.iterate_to_root():
- for (key, cls) in mapper._delete_orphans:
+ for key, cls in mapper._delete_orphans:
orphan_possible = True
has_parent = attributes.manager_of_class(cls).has_parent(
@HasMemoized.memoized_instancemethod
def __clause_element__(self):
-
annotations: Dict[str, Any] = {
"entity_namespace": self,
"parententity": self,
def _get_committed_state_attr_by_column(
self, state, dict_, column, passive=PassiveFlag.PASSIVE_RETURN_NO_VALUE
):
-
prop = self._columntoproperty[column]
return state.manager[prop.key].impl.get_committed_value(
state, dict_, passive=passive
m = m.inherits
for prop in self.attrs:
-
# skip prop keys that are not instrumented on the mapped class.
# this is primarily the "_sa_polymorphic_on" property that gets
# created for an ad-hoc polymorphic_on SQL expression, issue #8704
in_expr.in_(sql.bindparam("primary_keys", expanding=True))
).order_by(*primary_key)
else:
-
q = sql.select(self).set_label_style(
LABEL_STYLE_TABLENAME_PLUS_COL
)
return
_already_compiling = True
try:
-
# double-check inside mutex
for reg in registries:
if reg._new_mappers:
def _do_configure_registries(
registries: Set[_RegistryType], cascade: bool
) -> None:
-
registry = util.preloaded.orm_decl_api.registry
orig = set(registries)
@util.preload_module("sqlalchemy.orm.decl_api")
def _dispose_registries(registries: Set[_RegistryType], cascade: bool) -> None:
-
registry = util.preloaded.orm_decl_api.registry
orig = set(registries)
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction, states
):
-
has_identity = bool(state.key)
instance_key = state.key or mapper._identity_key_from_state(state)
for state, dict_, mapper, connection in _connections_for_states(
base_mapper, uowtransaction, states
):
-
mapper.dispatch.before_delete(mapper, connection, state)
if mapper.version_id_col is not None:
connection,
update_version_id,
) in states_to_update:
-
if table not in mapper._pks_by_table:
continue
update_version_id is not None
and mapper.version_id_col in mapper._cols_by_table[table]
):
-
if not bulk and not (params or value_params):
# HACK: check for history in other tables, in case the
# history is only in a different table than the one
connection,
update_version_id,
) in states_to_update:
-
# assert table in mapper._pks_by_table
pks = mapper._pks_by_table[table]
update_version_id is not None
and mapper.version_id_col in mapper._cols_by_table[table]
):
-
col = mapper.version_id_col
params[col._label] = update_version_id
connection,
update_version_id,
) in states_to_delete:
-
if table not in mapper._pks_by_table:
continue
rec[7],
),
):
-
statement = cached_stmt
if use_orm_insert_stmt is not None:
and has_all_pks
and not hasvalue
):
-
# the "we don't need newly generated values back" section.
# here we have all the PKs, all the defaults or we don't want
# to fetch them, or the dialect doesn't support RETURNING at all
if not allow_executemany:
check_rowcount = assert_singlerow
for state, state_dict, mapper_rec, connection, params in records:
-
c = connection.execute(
statement, params, execution_options=execution_options
)
# execute deletes individually so that versioned
# rows can be verified
for params in del_objects:
-
c = connection.execute(
statement, params, execution_options=execution_options
)
"""
for state, state_dict, mapper, connection, has_identity in states:
-
if mapper._readonly_props:
readonly = state.unmodified_intersection(
[
checks = [our_type]
for check_type in checks:
-
new_sqltype = registry._resolve_type(check_type)
if new_sqltype is not None:
break
criterion: Optional[_ColumnExpressionArgument[bool]] = None,
**kwargs: Any,
) -> Exists:
-
where_criteria = (
coercions.expect(roles.WhereHavingRole, criterion)
if criterion is not None
_recursive: Dict[Any, object],
_resolve_conflict_map: Dict[_IdentityKeyType[Any], object],
) -> None:
-
if load:
for r in self._reverse_property:
if (source_state, r) in _recursive:
"foreign_keys",
"remote_side",
):
-
rel_arg = getattr(init_args, attr)
rel_arg._resolve_against_registry(self._clsregistry_resolvers[1])
argument = extracted_mapped_annotation
if extracted_mapped_annotation is None:
-
if self.argument is None:
self._raise_for_required(key, cls)
else:
Optional[FromClause],
Optional[ClauseAdapter],
]:
-
aliased = False
if alias_secondary and self.secondary is not None:
class JoinCondition:
-
primaryjoin_initial: Optional[ColumnElement[bool]]
primaryjoin: ColumnElement[bool]
secondaryjoin: Optional[ColumnElement[bool]]
support_sync: bool = True,
can_be_synced_fn: Callable[..., bool] = lambda *c: True,
):
-
self.parent_persist_selectable = parent_persist_selectable
self.parent_local_selectable = parent_local_selectable
self.child_persist_selectable = child_persist_selectable
"the relationship." % (self.prop,)
)
else:
-
not_target = util.column_set(
self.parent_persist_selectable.c
).difference(self.child_persist_selectable.c)
or not self.prop.parent.common_parent(pr.parent)
)
):
-
other_props.append((pr, fr_))
if other_props:
def col_to_bind(
element: ColumnElement[Any], **kw: Any
) -> Optional[BindParameter[Any]]:
-
if (
(not reverse_direction and "local" in element._annotations)
or reverse_direction
session_factory: sessionmaker[_S],
scopefunc: Optional[Callable[[], Any]] = None,
):
-
"""Construct a new :class:`.scoped_session`.
:param session_factory: a factory to create new :class:`.Session`
def _iterate_self_and_parents(
self, upto: Optional[SessionTransaction] = None
) -> Iterable[SessionTransaction]:
-
current = self
result: Tuple[SessionTransaction, ...] = ()
while current:
bind: _SessionBind,
execution_options: Optional[_ExecuteOptions],
) -> Connection:
-
if bind in self._connections:
if execution_options:
util.warn(
(SessionTransactionState.ACTIVE,), SessionTransactionState.PREPARED
)
def _prepare_impl(self) -> None:
-
if self._parent is None or self.nested:
self.session.dispatch.before_commit(self.session)
def rollback(
self, _capture_exception: bool = False, _to_root: bool = False
) -> None:
-
stx = self.session._transaction
assert stx is not None
if stx is not self:
sess = self.session
if not rollback_err and not sess._is_clean():
-
# if items were added, deleted, or mutated
# here, we need to re-restore the snapshot
util.warn(
_StateChangeStates.ANY, SessionTransactionState.CLOSED
)
def close(self, invalidate: bool = False) -> None:
-
if self.nested:
self.session._nested_transaction = (
self._previous_nested_transaction
def _autobegin_t(self, begin: bool = False) -> SessionTransaction:
if self._transaction is None:
-
if not begin and not self.autobegin:
raise sa_exc.InvalidRequestError(
"Autobegin is disabled on this Session; please call "
# prevent against last minute dereferences of the object
obj = state.obj()
if obj is not None:
-
instance_key = mapper._identity_key_from_state(state)
if (
def _delete_impl(
self, state: InstanceState[Any], obj: object, head: bool
) -> None:
-
if state.key is None:
if head:
raise sa_exc.InvalidRequestError(
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
bind_arguments: Optional[_BindArguments] = None,
) -> Optional[_O]:
-
# convert composite types to individual args
if (
is_composite_class(primary_key_identity)
)
if is_dict:
-
pk_synonyms = mapper._pk_synonyms
if pk_synonyms:
and not mapper.always_refresh
and with_for_update is None
):
-
instance = self._identity_lookup(
mapper,
primary_key_identity,
)
def _flush(self, objects: Optional[Sequence[object]] = None) -> None:
-
dirty = self._dirty_states
if not dirty and not self._deleted and not self._new:
self.identity_map._modified.clear()
@util.decorator
def _go(fn: _F, self: Any, *arg: Any, **kw: Any) -> Any:
-
current_state = self._state
if (
impl_class=None,
**kw,
):
-
listen_hooks = []
uselist = useobject and prop.uselist
if prop is m._props.get(
prop.key
) and not m.class_manager._attr_has_impl(prop.key):
-
desc = attributes.register_attribute_impl(
m.class_,
prop.key,
adapter,
populators,
):
-
# for a DeferredColumnLoader, this method is only used during a
# "row processor only" query; see test_deferred.py ->
# tests with "rowproc_only" in their name. As of the 1.0 series,
only_load_props=None,
**kw,
):
-
if (
(
compile_state.compile_options._render_for_subquery
)
def _memoized_attr__simple_lazy_clause(self):
-
lazywhere = sql_util._deep_annotate(
self._lazywhere, {"_orm_adapt": True}
)
or passive & PassiveFlag.RELATED_OBJECT_OK
)
):
-
self._invoke_raise_load(state, passive, "raise")
session = _state_session(state)
lazy_clause, params = self._generate_lazy_clause(state, passive)
if execution_options:
-
execution_options = util.EMPTY_DICT.merge_with(
execution_options,
{
and context.query._compile_options._only_load_props
and self.key in context.query._compile_options._only_load_props
):
-
return self._immediateload_create_row_processor(
context,
query_entity,
__slots__ = ()
def _setup_for_recursion(self, context, path, loadopt, join_depth=None):
-
effective_path = (
context.compile_state.current_path or orm_util.PathRegistry.root
) + path
adapter,
populators,
):
-
(
effective_path,
run_loader,
recursion_depth,
execution_options,
):
-
if recursion_depth:
new_opt = Load(loadopt.path.entity)
new_opt.context = (
def _apply_joins(
self, q, to_join, left_alias, parent_alias, effective_entity
):
-
ltj = len(to_join)
if ltj == 1:
to_join = [
effective_entity,
loadopt,
):
-
# note that because the subqueryload object
# does not re-use the cached query, instead always making
# use of the current invoked query, while we have two queries
new_options = orig_query._with_options
if loadopt and loadopt._extra_criteria:
-
new_options += (
orm_util.LoaderCriteriaOption(
self.entity,
adapter,
populators,
):
-
if context.refresh_state:
return self._immediateload_create_row_processor(
context,
)
if user_defined_adapter is not False:
-
# setup an adapter but dont create any JOIN, assume it's already
# in the query
(
def _init_user_defined_eager_proc(
self, loadopt, compile_state, target_attributes
):
-
# check if the opt applies at all
if "eager_from_alias" not in loadopt.local_opts:
# nope
def _setup_query_on_user_defined_adapter(
self, context, entity, path, adapter, user_defined_adapter
):
-
# apply some more wrapping to the "user defined adapter"
# if we are setting up the query for SQL render.
adapter = entity._get_entity_clauses(context)
and not should_nest_selectable
and compile_state.from_clauses
):
-
indexes = sql_util.find_left_clause_that_matches_given(
compile_state.from_clauses, query_entity.selectable
)
def _splice_nested_inner_join(
self, path, join_obj, clauses, onclause, extra_criteria, splicing=False
):
-
# recursive fn to splice a nested join into an existing one.
# splicing=False means this is the outermost call, and it
# should return a value. splicing=<from object> is the recursive
adapter,
populators,
):
-
if context.refresh_state:
return self._immediateload_create_row_processor(
context,
data[k].extend(vv[1] for vv in v)
for key, state, state_dict, overwrite in chunk:
-
if not overwrite and self.key in state_dict:
continue
mapper_entities: Sequence[_MapperEntity],
raiseerr: bool,
) -> None:
-
reconciled_lead_entity = self._reconcile_query_entities_with_us(
mapper_entities, raiseerr
)
)
start_path = self._path_with_polymorphic_path
if current_path:
-
new_path = self._adjust_effective_path_for_current_path(
start_path, current_path
)
and dest_mapper._get_state_attr_by_column(dest, dest.dict, r)
not in orm_util._none_set
):
-
raise AssertionError(
"Dependency rule tried to blank-out primary key "
"column '%s' on instance '%s'" % (r, orm_util.state_str(dest))
sess = state.session
if sess:
-
if sess._warn_on_events:
sess._flush_warning("related attribute set")
class IterateMappersMixin:
-
__slots__ = ()
def _mappers(self, uow):
represents_outer_join: bool,
nest_adapters: bool,
):
-
mapped_class_or_ac = inspected.entity
mapper = inspected.mapper
flat: bool = False,
adapt_on_names: bool = False,
) -> Union[AliasedClass[_O], FromClause]:
-
if isinstance(element, FromClause):
if adapt_on_names:
raise sa_exc.ArgumentError(
adapt_on_names: bool = False,
_use_mapper_path: bool = False,
) -> AliasedClass[_O]:
-
primary_mapper = _class_to_mapper(base)
if selectable not in (None, False) and flat:
)
def _all_mappers(self) -> Iterator[Mapper[Any]]:
-
if self.entity:
yield from self.entity.mapper.self_and_descendants
else:
def _inspect_mc(
class_: Type[_O],
) -> Optional[Mapper[_O]]:
-
try:
class_manager = opt_manager_of_class(class_)
if class_manager is None or not class_manager.is_mapped:
return None
mapper = class_manager.mapper
except exc.NO_STATE:
-
return None
else:
return mapper
def _inspect_generic_alias(
class_: Type[_O],
) -> Optional[Mapper[_O]]:
-
origin = cast("Type[_O]", typing_get_origin(class_))
return _inspect_mc(origin)
"""
if raw_annotation is None:
-
if required:
raise sa_exc.ArgumentError(
f"Python typing annotation is required for attribute "
if not hasattr(annotated, "__origin__") or not is_origin_of_cls(
annotated, _MappedAnnotationBase
):
-
if expect_mapped:
if getattr(annotated, "__origin__", None) is typing.ClassVar:
return None
fn(state, value, initiator or self._remove_token)
def _modified_event(self, state, dict_):
-
if self.key not in state.committed_state:
state.committed_state[self.key] = self.collection_history_cls(
self, state, PassiveFlag.PASSIVE_NO_FETCH
__slots__ = ()
def __init__(self, attr, state):
-
self.instance = instance = state.obj()
self.attr = attr
threadconns: Optional[threading.local] = None,
fairy: Optional[_ConnectionFairy] = None,
) -> _ConnectionFairy:
-
if not fairy:
fairy = _ConnectionRecord.checkout(pool)
def invalidate(
self, e: Optional[BaseException] = None, soft: bool = False
) -> None:
-
if self.dbapi_connection is None:
util.warn("Can't invalidate an already-closed connection.")
return
_index = 0
def get_anon(self, object_: Any) -> Tuple[str, bool]:
-
idself = id(object_)
if idself in self:
s_val = self[idself]
...
else:
-
is_sql_compiler = operator.attrgetter("is_sql")
is_ddl_compiler = operator.attrgetter("is_ddl")
is_named_from_clause = operator.attrgetter("named_with_column")
cloned_ids: Dict[int, SupportsAnnotations] = {}
def clone(elem: SupportsAnnotations, **kw: Any) -> SupportsAnnotations:
-
# ind_cols_on_fromclause means make sure an AnnotatedFromClause
# has its own .c collection independent of that which its proxying.
# this is used specifically by orm.LoaderCriteriaOption to break
is_dml = False
if TYPE_CHECKING:
-
__visit_name__: str
def _compile_w_cache(
)
if len(current_intersection) > len(selected_intersection):
-
# 'current' has a larger field of correspondence than
# 'selected'. i.e. selectable.c.a1_x->a1.c.x->table.c.x
# matches a1.c.x->table.c.x better than
)
if key in self._index:
-
existing = self._index[key][1]
if existing is named_column:
return ck1._diff(ck2)
def _whats_different(self, other: CacheKey) -> Iterator[str]:
-
k1 = self.key
k2 = other.key
if impl._resolve_literal_only:
resolved = impl._literal_coercion(element, **kw)
else:
-
original_element = element
is_clause_element = False
class BinaryElementImpl(ExpressionElementImpl, RoleImpl):
-
__slots__ = ()
def _literal_coercion(
class ByOfImpl(_CoerceLiterals, _ColumnCoercions, RoleImpl, roles.ByOfRole):
-
__slots__ = ()
_coerce_consts = True
class DDLExpressionImpl(_Deannotate, _CoerceLiterals, RoleImpl):
-
__slots__ = ()
_coerce_consts = True
# FROMS left over? boom
if the_rest:
-
froms = the_rest
if froms:
template = (
) -> MutableMapping[
str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
]:
-
# mypy is not able to see the two value types as the above Union,
# it just sees "object". don't know how to resolve
return {
has_escaped_names = escape_names and bool(self.escaped_bind_names)
if extracted_parameters:
-
# related the bound parameters collected in the original cache key
# to those collected in the incoming cache key. They will not have
# matching names but they will line up positionally in the same
resolved_extracted = None
if params:
-
pd = {}
for bindparam, name in self.bind_names.items():
escaped_name = (
def visit_textual_select(
self, taf, compound_index=None, asfrom=False, **kw
):
-
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
)
def _generate_delimited_and_list(self, clauses, **kw):
-
lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
operators.and_,
elements.True_._singleton,
)
def _format_frame_clause(self, range_, **kw):
-
return "%s AND %s" % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
def visit_unary(
self, unary, add_to_result_map=None, result_map_targets=(), **kw
):
-
if add_to_result_map is not None:
result_map_targets += (unary,)
kw["add_to_result_map"] = add_to_result_map
def _literal_execute_expanding_parameter_literal_binds(
self, parameter, values, bind_expression_template=None
):
-
typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
if not values:
and isinstance(values[0], collections_abc.Sequence)
and not isinstance(values[0], (str, bytes))
):
-
if typ_dialect_impl._has_bind_expression:
raise NotImplementedError(
"bind_expression() on TupleType not supported with "
return (), replacement_expression
def _literal_execute_expanding_parameter(self, name, parameter, values):
-
if parameter.literal_execute:
return self._literal_execute_expanding_parameter_literal_binds(
parameter, values
if not values:
to_update = []
if typ_dialect_impl._is_tuple_type:
-
replacement_expression = self.visit_empty_set_op_expr(
parameter.type.types, parameter.expand_op
)
def _generate_generic_binary(
self, binary, opstring, eager_grouping=False, **kw
):
-
_in_binary = kw.get("_in_binary", False)
kw["_in_binary"] = True
visited_bindparam: Optional[List[str]] = None,
**kw: Any,
) -> str:
-
# TODO: accumulate_bind_names is passed by crud.py to gather
# names on a per-value basis, visited_bindparam is passed by
# visit_insert() to collect all parameters in the statement.
visited_bindparam.append(name)
if not escaped_from:
-
if self._bind_translate_re.search(name):
# not quite the translate use case as we want to
# also get a quick boolean if we even found
from_linter=None,
**kwargs,
):
-
if lateral:
if "enclosing_lateral" not in kwargs:
# if lateral is set and enclosing_lateral is not
populate_result_map: bool,
**kw: Any,
) -> str:
-
columns = [
self._label_returning_column(
stmt,
replaced_parameters = base_parameters.copy()
for i, param in enumerate(batch):
-
fmv = formatted_values_clause.replace(
"EXECMANY_INDEX__", str(i)
)
batchnum += 1
def visit_insert(self, insert_stmt, visited_bindparam=None, **kw):
-
compile_state = insert_stmt._compile_state_factory(
insert_stmt, self, **kw
)
returning_cols = self.implicit_returning or insert_stmt._returning
if returning_cols:
-
add_sentinel_cols = crud_params_struct.use_sentinel_columns
if add_sentinel_cols is not None:
elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
if use_insertmanyvalues:
-
self._insertmanyvalues = _InsertManyValues(
True,
self.dialect.default_metavalue_token,
)
if use_insertmanyvalues:
-
if (
implicit_sentinel
and (
)
def visit_update(self, update_stmt, **kw):
-
compile_state = update_stmt._compile_state_factory(
update_stmt, self, **kw
)
def create_table_constraints(
self, table, _include_foreign_key_constraints=None, **kw
):
-
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
constraints = []
return "NCLOB"
def _render_string_type(self, type_, name, length_override=None):
-
text = name
if length_override:
text += "(%d)" % length_override
toplevel,
kw,
):
-
cols = [stmt.table.c[_column_as_key(name)] for name in stmt._select_names]
assert compiler.stack[-1]["selectable"] is stmt
accumulated_bind_names: Set[str] = set()
if coercions._is_literal(value):
-
if (
insert_null_pk_still_autoincrements
and c.primary_key
compiler.postfetch.append(c)
else:
if c.primary_key:
-
if implicit_returning:
compiler.implicit_returning.append(c)
elif compiler.dialect.postfetch_lastrowid:
values: List[_CrudParamElementSQLExpr],
kw: Dict[str, Any],
) -> None:
-
if default_is_sequence(c.default):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
def _append_param_update(
compiler, compile_state, stmt, c, implicit_return_defaults, values, kw
):
-
include_table = compile_state.include_table_with_column_exprs
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
name: Optional[str] = None,
**kw: Any,
) -> Union[elements.BindParameter[Any], str]:
-
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
row = {_column_as_key(key): v for key, v in row.items()}
- for (col, col_expr, param, accumulated_names) in values_0:
+ for col, col_expr, param, accumulated_names in values_0:
if col.key in row:
key = col.key
values,
kw,
):
-
for k, v in stmt_parameter_tuples:
colkey = _column_as_key(k)
if colkey is not None:
checkfirst=self.checkfirst,
_is_metadata_operation=_is_metadata_operation,
):
-
for column in table.columns:
if column.default is not None:
self.traverse_single(column.default)
tables=event_collection,
checkfirst=self.checkfirst,
):
-
for table, fkcs in collection:
if table is not None:
self.traverse_single(
checkfirst=self.checkfirst,
_is_metadata_operation=_is_metadata_operation,
):
-
DropTable(table)._invoke_with(self.connection)
# traverse client side defaults which may refer to server-side
DropConstraint(constraint)._invoke_with(self.connection)
def visit_sequence(self, sequence, drop_ok=False):
-
if not drop_ok and not self._can_drop_sequence(sequence):
return
with self.with_ddl_events(sequence):
result_type: Optional[TypeEngine[_T]] = None,
**kw: Any,
) -> OperatorExpression[_T]:
-
coerced_obj = coercions.expect(
roles.BinaryElementRole, obj, expr=expr, operator=op
)
)
elif isinstance(arg, collections_abc.Sequence):
-
if arg and isinstance(arg[0], dict):
multi_kv_generator = DMLState.get_plugin_class(
self
return self
if TYPE_CHECKING:
-
# START OVERLOADED FUNCTIONS self.returning ReturningInsert 1-8 ", *, sort_by_parameter_order: bool = False" # noqa: E501
# code within this block is **programmatically,
)
if TYPE_CHECKING:
-
# START OVERLOADED FUNCTIONS self.returning ReturningDelete 1-8
# code within this block is **programmatically,
)
if TYPE_CHECKING:
-
# START GENERATED FUNCTION ACCESSORS
# code within this block is **programmatically,
default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
if "type_" not in kwargs:
-
type_from_args = _type_from_args(fn_args)
if isinstance(type_from_args, sqltypes.ARRAY):
kwargs["type_"] = type_from_args
element: Optional[visitors.ExternallyTraversible], **kw: Any
) -> Optional[visitors.ExternallyTraversible]:
if isinstance(element, elements.BindParameter):
-
if element.key in bindparam_lookup:
bind = bindparam_lookup[element.key]
if element.expanding:
if isinstance(cell_contents, _cache_key.HasCacheKey):
def get(closure, opts, anon_map, bindparams):
-
obj = closure[idx].cell_contents
if use_inspect:
obj = inspection.inspect(obj)
return self._sa__add_getter(key, operator.itemgetter)
def _add_getter(self, key, getter_fn):
-
bind_paths = object.__getattribute__(self, "_bind_paths")
bind_path_key = (key, getter_fn)
def _get_convention(dict_, key):
-
for super_ in key.__mro__:
if super_ in _prefix_dict and _prefix_dict[super_] in dict_:
return dict_[_prefix_dict[super_]]
class SchemaConst(Enum):
-
RETAIN_SCHEMA = 1
"""Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence`
or in some cases a :class:`_schema.ForeignKey` object, in situations
if the_sentinel:
the_sentinel_zero = the_sentinel[0]
if the_sentinel_zero.identity:
-
if the_sentinel_zero.identity._increment_is_negative:
if sentinel_is_explicit:
raise exc.InvalidRequestError(
_SentinelDefaultCharacterization.SENTINEL_DEFAULT
)
elif default_is_sequence(the_sentinel_zero.default):
-
if the_sentinel_zero.default._increment_is_negative:
if sentinel_is_explicit:
raise exc.InvalidRequestError(
_column: Column[Any]
if isinstance(self._colspec, str):
-
parenttable, tablekey, colname = self._resolve_col_tokens()
if self._unresolvable or tablekey not in parenttable.metadata:
# c, use hash(), so that an annotated version of the column
# is seen as the same as the non-annotated
if hash(names[effective_name]) != hash(c):
-
# different column under the same name. apply
# disambiguating label
if table_qualified:
@_generative
def alias(self, name: Optional[str] = None, flat: bool = False) -> Self:
-
"""Return a new :class:`_expression.Values`
construct that is a copy of this
one with the given name.
Iterable[Sequence[ColumnElement[Any]]]
] = None,
) -> None:
-
# this is a slightly hacky thing - the union exports a
# column that resembles just that of the *first* selectable.
# to get at a "composite" column, particularly foreign keys,
def _column_naming_convention(
cls, label_style: SelectLabelStyle
) -> _LabelConventionCallable:
-
table_qualified = label_style is LABEL_STYLE_TABLENAME_PLUS_COL
dedupe = label_style is not LABEL_STYLE_NONE
froms: List[FromClause] = []
for item in iterable_of_froms:
-
if is_subquery(item) and item.element is check_statement:
raise exc.InvalidRequestError(
"select() construct refers to itself as a FROM"
]
if self.statement._correlate_except is not None:
-
froms = [
f
for f in froms
and implicit_correlate_froms
and len(froms) > 1
):
-
froms = [
f
for f in froms
args: Tuple[_SetupJoinsElement, ...],
raw_columns: List[_ColumnsClauseElement],
) -> None:
- for (right, onclause, left, flags) in args:
+ for right, onclause, left, flags in args:
if TYPE_CHECKING:
if onclause is not None:
assert isinstance(onclause, ColumnElement)
from_clauses = self.from_clauses
if from_clauses:
-
indexes: List[int] = sql_util.find_left_clause_to_join_from(
from_clauses, right, onclause
)
]
),
):
-
potential[from_clause] = ()
all_clauses = list(potential.keys())
if is_column_element(c)
]
else:
-
prox = [
c._make_proxy(
subquery,
def self_group(
self, against: Optional[OperatorType] = None
) -> ColumnElement[Any]:
-
return self
if TYPE_CHECKING:
columns: List[_ColumnExpressionArgument[Any]],
positional: bool = False,
) -> None:
-
self._init(
text,
# convert for ORM attributes->columns, etc
@util.memoized_property
def _expression_adaptations(self):
-
# Based on
# https://www.postgresql.org/docs/current/static/functions-datetime.html.
matched_on: _MatchedOnType,
matched_on_flattened: Type[Any],
) -> Optional[Enum]:
-
# "generic form" indicates we were placed in a type map
# as ``sqlalchemy.Enum(enum.Enum)`` which indicates we need to
# get enumerated values from the datatype
type: ARRAY
def _setup_getitem(self, index):
-
arr_type = self.type
return_type: TypeEngine[Any]
def coerce_compared_value(
self, op: Optional[OperatorType], value: Any
) -> TypeEngine[Any]:
-
if value is type_api._NO_VALUE_IN_LIST:
return super().coerce_compared_value(op, value)
else:
return process
else:
-
if not self.as_uuid:
def process(value):
return element
def visit_setup_join_tuple(self, element, **kw):
- for (target, onclause, from_, flags) in element:
+ for target, onclause, from_, flags in element:
if from_ is not None:
yield from_
def visit_string_multi_dict(
self, attrname, left_parent, left, right_parent, right, **kw
):
-
for lk, rk in zip_longest(
sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
):
def visit_multi(
self, attrname, left_parent, left, right_parent, right, **kw
):
-
lhc = isinstance(left, HasCacheKey)
rhc = isinstance(right, HasCacheKey)
if lhc and rhc:
def _cached_sentinel_value_processor(
self, dialect: Dialect
) -> Optional[_SentinelProcessorType[_T]]:
-
try:
return dialect._type_memos[self]["sentinel"]
except KeyError:
@util.preload_module("sqlalchemy.engine.default")
def _default_dialect(self) -> Dialect:
-
default = util.preloaded.engine_default
# dmypy / mypy seems to sporadically keep thinking this line is
impl: Union[TypeEngine[Any], TypeEngineMixin],
**kw: Any,
) -> TypeEngine[Any]:
-
"""Given an impl, adapt this type's class to the impl assuming
"native".
if process_literal_param is not None:
impl_processor = self.impl_instance.literal_processor(dialect)
if impl_processor:
-
fixed_impl_processor = impl_processor
fixed_process_literal_param = process_literal_param
@util.memoized_property
def _has_bind_expression(self) -> bool:
-
return (
util.method_is_overridden(self, TypeDecorator.bind_expression)
or self.impl_instance._has_bind_expression
def traverse(
self, obj: Optional[ExternallyTraversible]
) -> Optional[ExternallyTraversible]:
-
return self.columns[obj]
def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
return elem
else:
if id(elem) not in cloned:
-
if "replace" in kw:
newelem = cast(
Optional[ExternallyTraversible], kw["replace"](elem)
class AssertRule:
-
is_consumed = False
errormessage = None
consume_statement = True
map_ = None
if isinstance(execute_observed.clauseelement, BaseDDLElement):
-
compiled = execute_observed.clauseelement.compile(
dialect=compare_dialect,
schema_translate_map=map_,
"""
if not ENABLE_ASYNCIO:
-
return fn(*args, **kwargs)
is_async = config._current.is_async
def combinations_list(
- arg_iterable: Iterable[
- Tuple[
- Any,
- ]
- ],
+ arg_iterable: Iterable[Tuple[Any,]],
**kw,
):
"As combination, but takes a single iterable"
return getattr(self.dbapi, key)
def connect(self, *args, **kwargs):
-
conn = self.dbapi.connect(*args, **kwargs)
if self.is_stopped:
self._safe(conn.close)
def run_test(subject, trans_on_subject, execute_on_subject):
with subject.begin() as trans:
-
if begin_nested:
if not config.requirements.savepoints.enabled:
config.skip_test("savepoints not enabled")
class TablesTest(TestBase):
-
# 'once', None
run_setup_bind = "once"
assert b_key is None
else:
-
eq_(a_key.key, b_key.key)
eq_(hash(a_key.key), hash(b_key.key))
def insertmanyvalues_fixture(
connection, randomize_rows=False, warn_on_downgraded=False
):
-
dialect = connection.dialect
orig_dialect = dialect._deliver_insertmanyvalues_batches
orig_conn = connection._exec_insertmany_context
@pre
def _set_disable_asyncio(opt, file_config):
if opt.disable_asyncio:
-
asyncio.ENABLE_ASYNCIO = False
@post
def _engine_uri(options, file_config):
-
from sqlalchemy import testing
from sqlalchemy.testing import config
from sqlalchemy.testing import provision
@post
def _requirements(options, file_config):
-
requirement_cls = file_config.get("sqla_testing", "requirement_cls")
_setup_requirements(requirement_cls)
def before_test(test, test_module_name, test_class, test_name):
-
# format looks like:
# "test.aaa_profiling.test_compiler.CompileTest.test_update_whereclause"
def pytest_collection_modifyitems(session, config, items):
-
# look for all those classes that specify __backend__ and
# expand them out into per-database test cases.
def setup_test_classes():
for test_class in test_classes:
-
# transfer legacy __backend__ and __sparse_backend__ symbols
# to be markers
add_markers = set()
return env[fn_name]
def decorate(fn, add_positional_parameters=()):
-
spec = inspect_getfullargspec(fn)
if add_positional_parameters:
spec.args.extend(add_positional_parameters)
)
else:
-
for arg in arg_sets:
if not isinstance(arg, tuple):
arg = (arg,)
@property
def platform_key(self):
-
dbapi_key = config.db.name + "_" + config.db.driver
if config.db.name == "sqlite" and config.db.dialect._is_url_file_db(
profile_f = open(self.fname, "w")
profile_f.write(self._header())
for test_key in sorted(self.data):
-
per_fn = self.data[test_key]
profile_f.write("\n# TEST: %s\n\n" % test_key)
for platform_key in sorted(per_fn):
@decorator
def wrap(fn, *args, **kw):
-
for warm in range(warmup):
fn(*args, **kw)
yield url
for drv in list(extra_drivers):
-
if "?" in drv:
-
driver_only, query_str = drv.split("?", 1)
else:
def drop_all_schema_objects(cfg, eng):
-
drop_all_schema_objects_pre_tables(cfg, eng)
drop_views(cfg, eng)
}
"""
with config.db.connect() as conn:
-
try:
supported = conn.dialect.get_isolation_level_values(
conn.connection.dbapi_connection
if test_opts.get("test_needs_autoincrement", False) and kw.get(
"primary_key", False
):
-
if col.default is None and col.server_default is None:
col.autoincrement = True
@testing.requires.ctes_with_update_delete
def test_delete_scalar_subq_round_trip(self, connection):
-
some_table = self.tables.some_table
some_other_table = self.tables.some_other_table
@requirements.duplicate_key_raises_integrity_error
def test_integrity_error(self):
-
with config.db.connect() as conn:
-
trans = conn.begin()
conn.execute(
self.tables.manual_pk.insert(), {"id": 1, "data": "d1"}
class AutocommitIsolationTest(fixtures.TablesTest):
-
run_deletes = "each"
__requires__ = ("autocommit",)
@classmethod
def define_tables(cls, metadata):
-
Table(
"t",
metadata,
)
def test_autoincrement_on_insert(self, connection):
-
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id(self, connection):
-
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
eq_(fetched_pk, pk)
def test_autoincrement_on_insert_implicit_returning(self, connection):
-
connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
self._assert_round_trip(self.tables.autoinc_pk, connection)
def test_last_inserted_id_implicit_returning(self, connection):
-
r = connection.execute(
self.tables.autoinc_pk.insert(), dict(data="some data")
)
)
if testing.requires.view_column_reflection.enabled:
-
if testing.requires.symbol_names_w_double_quote.enabled:
names = [
"quote ' one",
(True, testing.requires.schemas), False, argnames="use_schema"
)
def test_get_table_names(self, connection, order_by, use_schema):
-
if use_schema:
schema = config.test_schema
else:
argnames="use_views,use_schema",
)
def test_get_columns(self, connection, use_views, use_schema):
-
if use_schema:
schema = config.test_schema
else:
@testing.requires.temp_table_reflection
def test_reflect_table_temp_table(self, connection):
-
table_name = self.temp_table_name()
user_tmp = self.tables[table_name]
)
@testing.requires.index_reflection
def test_get_indexes(self, connection, use_schema):
-
if use_schema:
schema = config.test_schema
else:
@testing.requires.comment_reflection_full_unicode
def test_comments_unicode_full(self, connection, metadata):
-
Table(
"unicode_comments",
metadata,
class ComponentReflectionTestExtra(ComparesIndexes, fixtures.TestBase):
-
__backend__ = True
@testing.combinations(
)
def test_reflect_lowercase_forced_tables(self):
-
m2 = MetaData()
t2_ref = Table(
quoted_name("t2", quote=True), m2, autoload_with=config.db
class ServerSideCursorsTest(
fixtures.TestBase, testing.AssertsExecutionResults
):
-
__requires__ = ("server_side_cursors",)
__backend__ = True
eq_(res, [(-5, "b"), (0, "a"), (42, "c")])
def test_select_columns(self, connection):
-
res = connection.execute(
select(self.tables.tbl_a.c.id).order_by(self.tables.tbl_a.c.id)
).fetchall()
@testing.combinations((True,), (False,), argnames="implicit_returning")
@testing.requires.schemas
def test_insert_roundtrip_translate(self, connection, implicit_returning):
-
seq_no_returning = Table(
"seq_no_returning_sch",
MetaData(),
literal_round_trip(Integer, [5], [5])
def _huge_ints():
-
return testing.combinations(
2147483649, # 32 bits
2147483648, # 32 bits
)
@testing.combinations(100, 1999, 3000, 4000, 5000, 9000, argnames="length")
def test_round_trip_pretty_large_data(self, connection, unicode_, length):
-
if unicode_:
data = "réve🐍illé" * ((length // 9) + 1)
data = data[0 : (length // 2)]
eq_(row, (data_element,))
def _index_fixtures(include_comparison):
-
if include_comparison:
# basically SQL Server and MariaDB can kind of do json
# comparison, MySQL, PG and SQLite can't. not worth it.
def _json_value_insert(self, connection, datatype, value, data_element):
data_table = self.tables.data_table
if datatype == "_decimal":
-
# Python's builtin json serializer basically doesn't support
# Decimal objects without implicit float conversion period.
# users can otherwise use simplejson which supports
data_element = {"key1": value}
with config.db.begin() as conn:
-
datatype, compare_value, p_s = self._json_value_insert(
conn, datatype, value, data_element
)
data_table = self.tables.data_table
data_element = {"key1": {"subkey1": value}}
with config.db.begin() as conn:
-
datatype, compare_value, p_s = self._json_value_insert(
conn, datatype, value, data_element
)
)
def test_eval_none_flag_orm(self, connection):
-
Base = declarative_base()
class Data(Base):
edges[child].add(parent)
def _all_orderings(elements):
-
if len(elements) == 1:
yield list(elements)
else:
def decorate(fn):
def run_ddl(self):
-
metadata = self.metadata = schema.MetaData()
try:
result = fn(self, metadata)
@decorator
def go(fn, *args, **kw):
-
try:
return fn(*args, **kw)
finally:
consider_schemas=(None,),
include_names=None,
):
-
if include_names is not None:
include_names = set(include_names)
if typing.TYPE_CHECKING:
class greenlet(Protocol):
-
dead: bool
gr_context: Optional[Context]
if not isinstance(current, _AsyncIoGreenlet):
loop = get_event_loop()
if loop.is_running():
-
_safe_cancel_awaitable(awaitable)
raise exc.MissingGreenlet(
def _util_async_run(
fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
-
"""for test suite/ util only"""
loop = get_event_loop()
if py310:
anext_ = anext
else:
-
_NOT_PROVIDED = object()
from collections.abc import AsyncIterator
check_defaults: Union[Set[str], Tuple[()]]
if spec.defaults is not None:
-
defaults = dict(
zip(
spec.args[(len(spec.args) - len(spec.defaults)) :],
check_defaults = set(defaults).intersection(messages)
check_kw = set(messages).difference(defaults)
elif spec.kwonlydefaults is not None:
-
defaults = spec.kwonlydefaults
check_defaults = set(defaults).intersection(messages)
check_kw = set(messages).difference(defaults)
) -> Type[_T]:
doc = cls.__doc__ is not None and cls.__doc__ or ""
if docstring_header is not None:
-
if constructor is not None:
docstring_header %= dict(func=constructor)
# additional issues, RO properties:
# https://github.com/python/mypy/issues/12440
if TYPE_CHECKING:
-
# allow memoized and non-memoized to be freely mixed by having them
# be the same class
memoized_property = generic_fn_descriptor
category: Optional[Type[Warning]] = None,
stacklevel: int = 2,
) -> None:
-
# adjust the given stacklevel to be outside of SQLAlchemy
try:
frame = sys._getframe(stacklevel)
def sort_as_subsets(
tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
) -> Iterator[Sequence[_T]]:
-
edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
for parent, child in tuples:
edges[child].add(parent)
and is_generic(annotation)
and not is_literal(annotation)
):
-
if _already_seen is None:
_already_seen = set()
cmdclass = {"build_ext": _cy_build_ext}
elif REQUIRE_EXTENSION:
-
reasons = []
if not HAS_CYTHON:
reasons.append("Cython is missing")
@classmethod
def setup_test_class(cls):
-
global t1, t2, metadata
metadata = MetaData()
t1 = Table(
("require_embedded",), ("no_embedded",), argnames="require_embedded"
)
def test_corresponding_column_isolated(self, t1, require_embedded):
-
subq = select(t1).union_all(select(t1)).subquery()
target = subq.c.x7
def test_gen_subq_to_table_single_corresponding_column(
self, t1, require_embedded
):
-
src = t1.c.x7
require_embedded = require_embedded == "require_embedded"
def test_gen_subq_to_table_many_corresponding_column(
self, t1, require_embedded
):
-
require_embedded = require_embedded == "require_embedded"
@profiling.function_call_count(variance=0.15, warmup=1)
subq = select(t1).union_all(select(t1)).subquery()
for name in ("x%d" % i for i in range(1, 10)):
-
target = subq.c[name]
src = t1.c[name]
def test_gen_subq_aliased_class_select(
self, t1, require_embedded, inheritance_model
):
-
A = inheritance_model
require_embedded = require_embedded == "require_embedded"
@profiling.function_call_count(variance=0.15, warmup=1)
def go():
-
a1a1 = aliased(A)
a1a2 = aliased(A)
subq = select(a1a1).union_all(select(a1a2)).subquery()
def test_gen_subq_aliased_class_select_cols(
self, t1, require_embedded, inheritance_model
):
-
A = inheritance_model
require_embedded = require_embedded == "require_embedded"
@profiling.function_call_count(variance=0.15, warmup=1)
def go():
-
a1a1 = aliased(A)
a1a2 = aliased(A)
subq = select(a1a1).union_all(select(a1a2)).subquery()
@classmethod
def define_tables(cls, metadata):
-
Table(
"a",
metadata,
@async_test
async def test_ok(self):
-
eq_(await greenlet_spawn(go, run1, run2), 3)
@async_test
)
def test_raise_on_cycle_two(self):
-
# this condition was arising from ticket:362 and was not treated
# properly by topological sort
self.assert_sort(tuples)
def test_ticket_1380(self):
-
# ticket:1380 regression: would raise a KeyError
tuples = [(id(i), i) for i in range(3)]
pass
class E2(event.Events):
-
_dispatch_target = T2
def event_four(self, x):
assert_raises(TypeError, should_raise)
def test_serialize(self):
-
keyed_tuple = self._fixture([1, 2, 3], ["a", None, "b"])
for loads, dumps in picklers():
default_filters=None,
data=None,
):
-
if data is None:
data = [(1, 1, 1), (2, 1, 2), (1, 3, 2), (4, 1, 2)]
if num_rows is not None:
class MergeResultTest(fixtures.TestBase):
@testing.fixture
def merge_fixture(self):
-
r1 = result.IteratorResult(
result.SimpleResultMetaData(["user_id", "user_name"]),
iter([(7, "u1"), (8, "u2")]),
@testing.fixture
def dupe_fixture(self):
-
r1 = result.IteratorResult(
result.SimpleResultMetaData(["x", "y", "z"]),
iter([(1, 2, 1), (2, 2, 1)]),
eq_(util.to_list((1, 2, 3)), [1, 2, 3])
def test_from_bytes(self):
-
eq_(util.to_list(compat.b("abc")), [compat.b("abc")])
eq_(
assert "kcol2" in cc
def test_dupes_add(self):
-
c1, c2a, c3, c2b = (
column("c1"),
column("c2"),
eq_(ci.keys(), ["c1", "c2", "c3", "c2"])
def test_dupes_construct(self):
-
c1, c2a, c3, c2b = (
column("c1"),
column("c2"),
eq_(ci.keys(), ["c1", "c2", "c3", "c2"])
def test_identical_dupe_construct(self):
-
c1, c2, c3 = (column("c1"), column("c2"), column("c3"))
cc = sql.ColumnCollection(
self._assert_collection_integrity(cc)
def test_dupes_construct_dedupe(self):
-
c1, c2a, c3, c2b = (
column("c1"),
column("c2"),
eq_(list(ci), [c1, c2, c3])
def test_identical_dupe_construct_dedupes(self):
-
c1, c2, c3 = (column("c1"), column("c2"), column("c3"))
cc = DedupeColumnCollection(
self._assert_collection_integrity(cc)
def test_remove(self):
-
c1, c2, c3 = column("c1"), column("c2"), column("c3")
cc = DedupeColumnCollection(
assert_raises(IndexError, lambda: ci[2])
def test_remove_doesnt_change_iteration(self):
-
c1, c2, c3, c4, c5 = (
column("c1"),
column("c2"),
argnames="fn,wanted,grouped",
)
def test_specs(self, fn, wanted, grouped):
-
# test direct function
if grouped is None:
parsed = util.format_argspec_plus(fn)
)
with eng.begin() as conn:
-
tbl.create(conn)
conn.execute(tbl.insert(), {"id": 1})
eq_(conn.scalar(tbl.select()), 1)
)
with engine.begin() as conn:
-
if expect_failure:
with expect_raises(DBAPIError):
conn.execute(observations.insert(), records)
@testing.requires.schemas
def test_insert_using_schema_translate(self, connection, metadata):
-
t = Table(
"t",
metadata,
class MatchTest(AssertsCompiledSQL, fixtures.TablesTest):
-
__only_on__ = "mssql"
__skip_if__ = (full_text_search_missing,)
__backend__ = True
argnames="type_obj,ddl",
)
def test_assorted_types(self, metadata, connection, type_obj, ddl):
-
table = Table("type_test", metadata, Column("col1", type_obj))
table.create(connection)
"""test #6910"""
with testing.db.connect() as c1, testing.db.connect() as c2:
-
try:
with c1.begin():
c1.exec_driver_sql(
)
def test_indexes_cols(self, metadata, connection):
-
t1 = Table("t", metadata, Column("x", Integer), Column("y", Integer))
Index("foo", t1.c.x, t1.c.y)
metadata.create_all(connection)
eq_(set(list(t2.indexes)[0].columns), {t2.c["x"], t2.c.y})
def test_indexes_cols_with_commas(self, metadata, connection):
-
t1 = Table(
"t",
metadata,
eq_(set(list(t2.indexes)[0].columns), {t2.c["x, col"], t2.c.y})
def test_indexes_cols_with_spaces(self, metadata, connection):
-
t1 = Table(
"t",
metadata,
eq_(set(list(t2.indexes)[0].columns), {t2.c["x col"], t2.c.y})
def test_indexes_with_filtered(self, metadata, connection):
-
t1 = Table(
"t",
metadata,
@classmethod
def define_tables(cls, metadata):
-
for i, col in enumerate(
[
Column(
eq_(value, returned)
def test_float(self, metadata, connection):
-
float_table = Table(
"float_table",
metadata,
use_returning,
insertmany,
):
-
if datatype is NVARCHAR and length != "max" and length > 4000:
return
elif unicode_ and datatype not in (NVARCHAR, UnicodeText):
try:
yield table, expected_mysql, expected_mdb
finally:
-
reserved_words.RESERVED_WORDS_MARIADB.discard("mdb_reserved")
reserved_words.RESERVED_WORDS_MYSQL.discard("mysql_reserved")
reserved_words.RESERVED_WORDS_MYSQL.discard("mdb_mysql_reserved")
class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = mysql.dialect()
@testing.combinations(
(m.MSSet("1", "2"), "t.col"),
)
def test_unsupported_casts(self, type_, expected):
-
t = sql.table("t", sql.column("col"))
with expect_warnings(
"Datatype .* does not support CAST on MySQL/MariaDb;"
)
@testing.combinations(True, False, argnames="maria_db")
def test_float_cast(self, type_, expected, maria_db):
-
dialect = mysql.dialect()
if maria_db:
dialect.is_mariadb = maria_db
)
dialect = None
elif version.mysql8:
-
expected_sql = (
"INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) "
"AS new ON DUPLICATE KEY UPDATE bar = "
class MatchExpressionTest(fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = mysql.dialect()
match_table = table(
class CompileTest(AssertsCompiledSQL, fixtures.TestBase):
-
__dialect__ = mysql.dialect()
def test_distinct_string(self):
)
def test_no_show_variables(self):
-
engine = engines.testing_engine()
def my_execute(self, statement, *args, **kw):
engine.connect()
def test_no_default_isolation_level(self):
-
engine = engines.testing_engine()
real_isolation_level = testing.db.dialect.get_isolation_level
("utf8",),
)
def test_special_encodings(self, enc):
-
eng = engines.testing_engine(
options={"connect_args": {"charset": enc, "use_unicode": 0}}
)
@testing.only_on("mariadb+mariadbconnector")
def test_mariadb_connector_special_encodings(self):
-
eng = engines.testing_engine()
conn = eng.connect()
metadata,
connection,
):
-
specs = [(mysql.ENUM("", "fleem"), mysql.ENUM("", "fleem"))]
self._run_test(metadata, connection, specs, ["enums"])
class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL):
-
__only_on__ = "mysql", "mariadb"
__backend__ = True
eq_(ref_kw["partitions"], "6")
def test_reflection_with_subpartition_options(self, connection, metadata):
-
subpartititon_text = """HASH (TO_DAYS (c2))
SUBPARTITIONS 2(
PARTITION p0 VALUES LESS THAN (1990),
class TypeRoundTripTest(fixtures.TestBase, AssertsExecutionResults):
-
__dialect__ = mysql.dialect()
__only_on__ = "mysql", "mariadb"
__backend__ = True
argnames="store, expected",
)
def test_bit_50_roundtrip(self, connection, bit_table, store, expected):
-
reflected = Table("mysql_bits", MetaData(), autoload_with=connection)
expected = expected or store
@testing.requires.reflects_json_type
def test_reflection(self, metadata, connection):
-
Table("mysql_json", metadata, Column("foo", mysql.JSON))
metadata.create_all(connection)
class EnumSetTest(
fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
):
-
__only_on__ = "mysql", "mariadb"
__dialect__ = mysql.dialect()
__backend__ = True
)
def test_enum_parse(self, metadata, connection):
-
enum_table = Table(
"mysql_enum",
metadata,
)
@testing.only_on(["oracle+cx_oracle", "oracle+oracledb"])
def test_is_disconnect(self, message, code, expected):
-
dialect = testing.db.dialect
exception_obj = dialect.dbapi.InterfaceError()
eng = engines.testing_engine()
with eng.connect() as conn:
-
trans = conn.begin()
eq_(
testing.db.dialect._get_default_schema_name(conn),
utf8_w_errors = data.encode("utf-16")
if has_errorhandler:
-
eq_(
outconverter(utf8_w_errors),
data.encode("utf-16").decode("utf-8", "ignore"),
class QuotedBindRoundTripTest(fixtures.TestBase):
-
__only_on__ = "oracle"
__backend__ = True
class ExecuteTest(fixtures.TestBase):
-
__only_on__ = "oracle"
__backend__ = True
class ConstraintTest(AssertsCompiledSQL, fixtures.TestBase):
-
__only_on__ = "oracle"
__backend__ = True
def test_oracle_has_no_on_update_cascade(
self, metadata, connection, plain_foo_table
):
-
bar = Table(
"bar",
metadata,
def test_reflect_check_include_all(
self, metadata, connection, plain_foo_table
):
-
insp = inspect(connection)
eq_(insp.get_check_constraints("foo"), [])
eq_(
__backend__ = True
def setup_test(self):
-
with testing.db.begin() as conn:
conn.exec_driver_sql("create table my_table (id integer)")
conn.exec_driver_sql(
@testing.fails_if(all_tables_compression_missing)
def test_reflect_basic_compression(self, metadata, connection):
-
tbl = Table(
"test_compress",
metadata,
def test_include_indexes_resembling_pk(
self, metadata, connection, explicit_pk
):
-
t = Table(
"sometable",
metadata,
eq_(inspect(connection).get_indexes("sometable"), expected)
def test_basic(self, metadata, connection):
-
s_table = Table(
"sometable",
metadata,
eq_(t2.c.c4.type.length, 180)
def test_long_type(self, metadata, connection):
-
t = Table("t", metadata, Column("data", oracle.LONG))
metadata.create_all(connection)
connection.execute(t.insert(), dict(data="xyz"))
r"will now invalidate all prepared caches in response "
r"to this exception\)",
):
-
result = await conn.execute(
t1.select()
.where(t1.c.name.like("some name%"))
@async_test
async def test_failed_commit_recover(self, metadata, async_testing_engine):
-
Table("t1", metadata, Column("id", Integer, primary_key=True))
t2 = Table(
async def test_rollback_twice_no_problem(
self, metadata, async_testing_engine
):
-
engine = async_testing_engine()
async with engine.connect() as conn:
-
trans = await conn.begin()
await trans.rollback()
@async_test
async def test_closed_during_execute(self, metadata, async_testing_engine):
-
engine = async_testing_engine()
async with engine.connect() as conn:
async def test_failed_rollback_recover(
self, metadata, async_testing_engine
):
-
engine = async_testing_engine()
async with engine.connect() as conn:
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = postgresql.dialect()
def test_plain_stringify_returning(self):
)
def test_create_index_with_ops(self):
-
m = MetaData()
tbl = Table(
"testtbl",
)
def test_pg_array_agg_implicit_pg_array(self):
-
expr = pg_array_agg(column("data", Integer))
assert isinstance(expr.type, PG_ARRAY)
is_(expr.type.item_type._type_affinity, Integer)
def test_pg_array_agg_uses_base_array(self):
-
expr = pg_array_agg(column("data", sqltypes.ARRAY(Integer)))
assert isinstance(expr.type, sqltypes.ARRAY)
assert not isinstance(expr.type, PG_ARRAY)
is_(expr.type.item_type._type_affinity, Integer)
def test_pg_array_agg_uses_pg_array(self):
-
expr = pg_array_agg(column("data", PG_ARRAY(Integer)))
assert isinstance(expr.type, PG_ARRAY)
is_(expr.type.item_type._type_affinity, Integer)
def test_pg_array_agg_explicit_base_array(self):
-
expr = pg_array_agg(
column("data", sqltypes.ARRAY(Integer)),
type_=sqltypes.ARRAY(Integer),
is_(expr.type.item_type._type_affinity, Integer)
def test_pg_array_agg_explicit_pg_array(self):
-
expr = pg_array_agg(
column("data", sqltypes.ARRAY(Integer)), type_=PG_ARRAY(Integer)
)
stmt.on_conflict_do_nothing,
stmt.on_conflict_do_update,
):
-
with testing.expect_raises_message(
exc.InvalidRequestError,
"This Insert construct already has an "
)
def test_do_nothing_no_target(self):
-
i = (
insert(self.table1)
.values(dict(name="foo"))
)
def test_do_nothing_index_elements_target(self):
-
i = (
insert(self.table1)
.values(dict(name="foo"))
)
def test_distinct_on_subquery_anon(self):
-
sq = select(self.table).alias()
q = (
select(self.table.c.id, sq.c.id)
def test_ensure_version_is_qualified(
self, future_connection, testing_engine, metadata
):
-
default_schema_name = future_connection.dialect.default_schema_name
event.listen(
metadata,
class MiscBackendTest(
fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
):
-
__only_on__ = "postgresql"
__backend__ = True
with testing.db.connect().execution_options(
isolation_level="SERIALIZABLE"
) as conn:
-
dbapi_conn = conn.connection.dbapi_connection
is_false(dbapi_conn.autocommit)
with conn.begin():
-
existing_isolation = conn.exec_driver_sql(
"show transaction isolation level"
).scalar()
dbapi_conn.autocommit = False
with conn.begin():
-
existing_isolation = conn.exec_driver_sql(
"show transaction isolation level"
).scalar()
class OnConflictTest(fixtures.TablesTest):
-
__only_on__ = ("postgresql >= 9.5",)
__backend__ = True
run_define_tables = "each"
class InsertTest(fixtures.TestBase, AssertsExecutionResults):
-
__only_on__ = "postgresql"
__backend__ = True
def test_foreignkey_missing_insert(
self, metadata, connection, implicit_returning
):
-
Table(
"t1",
metadata,
self._assert_data_noautoincrement(connection, table)
def _ints_and_strs_setinputsizes(self, connection):
-
return (
connection.dialect._bind_typing_render_casts
and String().dialect_impl(connection.dialect).render_bind_cast
class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
-
__only_on__ = "postgresql >= 8.3"
__backend__ = True
)
def _strs_render_bind_casts(self, connection):
-
return (
connection.dialect._bind_typing_render_casts
and String().dialect_impl(connection.dialect).render_bind_cast
__backend__ = True
def test_tuple_containment(self, connection):
-
for test, exp in [
([("a", "b")], True),
([("a", "c")], False),
@classmethod
def insert_data(cls, connection):
-
connection.execute(
cls.tables.t.insert(),
{
)
def test_table_valued(self, assets_transactions, connection):
-
jb = func.jsonb_each(assets_transactions.c.contents).table_valued(
"key", "value"
)
)
def test_function_against_row_constructor(self, connection):
-
stmt = select(func.row_to_json(func.row(1, "foo")))
eq_(connection.scalar(stmt), {"f1": 1, "f2": "foo"})
def test_with_ordinality_named(self, connection):
-
stmt = select(
func.generate_series(4, 1, -1)
.table_valued("gs", with_ordinality="ordinality")
eq_(connection.execute(stmt).all(), [(4, 1), (3, 2), (2, 3), (1, 4)])
def test_with_ordinality_star(self, connection):
-
stmt = select("*").select_from(
func.generate_series(4, 1, -1).table_valued(
with_ordinality="ordinality"
)
def test_unnest_with_ordinality(self, connection):
-
array_val = postgresql.array(
[postgresql.array([14, 41, 7]), postgresql.array([54, 9, 49])]
)
)
def test_unnest_with_ordinality_named(self, connection):
-
array_val = postgresql.array(
[postgresql.array([14, 41, 7]), postgresql.array([54, 9, 49])]
)
argnames="cast_fn",
)
def test_render_derived_quoting_text(self, connection, cast_fn):
-
value = (
'[{"CaseSensitive":1,"the % value":"foo"}, '
'{"CaseSensitive":"2","the % value":"bar"}]'
argnames="cast_fn",
)
def test_render_derived_quoting_text_to_json(self, connection, cast_fn):
-
value = (
'[{"CaseSensitive":1,"the % value":"foo"}, '
'{"CaseSensitive":"2","the % value":"bar"}]'
assert inspect(connection).has_table("some_temp_table")
def test_cross_schema_reflection_one(self, metadata, connection):
-
meta1 = metadata
users = Table(
)
def test_uppercase_lowercase_table(self, metadata, connection):
-
a_table = Table("a", metadata, Column("x", Integer))
A_table = Table("A", metadata, Column("x", Integer))
assert inspect(connection).has_table("A")
def test_uppercase_lowercase_sequence(self, connection):
-
a_seq = Sequence("a")
A_seq = Sequence("A")
is_false(inspector.has_type("mood"))
def test_inspect_enums(self, metadata, inspect_fixture):
-
inspector, conn = inspect_fixture
enum_type = postgresql.ENUM(
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
-
postgresql.ENUM(
"CapsOne",
"CapsTwo",
counter = itertools.count()
for enum in "lower_case", "UpperCase", "Name.With.Dot":
for schema in None, "test_schema", "TestSchema":
-
enum_type = postgresql.ENUM(
"CapsOne",
"CapsTwo",
],
)
elif datatype == "domain":
-
def_schame = testing.config.db.dialect.default_schema_name
eq_(
inspect(connection).get_domains(schema=assert_schema),
argnames="datatype",
)
def test_name_required(self, metadata, connection, datatype):
-
assert_raises(exc.CompileError, datatype.create, connection)
assert_raises(
exc.CompileError, datatype.compile, dialect=connection.dialect
]
def test_generate_multiple_on_metadata(self, connection, metadata):
-
e1 = Enum("one", "two", "three", name="myenum", metadata=metadata)
t1 = Table("e1", metadata, Column("c1", e1))
def test_generate_multiple_schemaname_on_metadata(
self, metadata, connection
):
-
Enum("one", "two", "three", name="myenum", metadata=metadata)
Enum(
"one",
]
def test_create_drop_schema_translate_map(self, connection):
-
conn = connection.execution_options(
schema_translate_map={None: testing.config.test_schema}
)
__backend__ = True
def test_numeric_codes(self):
-
dialects = (
pg8000.dialect(),
psycopg2.dialect(),
assert row[0] >= somedate
def test_without_timezone(self, connection):
-
# get a date without a tzinfo
tztable, notztable = self.tables("tztable", "notztable")
class TimePrecisionTest(fixtures.TestBase):
-
__only_on__ = "postgresql"
__backend__ = True
argnames="with_enum, using_aggregate_order_by",
)
def test_array_agg_specific(self, with_enum, using_aggregate_order_by):
-
element = ENUM(name="pgenum") if with_enum else Integer()
element_type = type(element)
expr = (
class ArrayRoundTripTest:
-
__only_on__ = "postgresql"
__backend__ = True
)
def test_tuple_flag(self, connection, metadata):
-
t1 = Table(
"t1",
metadata,
class CoreArrayRoundTripTest(
ArrayRoundTripTest, fixtures.TablesTest, AssertsExecutionResults
):
-
ARRAY = sqltypes.ARRAY
@testing.metadata_fixture()
def special_types_table(self, metadata):
-
# create these types so that we can issue
# special SQL92 INTERVAL syntax
class y2m(types.UserDefinedType, postgresql.INTERVAL):
)
def test_bind_serialize_default(self):
-
dialect = postgresql.dialect(use_native_hstore=False)
proc = self.test_table.c.hash.type._cached_bind_processor(dialect)
eq_(
class _Int4RangeTests:
-
_col_type = INT4RANGE
_col_str = "INT4RANGE"
_col_str_arr = "INT8RANGE"
class _Int8RangeTests:
-
_col_type = INT8RANGE
_col_str = "INT8RANGE"
class _NumRangeTests:
-
_col_type = NUMRANGE
_col_str = "NUMRANGE"
class _DateRangeTests:
-
_col_type = DATERANGE
_col_str = "DATERANGE"
class _DateTimeRangeTests:
-
_col_type = TSRANGE
_col_str = "TSRANGE"
class _DateTimeTZRangeTests:
-
_col_type = TSTZRANGE
_col_str = "TSTZRANGE"
class _Int4MultiRangeTests:
-
_col_type = INT4MULTIRANGE
_col_str = "INT4MULTIRANGE"
class _Int8MultiRangeTests:
-
_col_type = INT8MULTIRANGE
_col_str = "INT8MULTIRANGE"
class _NumMultiRangeTests:
-
_col_type = NUMMULTIRANGE
_col_str = "NUMMULTIRANGE"
class _DateMultiRangeTests:
-
_col_type = DATEMULTIRANGE
_col_str = "DATEMULTIRANGE"
class _DateTimeMultiRangeTests:
-
_col_type = TSMULTIRANGE
_col_str = "TSMULTIRANGE"
class _DateTimeTZMultiRangeTests:
-
_col_type = TSTZMULTIRANGE
_col_str = "TSTZMULTIRANGE"
return self.tables.data_table
def _fixture_data(self, connection):
-
data = [
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}},
{"name": "r2", "data": {"k1": "r2v1", "k2": "r2v2"}},
class TestTypes(fixtures.TestBase, AssertsExecutionResults):
-
__only_on__ = "sqlite"
__backend__ = True
)
def test_cant_parse_datetime_message(self, connection):
- for (typ, disp) in [
+ for typ, disp in [
(Time, "time"),
(DateTime, "datetime"),
(Date, "date"),
class JSONTest(fixtures.TestBase):
-
__requires__ = ("json_type",)
__only_on__ = "sqlite"
__backend__ = True
class DefaultsTest(fixtures.TestBase, AssertsCompiledSQL):
-
__only_on__ = "sqlite"
__backend__ = True
def test_default_reflection(self, connection, metadata):
-
specs = [
(String(3), '"foo"'),
(sqltypes.NUMERIC(10, 2), "100.50"),
"table_info()",
)
def test_default_reflection_2(self):
-
db = testing.db
m = MetaData()
expected = ["'my_default'", "0"]
class DialectTest(
fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL
):
-
__only_on__ = "sqlite"
__backend__ = True
eq_(len(c2.c), 2)
def test_crud(self, connection):
-
(ct,) = self.tables("test_schema.created")
connection.execute(ct.insert(), {"id": 1, "name": "foo"})
eq_(connection.execute(ct.select()).fetchall(), [(1, "foo")])
)
def test_column_defaults_ddl(self):
-
t = Table(
"t",
MetaData(),
class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = sqlite.dialect()
def test_on_conflict_clause_column_not_null(self):
)
def test_on_conflict_clause_unique_constraint(self):
-
meta = MetaData()
t = Table(
"n",
)
def test_on_conflict_clause_primary_key(self):
-
meta = MetaData()
t = Table(
"n",
)
def test_on_conflict_clause_primary_key_constraint_from_column(self):
-
meta = MetaData()
t = Table(
"n",
)
def test_on_conflict_clause_check_constraint(self):
-
meta = MetaData()
t = Table(
"n",
)
def test_on_conflict_clause_check_constraint_from_column(self):
-
meta = MetaData()
t = Table(
"n",
)
def test_on_conflict_clause_primary_key_constraint(self):
-
meta = MetaData()
t = Table(
"n",
class MatchTest(fixtures.TestBase, AssertsCompiledSQL):
-
__only_on__ = "sqlite"
__skip_if__ = (full_text_search_missing,)
__backend__ = True
@classmethod
def setup_test_class(cls):
with testing.db.begin() as conn:
-
conn.exec_driver_sql("CREATE TABLE a1 (id INTEGER PRIMARY KEY)")
conn.exec_driver_sql("CREATE TABLE a2 (id INTEGER PRIMARY KEY)")
conn.exec_driver_sql(
@testing.fixture
def temp_table_fixture(self, connection):
-
connection.exec_driver_sql(
"CREATE TEMPORARY TABLE g "
"(x INTEGER, CONSTRAINT foo_gx UNIQUE(x))"
with mock.patch.object(
dialect, "_get_table_sql", _get_table_sql
):
-
fkeys = dialect.get_foreign_keys(None, "foo")
eq_(
fkeys,
def test_unique_constraint_named_broken_temp(
self, connection, temp_table_fixture
):
-
inspector = inspect(connection)
eq_(
inspector.get_unique_constraints("g"),
class TypeReflectionTest(fixtures.TestBase):
-
__only_on__ = "sqlite"
__backend__ = True
class OnConflictTest(AssertsCompiledSQL, fixtures.TablesTest):
-
__only_on__ = ("sqlite >= 3.24.0",)
__backend__ = True
stmt.on_conflict_do_nothing,
stmt.on_conflict_do_update,
):
-
with testing.expect_raises_message(
exc.InvalidRequestError,
"This Insert construct already has an "
eq_(res, ["sqlitetemptable"])
def test_get_temp_view_names(self, connection):
-
view = (
"CREATE TEMPORARY VIEW sqlitetempview AS "
"SELECT * FROM sqliteatable"
class DDLEventWCreateHarness(DDLEventHarness):
-
requires_table_to_exist = True
def test_straight_create_drop(
)[0]
if ddl_if_type in ("callable", "callable_w_state"):
-
if ddl_if_type == "callable":
check_state = None
else:
)
def test_ddl_hastable(self, plain_tables, connection):
-
map_ = {
None: config.test_schema,
"foo": config.test_schema,
)
def test_via_engine(self, plain_tables, metadata):
-
with config.db.begin() as connection:
metadata.create_all(connection)
stmt = str(select(1).compile(dialect=e1.dialect))
with e1.connect() as conn:
-
result = conn.exec_driver_sql(stmt)
eq_(result.scalar(), 1)
@testing.emits_warning("The garbage collector is trying to clean up")
def test_execute_events(self):
-
stmts = []
cursor_stmts = []
with patch.object(
engine.dialect, "is_disconnect", Mock(return_value=orig_error)
):
-
with engine.connect() as c:
try:
c.exec_driver_sql("SELECT x FROM nonexistent")
with patch.object(
engine.dialect, "is_disconnect", Mock(return_value=orig_error)
):
-
with engine.connect() as c:
target_crec = c.connection._connection_record
try:
assert_raises(MySpecialException, conn.get_isolation_level)
def test_handle_error_not_on_connection(self, connection):
-
with expect_raises_message(
tsa.exc.InvalidRequestError,
r"The handle_error\(\) event hook as of SQLAlchemy 2.0 is "
dbapi = MockDBAPI(
foober=12, lala=18, hoho={"this": "dict"}, fooz="somevalue"
)
- for (value, expected) in [
+ for value, expected in [
("rollback", pool.reset_rollback),
("commit", pool.reset_commit),
(None, pool.reset_none),
with mock.patch.object(
engine.dialect.loaded_dbapi, "connect", mock_connect
):
-
# set up initial connection. pre_ping works on subsequent connects
engine.connect().close()
def test_reconnect(self):
with self.engine.connect() as conn:
-
eq_(conn.execute(select(1)).scalar(), 1)
assert not conn.closed
@testing.crashes("oracle", "FIXME: unknown, confirm not fails_on")
@testing.requires.check_constraints
def test_reserved(self, connection, metadata):
-
# check a table that uses a SQL reserved name doesn't cause an
# error
@classmethod
def define_tables(cls, metadata):
-
no_multibyte_period = {("plain", "col_plain", "ix_plain")}
no_has_table = [
(
reflected = set(inspect(connection).get_table_names())
if not names.issubset(reflected) and hasattr(unicodedata, "normalize"):
-
# Python source files in the utf-8 coding seem to
# normalize literals as NFC (and the above are
# explicitly NFC). Maybe this database normalizes NFD
@testing.requires.cross_schema_fk_reflection
@testing.requires.implicit_default_schema
def test_blank_schema_arg(self, connection, metadata):
-
Table(
"some_table",
metadata,
@testing.requires.schemas
def test_explicit_default_schema(self, connection, metadata):
-
schema = connection.dialect.default_schema_name
assert bool(schema)
m = MetaData()
def column_reflect(insp, table, column_info):
-
if column_info["name"] == "q":
column_info["key"] = "qyz"
elif column_info["name"] == "x":
savepoint = [None]
def go():
-
with connection.begin_nested() as sp:
savepoint[0] = sp
# force the "commit" of the savepoint that occurs
@testing.requires.autocommit
def test_no_autocommit_w_begin(self):
-
with testing.db.begin() as conn:
assert_raises_message(
exc.InvalidRequestError,
@testing.requires.autocommit
def test_no_autocommit_w_autobegin(self):
-
with testing.db.connect() as conn:
conn.execute(select(1))
users = self.tables.users
with testing.db.connect() as conn:
-
assert not conn.in_transaction()
conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
def test_rollback_inactive(self):
users = self.tables.users
with testing.db.connect() as conn:
-
conn.execute(users.insert(), {"user_id": 1, "user_name": "name"})
conn.commit()
assert False, "no non-default isolation level available"
def test_engine_param_stays(self):
-
eng = testing_engine()
with eng.connect() as conn:
isolation_level = eng.dialect.get_isolation_level(
transactional = True
def reset_characteristic(self, dialect, dbapi_conn):
-
dialect.reset_foo(dbapi_conn)
def set_characteristic(self, dialect, dbapi_conn, value):
-
dialect.set_foo(dbapi_conn, value)
def get_characteristic(self, dialect, dbapi_conn):
return base.Engine(pool, FooDialect(), u), connection
def test_engine_param_stays(self, characteristic_fixture):
-
engine, connection = characteristic_fixture
foo_level = engine.dialect.get_foo(engine.connect().connection)
)
def test_per_engine(self, characteristic_fixture):
-
engine, connection = characteristic_fixture
pool, dialect, url = engine.pool, engine.dialect, engine.url
eq_(eng.dialect.get_foo(conn.connection), "new_value")
def test_per_option_engine(self, characteristic_fixture):
-
engine, connection = characteristic_fixture
eng = engine.execution_options(foo="new_value")
async def run_test(subject, trans_on_subject, execute_on_subject):
async with subject.begin() as trans:
-
if begin_nested:
if not config.requirements.savepoints.enabled:
config.skip_test("savepoints not enabled")
@async_test
async def test_connection_info(self, async_engine):
-
async with async_engine.connect() as conn:
conn.info["foo"] = "bar"
@async_test
async def test_connection_eq_ne(self, async_engine):
-
async with async_engine.connect() as conn:
c2 = _async_engine.AsyncConnection(
async_engine, conn.sync_connection
@async_test
async def test_transaction_eq_ne(self, async_engine):
-
async with async_engine.connect() as conn:
t1 = await conn.begin()
"do_rollback",
mock.Mock(side_effect=Exception("can't run rollback")),
), mock.patch("sqlalchemy.util.warn") as m:
-
_finalize_fairy(
None, rec, pool, ref, echo, transaction_was_reset=False
)
@async_test
async def test_get_raw_connection(self, async_connection):
-
pooled = await async_connection.get_raw_connection()
is_(pooled, async_connection.sync_connection.connection)
@async_test
async def test_connection_not_started(self, async_engine):
-
conn = async_engine.connect()
testing.assert_raises_message(
asyncio_exc.AsyncContextNotStarted,
users = self.tables.users
async with async_engine.begin() as conn:
-
savepoint = await conn.begin_nested()
await conn.execute(delete(users))
await savepoint.rollback()
users = self.tables.users
async with async_engine.begin() as conn:
-
savepoint = await conn.begin_nested()
await conn.execute(delete(users))
await savepoint.commit()
@async_test
async def test_conn_transaction_not_started(self, async_engine):
-
async with async_engine.connect() as conn:
trans = conn.begin()
with expect_raises_message(
async def test_get_transaction(self, async_engine):
async with async_engine.connect() as conn:
async with conn.begin() as trans:
-
is_(trans.connection, conn)
is_(conn.get_transaction(), trans)
)
def test_regenerate_connection(self, connection):
-
async_connection = AsyncConnection._retrieve_proxy_for_target(
connection
)
eq_(len(ReversibleProxy._proxy_objects), 0)
def test_regen_conn_but_not_engine(self, async_engine):
-
with async_engine.sync_engine.connect() as sync_conn:
-
async_conn = AsyncConnection._retrieve_proxy_for_target(sync_conn)
async_conn2 = AsyncConnection._retrieve_proxy_for_target(sync_conn)
@async_test
async def test_orm_sessionmaker_block_one(self, async_engine):
-
User = self.classes.User
maker = sessionmaker(async_engine, class_=AsyncSession)
@async_test
async def test_orm_sessionmaker_block_two(self, async_engine):
-
User = self.classes.User
maker = sessionmaker(async_engine, class_=AsyncSession)
@async_test
async def test_async_sessionmaker_block_one(self, async_engine):
-
User = self.classes.User
maker = async_sessionmaker(async_engine)
@async_test
async def test_async_sessionmaker_block_two(self, async_engine):
-
User = self.classes.User
maker = async_sessionmaker(async_engine)
@async_test
async def test_trans(self, async_session, async_engine):
async with async_engine.connect() as outer_conn:
-
User = self.classes.User
async with async_session.begin():
-
eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
u1 = User(name="u1")
@async_test
async def test_commit_as_you_go(self, async_session, async_engine):
async with async_engine.connect() as outer_conn:
-
User = self.classes.User
eq_(await outer_conn.scalar(select(func.count(User.id))), 0)
@async_test
async def test_trans_noctx(self, async_session, async_engine):
async with async_engine.connect() as outer_conn:
-
User = self.classes.User
trans = await async_session.begin()
User = self.classes.User
async with async_engine.connect() as conn:
-
await conn.begin()
await conn.begin_nested()
async def test_new_style_active_history(
self, async_session, one_to_one_fixture, _legacy_inactive_history_style
):
-
A, B = await one_to_one_fixture(_legacy_inactive_history_style)
a1 = A()
@async_test
async def test_get_transaction(self, async_session):
-
is_(async_session.get_transaction(), None)
is_(async_session.get_nested_transaction(), None)
def test_inspect_session_no_asyncio_imported(self):
with mock.patch("sqlalchemy.orm.state._async_provider", None):
-
User = self.classes.User
s1 = Session(testing.db)
)
class Employee(Base, fixtures.ComparableEntity):
-
__table__ = punion
__mapper_args__ = {"polymorphic_on": punion.c.type}
class Engineer(Employee):
-
__table__ = engineers
__mapper_args__ = {
"polymorphic_identity": "engineer",
}
class Manager(Employee):
-
__table__ = managers
__mapper_args__ = {
"polymorphic_identity": "manager",
"""test the example from the declarative docs."""
class Employee(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
name = Column(String(50))
class Engineer(Employee):
-
__tablename__ = "engineers"
__mapper_args__ = {"concrete": True}
id = Column(
name = Column(String(50))
class Manager(Employee):
-
__tablename__ = "manager"
__mapper_args__ = {"concrete": True}
id = Column(
)
def _roundtrip(self):
-
User = Base.registry._class_registry["User"]
Address = Base.registry._class_registry["Address"]
)
def _roundtrip(self):
-
User = Base.registry._class_registry["User"]
Item = Base.registry._class_registry["Item"]
)
async with async_session() as session:
-
result = await session.execute(select(A).order_by(A.id))
r: ScalarResult[A] = result.scalars()
def do_something_with_mapped_class(
cls_: MappedClassProtocol[Employee],
) -> None:
-
# EXPECTED_TYPE: Select[Any]
reveal_type(cls_.__table__.select())
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: InstrumentedAttribute[datetime]
reveal_type(Engineer.start_date)
u1 = User()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: str
reveal_type(User.__tablename__)
session.commit()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: AppenderQuery[Address]
reveal_type(u.addresses)
count = u.addresses.count()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: int
reveal_type(count)
address = u.addresses.filter(Address.email_address.like("xyz")).one()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: Address
reveal_type(address)
current_addresses = list(u.addresses)
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: list[Address]
reveal_type(current_addresses)
c1 = cols[0]
if typing.TYPE_CHECKING:
-
# EXPECTED_RE_TYPE: sqlalchemy.engine.base.Engine
reveal_type(e)
def regular() -> None:
-
e = create_engine("sqlite://")
# EXPECTED_TYPE: Engine
reveal_type(e)
with e.connect() as conn:
-
# EXPECTED_TYPE: Connection
reveal_type(conn)
reveal_type(result)
with e.begin() as conn:
-
# EXPECTED_TYPE: Connection
reveal_type(conn)
reveal_type(e)
async with e.connect() as conn:
-
# EXPECTED_TYPE: AsyncConnection
reveal_type(conn)
reveal_type(ctx_async_scalar_result)
async with e.begin() as conn:
-
# EXPECTED_TYPE: AsyncConnection
reveal_type(conn)
class FirstNameLastName(FirstNameOnly):
-
last_name: Mapped[str]
@FirstNameOnly.name.getter
if TYPE_CHECKING:
-
# EXPECTED_TYPE: StatementLambdaElement
reveal_type(s5)
# test #9125
for row in sess.query(User.id, User.name):
-
# EXPECTED_TYPE: Row[Tuple[int, str]]
reveal_type(row)
for uobj1 in sess.query(User):
-
# EXPECTED_TYPE: User
reveal_type(uobj1)
if typing.TYPE_CHECKING:
-
# as far as if this is ColumnElement, BinaryElement, SQLCoreOperations,
# that might change. main thing is it's SomeSQLColThing[bool] and
# not 'bool' or 'Any'.
def t_legacy_query_cols_2_with_entities() -> None:
-
q1 = session.query(User)
# EXPECTED_TYPE: Query[User]
def t_from_statement() -> None:
-
t = text("select * from user")
# EXPECTED_TYPE: TextClause
def t_entity_varieties() -> None:
-
a1 = aliased(User)
s1 = select(User.id, User, User.name).where(User.name == "foo")
def t_ambiguous_result_type_two() -> None:
-
stmt = select(column("q"))
# EXPECTED_TYPE: Select[Tuple[Any]]
def t_aliased() -> None:
-
a1 = aliased(User)
s1 = select(a1)
session.commit()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: WriteOnlyCollection[Address]
reveal_type(u.addresses)
).one()
if typing.TYPE_CHECKING:
-
# EXPECTED_TYPE: Address
reveal_type(address)
id_="isaa",
)
def test_files(self, mypy_runner, filename, path, use_plugin):
-
expected_messages = []
expected_re = re.compile(r"\s*# EXPECTED(_MYPY)?(_RE)?(_TYPE)?: (.+)")
py_ver_re = re.compile(r"^#\s*PYTHON_VERSION\s?>=\s?(\d+\.\d+)")
{"e", "f", "g"},
set(),
):
-
eq_(p1.children.union(other), control.union(other))
eq_(p1.children.difference(other), control.difference(other))
eq_((p1.children - other), (control - other))
class LazyLoadTest(fixtures.MappedTest):
@classmethod
def define_tables(cls, metadata):
-
Table(
"Parent",
metadata,
cls.mapper_registry.map_imperatively(B, b)
def test_update_one_elem_dict(self):
-
a1 = self.classes.A()
a1.elements.update({("B", 3): "elem2"})
eq_(a1.elements, {("B", 3): "elem2"})
class ScalarRemoveListObjectCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = True
cascade_scalar_deletes = True
class ScalarRemoveScalarObjectCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = True
cascade_scalar_deletes = True
class ScalarRemoveListScalarCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = False
cascade_scalar_deletes = True
class ScalarRemoveScalarScalarCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = False
cascade_scalar_deletes = True
class ScalarRemoveListObjectNoCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = True
cascade_scalar_deletes = False
class ScalarRemoveScalarObjectNoCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = True
cascade_scalar_deletes = False
class ScalarRemoveListScalarNoCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = False
cascade_scalar_deletes = False
class ScalarRemoveScalarScalarNoCascade(
ScalarRemoveTest, fixtures.DeclarativeMappedTest
):
-
run_create_tables = None
useobject = False
cascade_scalar_deletes = False
@classmethod
def setup_classes(cls):
-
Base = cls.DeclarativeBasic
class A(Base):
def test_spoiled(self):
with self._fixture() as (sess, bq):
-
result = bq.spoil()(sess).with_post_criteria(
lambda q: q.execution_options(yes=True)
)
)
class MyClass:
-
# This proves that a staticmethod will work here; don't
# flatten this back to a class assignment!
def __sa_instrumentation_manager__(cls):
return self.dbs
def teardown_test(self):
-
testing_reaper.checkin_all()
for i in range(1, 5):
os.remove("shard%d_%s.db" % (i, provision.FOLLOWER_IDENT))
class IndexPropertyArrayTest(fixtures.DeclarativeMappedTest):
-
__requires__ = ("array_type",)
__backend__ = True
class IndexPropertyJsonTest(fixtures.DeclarativeMappedTest):
-
# TODO: remove reliance on "astext" for these tests
__requires__ = ("json_type",)
__only_on__ = "postgresql"
eq_(j.other, 42)
def test_modified(self):
-
Json = self.classes.Json
s = Session(testing.db)
class SerializeTest(AssertsCompiledSQL, fixtures.MappedTest):
-
run_setup_mappers = "once"
run_inserts = "once"
run_deletes = None
assert A.__mapper__.primary_key[1] is A.__table__.c.b
def test_mapper_pk_arg_degradation_no_col(self, decl_base):
-
with expect_raises_message(
exc.ArgumentError,
"Can't determine primary_key column 'q' - no attribute is "
@testing.variation("proptype", ["relationship", "colprop"])
def test_mapper_pk_arg_degradation_is_not_a_col(self, decl_base, proptype):
-
with expect_raises_message(
exc.ArgumentError,
"Can't determine primary_key column 'b'; property does "
Base = declarative_base()
class User(Base):
-
__tablename__ = "users"
__table_args__ = {"schema": "fooschema"}
)
class Prop(Base):
-
__tablename__ = "props"
__table_args__ = {"schema": "fooschema"}
@reg.mapped
class User:
-
__tablename__ = "users"
__table_args__ = {"schema": "fooschema"}
@reg.mapped
class Prop:
-
__tablename__ = "props"
__table_args__ = {"schema": "fooschema"}
def test_as_declarative(self, metadata):
class User(fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
addresses = relationship("Address", backref="user")
class Address(fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_map_declaratively(self, metadata):
class User(fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
addresses = relationship("Address", backref="user")
class Address(fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
id = Column(Integer, primary_key=True)
def test_non_sql_expression_warning_five(self):
-
# test for #9537
with assertions.expect_warnings(
r"Attribute 'x' on class <class .*Foo5.* appears to be a "
def test_reserved_identifiers(
self, decl_base, name, expect_raise, attrtype
):
-
if attrtype.column:
clsdict = {
"__tablename__": "user",
def test_string_dependency_resolution(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class Foo(Base, fixtures.ComparableEntity):
-
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
rel = relationship("User", primaryjoin="User.addresses==Foo.id")
def test_string_dependency_resolution_synonym(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
eq_(sess.query(User).filter(User.name == "ed").one(), User(name="ed"))
class Foo(Base, fixtures.ComparableEntity):
-
__tablename__ = "foo"
id = Column(Integer, primary_key=True)
_user_id = Column(Integer)
name = Column(String(50))
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
"""test that full tinkery expressions work as written"""
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
addresses = relationship(
)
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
def test_string_dependency_resolution_module_qualified(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
addresses = relationship(
)
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
user_id = Column(Integer, ForeignKey("users.id"))
def test_string_dependency_resolution_in_backref(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
)
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
email = Column(String(50))
def test_string_dependency_resolution_tables(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
)
class Prop(Base, fixtures.ComparableEntity):
-
__tablename__ = "props"
id = Column(Integer, primary_key=True)
name = Column(String(50))
def test_string_dependency_resolution_table_over_class(self):
# test for second half of #5774
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
)
class Prop(Base, fixtures.ComparableEntity):
-
__tablename__ = "props"
id = Column(Integer, primary_key=True)
name = Column(String(50))
def test_string_dependency_resolution_class_over_table(self):
# test for second half of #5774
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(Integer, primary_key=True)
name = Column(String(50))
def test_uncompiled_attributes_in_relationship(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column(Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_nice_dependency_error(self):
class User(Base):
-
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
addresses = relationship("Address")
class Address(Base):
-
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
foo = sa.orm.column_property(User.id == 5)
def test_nice_dependency_error_works_with_hasattr(self):
class User(Base):
-
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
addresses = relationship("Address")
)
def test_uses_get_on_class_col_fk(self):
-
# test [ticket:1492]
class Topic(Base):
-
__tablename__ = "topic"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class Detail(Base):
-
__tablename__ = "detail"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
self, require_metaclass, assert_user_address_mapping, _column
):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
User.addresses = relationship("Address", backref="user")
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = _column(Integer, primary_key=True)
@testing.combinations(Column, mapped_column, argnames="_column")
def test_add_prop_manual(self, assert_user_address_mapping, _column):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = _column("id", Integer, primary_key=True)
)
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = _column(Integer, primary_key=True)
def test_eager_order_by(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column("user_id", Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_order_by_multi(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column("user_id", Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
u.addresses
def test_oops(self):
-
with testing.expect_warnings(
"Ignoring declarative-like tuple value of " "attribute 'name'"
):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
name = (Column("name", String(50)),)
def test_table_args_no_dict(self):
class Foo1(Base):
-
__tablename__ = "foo"
__table_args__ = (ForeignKeyConstraint(["id"], ["foo.bar"]),)
id = Column("id", Integer, primary_key=True)
def test_table_args_type(self):
def err():
class Foo1(Base):
-
__tablename__ = "foo"
__table_args__ = ForeignKeyConstraint(["id"], ["foo.id"])
id = Column("id", Integer, primary_key=True)
def test_table_args_none(self):
class Foo2(Base):
-
__tablename__ = "foo"
__table_args__ = None
id = Column("id", Integer, primary_key=True)
def test_table_args_dict_format(self):
class Foo2(Base):
-
__tablename__ = "foo"
__table_args__ = {"mysql_engine": "InnoDB"}
id = Column("id", Integer, primary_key=True)
def test_table_args_tuple_format(self):
class Foo2(Base):
-
__tablename__ = "foo"
__table_args__ = {"mysql_engine": "InnoDB"}
id = Column("id", Integer, primary_key=True)
class Bar(Base):
-
__tablename__ = "bar"
__table_args__ = (
ForeignKeyConstraint(["id"], ["foo.id"]),
def test_expression(self, require_metaclass):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
addresses = relationship("Address", backref="user")
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_useless_declared_attr(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column("user_id", Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_column(self, require_metaclass):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_column_properties(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column(Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_column_properties_2(self):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(Integer, primary_key=True)
email = Column(String(50))
user_id = Column(Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
def test_deferred(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_synonym_inline(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
from sqlalchemy.orm.properties import ColumnProperty
class CustomCompare(ColumnProperty.Comparator):
-
__hash__ = None
def __eq__(self, other):
return self.__clause_element__() == other + " FOO"
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_synonym_added(self, require_metaclass):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_reentrant_compile_via_foreignkey(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
addresses = relationship("Address", backref="user")
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_relationship_reference(self, require_metaclass):
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
user_id = Column("user_id", Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
def test_pk_with_fk_init(self):
class Bar(Base):
-
__tablename__ = "bar"
id = sa.Column(
sa.Integer, sa.ForeignKey("foo.id"), primary_key=True
ex = sa.Column(sa.Integer, primary_key=True)
class Foo(Base):
-
__tablename__ = "foo"
id = sa.Column(sa.Integer, primary_key=True)
bars = sa.orm.relationship(Bar)
meta.create_all(testing.db)
class MyObj(Base):
-
__table__ = Table("t1", Base.metadata, autoload_with=testing.db)
sess = fixture_session()
def test_synonym_for(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__requires__ = ("predictable_gc",)
def test_same_module_same_name(self):
-
base = registry()
f1 = MockClass(base, "foo.bar.Foo")
f2 = MockClass(base, "foo.bar.Foo")
argnames="name",
)
def test_name_resolution_failures(self, name, registry):
-
Base = registry.generate_base()
f1 = MockClass(registry, "existent.Foo")
def test_warn_on_non_dc_mixin(self):
class _BaseMixin:
-
create_user: Mapped[int] = mapped_column()
update_user: Mapped[Optional[int]] = mapped_column(
default=None, init=False
b: Mapped[int] = mapped_column(default=1)
class Child(Mixin, dc_decl_base):
-
__tablename__ = "child"
_: dataclasses.KW_ONLY
)
if dataclass_scope.on_base_class:
-
with non_dc_mixin():
class Book(Mixin, MappedAsDataclass, Base, **klass_kw):
expected_annotations[Book] = {"id": int, "polymorphic_type": str}
if dataclass_scope.on_sub_class:
-
with non_dc_mixin():
class Novel(MappedAsDataclass, Book, **klass_kw):
description: Mapped[Optional[str]]
else:
-
with non_dc_mixin():
class Novel(Book):
@testing.emits_warning(r".*does not indicate a 'polymorphic_identity'")
def test_we_must_copy_mapper_args(self):
class Person(Base):
-
__tablename__ = "people"
id = Column(Integer, primary_key=True)
discriminator = Column("type", String(50))
}
class Engineer(Person):
-
primary_language = Column(String(50))
assert "inherits" not in Person.__mapper_args__
def test_we_must_only_copy_column_mapper_args(self):
class Person(Base):
-
__tablename__ = "people"
id = Column(Integer, primary_key=True)
a = Column(Integer)
def test_custom_join_condition(self):
class Foo(Base):
-
__tablename__ = "foo"
id = Column("id", Integer, primary_key=True)
class Bar(Foo):
-
__tablename__ = "bar"
bar_id = Column("id", Integer, primary_key=True)
foo_id = Column("foo_id", Integer)
def test_joined(self):
class Company(Base, fixtures.ComparableEntity):
-
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
employees = relationship("Person")
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
id = Column(
primary_language = Column("primary_language", String(50))
class Manager(Person):
-
__tablename__ = "managers"
__mapper_args__ = {"polymorphic_identity": "manager"}
id = Column(
def test_add_subcol_after_the_fact(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
id = Column(
def test_add_parentcol_after_the_fact(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
def test_add_sub_parentcol_after_the_fact(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
)
class Admin(Engineer):
-
__tablename__ = "admins"
__mapper_args__ = {"polymorphic_identity": "admin"}
workstation = Column(String(50))
def test_subclass_mixin(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
__mapper_args__ = {"polymorphic_on": discriminator}
class MyMixin:
-
pass
class Engineer(MyMixin, Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
id = Column(
def test_with_undefined_foreignkey(self):
class Parent(Base):
-
__tablename__ = "parent"
id = Column("id", Integer, primary_key=True)
tp = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=tp)
class Child1(Parent):
-
__tablename__ = "child1"
id = Column(
"id", Integer, ForeignKey("parent.id"), primary_key=True
# though child2 doesn't exist yet
class Child2(Parent):
-
__tablename__ = "child2"
id = Column(
"id", Integer, ForeignKey("parent.id"), primary_key=True
class."""
class Company(Base, fixtures.ComparableEntity):
-
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
employees = relationship("Person")
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
Base.metadata.create_all(testing.db)
"""
class Company(Base, fixtures.ComparableEntity):
-
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
employees = relationship("Person")
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
golf_swing = Column(String(50))
"""test the somewhat unusual case of [ticket:3341]"""
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
def test_joined_from_single(self):
class Company(Base, fixtures.ComparableEntity):
-
__tablename__ = "companies"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
employees = relationship("Person")
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
golf_swing = Column(String(50))
class Engineer(Person):
-
__tablename__ = "engineers"
__mapper_args__ = {"polymorphic_identity": "engineer"}
id = Column(Integer, ForeignKey("people.id"), primary_key=True)
def test_single_from_joined_colsonsub(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_add_deferred(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
"id", Integer, primary_key=True, test_needs_autoincrement=True
"""
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language_id = Column(Integer, ForeignKey("languages.id"))
primary_language = relationship("Language")
class Language(Base, fixtures.ComparableEntity):
-
__tablename__ = "languages"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_single_three_levels(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
class JuniorEngineer(Engineer):
-
__mapper_args__ = {"polymorphic_identity": "junior_engineer"}
nerf_gun = Column(String(50))
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
golf_swing = Column(String(50))
def test_single_detects_conflict(self):
class Person(Base):
-
__tablename__ = "people"
id = Column(Integer, primary_key=True)
name = Column(String(50))
__mapper_args__ = {"polymorphic_on": discriminator}
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column(String(50))
def go():
class Manager(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
golf_swing = Column(String(50))
primary_language = Column(String(50))
def go():
class Salesman(Person):
-
__mapper_args__ = {"polymorphic_identity": "manager"}
name = Column(String(50))
def test_single_no_special_cols(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
def go():
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column("primary_language", String(50))
foo_bar = Column(Integer, primary_key=True)
def test_single_no_table_args(self):
class Person(Base, fixtures.ComparableEntity):
-
__tablename__ = "people"
id = Column("id", Integer, primary_key=True)
name = Column("name", String(50))
def go():
class Engineer(Person):
-
__mapper_args__ = {"polymorphic_identity": "engineer"}
primary_language = Column("primary_language", String(50))
def test_simple_wbase(self):
class MyMixin:
-
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
return "bar" + str(self.id)
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
name = Column(String(100), nullable=False, index=True)
def test_simple_wdecorator(self):
class MyMixin:
-
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
@mapper_registry.mapped
class MyModel(MyMixin):
-
__tablename__ = "test"
name = Column(String(100), nullable=False, index=True)
def test_declarative_mixin_decorator(self):
@declarative_mixin
class MyMixin:
-
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
@mapper_registry.mapped
class MyModel(MyMixin):
-
__tablename__ = "test"
name = Column(String(100), nullable=False, index=True)
@testing.combinations(Column, mapped_column, argnames="_column")
def test_unique_column(self, _column):
class MyMixin:
-
id = _column(Integer, primary_key=True)
value = _column(String, unique=True)
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
assert MyModel.__table__.c.value.unique
@testing.combinations(Column, mapped_column, argnames="_column")
def test_hierarchical_bases_wbase(self, _column):
class MyMixinParent:
-
id = _column(
Integer, primary_key=True, test_needs_autoincrement=True
)
return "bar" + str(self.id)
class MyMixin(MyMixinParent):
-
baz = _column(String(100), nullable=False, index=True)
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
name = _column(String(100), nullable=False, index=True)
@testing.combinations(Column, mapped_column, argnames="_column")
def test_hierarchical_bases_wdecorator(self, _column):
class MyMixinParent:
-
id = _column(
Integer, primary_key=True, test_needs_autoincrement=True
)
return "bar" + str(self.id)
class MyMixin(MyMixinParent):
-
baz = _column(String(100), nullable=False, index=True)
@mapper_registry.mapped
class MyModel(MyMixin):
-
__tablename__ = "test"
name = Column(String(100), nullable=False, index=True)
def go():
class MyModel(Base, MyRelMixin):
-
__tablename__ = "foo"
assert_raises(sa.exc.InvalidRequestError, go)
eq_(class_mapper(Engineer).polymorphic_identity, "Engineer")
def test_mapper_args_declared_attr_two(self):
-
# same as test_mapper_args_declared_attr, but we repeat
# ComputedMapperArgs on both classes for no apparent reason.
return {"polymorphic_identity": cls.__name__}
class Person(Base, ComputedMapperArgs):
-
__tablename__ = "people"
id = Column(Integer, primary_key=True)
discriminator = Column("type", String(50))
def test_table_args_composite(self):
class MyMixin1:
-
__table_args__ = {"info": {"baz": "bob"}}
class MyMixin2:
-
__table_args__ = {"info": {"foo": "bar"}}
class MyModel(Base, MyMixin1, MyMixin2):
-
__tablename__ = "test"
@declared_attr
def test_mapper_args_inherited(self):
class MyMixin:
-
__mapper_args__ = {"always_refresh": True}
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(Integer, primary_key=True)
class MyMixin:
@declared_attr
def __mapper_args__(cls):
-
# tenuous, but illustrates the problem!
if cls.__name__ == "MyModel":
return dict(always_refresh=False)
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(Integer, primary_key=True)
def test_mapper_args_polymorphic_on_inherited(self):
class MyMixin:
-
type_ = Column(String(50))
__mapper_args__ = {"polymorphic_on": type_}
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(Integer, primary_key=True)
def test_mapper_args_overridden(self):
class MyMixin:
-
__mapper_args__ = dict(always_refresh=True)
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
__mapper_args__ = dict(always_refresh=False)
id = Column(Integer, primary_key=True)
def test_mapper_args_composite(self):
class MyMixin1:
-
type_ = Column(String(50))
__mapper_args__ = {"polymorphic_on": type_}
class MyMixin2:
-
__mapper_args__ = {"always_refresh": True}
class MyModel(Base, MyMixin1, MyMixin2):
-
__tablename__ = "test"
@declared_attr
def test_single_table_no_propagation(self):
class IdColumn:
-
id = Column(Integer, primary_key=True)
class Generic(Base, IdColumn):
-
__tablename__ = "base"
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
value = Column(Integer())
class Specific(Generic):
-
__mapper_args__ = dict(polymorphic_identity="specific")
assert Specific.__table__ is Generic.__table__
id = Column(Integer, primary_key=True)
class Generic(Base, CommonMixin):
-
discriminator = Column("python_type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
class Specific(Generic):
-
__mapper_args__ = dict(polymorphic_identity="specific")
id = Column(Integer, ForeignKey("generic.id"), primary_key=True)
timestamp = Column(Integer)
class BaseType(Base, CommonMixin):
-
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
value = Column(Integer())
class Single(BaseType):
-
__tablename__ = None
__mapper_args__ = dict(polymorphic_identity="type1")
class Joined(BaseType):
-
__mapper_args__ = dict(polymorphic_identity="type2")
id = Column(Integer, ForeignKey("basetype.id"), primary_key=True)
return cls.__name__.lower()
class BaseType(Base, NoJoinedTableNameMixin):
-
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
value = Column(Integer())
class Specific(BaseType):
-
__mapper_args__ = dict(polymorphic_identity="specific")
eq_(BaseType.__table__.name, "basetype")
return cls.__name__.lower()
class BaseType(Base, TableNameMixin):
-
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
value = Column(Integer())
class Specific(BaseType, TableNameMixin):
-
__mapper_args__ = dict(polymorphic_identity="specific")
id = Column(Integer, ForeignKey("basetype.id"), primary_key=True)
def test_single_back_propagate(self):
class ColumnMixin:
-
timestamp = Column(Integer)
class BaseType(Base):
-
__tablename__ = "foo"
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
class Specific(BaseType, ColumnMixin):
-
__mapper_args__ = dict(polymorphic_identity="specific")
eq_(list(BaseType.__table__.c.keys()), ["type", "id", "timestamp"])
def test_table_in_model_and_same_column_in_mixin(self):
class ColumnMixin:
-
data = Column(Integer)
class Model(Base, ColumnMixin):
-
__table__ = Table(
"foo",
Base.metadata,
def go():
class Model(Base, ColumnMixin):
-
__table__ = Table(
"foo",
Base.metadata,
)
def test_table_in_model_and_different_named_alt_key_column_in_mixin(self):
-
# here, the __table__ has a column 'tada'. We disallow
# the add of the 'foobar' column, even though it's
# keyed to 'tada'.
def go():
class Model(Base, ColumnMixin):
-
__table__ = Table(
"foo",
Base.metadata,
def test_table_in_model_overrides_different_typed_column_in_mixin(self):
class ColumnMixin:
-
data = Column(String)
class Model(Base, ColumnMixin):
-
__table__ = Table(
"foo",
Base.metadata,
def test_mixin_column_ordering(self):
class Foo:
-
col1 = Column(Integer)
col3 = Column(Integer)
class Bar:
-
col2 = Column(Integer)
col4 = Column(Integer)
class Model(Base, Foo, Bar):
-
id = Column(Integer, primary_key=True)
__tablename__ = "model"
return column_property(Column("prop", String(50)))
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class MyOtherModel(Base, MyMixin):
-
__tablename__ = "othertest"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
argnames="clstype",
)
def test_column_property_col_ref(self, decl_base, clstype):
-
if clstype == "anno":
class SomethingMixin:
return column_property(Column(String(50)))
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(Integer, primary_key=True)
__mapper_args__ = {"polymorphic_on": type_}
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(Integer, primary_key=True)
def test_column_in_mapper_args_used_multiple_times(self):
class MyMixin:
-
version_id = Column(Integer)
__mapper_args__ = {"version_id_col": version_id}
class ModelOne(Base, MyMixin):
-
__tablename__ = "m1"
id = Column(Integer, primary_key=True)
class ModelTwo(Base, MyMixin):
-
__tablename__ = "m2"
id = Column(Integer, primary_key=True)
return deferred(Column("data", String(50)))
class MyModel(Base, MyMixin):
-
__tablename__ = "test"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
return relationship("Target")
class Foo(Base, RefTargetMixin):
-
__tablename__ = "foo"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class Bar(Base, RefTargetMixin):
-
__tablename__ = "bar"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
)
class Target(Base):
-
__tablename__ = "target"
id = Column(
Integer, primary_key=True, test_needs_autoincrement=True
def test_basic(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
__autoload_with__ = testing.db
addresses = relationship("Address", backref="user")
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
__autoload_with__ = testing.db
def test_rekey_wbase(self):
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
__autoload_with__ = testing.db
nom = Column("name", String(50), key="nom")
addresses = relationship("Address", backref="user")
class Address(Base, fixtures.ComparableEntity):
-
__tablename__ = "addresses"
__autoload_with__ = testing.db
def test_rekey_wdecorator(self):
@registry.mapped
class User(fixtures.ComparableMixin):
-
__tablename__ = "users"
__autoload_with__ = testing.db
nom = Column("name", String(50), key="nom")
@registry.mapped
class Address(fixtures.ComparableMixin):
-
__tablename__ = "addresses"
__autoload_with__ = testing.db
def test_supplied_fk(self):
class IMHandle(Base, fixtures.ComparableEntity):
-
__tablename__ = "imhandles"
__autoload_with__ = testing.db
user_id = Column("user_id", Integer, ForeignKey("users.id"))
class User(Base, fixtures.ComparableEntity):
-
__tablename__ = "users"
__autoload_with__ = testing.db
handles = relationship("IMHandle", backref="user")
(BIGINT().with_variant(String(), "some_other_dialect")),
)
def test_type_map_varieties(self, typ):
-
Base = declarative_base(type_annotation_map={int: typ})
class MyClass(Base):
data: Mapped[MyClass] = mapped_column()
def test_construct_lhs_sqlalchemy_type(self, decl_base):
-
with expect_raises_message(
sa_exc.ArgumentError,
"The type provided inside the 'data' attribute Mapped "
def test_construct_nullability_overrides(
self, decl_base, include_rhs_type, use_mixin
):
-
if include_rhs_type:
args = (String,)
else:
def test_pep484_newtypes_as_typemap_keys(
self, decl_base: Type[DeclarativeBase]
):
-
global str50, str30, str3050
str50 = NewType("str50", str)
r"state the generic argument using an annotation, e.g. "
r'"bs: Mapped\[List\[\'B\'\]\] = relationship\(\)"',
):
-
decl_base.registry.configure()
def test_required_no_arg(self, decl_base):
is_(a1, b1.a)
def test_wrong_annotation_type_one(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class A(decl_base):
data: "B" = relationship() # type: ignore # noqa
def test_wrong_annotation_type_two(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class B(decl_base):
data: B = relationship() # type: ignore # noqa
def test_wrong_annotation_type_three(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class B(decl_base):
@testing.variation("anno_type", ["plain", "typemap", "annotated"])
@testing.variation("inh_type", ["single", "joined"])
def test_mixin_interp_on_inh(self, decl_base, inh_type, anno_type):
-
global anno_col
if anno_type.typemap:
(BIGINT().with_variant(String(), "some_other_dialect")),
)
def test_type_map_varieties(self, typ):
-
Base = declarative_base(type_annotation_map={int: typ})
class MyClass(Base):
data: Mapped[MyClass] = mapped_column()
def test_construct_lhs_sqlalchemy_type(self, decl_base):
-
with expect_raises_message(
sa_exc.ArgumentError,
"The type provided inside the 'data' attribute Mapped "
def test_construct_nullability_overrides(
self, decl_base, include_rhs_type, use_mixin
):
-
if include_rhs_type:
args = (String,)
else:
def test_pep484_newtypes_as_typemap_keys(
self, decl_base: Type[DeclarativeBase]
):
-
# anno only: global str50, str30, str3050
str50 = NewType("str50", str)
r"state the generic argument using an annotation, e.g. "
r'"bs: Mapped\[List\[\'B\'\]\] = relationship\(\)"',
):
-
decl_base.registry.configure()
def test_required_no_arg(self, decl_base):
is_(a1, b1.a)
def test_wrong_annotation_type_one(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class A(decl_base):
data: "B" = relationship() # type: ignore # noqa
def test_wrong_annotation_type_two(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class B(decl_base):
data: B = relationship() # type: ignore # noqa
def test_wrong_annotation_type_three(self, decl_base):
-
with expect_annotation_syntax_error("A.data"):
class B(decl_base):
@testing.variation("anno_type", ["plain", "typemap", "annotated"])
@testing.variation("inh_type", ["single", "joined"])
def test_mixin_interp_on_inh(self, decl_base, inh_type, anno_type):
-
# anno only: global anno_col
if anno_type.typemap:
session = fixture_session()
with self.sql_execution_asserter() as asserter:
-
session.bulk_save_objects([User(name="A"), User(name="B")])
session.add(User(name="C"))
data = data[0]
with self.sql_execution_asserter() as asserter:
-
# tests both caching and that the data dictionaries aren't
# mutated...
class BulkDMLReturningJoinedInhTest(
BulkDMLReturningInhTest, fixtures.DeclarativeMappedTest
):
-
__requires__ = ("insert_returning", "insert_executemany_returning")
__backend__ = True
sess = fixture_session(bind=self.bind)
with self.sql_execution_asserter() as asserter:
-
if not expect_entity:
row = sess.execute(outer_stmt).one()
eq_(row, (id_, "some user"))
execution_options=dict(synchronize_session="evaluate"),
)
elif update_type == "bulk":
-
data = [
{"id": john.id, "age": 25},
{"id": jack.id, "age": 37},
@event.listens_for(session, "after_bulk_update")
def do_orm_execute(bulk_ud):
-
cols = [
c.key
for c, v in (
class InheritTest(fixtures.DeclarativeMappedTest):
-
run_inserts = "each"
run_deletes = "each"
e2 = s.query(Engineer).filter_by(name="e2").first()
with self.sql_execution_asserter() as asserter:
-
assert e2 in s
q = (
@classmethod
def insert_data(cls, connection):
-
cls.e1 = e1 = Engineer(
name="dilbert",
engineer_name="dilbert",
for parent in ["a", "b", "c"]:
for child in ["a", "b", "c"]:
for direction in [ONETOMANY, MANYTOONE]:
-
name = "Test%sTo%s%s" % (
parent,
child,
Base = cls.DeclarativeBasic
class Content(Base):
-
__tablename__ = "content"
id = Column(Integer, primary_key=True)
__mapper_args__ = {"polymorphic_on": type}
class Folder(Content):
-
__tablename__ = "folder"
id = Column(ForeignKey("content.id"), primary_key=True)
"Status", "Person", "Engineer", "Manager", "Car"
)
with sessionmaker(connection).begin() as session:
-
active = Status(name="active")
dead = Status(name="dead")
__dialect__ = "default"
def _fixture(self, use_correlate_except):
-
Base = self.DeclarativeBasic
class Superclass(Base):
)
with Session(connection) as sess:
-
grandparent_otherrel1 = OtherRelated(name="GP1")
grandparent_otherrel2 = OtherRelated(name="GP2")
"sqlalchemy.engine.result.ResultMetaData._key_fallback",
_key_fallback,
):
-
eq_(s1.sub, "s1sub")
def test_optimized_get_blank_intermediary(self, registry, connection):
class DiscriminatorOrPkNoneTest(fixtures.DeclarativeMappedTest):
-
run_setup_mappers = "once"
__dialect__ = "default"
"ASingleSubA", "ASingleSubB", "AJoinedSubA", "AJoinedSubB"
)
with Session(connection) as s:
-
s.add_all(
[ASingleSubA(), ASingleSubB(), AJoinedSubA(), AJoinedSubB()]
)
@classmethod
def make_statement(cls, *filter_cond, include_metadata=False):
-
a_stmt = (
select(
A.id,
registry.metadata.create_all(connection)
with Session(connection) as sess:
-
sess.add_all(
[
A(thing1="thing1_1"),
)
def test_primary_table_only_for_requery(self):
-
session = fixture_session()
if self.redefine_colprop:
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*", selectable=None)
eq_(
sess.query(wp).order_by(wp.person_id).all(),
)
def test_self_referential_two_point_five_future(self):
-
# TODO: this is the first test *EVER* of an aliased class of
# an aliased class. we should add many more tests for this.
# new case added in Id810f485c5f7ed971529489b84694e02a3356d6d
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*")
eq_(
sess.query(wp).order_by(wp.person_id).all(),
)
def test_correlation_w_polymorphic(self):
-
sess = fixture_session()
p_poly = with_polymorphic(Person, "*")
)
def test_correlation_w_polymorphic_flat(self):
-
sess = fixture_session()
p_poly = with_polymorphic(Person, "*", flat=True)
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*")
eq_(
sess.query(wp).order_by(wp.person_id).all(),
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*")
eq_(
sess.query(wp).order_by(wp.person_id).all(),
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*")
eq_(
sess.query(wp).order_by(wp.person_id).all(),
sess = fixture_session()
def go():
-
wp = with_polymorphic(Person, "*")
eq_(
sess.query(wp).order_by(wp.person_id).all(),
class SelfReferentialTestJoinedToBase(fixtures.MappedTest):
-
run_setup_mappers = "once"
@classmethod
class SelfReferentialJ2JTest(fixtures.MappedTest):
-
run_setup_mappers = "once"
@classmethod
class SelfReferentialJ2JSelfTest(fixtures.MappedTest):
-
run_setup_mappers = "once"
@classmethod
class M2MFilterTest(fixtures.MappedTest):
-
run_setup_mappers = "once"
run_inserts = "once"
run_deletes = None
class SameNameOnJoined(fixtures.MappedTest):
-
run_setup_mappers = "once"
run_inserts = None
run_deletes = None
)
def test_having(self):
-
Engineer, Manager = self.classes("Engineer", "Manager")
sess = fixture_session()
# so test it both ways even though when things are "working", there's
# no problem
if ensure_no_warning:
-
a = results.first()
else:
with expect_warnings(
@classmethod
def define_tables(cls, metadata):
-
Table(
"owners",
metadata,
self.mapper_registry.map_imperatively(User, users)
with testing.db.connect() as c:
-
sess = Session(bind=c)
u = User(name="u1")
)
def test_selects_w_orm_joins(self):
-
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
)
def test_orm_query_w_orm_joins(self):
-
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
)
def test_orm_query_basic(self):
-
User, Address, Keyword, Order, Item = self.classes(
"User", "Address", "Keyword", "Order", "Item"
)
return User, Address
def test_subqueryload(self, plain_fixture):
-
# subqueryload works pretty poorly w/ caching because it has
# to create a new query. previously, baked query went through a
# bunch of hoops to improve upon this and they were found to be
user_table = inspect(User).persist_selectable
def go():
-
my_thing = case((User.id > 9, 1), else_=2)
# include entities in the statement so that we test that
User, Order = self.classes.User, self.classes.Order
with fixture_session() as sess:
-
u = User(name="jack")
sess.add(u)
sess.commit()
User, Order = self.classes.User, self.classes.Order
with fixture_session() as sess:
-
u = User(name="jack")
o1 = Order()
o2m=False,
m2o=False,
):
-
Address, addresses, users, User = (
self.classes.Address,
self.tables.addresses,
fwd=False,
bkd=False,
):
-
keywords, items, item_keywords, Keyword, Item = (
self.tables.keywords,
self.tables.items,
assert_eq()
if hasattr(direct, "__setitem__") or hasattr(direct, "__setslice__"):
-
values = [creator(), creator()]
direct[:] = values
control[:] = values
self.assert_(e7 not in canary.removed)
def _test_list_dataclasses(self, typecallable):
-
creator = self.SimpleComparableEntity
@dataclasses.dataclass
self._test_list_dataclasses(list)
def test_list_setitem_with_slices(self):
-
# this is a "list" that has no __setslice__
# or __delslice__ methods. The __setitem__
# and __delitem__ must therefore accept
self.assert_(e3 in canary.data)
def _test_set_dataclasses(self, typecallable):
-
creator = self.SimpleComparableEntity
@dataclasses.dataclass
)
def _test_dict_dataclasses(self, typecallable):
-
creator = self.SimpleComparableEntity
@dataclasses.dataclass
stmt = resolve_lambda(test_case, User=User, user_table=user_table)
with Session(testing.db) as s:
-
with mock.patch.object(s, "_autoflush", wrap=True) as before_flush:
r = s.execute(stmt)
r.close()
class ExplicitWithPolymorhpicTest(
_poly_fixtures._PolymorphicUnions, AssertsCompiledSQL
):
-
__dialect__ = "default"
default_punion = (
)
def test_select_where_subclass(self):
-
Engineer = self.classes.Engineer
# what will *not* work with Core, that the ORM does for now,
)
def test_select_where_columns_subclass(self):
-
Engineer = self.classes.Engineer
# what will *not* work with Core, that the ORM does for now,
class RelNaturalAliasedJoinsTest(
_poly_fixtures._PolymorphicAliasedJoins, RelationshipNaturalInheritedTest
):
-
# this is the label style for the polymorphic selectable, not the
# outside query
label_style = LABEL_STYLE_TABLENAME_PLUS_COL
@declarative
@dataclasses.dataclass
class SpecialWidget(Widget):
-
magic: bool = False
__mapper_args__ = dict(
@dataclasses.dataclass
class SurrogateWidgetPK:
-
__sa_dataclass_metadata_key__ = "sa"
widget_id: int = dataclasses.field(
@dataclasses.dataclass
class SurrogateAccountPK:
-
__sa_dataclass_metadata_key__ = "sa"
account_id = Column(
@dataclasses.dataclass
class WidgetDC:
-
__sa_dataclass_metadata_key__ = "sa"
widget_id: int = dataclasses.field(
@dataclasses.dataclass
class AccountDC:
-
__sa_dataclass_metadata_key__ = "sa"
# relationship on mixin
@dataclasses.dataclass
class WidgetDC:
-
__sa_dataclass_metadata_key__ = "sa"
widget_id: int = dataclasses.field(
@dataclasses.dataclass
class AccountDC:
-
__sa_dataclass_metadata_key__ = "sa"
# relationship on mixin
@declarative
@dataclasses.dataclass
class BaseType(CommonMixin):
-
discriminator = Column("type", String(50))
__mapper_args__ = dict(polymorphic_on=discriminator)
id = Column(Integer, primary_key=True)
@declarative
@dataclasses.dataclass
class Single(BaseType):
-
__tablename__ = None
__mapper_args__ = dict(polymorphic_identity="type1")
@declarative
@dataclasses.dataclass
class Joined(BaseType):
-
__mapper_args__ = dict(polymorphic_identity="type2")
id = Column(
Integer, ForeignKey("basetype.id"), primary_key=True
@declarative
@dataclasses.dataclass
class Single(BaseType):
-
__tablename__ = "single"
__mapper_args__ = dict(polymorphic_identity="type1")
run_deletes = None
def test_o2m_noload(self):
-
Address, addresses, users, User = (
self.classes.Address,
self.tables.addresses,
session.commit()
def test_data_loaded(self):
-
User, Task = self.classes("User", "Task")
session = fixture_session()
self._run_double_test(1)
def test_selectin(self):
-
users, orders, User, Address, Order, addresses = (
self.tables.users,
self.tables.orders,
self._run_double_test(4)
def test_subqueryload(self):
-
users, orders, User, Address, Order, addresses = (
self.tables.users,
self.tables.orders,
_fixtures.FixtureTest,
testing.AssertsExecutionResults,
):
-
run_inserts = None
@testing.combinations(
if sess:
sess.autoflush = False
try:
-
if self.lazy == "write_only" and compare_passive is not None:
eq_(
attributes.get_history(
cls._setup_stock_mapping()
def _caching_session_fixture(self):
-
cache = {}
maker = sessionmaker(testing.db, future=True)
User, Address = self.classes("User", "Address")
with self.sql_execution_asserter(testing.db) as asserter:
-
with self._caching_session_fixture() as session:
stmt = (
select(User)
sess.execute(insert(User), orig_params)
def test_chained_events_one(self):
-
sess = Session(testing.db, future=True)
@event.listens_for(sess, "do_orm_execute")
@event.listens_for(session, "do_orm_execute")
def do_orm_execute(ctx):
-
if not ctx.is_select:
assert_raises_message(
sa.exc.InvalidRequestError,
)
def test_all_mappers_accessor_two(self):
-
sess = Session(testing.db, future=True)
canary = self._flag_fixture(sess)
)
def test_chained_events_two(self):
-
sess = Session(testing.db, future=True)
def added(ctx):
["listen_on_mapper", "listen_on_base", "listen_on_mixin"],
)
def test_mapper_config_sequence(self, decl_base, listen_type):
-
canary = Mock()
if listen_type.listen_on_mapper:
sess = fixture_session(autoflush=False)
with self.sql_execution_asserter(testing.db) as asserter:
-
if case == "contains,joined":
a1 = (
sess.query(Address)
"with a refresh"
),
):
-
sess.refresh(u, ["name"])
# id was not expired
)
def test_no_joinedload(self):
-
User = self.classes.User
s = fixture_session()
)
def test_anonymous_expression_plus_flag_aliased_join_newstyle(self):
-
User = self.classes.User
Address = self.classes.Address
addresses = self.tables.addresses
sess.expunge_all()
def go():
-
# same as above, except Order is aliased, so two adapters
# are applied by the eager loader
q4,
q5,
]:
-
eq_(
q.all(),
[
Base.registry.dispose()
def _combinations(fn):
-
return testing.combinations(
(True,), (False,), argnames="include_property"
)(
return [_random_name() for i in range(random.randint(8, 15))]
def _ordered_name_fixture(self, glbls, clsname, base, supercls):
-
names = self._random_names()
if base is supercls:
)
def test_double_w_ac_against_subquery(self):
-
(
users,
orders,
self._run_double_test()
def test_double_w_ac(self):
-
(
users,
orders,
)
with fixture_session() as sess:
-
# load address
a1 = (
sess.query(Address)
)
def test_self_referential_roundtrip(self):
-
place, Place, place_place = (
self.tables.place,
self.classes.Place,
"sqlalchemy.orm.attributes.register_attribute_impl",
side_effect=register_attribute_impl,
) as some_mock:
-
self.mapper(A, users, properties={"bs": relationship(B)})
self.mapper(B, addresses)
eq_(recon, ["A", "B", "C"])
def test_reconstructor_init(self):
-
users = self.tables.users
recon = []
class DocumentTest(fixtures.TestBase):
def setup_test(self):
-
self.mapper = registry().map_imperatively
def test_doc_propagate(self):
self.load_tracker(Address, load)
with fixture_session(expire_on_commit=False) as sess, sess.begin():
-
# set up data and save
u = User(
id=7,
self.load_tracker(Item, load)
with fixture_session(expire_on_commit=False) as sess:
-
i1 = Item()
i1.description = "item 1"
class SubclassRelationshipTest2(
testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest
):
-
run_setup_classes = "once"
run_setup_mappers = "once"
run_inserts = "once"
class SubclassRelationshipTest3(
testing.AssertsCompiledSQL, fixtures.DeclarativeMappedTest
):
-
run_setup_classes = "once"
run_setup_mappers = "once"
run_inserts = "once"
r'does not link from relationship "Company.employees". Did you '
r'mean to use "Company.employees.of_type\(Engineer\)"\?',
):
-
if use_options:
s.query(Company).options(
joinedload(Company.employees).options(
@testing.fixture
def user_address_fixture(self, registry):
-
registry.map_imperatively(
User,
self.tables.users,
class OptionsTest(_Polymorphic):
def test_options_of_type(self):
-
with_poly = with_polymorphic(Person, [Engineer, Manager], flat=True)
for opt, serialized_path, serialized_of_type in [
(
(lambda s, users: select(users.c.id, users.c.name),),
)
def test_legacy_tuple_old_select(self, test_case):
-
User, users = self.classes.User, self.tables.users
self.mapper_registry.map_imperatively(User, users)
@testing.fixture
def assert_row_keys(self):
def go(stmt, expected, coreorm_exec, selected_columns=None):
-
if coreorm_exec == "core":
with testing.db.connect() as conn:
row = conn.execute(stmt).first()
stmt._label_style is not LABEL_STYLE_NONE
and coreorm_exec == "orm"
):
-
for k in expected:
is_not_none(getattr(row, k))
(Query.order_by, lambda meth, User: meth(User.name)),
)
def test_from_statement_text(self, meth, test_case):
-
User = self.classes.User
s = fixture_session()
q = s.query(User)
)
def test_selfref_relationship(self):
-
Node = self.classes.Node
nalias = aliased(Node)
def test_clauses(self):
User, Address = self.classes.User, self.classes.Address
- for (expr, compare) in (
+ for expr, compare in (
(func.max(User.id), "max(users.id)"),
(User.id.desc(), "users.id DESC"),
(
eq_(q.distinct().count(), 3)
def test_cols_future(self):
-
User, Address = self.classes.User, self.classes.Address
s = fixture_session()
Address.special_user == User(id=None, name=None)
)
with expect_warnings("Got None for value of column"):
-
self.assert_compile(
q,
"SELECT addresses.id AS addresses_id, "
)
)
with expect_warnings("Got None for value of column"):
-
self.assert_compile(
q,
"SELECT users.id AS users_id, users.name AS users_name "
@classmethod
def setup_mappers(cls):
-
nodes = cls.tables.nodes
Node = cls.classes.Node
@testing.fixture
def limited_cache_conn(self, connection):
-
connection.engine._compiled_cache.clear()
assert_limit = 0
)
def test_determine_join_ambiguous_fks_m2m(self):
-
self._assert_raises_ambig_join(
relationships.JoinCondition,
"Whatever.foo",
)
)
elif style.from_statement:
-
stmt = (
select(Order.id, Order.description)
.from_statement(stmt)
s = Session(testing.db, future=True)
with self.sql_execution_asserter() as asserter:
-
s.execute(stmt).all()
asserter.assert_(
User, Address = user_address_fixture
def get_statement(closure="name"):
-
stmt = select(User).options(
selectinload(User.addresses),
with_loader_criteria(
User, Address = user_address_fixture
def get_statement(closure="name"):
-
stmt = (
select(User)
.options(
)
with self.sql_execution_asserter() as asserter:
-
s.execute(stmt)
asserter.assert_(
for value in "ed@wood.com", "ed@lala.com":
s.close()
with self.sql_execution_asserter() as asserter:
-
result = go(value)
eq_(
for value in "ed@wood.com", "ed@lala.com":
with self.sql_execution_asserter() as asserter:
-
result = go(value)
eq_(
for value in "ed@wood.com", "ed@lala.com":
with self.sql_execution_asserter() as asserter:
-
result = go(value)
eq_(
for i in range(3):
def go():
-
sess = fixture_session()
u = aliased(User)
self.assert_sql_count(testing.db, go, 2)
def test_from_aliased_w_cache_three(self):
-
User, Dingaling, Address = self.user_dingaling_fixture()
for i in range(3):
)
def test_double_w_ac_against_subquery(self):
-
(
users,
orders,
self._run_double_test()
def test_double_w_ac(self):
-
(
users,
orders,
@classmethod
def insert_data(cls, connection):
-
e1 = Engineer(primary_language="java")
e2 = Engineer(primary_language="c++")
e1.paperwork = [
sess.commit()
def test_one_to_many(self):
-
Company, Programmer, Manager, GolfSwing, Language = self.classes(
"Company", "Programmer", "Manager", "GolfSwing", "Language"
)
.offset(offset)
.options(selectinload(A.bs))
):
-
# this part fails with joined eager loading
# (if you enable joined eager w/ yield_per)
eq_(a.bs, [B(id=(a.id * 6) + j) for j in range(1, 6)])
s.commit()
def test_load_composite_then_non_composite(self):
-
A, B, A2, B2 = self.classes("A", "B", "A2", "B2")
s = fixture_session()
)
def test_double_w_ac_against_subquery(self):
-
(
users,
orders,
self._run_double_test()
def test_double_w_ac(self):
-
(
users,
orders,
@classmethod
def insert_data(cls, connection):
-
e1 = Engineer(primary_language="java")
e2 = Engineer(primary_language="c++")
e1.paperwork = [
self._run_test_m2o(None, False)
def _run_test_m2o(self, director_strategy_level, photo_strategy_level):
-
# test where the innermost is m2o, e.g.
# Movie->director
cache = {}
for i in range(3):
-
subq = (
s.query(B)
.join(B.a)
s.close()
def test_subq_w_from_self_two(self):
-
A, B, C = self.classes("A", "B", "C")
s = fixture_session()
for i in range(3):
def go():
-
subq = s.query(B).join(B.a).subquery()
bq = aliased(B, subq)
assert not connection.in_transaction()
elif external_state.transaction:
-
assert t1 is not None
if (
"do_commit",
side_effect=testing.db.dialect.do_commit,
) as succeed_mock:
-
# sess.begin() -> commit(). why would do_rollback() be called?
# because of connection pool finalize_fairy *after* the commit.
# this will cause the conn.close() in session.commit() to fail,
def subtransaction_recipe_one(self):
@contextlib.contextmanager
def transaction(session):
-
if session.in_transaction():
outermost = False
else:
self.mapper_registry.map_imperatively(User, users)
with fixture_session() as sess:
-
sess.begin()
sess.begin_nested()
with subtransaction_recipe(sess):
-
sess.add(User(name="u1"))
sess.commit()
class TransactionFlagsTest(fixtures.TestBase):
def test_in_transaction(self):
with fixture_session() as s1:
-
eq_(s1.in_transaction(), False)
trans = s1.begin()
def test_in_transaction_nesting(self):
with fixture_session() as s1:
-
eq_(s1.in_transaction(), False)
trans = s1.begin()
@event.listens_for(self.session, "after_transaction_end")
def restart_savepoint(session, transaction):
if transaction.nested and not transaction._parent.nested:
-
# ensure that state is expired the way
# session.commit() at the top level normally does
# (optional step)
self.mapper_registry.map_imperatively(Book, book)
with fixture_session() as sess:
-
b1 = Book(book_id="abc", title="def")
sess.add(b1)
sess.flush()
Data = self.classes.Data
with fixture_session() as sess:
-
d1 = Data(a="hello", b="there")
sess.add(d1)
sess.flush()
@testing.fixture
def selectable_fixture(self, decl_base):
-
t1, t2 = self.tables("test", "test2")
stmt = (
@testing.requires.sequences_as_server_defaults
@testing.requires.insert_returning
def test_b(self, base, run_test):
-
seq = normalize_sequence(config, Sequence("x_seq"))
class A(base):
"""
class Datum(decl_base):
-
__tablename__ = "datum"
datum_id = Column(Integer, Identity(), primary_key=True)
class Result(decl_base):
-
__tablename__ = "result"
if pk_type.plain_autoinc:
)
class ResultDatum(decl_base):
-
__tablename__ = "result_datum"
result_id = Column(ForeignKey(Result.result_id), primary_key=True)
"""
class Datum(decl_base):
-
__tablename__ = "datum"
datum_id = Column(Integer, primary_key=True)
data = Column(String(10))
class Result(decl_base):
-
__tablename__ = "result"
result_id = Column(Integer, primary_key=True)
)
class User(fixtures.ComparableEntity):
-
if need_remove_param:
@validates("addresses", **validate_kw)
return item
class Address(fixtures.ComparableEntity):
-
if need_remove_param:
@validates("user", **validate_kw)
with patch.object(
config.db.dialect, "supports_sane_multi_rowcount", False
), patch("sqlalchemy.engine.cursor.CursorResult.rowcount", rowcount):
-
Foo = self.classes.Foo
s1 = self._fixture()
f1s1 = Foo(value="f1 value")
eq_(f1s1.version_id, 2)
def test_update_delete_no_plain_rowcount(self):
-
with patch.object(
config.db.dialect, "supports_sane_rowcount", False
), patch.object(
NUMBER = 1000000
def init_objects(self):
-
self.object_1 = column("x")
self.object_2 = bindparam("y")
@test_case
def test_apply_non_present(self):
-
self.name.apply_map(self.impl_w_non_present)
@test_case
def test_apply_present(self):
-
self.name.apply_map(self.impl_w_present)
def before_cursor_execute(
conn, cursor, statement, parameters, context, executemany
):
-
nonlocal now
now = time.time()
def runit_query_runs(status, factor=1, query_runs=5):
-
# do some heavier reading
for i in range(query_runs):
status("Heavy query run #%d" % (i + 1))
@property
def sqlite_partial_indexes(self):
-
return only_on(self._sqlite_partial_idx)
@property
)
def test_text_doesnt_explode(self, connection):
-
for s in [
select(
case(
return stmt
def three():
-
a1 = table_a.alias()
a2 = table_a.alias()
ex = exists().where(table_b.c.b == a1.c.a)
fixtures.append(_complex_fixtures)
def _statements_w_context_options_fixtures():
-
return [
select(table_a)._add_context_option(opt1, True),
select(table_a)._add_context_option(opt1, 5),
return anon_col > 5
def three():
-
l1, l2 = table_a.c.a.label(None), table_a.c.b.label(None)
stmt = select(table_a.c.a, table_a.c.b, l1, l2)
)
def test_compare_metadata_tables_annotations_two(self):
-
t1 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
t2 = Table("a", MetaData(), Column("q", Integer), Column("p", Integer))
ne_(t3._generate_cache_key(), t4._generate_cache_key())
def test_compare_comparison_associative(self):
-
l1 = table_c.c.x == table_d.c.y
l2 = table_d.c.y == table_c.c.x
l3 = table_c.c.x == table_d.c.z
is_false(l1.compare(l3))
def test_compare_clauselist_associative(self):
-
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
l2 = and_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
is_false(l1.compare(l3))
def test_compare_clauselist_not_associative(self):
-
l1 = ClauseList(
table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
)
is_false(l1.compare(l2))
def test_compare_clauselist_assoc_different_operator(self):
-
l1 = and_(table_c.c.x == table_d.c.y, table_c.c.y == table_d.c.z)
l2 = or_(table_c.c.y == table_d.c.z, table_c.c.x == table_d.c.y)
is_false(l1.compare(l2))
def test_compare_clauselist_not_assoc_different_operator(self):
-
l1 = ClauseList(
table_c.c.x, table_c.c.y, table_d.c.y, operator=operators.sub
)
assert not hasattr(c1, "__dict__")
def test_compile_label_is_slots(self):
-
c1 = compiler._CompileLabel(column("q"), "somename")
eq_(c1.name, "somename")
)
def test_dupe_columns_use_labels_from_anon(self):
-
t = table("t", column("a"), column("b"))
a = t.alias()
)
def test_literal(self):
-
self.assert_compile(
select(literal("foo")), "SELECT :param_1 AS anon_1"
)
)
def test_over_framespec(self):
-
expr = table1.c.myid
self.assert_compile(
select(func.row_number().over(order_by=expr, rows=(0, None))),
self.assert_compile(stmt, expected, dialect=dialect)
def test_statement_hints(self):
-
stmt = (
select(table1.c.myid)
.with_statement_hint("test hint one")
[5, 6],
),
]:
-
self.assert_compile(
stmt, expected_named_stmt, params=expected_default_params_dict
)
)
def test_bind_anon_name_special_chars_uniqueify_two(self):
-
t = table("t", column("_3foo"), column("4(foo"))
self.assert_compile(
def test_construct_params_combine_extracted(
self, stmt1, stmt2, param1, param2, extparam1, extparam2
):
-
if extparam1:
keys = list(extparam1)
else:
@testing.variation("scalar_subquery", [True, False])
def test_select_in(self, scalar_subquery):
-
stmt = select(table2.c.otherid, table2.c.othername)
if scalar_subquery:
)
def test_dialect_specific_ddl(self):
-
from sqlalchemy.dialects.postgresql import ExcludeConstraint
m = MetaData()
class MyCompiler(compiler.SQLCompiler):
def visit_select(self, stmt, *arg, **kw):
-
if stmt is stmt2.element:
with self._nested_result() as nested:
contexts[stmt2.element] = nested
proxied = [obj[0] for (k, n, obj, type_) in compiled._result_columns]
for orig_obj, proxied_obj in zip(orig, proxied):
-
is_(orig_obj, proxied_obj)
("sometable", "this_name_is_too_long", "ix_sometable_t_09aa"),
("sometable", "this_name_alsois_long", "ix_sometable_t_3cf1"),
]:
-
t1 = Table(
tname, MetaData(), Column(cname, Integer, index=True)
)
class CTETest(fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = "default_enhanced"
def test_nonrecursive(self):
eq_(stmt.compile().isupdate, False)
def test_pg_example_three(self):
-
parts = table("parts", column("part"), column("sub_part"))
included_parts = (
)
def test_textual_select_uses_independent_cte_two(self):
-
foo = table("foo", column("id"))
bar = table("bar", column("id"), column("attr"), column("foo_id"))
s1 = select(foo.c.id)
class NestingCTETest(fixtures.TestBase, AssertsCompiledSQL):
-
__dialect__ = "default_enhanced"
def test_select_with_nesting_cte_in_cte(self):
def test_nesting_cte_in_recursive_cte_positional(
self, nesting_cte_in_recursive_cte
):
-
self.assert_compile(
nesting_cte_in_recursive_cte,
"WITH RECURSIVE rec_cte(outer_cte) AS (WITH nesting AS "
def test_recursive_nesting_cte_in_recursive_cte_positional(
self, recursive_nesting_cte_in_recursive_cte
):
-
self.assert_compile(
recursive_nesting_cte_in_recursive_cte,
"WITH RECURSIVE rec_cte(outer_cte) AS ("
)
def test_literal_binds_pgarray(self):
-
m = MetaData()
t = Table(
"t",
class AutoIncrementTest(fixtures.TestBase):
-
__backend__ = True
@testing.requires.empty_inserts
)
def test_alias_union(self):
-
# same as testunion, except its an alias of the union
u = (
class TableDeprecationTest(fixtures.TestBase):
def test_mustexists(self):
with testing.expect_deprecated("Deprecated alias of .*must_exist"):
-
with testing.expect_raises_message(
exc.InvalidRequestError, "Table 'foo' not defined"
):
self.assert_compile(func.random(), ret, dialect=dialect)
def test_cube_operators(self):
-
t = table(
"t",
column("value"),
assert_raises(TypeError, func.char_length)
def test_return_type_detection(self):
-
for fn in [func.coalesce, func.max, func.min, func.sum]:
for args, type_ in [
(
)
def test_as_comparison(self):
-
fn = func.substring("foo", "foobar").as_comparison(1, 2)
is_(fn.type._type_affinity, Boolean)
)
def test_as_comparison_annotate(self):
-
fn = func.foobar("x", "y", "q", "p", "r").as_comparison(2, 5)
from sqlalchemy.sql import annotation
eq_(fn_annotated.left._annotations, {"token": "yes"})
def test_as_comparison_many_argument(self):
-
fn = func.some_comparison("x", "y", "z", "p", "q", "r").as_comparison(
2, 5
)
eq_(expr.type.dimensions, col.type.dimensions)
def test_array_agg_array_literal_implicit_type(self):
-
expr = array([column("data", Integer), column("d2", Integer)])
assert isinstance(expr.type, PG_ARRAY)
assert "GenericFunction" not in functions._registry["_default"]
def test_register_function(self):
-
# test generic function registering
class registered_func(GenericFunction):
_register = True
)
def test_scalar_subquery(self):
-
a = table(
"a",
column("id"),
)
def test_named_table_valued(self):
-
fn = (
func.json_to_recordset( # noqa
'[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]'
)
def test_named_table_valued_w_quoting(self):
-
fn = (
func.json_to_recordset( # noqa
'[{"CaseSensitive":1,"the % value":"foo"}, '
)
def test_named_table_valued_subquery(self):
-
fn = (
func.json_to_recordset( # noqa
'[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]'
)
def test_named_table_valued_alias(self):
-
"""select * from json_to_recordset
('[{"a":1,"b":"foo"},{"a":"2","c":"bar"}]') as x(a int, b text);"""
),
)
def test_create_ddl(self, identity_args, text):
-
if getattr(
self, "__dialect__", None
) != "default_enhanced" and testing.against("oracle"):
@testing.combinations("sqlite", "mysql", "mariadb", "postgresql", "oracle")
def test_identity_is_ignored(self, dialect):
-
t = Table(
"foo_table",
MetaData(),
table1.c.description,
)
elif column_style == "inspectables":
-
myid, name, description = (
ORMExpr(table1.c.myid),
ORMExpr(table1.c.name),
@testing.requires.multivalues_inserts
@testing.combinations("string", "column", "expect", argnames="keytype")
def test_multivalues_insert(self, connection, keytype):
-
users = self.tables.users
if keytype == "string":
@testing.requires.sql_expressions_inserted_as_primary_key
def test_sql_expr_lastrowid(self, connection):
-
# see also test.orm.test_unitofwork.py
# ClauseAttributesTest.test_insert_pk_expression
t = self.tables.foo_no_seq
argnames="paramtype",
)
def test_page_size_adjustment(self, testing_engine, batchsize, paramtype):
-
t = self.tables.data
if paramtype == "engine" and batchsize is not None:
)
def test_disabled(self, testing_engine):
-
e = testing_engine(
options={"use_insertmanyvalues": False},
share_pool=True,
autoincrement_is_sequence=False,
connection=None,
):
-
if connection:
dialect = connection.dialect
else:
and warn_for_downgrades
and dialect.use_insertmanyvalues
):
-
if (
not separate_sentinel
and (
default_type,
sort_by_parameter_order,
):
-
t1 = Table(
"data",
metadata,
metadata,
connection,
):
-
if pk_type.plain_autoinc:
pk_col = Column("id", Integer, primary_key=True)
elif pk_type.sequence:
metadata.create_all(connection)
return
else:
-
metadata.create_all(connection)
fixtures.insertmanyvalues_fixture(
sentinel_type,
add_sentinel_flag_to_col,
):
-
if sentinel_type.identity:
sentinel_args = [Identity()]
elif sentinel_type.sequence:
metadata.create_all(engine)
with engine.connect() as conn:
-
fixtures.insertmanyvalues_fixture(
conn,
randomize_rows=bool(randomize_returning),
assert not hasattr(Foo(), "__clause_element__")
def test_col_now_has_a_clauseelement(self):
-
x = Column("foo", Integer)
assert hasattr(x, "__clause_element__")
def test_stale_checker_embedded(self):
def go(x):
-
stmt = select(lambda: x)
return stmt
def test_stale_checker_statement(self):
def go(x):
-
stmt = lambdas.lambda_stmt(lambda: select(x))
return stmt
def test_stale_checker_linked(self):
def go(x, y):
-
stmt = lambdas.lambda_stmt(lambda: select(x)) + (
lambda s: s.where(y > 5)
)
)
def test_boolean_conditionals(self):
-
tab = table("foo", column("id"), column("col"))
def run_my_statement(parameter, add_criteria=False):
ne_(s1key[0], s2key[0])
def test_stmt_lambda_w_set_of_opts(self):
-
stmt = lambdas.lambda_stmt(lambda: select(column("x")))
class MyUncacheable(ExecutableOption):
)
def test_in_parameters_one(self):
-
expr1 = select(1).where(column("q").in_(["a", "b", "c"]))
self.assert_compile(expr1, "SELECT 1 WHERE q IN (__[POSTCOMPILE_q_1])")
x = 5
def my_lambda():
-
y = 10
z = y + 18
z = 10
def my_lambda():
-
y = x + z
expr1 = users.c.name > x
z = 10
def my_lambda():
-
y = 10 + z
expr1 = users.c.name > x
)
def test_fk_mismatched_local_remote_cols(self):
-
assert_raises_message(
exc.ArgumentError,
"ForeignKeyConstraint number of constrained columns must "
@emits_warning("Table '.+' already exists within the given MetaData")
def test_already_exists(self):
-
meta1 = MetaData()
table1 = Table(
"mytable", meta1, Column("myid", Integer, primary_key=True)
)
def test_reset_exported_passes(self):
-
m = MetaData()
t = Table("t", m, Column("foo", Integer))
"Table 't' specifies columns 'a', 'b', 'c' as primary_key=True, "
"not matching locally specified columns 'b', 'c'"
):
-
Table(
"t",
m,
self._test_before_parent_attach(typ)
def test_before_parent_attach_variant_array_schematype(self):
-
target = Enum("one", "two", "three")
typ = ARRAY(target).with_variant(String(), "other")
self._test_before_parent_attach(typ, evt_target=target)
eq_ignore_whitespace(str(element), expected)
def test_create_drop_schema(self):
-
self.assert_compile(
schema.CreateSchema("sa_schema"), "CREATE SCHEMA sa_schema"
)
def test_table_w_two_same_named_columns(
self, empty_meta, scenario: Variation, both_have_keys: Variation
):
-
if scenario.inplace:
with expect_raises_message(
exc.DuplicateColumnError,
assert c in t.indexes
def test_auto_append_lowercase_table(self):
-
t = table("t", column("a"))
t2 = table("t2", column("a"))
for c in (
return t, ClauseElement(t.c.q)
def test_pickle_fk_annotated_col(self, no_pickle_annotated):
-
t, q_col = no_pickle_annotated
t2 = Table("t2", t.metadata, Column("p", ForeignKey(q_col)))
assert col.name == c[i].name
def test_name_none(self):
-
c = Column(Integer)
assert_raises_message(
exc.ArgumentError,
)
def test_name_blank(self):
-
c = Column("", Integer)
assert_raises_message(
exc.ArgumentError,
paramname,
value,
):
-
args = []
params = {}
if paramname == "type" or isinstance(
value,
override_value,
):
-
args = []
params = {}
override_args = []
self._no_error(Column("foo", ForeignKey("bar.id"), Sequence("a")))
def test_column_info(self):
-
c1 = Column("foo", String, info={"x": "y"})
c2 = Column("bar", String, info={})
c3 = Column("bat", String)
)
def test_subquery_four(self):
-
# Not lower case names, quotes off, should not quote
metadata = MetaData()
t1 = Table(
is_true(coerced.compare(stmt.scalar_subquery().label(None)))
def test_scalar_select(self):
-
with testing.expect_warnings(
"implicitly coercing SELECT object to scalar subquery"
):
eq_(list(coll), [table1.c.name, table1.c.description])
def test_missing_key(self):
-
with expect_raises_message(KeyError, "unknown"):
table1.c["myid", "unknown"]
def test_missing_index(self):
-
with expect_raises_message(IndexError, "5"):
table1.c["myid", 5]
)
def test_join_against_join(self):
-
j = outerjoin(table1, table2, table1.c.col1 == table2.c.col2)
jj = (
select(table1.c.col1.label("bar_col1"))
)
def test_union_correspondence(self):
-
# tests that we can correspond a column in a Select statement
# with a certain Table, against a column in a Union where one of
# its underlying Selects matches to that same Table
assert u.selected_columns.col3 is not None
def test_alias_union(self):
-
# same as testunion, except its an alias of the union
u = (
assert u1.corresponding_column(table2.c.col3) is u1.c._all_columns[2]
def test_select_union(self):
-
# like testaliasunion, but off a Select off the union.
u = (
assert s.corresponding_column(s2.c.table2_col2) is s.c.col2
def test_union_against_join(self):
-
# same as testunion, except its an alias of the union
u = (
t1t2 = t1.join(t2)
t2t3 = t2.join(t3)
- for (left, right, a_subset, expected) in [
+ for left, right, a_subset, expected in [
(t1, t2, None, t1.c.id == t2.c.t1id),
(t1t2, t3, t2, t1t2.c.t2_id == t3.c.t2id),
(t2t3, t1, t3, t1.c.id == t3.c.t1id),
lambda s: visitors.cloned_traverse(s, {}, {}),
lambda s: visitors.replacement_traverse(s, {}, lambda x: None),
):
-
sel = fn(select(fn(select(fn(s.subquery())).subquery())))
eq_(str(assert_s), str(sel))
)
def test_unary_boolean(self):
-
s1 = select(not_(True)).set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
eq_(
[type(entry[-1]) for entry in s1.compile()._result_columns],
self._assert_seq_result(connection.scalar(s))
def test_execute_deprecated(self, connection):
-
s = normalize_sequence(config, Sequence("my_sequence", optional=True))
with expect_deprecated(
),
)
def test_select_composition_nine(self, label_style, expected):
-
s1 = select(table1.c.myid, text("whatever"))
if label_style:
s1 = s1.set_label_style(label_style)
),
)
def test_select_composition_ten(self, label_style, expected):
-
s1 = select(table1.c.myid, text("whatever"))
if label_style:
s1 = s1.set_label_style(label_style)
),
)
def test_select_composition_eleven(self, label_style, expected):
-
stmt = select(table1.c.myid, text("whatever"))
if label_style:
stmt = stmt.set_label_style(label_style)
),
)
def test_select_selected_columns_ignores_text(self, label_style, expected):
-
stmt = select(table1.c.myid, text("whatever"), table1.c.description)
if label_style:
stmt = stmt.set_label_style(label_style)
)
def test_text_in_select_nonfrom(self):
-
generate_series = text(
"generate_series(:x, :y, :z) as s(a)"
).bindparams(x=None, y=None, z=None)
(column("q").op("+")(5).label("a"), "a DESC", (desc,)),
)
def test_order_by_expr(self, case, expected, modifiers):
-
order_by = case
for mod in modifiers:
order_by = mod(order_by)
self._test_exception(stmt, "foobar")
def test_distinct_label(self):
-
stmt = select(table1.c.myid.label("foo")).distinct("foo")
self.assert_compile(
stmt,
)
def test_distinct_label_keyword(self):
-
stmt = select(table1.c.myid.label("foo")).distinct("foo")
self.assert_compile(
stmt,
def _fixture(self):
class MyString(String):
-
# supersedes any processing that might be on
# String
def bind_expression(self, bindvalue):
def _adaptions():
for typ in _all_types(omit_special_types=True):
-
# up adapt from LowerCase to UPPERCASE,
# as well as to all non-sqltypes
up_adaptions = [typ] + typ.__subclasses__()
class UserDefinedTest(
_UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL
):
-
run_create_tables = None
run_inserts = None
run_deletes = None
class IntervalTest(fixtures.TablesTest, AssertsExecutionResults):
-
__backend__ = True
@classmethod
)
def test_labels_no_collision(self):
-
t = table("foo", column("id"), column("foo_id"))
self.assert_compile(
)
if paramstyle.qmark:
-
dialect = default.StrCompileDialect(paramstyle="qmark")
self.assert_compile(
upd,
),
)
def test_unwrap_order_by(self, expr, expected):
-
expr = coercions.expect(roles.OrderByRole, expr)
unwrapped = sql_util.unwrap_order_by(expr)
disable_format = False
for line_no, line in enumerate(original.splitlines(), 1):
-
if (
line
and not disable_format
attributes: Iterable[str],
cls: Type[Any],
):
-
sphinx_symbol_match = re.match(r":class:`(.+)`", target_cls_sphinx_name)
if not sphinx_symbol_match:
raise Exception(
def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
-
class_entries = classes[modname]
# use tempfile in same path as the module, or at least in the
delete=False,
suffix=".py",
) as buf, open(filename) as orig_py:
-
in_block = False
current_clsname = None
for line in orig_py:
def run_module(modname: str, cmd: code_writer_cmd) -> None:
-
cmd.write_status(f"importing module {modname}\n")
mod = importlib.import_module(modname)
destination_path = mod.__file__
]
if __name__ == "__main__":
-
cmd = code_writer_cmd(__file__)
with cmd.add_arguments() as parser:
def process_functions(filename: str, cmd: code_writer_cmd) -> str:
-
with NamedTemporaryFile(
mode="w",
delete=False,
if __name__ == "__main__":
-
cmd = code_writer_cmd(__file__)
with cmd.run_program():
def process_module(modname: str, filename: str, cmd: code_writer_cmd) -> str:
-
# use tempfile in same path as the module, or at least in the
# current working directory, so that black / zimports use
# local pyproject.toml
def run_module(modname: str, cmd: code_writer_cmd) -> None:
-
cmd.write_status(f"importing module {modname}\n")
mod = importlib.import_module(modname)
destination_path = mod.__file__
]
if __name__ == "__main__":
-
cmd = code_writer_cmd(__file__)
with cmd.add_arguments() as parser:
def run_operation(
name: str, source: str, dest: str, cmd: code_writer_cmd
) -> None:
-
source_data = Path(source).read_text().replace(remove_str, "")
dest_data = header.format(source=source, this_file=this_file) + source_data
[testenv:lint]
basepython = python3
deps=
- flake8==5.0.0
- #flake8-import-order
- git+https://github.com/sqlalchemyorg/flake8-import-order@fix_options
+ flake8==6.0.0
+ flake8-import-order
flake8-builtins
flake8-future-annotations>=0.0.5
flake8-docstrings>=1.6.0
# in case it requires a version pin
pydocstyle
pygments
- black==22.8.0
+ black==23.3.0
slotscheck>=0.12,<0.13
# this is to satisfy the mypy plugin dependency