From: Federico Caselli Date: Mon, 29 Jan 2024 20:16:02 +0000 (+0100) Subject: Update black to 24.1.1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9b153ff18f12eab7b74a20ce53538666600f8bbf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Update black to 24.1.1 Change-Id: Iadaea7b798d8e99302e1acb430dc7b758ca61137 --- diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f169100aa6..d523c0499a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ # See https://pre-commit.com/hooks.html for more hooks repos: - repo: https://github.com/python/black - rev: 23.3.0 + rev: 24.1.1 hooks: - id: black diff --git a/doc/build/changelog/migration_05.rst b/doc/build/changelog/migration_05.rst index d26a22c0d0..8b48f13f6b 100644 --- a/doc/build/changelog/migration_05.rst +++ b/doc/build/changelog/migration_05.rst @@ -443,8 +443,7 @@ Schema/Types :: - class MyType(AdaptOldConvertMethods, TypeEngine): - ... + class MyType(AdaptOldConvertMethods, TypeEngine): ... * The ``quote`` flag on ``Column`` and ``Table`` as well as the ``quote_schema`` flag on ``Table`` now control quoting @@ -589,8 +588,7 @@ Removed :: class MyQuery(Query): - def get(self, ident): - ... + def get(self, ident): ... session = sessionmaker(query_cls=MyQuery)() diff --git a/doc/build/changelog/migration_08.rst b/doc/build/changelog/migration_08.rst index 0f661cca79..7b42aae474 100644 --- a/doc/build/changelog/migration_08.rst +++ b/doc/build/changelog/migration_08.rst @@ -1394,8 +1394,7 @@ yet, we'll be adding the ``inspector`` argument into it directly:: @event.listens_for(Table, "column_reflect") - def listen_for_col(inspector, table, column_info): - ... + def listen_for_col(inspector, table, column_info): ... :ticket:`2418` diff --git a/doc/build/changelog/migration_14.rst b/doc/build/changelog/migration_14.rst index ae93003ae6..aef07864d6 100644 --- a/doc/build/changelog/migration_14.rst +++ b/doc/build/changelog/migration_14.rst @@ -552,8 +552,7 @@ SQLAlchemy has for a long time used a parameter-injecting decorator to help reso mutually-dependent module imports, like this:: @util.dependency_for("sqlalchemy.sql.dml") - def insert(self, dml, *args, **kw): - ... + def insert(self, dml, *args, **kw): ... Where the above function would be rewritten to no longer have the ``dml`` parameter on the outside. This would confuse code-linting tools into seeing a missing parameter @@ -2274,8 +2273,7 @@ in any way:: addresses = relationship(Address, backref=backref("user", viewonly=True)) - class Address(Base): - ... + class Address(Base): ... u1 = session.query(User).filter_by(name="x").first() diff --git a/doc/build/core/connections.rst b/doc/build/core/connections.rst index 994daa8f54..1de53fdc85 100644 --- a/doc/build/core/connections.rst +++ b/doc/build/core/connections.rst @@ -1490,10 +1490,8 @@ Basic guidelines include: def my_stmt(parameter, thing=False): stmt = lambda_stmt(lambda: select(table)) - stmt += ( - lambda s: s.where(table.c.x > parameter) - if thing - else s.where(table.c.y == parameter) + stmt += lambda s: ( + s.where(table.c.x > parameter) if thing else s.where(table.c.y == parameter) ) return stmt diff --git a/doc/build/errors.rst b/doc/build/errors.rst index 48fdedeace..55ac40ae5f 100644 --- a/doc/build/errors.rst +++ b/doc/build/errors.rst @@ -1777,8 +1777,7 @@ and associating the :class:`_engine.Engine` with the Base = declarative_base(metadata=metadata_obj) - class MyClass(Base): - ... + class MyClass(Base): ... session = Session() @@ -1796,8 +1795,7 @@ engine:: Base = declarative_base() - class MyClass(Base): - ... + class MyClass(Base): ... session = Session() diff --git a/doc/build/orm/basic_relationships.rst b/doc/build/orm/basic_relationships.rst index 7e3ce5ec55..0860f69fcf 100644 --- a/doc/build/orm/basic_relationships.rst +++ b/doc/build/orm/basic_relationships.rst @@ -1116,15 +1116,13 @@ class were available, we could also apply it afterwards:: # we create a Parent class which knows nothing about Child - class Parent(Base): - ... + class Parent(Base): ... # ... later, in Module B, which is imported after module A: - class Child(Base): - ... + class Child(Base): ... from module_a import Parent diff --git a/doc/build/orm/collection_api.rst b/doc/build/orm/collection_api.rst index eff6d87cb4..b256af92a1 100644 --- a/doc/build/orm/collection_api.rst +++ b/doc/build/orm/collection_api.rst @@ -533,8 +533,7 @@ methods can be changed as well: ... @collection.iterator - def hey_use_this_instead_for_iteration(self): - ... + def hey_use_this_instead_for_iteration(self): ... There is no requirement to be "list-like" or "set-like" at all. Collection classes can be any shape, so long as they have the append, remove and iterate diff --git a/doc/build/orm/extensions/mypy.rst b/doc/build/orm/extensions/mypy.rst index 042af37091..8275e94866 100644 --- a/doc/build/orm/extensions/mypy.rst +++ b/doc/build/orm/extensions/mypy.rst @@ -179,8 +179,7 @@ following:: ) name: Mapped[Optional[str]] = Mapped._special_method(Column(String)) - def __init__(self, id: Optional[int] = ..., name: Optional[str] = ...) -> None: - ... + def __init__(self, id: Optional[int] = ..., name: Optional[str] = ...) -> None: ... some_user = User(id=5, name="user") diff --git a/doc/build/orm/inheritance.rst b/doc/build/orm/inheritance.rst index 574b4fc739..3764270d8c 100644 --- a/doc/build/orm/inheritance.rst +++ b/doc/build/orm/inheritance.rst @@ -203,12 +203,10 @@ and ``Employee``:: } - class Manager(Employee): - ... + class Manager(Employee): ... - class Engineer(Employee): - ... + class Engineer(Employee): ... If the foreign key constraint is on a table corresponding to a subclass, the relationship should target that subclass instead. In the example @@ -248,8 +246,7 @@ established between the ``Manager`` and ``Company`` classes:: } - class Engineer(Employee): - ... + class Engineer(Employee): ... Above, the ``Manager`` class will have a ``Manager.company`` attribute; ``Company`` will have a ``Company.managers`` attribute that always diff --git a/doc/build/orm/persistence_techniques.rst b/doc/build/orm/persistence_techniques.rst index 982f27ebdc..69fad33b22 100644 --- a/doc/build/orm/persistence_techniques.rst +++ b/doc/build/orm/persistence_techniques.rst @@ -713,20 +713,16 @@ connections:: pass - class User(BaseA): - ... + class User(BaseA): ... - class Address(BaseA): - ... + class Address(BaseA): ... - class GameInfo(BaseB): - ... + class GameInfo(BaseB): ... - class GameStats(BaseB): - ... + class GameStats(BaseB): ... Session = sessionmaker() diff --git a/examples/asyncio/async_orm.py b/examples/asyncio/async_orm.py index 592323be42..daf810c65d 100644 --- a/examples/asyncio/async_orm.py +++ b/examples/asyncio/async_orm.py @@ -2,6 +2,7 @@ for asynchronous ORM use. """ + from __future__ import annotations import asyncio diff --git a/examples/asyncio/async_orm_writeonly.py b/examples/asyncio/async_orm_writeonly.py index 263c0d2919..8ddc0ecdb2 100644 --- a/examples/asyncio/async_orm_writeonly.py +++ b/examples/asyncio/async_orm_writeonly.py @@ -2,6 +2,7 @@ of ORM collections under asyncio. """ + from __future__ import annotations import asyncio diff --git a/examples/asyncio/basic.py b/examples/asyncio/basic.py index 6cfa9ed014..5994fc765e 100644 --- a/examples/asyncio/basic.py +++ b/examples/asyncio/basic.py @@ -6,7 +6,6 @@ within a coroutine. """ - import asyncio from sqlalchemy import Column diff --git a/examples/custom_attributes/custom_management.py b/examples/custom_attributes/custom_management.py index aa9ea7a689..da22ee3276 100644 --- a/examples/custom_attributes/custom_management.py +++ b/examples/custom_attributes/custom_management.py @@ -9,6 +9,7 @@ descriptors with a user-defined system. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/dogpile_caching/caching_query.py b/examples/dogpile_caching/caching_query.py index b184863156..8c85d74811 100644 --- a/examples/dogpile_caching/caching_query.py +++ b/examples/dogpile_caching/caching_query.py @@ -19,6 +19,7 @@ The rest of what's here are standard SQLAlchemy and dogpile.cache constructs. """ + from dogpile.cache.api import NO_VALUE from sqlalchemy import event @@ -28,7 +29,6 @@ from sqlalchemy.orm.interfaces import UserDefinedOption class ORMCache: - """An add-on for an ORM :class:`.Session` optionally loads full results from a dogpile cache region. diff --git a/examples/dogpile_caching/environment.py b/examples/dogpile_caching/environment.py index 4b5a317917..4962826280 100644 --- a/examples/dogpile_caching/environment.py +++ b/examples/dogpile_caching/environment.py @@ -2,6 +2,7 @@ bootstrap fixture data if necessary. """ + from hashlib import md5 import os diff --git a/examples/dogpile_caching/fixture_data.py b/examples/dogpile_caching/fixture_data.py index 8387a2cb27..775fb63b1a 100644 --- a/examples/dogpile_caching/fixture_data.py +++ b/examples/dogpile_caching/fixture_data.py @@ -3,6 +3,7 @@ a few US/Canadian cities. Then, 100 Person records are installed, each with a randomly selected postal code. """ + import random from .environment import Base diff --git a/examples/dogpile_caching/model.py b/examples/dogpile_caching/model.py index cae2ae2776..926a5fa5d6 100644 --- a/examples/dogpile_caching/model.py +++ b/examples/dogpile_caching/model.py @@ -7,6 +7,7 @@ PostalCode --(has a)--> City City --(has a)--> Country """ + from sqlalchemy import Column from sqlalchemy import ForeignKey from sqlalchemy import Integer diff --git a/examples/dogpile_caching/relationship_caching.py b/examples/dogpile_caching/relationship_caching.py index 058d552225..a5b654b06c 100644 --- a/examples/dogpile_caching/relationship_caching.py +++ b/examples/dogpile_caching/relationship_caching.py @@ -6,6 +6,7 @@ related PostalCode, City, Country objects should be pulled from long term cache. """ + import os from sqlalchemy import select diff --git a/examples/generic_associations/discriminator_on_association.py b/examples/generic_associations/discriminator_on_association.py index f0f1d7ed99..93c1b29ef9 100644 --- a/examples/generic_associations/discriminator_on_association.py +++ b/examples/generic_associations/discriminator_on_association.py @@ -15,6 +15,7 @@ it uses a fixed number of tables to serve any number of potential parent objects, but is also slightly more complex. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/generic_associations/generic_fk.py b/examples/generic_associations/generic_fk.py index 5c70f93aac..d45166d333 100644 --- a/examples/generic_associations/generic_fk.py +++ b/examples/generic_associations/generic_fk.py @@ -17,6 +17,7 @@ queued up, here it is. The author recommends "table_per_related" or "table_per_association" instead of this approach. """ + from sqlalchemy import and_ from sqlalchemy import Column from sqlalchemy import create_engine diff --git a/examples/generic_associations/table_per_association.py b/examples/generic_associations/table_per_association.py index 2e412869f0..04786bd49b 100644 --- a/examples/generic_associations/table_per_association.py +++ b/examples/generic_associations/table_per_association.py @@ -11,6 +11,7 @@ has no dependency on the system. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/generic_associations/table_per_related.py b/examples/generic_associations/table_per_related.py index 5b83e6e68f..23c75b0b9d 100644 --- a/examples/generic_associations/table_per_related.py +++ b/examples/generic_associations/table_per_related.py @@ -16,6 +16,7 @@ but there really isn't any - the management and targeting of these tables is completely automated. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import ForeignKey diff --git a/examples/inheritance/concrete.py b/examples/inheritance/concrete.py index f7f6b3ac64..e718e2fc35 100644 --- a/examples/inheritance/concrete.py +++ b/examples/inheritance/concrete.py @@ -1,4 +1,5 @@ """Concrete-table (table-per-class) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/inheritance/joined.py b/examples/inheritance/joined.py index 7dee935fab..c2ba6942cc 100644 --- a/examples/inheritance/joined.py +++ b/examples/inheritance/joined.py @@ -1,4 +1,5 @@ """Joined-table (table-per-subclass) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/inheritance/single.py b/examples/inheritance/single.py index 8da75dd7c4..6337bb4b2e 100644 --- a/examples/inheritance/single.py +++ b/examples/inheritance/single.py @@ -1,4 +1,5 @@ """Single-table (table-per-hierarchy) inheritance example.""" + from __future__ import annotations from typing import Annotated diff --git a/examples/materialized_paths/materialized_paths.py b/examples/materialized_paths/materialized_paths.py index f458270c72..19d3ed491c 100644 --- a/examples/materialized_paths/materialized_paths.py +++ b/examples/materialized_paths/materialized_paths.py @@ -26,6 +26,7 @@ already stored in the path itself. Updates require going through all descendants and changing the prefix. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import func diff --git a/examples/performance/__init__.py b/examples/performance/__init__.py index 7e24b9b8fd..34db251e5c 100644 --- a/examples/performance/__init__.py +++ b/examples/performance/__init__.py @@ -205,6 +205,7 @@ We can run our new script directly:: """ # noqa + import argparse import cProfile import gc diff --git a/examples/performance/bulk_updates.py b/examples/performance/bulk_updates.py index 8b782353df..de5e6dc27d 100644 --- a/examples/performance/bulk_updates.py +++ b/examples/performance/bulk_updates.py @@ -3,6 +3,7 @@ of rows in bulk (under construction! there's just one test at the moment) """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import Identity diff --git a/examples/performance/large_resultsets.py b/examples/performance/large_resultsets.py index b93459150e..3617141127 100644 --- a/examples/performance/large_resultsets.py +++ b/examples/performance/large_resultsets.py @@ -13,6 +13,7 @@ full blown ORM doesn't do terribly either even though mapped objects provide a huge amount of functionality. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import Identity diff --git a/examples/performance/short_selects.py b/examples/performance/short_selects.py index 553c2fed5f..bc6a9c79ac 100644 --- a/examples/performance/short_selects.py +++ b/examples/performance/short_selects.py @@ -3,6 +3,7 @@ record by primary key """ + import random from sqlalchemy import bindparam diff --git a/examples/performance/single_inserts.py b/examples/performance/single_inserts.py index 904fda2d03..4b8132c50a 100644 --- a/examples/performance/single_inserts.py +++ b/examples/performance/single_inserts.py @@ -4,6 +4,7 @@ within a distinct transaction, and afterwards returns to essentially a a database connection, inserts the row, commits and closes. """ + from sqlalchemy import bindparam from sqlalchemy import Column from sqlalchemy import create_engine diff --git a/examples/sharding/asyncio.py b/examples/sharding/asyncio.py index 4b32034c9f..a63b0fcaaa 100644 --- a/examples/sharding/asyncio.py +++ b/examples/sharding/asyncio.py @@ -8,6 +8,7 @@ in exactly the same way. The main change is how the the routine that generates new primary keys. """ + from __future__ import annotations import asyncio diff --git a/examples/sharding/separate_databases.py b/examples/sharding/separate_databases.py index f836aaec00..9a700734c5 100644 --- a/examples/sharding/separate_databases.py +++ b/examples/sharding/separate_databases.py @@ -1,4 +1,5 @@ """Illustrates sharding using distinct SQLite databases.""" + from __future__ import annotations import datetime diff --git a/examples/sharding/separate_schema_translates.py b/examples/sharding/separate_schema_translates.py index 095ae1cc69..fd754356e5 100644 --- a/examples/sharding/separate_schema_translates.py +++ b/examples/sharding/separate_schema_translates.py @@ -4,6 +4,7 @@ where a different "schema_translates_map" can be used for each shard. In this example we will set a "shard id" at all times. """ + from __future__ import annotations import datetime diff --git a/examples/sharding/separate_tables.py b/examples/sharding/separate_tables.py index 1caaaf329b..3084e9f069 100644 --- a/examples/sharding/separate_tables.py +++ b/examples/sharding/separate_tables.py @@ -1,5 +1,6 @@ """Illustrates sharding using a single SQLite database, that will however have multiple tables using a naming convention.""" + from __future__ import annotations import datetime diff --git a/examples/versioned_rows/versioned_rows.py b/examples/versioned_rows/versioned_rows.py index 96d2e399ec..80803b3932 100644 --- a/examples/versioned_rows/versioned_rows.py +++ b/examples/versioned_rows/versioned_rows.py @@ -3,6 +3,7 @@ an UPDATE statement on a single row into an INSERT statement, so that a new row is inserted with the new data, keeping the old row intact. """ + from sqlalchemy import Column from sqlalchemy import create_engine from sqlalchemy import event diff --git a/examples/versioned_rows/versioned_rows_w_versionid.py b/examples/versioned_rows/versioned_rows_w_versionid.py index fcf8082814..d030ed065c 100644 --- a/examples/versioned_rows/versioned_rows_w_versionid.py +++ b/examples/versioned_rows/versioned_rows_w_versionid.py @@ -6,6 +6,7 @@ This example adds a numerical version_id to the Versioned class as well as the ability to see which row is the most "current" version. """ + from sqlalchemy import Boolean from sqlalchemy import Column from sqlalchemy import create_engine diff --git a/lib/sqlalchemy/connectors/asyncio.py b/lib/sqlalchemy/connectors/asyncio.py index 5126a46608..5add8e4a12 100644 --- a/lib/sqlalchemy/connectors/asyncio.py +++ b/lib/sqlalchemy/connectors/asyncio.py @@ -36,17 +36,13 @@ class AsyncIODBAPIConnection(Protocol): """ - async def close(self) -> None: - ... + async def close(self) -> None: ... - async def commit(self) -> None: - ... + async def commit(self) -> None: ... - def cursor(self) -> AsyncIODBAPICursor: - ... + def cursor(self) -> AsyncIODBAPICursor: ... - async def rollback(self) -> None: - ... + async def rollback(self) -> None: ... class AsyncIODBAPICursor(Protocol): @@ -56,8 +52,7 @@ class AsyncIODBAPICursor(Protocol): """ - def __aenter__(self) -> Any: - ... + def __aenter__(self) -> Any: ... @property def description( @@ -67,52 +62,41 @@ class AsyncIODBAPICursor(Protocol): ... @property - def rowcount(self) -> int: - ... + def rowcount(self) -> int: ... arraysize: int lastrowid: int - async def close(self) -> None: - ... + async def close(self) -> None: ... async def execute( self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: - ... + ) -> Any: ... async def executemany( self, operation: Any, parameters: _DBAPIMultiExecuteParams, - ) -> Any: - ... + ) -> Any: ... - async def fetchone(self) -> Optional[Any]: - ... + async def fetchone(self) -> Optional[Any]: ... - async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: - ... + async def fetchmany(self, size: Optional[int] = ...) -> Sequence[Any]: ... - async def fetchall(self) -> Sequence[Any]: - ... + async def fetchall(self) -> Sequence[Any]: ... - async def setinputsizes(self, sizes: Sequence[Any]) -> None: - ... + async def setinputsizes(self, sizes: Sequence[Any]) -> None: ... - def setoutputsize(self, size: Any, column: Any) -> None: - ... + def setoutputsize(self, size: Any, column: Any) -> None: ... async def callproc( self, procname: str, parameters: Sequence[Any] = ... - ) -> Any: - ... + ) -> Any: ... - async def nextset(self) -> Optional[bool]: - ... + async def nextset(self) -> Optional[bool]: ... class AsyncAdapt_dbapi_cursor: diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index 7e1cd3afe8..f204d80a8e 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -217,9 +217,11 @@ class PyODBCConnector(Connector): cursor.setinputsizes( [ - (dbtype, None, None) - if not isinstance(dbtype, tuple) - else dbtype + ( + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + ) for key, dbtype, sqltype in list_of_tuples ] ) diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index e015dccdc9..9f5b010dd7 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1426,7 +1426,6 @@ class ROWVERSION(TIMESTAMP): class NTEXT(sqltypes.UnicodeText): - """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" @@ -1596,12 +1595,12 @@ class UNIQUEIDENTIFIER(sqltypes.Uuid[sqltypes._UUID_RETURN]): @overload def __init__( self: UNIQUEIDENTIFIER[_python_UUID], as_uuid: Literal[True] = ... - ): - ... + ): ... @overload - def __init__(self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ...): - ... + def __init__( + self: UNIQUEIDENTIFIER[str], as_uuid: Literal[False] = ... + ): ... def __init__(self, as_uuid: bool = True): """Construct a :class:`_mssql.UNIQUEIDENTIFIER` type. @@ -2483,10 +2482,12 @@ class MSSQLCompiler(compiler.SQLCompiler): type_expression = "ELSE CAST(JSON_VALUE(%s, %s) AS %s)" % ( self.process(binary.left, **kw), self.process(binary.right, **kw), - "FLOAT" - if isinstance(binary.type, sqltypes.Float) - else "NUMERIC(%s, %s)" - % (binary.type.precision, binary.type.scale), + ( + "FLOAT" + if isinstance(binary.type, sqltypes.Float) + else "NUMERIC(%s, %s)" + % (binary.type.precision, binary.type.scale) + ), ) elif binary.type._type_affinity is sqltypes.Boolean: # the NULL handling is particularly weird with boolean, so @@ -2522,7 +2523,6 @@ class MSSQLCompiler(compiler.SQLCompiler): class MSSQLStrictCompiler(MSSQLCompiler): - """A subclass of MSSQLCompiler which disables the usage of bind parameters where not allowed natively by MS-SQL. diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 1177163883..0c5f2372de 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -207,6 +207,7 @@ class NumericSqlVariant(TypeDecorator): int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the correct value as string. """ + impl = Unicode cache_ok = True diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index f27dee1bd5..76ea046de9 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -369,7 +369,6 @@ from ...engine import cursor as _cursor class _ms_numeric_pyodbc: - """Turns Decimals with adjusted() < 0 or > 7 into strings. The routines here are needed for older pyodbc versions diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 6b8b2e4b18..af1a030ced 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1945,17 +1945,19 @@ class MySQLDDLCompiler(compiler.DDLCompiler): columns = [ self.sql_compiler.process( - elements.Grouping(expr) - if ( - isinstance(expr, elements.BinaryExpression) - or ( - isinstance(expr, elements.UnaryExpression) - and expr.modifier - not in (operators.desc_op, operators.asc_op) + ( + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) ) - or isinstance(expr, functions.FunctionElement) - ) - else expr, + else expr + ), include_table=False, literal_binds=True, ) @@ -1984,12 +1986,14 @@ class MySQLDDLCompiler(compiler.DDLCompiler): # mapping specifying the prefix length for each column of the # index columns = ", ".join( - "%s(%d)" % (expr, length[col.name]) - if col.name in length - else ( - "%s(%d)" % (expr, length[expr]) - if expr in length - else "%s" % expr + ( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr + ) ) for col, expr in zip(index.expressions, columns) ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 68c9928919..4540e00b6a 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -1479,9 +1479,9 @@ class OracleDialect(default.DefaultDialect): self.use_ansi = use_ansi self.optimize_limits = optimize_limits self.exclude_tablespaces = exclude_tablespaces - self.enable_offset_fetch = ( - self._supports_offset_fetch - ) = enable_offset_fetch + self.enable_offset_fetch = self._supports_offset_fetch = ( + enable_offset_fetch + ) def initialize(self, connection): super().initialize(connection) @@ -2538,10 +2538,12 @@ class OracleDialect(default.DefaultDialect): return ( ( (schema, self.normalize_name(table)), - {"text": comment} - if comment is not None - and not comment.startswith(ignore_mat_view) - else default(), + ( + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default() + ), ) for table, comment in result ) @@ -3083,9 +3085,11 @@ class OracleDialect(default.DefaultDialect): table_uc[constraint_name] = uc = { "name": constraint_name, "column_names": [], - "duplicates_index": constraint_name - if constraint_name_orig in index_names - else None, + "duplicates_index": ( + constraint_name + if constraint_name_orig in index_names + else None + ), } else: uc = table_uc[constraint_name] @@ -3097,9 +3101,11 @@ class OracleDialect(default.DefaultDialect): return ( ( key, - list(unique_cons[key].values()) - if key in unique_cons - else default(), + ( + list(unique_cons[key].values()) + if key in unique_cons + else default() + ), ) for key in ( (schema, self.normalize_name(obj_name)) @@ -3222,9 +3228,11 @@ class OracleDialect(default.DefaultDialect): return ( ( key, - check_constraints[key] - if key in check_constraints - else default(), + ( + check_constraints[key] + if key in check_constraints + else default() + ), ) for key in ( (schema, self.normalize_name(obj_name)) diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 69ee82bd23..9346224664 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -840,9 +840,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): ) for param in self.parameters: - param[ - quoted_bind_names.get(name, name) - ] = out_parameters[name] + param[quoted_bind_names.get(name, name)] = ( + out_parameters[name] + ) def _generate_cursor_outputtype_handler(self): output_handlers = {} diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 9e81e8368c..e88c27d2de 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -46,7 +46,6 @@ def All(other, arrexpr, operator=operators.eq): class array(expression.ExpressionClauseList[_T]): - """A PostgreSQL ARRAY literal. This is used to produce ARRAY literals in SQL expressions, e.g.:: @@ -110,17 +109,17 @@ class array(expression.ExpressionClauseList[_T]): main_type = ( type_arg if type_arg is not None - else self._type_tuple[0] - if self._type_tuple - else sqltypes.NULLTYPE + else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE ) if isinstance(main_type, ARRAY): self.type = ARRAY( main_type.item_type, - dimensions=main_type.dimensions + 1 - if main_type.dimensions is not None - else 2, + dimensions=( + main_type.dimensions + 1 + if main_type.dimensions is not None + else 2 + ), ) else: self.type = ARRAY(main_type) @@ -226,7 +225,6 @@ class ARRAY(sqltypes.ARRAY): """ class Comparator(sqltypes.ARRAY.Comparator): - """Define comparison operations for :class:`_types.ARRAY`. Note that these operations are in addition to those provided diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index fe6f17a74f..4655f50a86 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -488,19 +488,15 @@ class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): class _AsyncpgConnection(Protocol): async def executemany( self, operation: Any, seq_of_parameters: Sequence[Tuple[Any, ...]] - ) -> Any: - ... + ) -> Any: ... - async def reload_schema_state(self) -> None: - ... + async def reload_schema_state(self) -> None: ... async def prepare( self, operation: Any, *, name: Optional[str] = None - ) -> Any: - ... + ) -> Any: ... - def is_closed(self) -> bool: - ... + def is_closed(self) -> bool: ... def transaction( self, @@ -508,22 +504,17 @@ class _AsyncpgConnection(Protocol): isolation: Optional[str] = None, readonly: bool = False, deferrable: bool = False, - ) -> Any: - ... + ) -> Any: ... - def fetchrow(self, operation: str) -> Any: - ... + def fetchrow(self, operation: str) -> Any: ... - async def close(self) -> None: - ... + async def close(self) -> None: ... - def terminate(self) -> None: - ... + def terminate(self) -> None: ... class _AsyncpgCursor(Protocol): - def fetch(self, size: int) -> Any: - ... + def fetch(self, size: int) -> Any: ... class AsyncAdapt_asyncpg_cursor(AsyncAdapt_dbapi_cursor): @@ -832,9 +823,9 @@ class AsyncAdapt_asyncpg_connection(AsyncAdapt_dbapi_connection): translated_error = exception_mapping[super_]( "%s: %s" % (type(error), error) ) - translated_error.pgcode = ( - translated_error.sqlstate - ) = getattr(error, "sqlstate", None) + translated_error.pgcode = translated_error.sqlstate = ( + getattr(error, "sqlstate", None) + ) raise translated_error from error else: super()._handle_exception(error) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index ef70000c1b..f9347c9986 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2092,9 +2092,11 @@ class PGCompiler(compiler.SQLCompiler): text += "\n FETCH FIRST (%s)%s ROWS %s" % ( self.process(select._fetch_clause, **kw), " PERCENT" if select._fetch_clause_options["percent"] else "", - "WITH TIES" - if select._fetch_clause_options["with_ties"] - else "ONLY", + ( + "WITH TIES" + if select._fetch_clause_options["with_ties"] + else "ONLY" + ), ) return text @@ -2264,9 +2266,11 @@ class PGDDLCompiler(compiler.DDLCompiler): ", ".join( [ self.sql_compiler.process( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr, + ( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr + ), include_table=False, literal_binds=True, ) @@ -2591,17 +2595,21 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( - "(%d)" % type_.precision - if getattr(type_, "precision", None) is not None - else "", + ( + "(%d)" % type_.precision + if getattr(type_, "precision", None) is not None + else "" + ), (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) @@ -3107,9 +3115,7 @@ class PGDialect(default.DefaultDialect): def get_deferrable(self, connection): raise NotImplementedError() - def _split_multihost_from_url( - self, url: URL - ) -> Union[ + def _split_multihost_from_url(self, url: URL) -> Union[ Tuple[None, None], Tuple[Tuple[Optional[str], ...], Tuple[Optional[int], ...]], ]: @@ -3641,9 +3647,11 @@ class PGDialect(default.DefaultDialect): # dictionary with (name, ) if default search path or (schema, name) # as keys enums = dict( - ((rec["name"],), rec) - if rec["visible"] - else ((rec["schema"], rec["name"]), rec) + ( + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + ) for rec in self._load_enums( connection, schema="*", info_cache=kw.get("info_cache") ) @@ -3671,9 +3679,9 @@ class PGDialect(default.DefaultDialect): for row_dict in rows: # ensure that each table has an entry, even if it has no columns if row_dict["name"] is None: - columns[ - (schema, row_dict["table_name"]) - ] = ReflectionDefaults.columns() + columns[(schema, row_dict["table_name"])] = ( + ReflectionDefaults.columns() + ) continue table_cols = columns[(schema, row_dict["table_name"])] @@ -4036,13 +4044,15 @@ class PGDialect(default.DefaultDialect): return ( ( (schema, table_name), - { - "constrained_columns": [] if cols is None else cols, - "name": pk_name, - "comment": comment, - } - if pk_name is not None - else default(), + ( + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + "comment": comment, + } + if pk_name is not None + else default() + ), ) for table_name, cols, pk_name, comment, _ in result ) diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index f227d0fac5..4404ecd37b 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -257,9 +257,9 @@ class OnConflictClause(ClauseElement): self.inferred_target_elements = index_elements self.inferred_target_whereclause = index_where elif constraint is None: - self.constraint_target = ( - self.inferred_target_elements - ) = self.inferred_target_whereclause = None + self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None class OnConflictDoNothing(OnConflictClause): diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index a0a34a9648..56bec1dc73 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -163,7 +163,6 @@ class EnumDropper(NamedTypeDropper): class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum): - """PostgreSQL ENUM type. This is a subclass of :class:`_types.Enum` which includes diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 6faf5e11cd..980f144935 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -723,12 +723,12 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): __abstract__ = True @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index 879389989c..2acf63bef6 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -38,15 +38,14 @@ class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): @overload def __init__( self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... - ) -> None: - ... + ) -> None: ... @overload - def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None: - ... + def __init__( + self: PGUuid[str], as_uuid: Literal[False] = ... + ) -> None: ... - def __init__(self, as_uuid: bool = True) -> None: - ... + def __init__(self, as_uuid: bool = True) -> None: ... class BYTEA(sqltypes.LargeBinary): @@ -129,14 +128,12 @@ class MONEY(sqltypes.TypeEngine[str]): class OID(sqltypes.TypeEngine[int]): - """Provide the PostgreSQL OID type.""" __visit_name__ = "OID" class REGCONFIG(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL REGCONFIG type. .. versionadded:: 2.0.0rc1 @@ -147,7 +144,6 @@ class REGCONFIG(sqltypes.TypeEngine[str]): class TSQUERY(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL TSQUERY type. .. versionadded:: 2.0.0rc1 @@ -158,7 +154,6 @@ class TSQUERY(sqltypes.TypeEngine[str]): class REGCLASS(sqltypes.TypeEngine[str]): - """Provide the PostgreSQL REGCLASS type. .. versionadded:: 1.2.7 @@ -169,7 +164,6 @@ class REGCLASS(sqltypes.TypeEngine[str]): class TIMESTAMP(sqltypes.TIMESTAMP): - """Provide the PostgreSQL TIMESTAMP type.""" __visit_name__ = "TIMESTAMP" @@ -190,7 +184,6 @@ class TIMESTAMP(sqltypes.TIMESTAMP): class TIME(sqltypes.TIME): - """PostgreSQL TIME type.""" __visit_name__ = "TIME" @@ -211,7 +204,6 @@ class TIME(sqltypes.TIME): class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): - """PostgreSQL INTERVAL type.""" __visit_name__ = "INTERVAL" @@ -281,7 +273,6 @@ PGBit = BIT class TSVECTOR(sqltypes.TypeEngine[str]): - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL text search type TSVECTOR. @@ -298,7 +289,6 @@ class TSVECTOR(sqltypes.TypeEngine[str]): class CITEXT(sqltypes.TEXT): - """Provide the PostgreSQL CITEXT type. .. versionadded:: 2.0.7 diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 59ba49c25e..6db8214652 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -2030,9 +2030,9 @@ class SQLiteDialect(default.DefaultDialect): ) if self.dbapi.sqlite_version_info < (3, 35) or util.pypy: - self.update_returning = ( - self.delete_returning - ) = self.insert_returning = False + self.update_returning = self.delete_returning = ( + self.insert_returning + ) = False if self.dbapi.sqlite_version_info < (3, 32, 0): # https://www.sqlite.org/limits.html diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 42e5b0fc7a..dcf5e4482e 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -198,9 +198,9 @@ class OnConflictClause(ClauseElement): self.inferred_target_elements = index_elements self.inferred_target_whereclause = index_where else: - self.constraint_target = ( - self.inferred_target_elements - ) = self.inferred_target_whereclause = None + self.constraint_target = self.inferred_target_elements = ( + self.inferred_target_whereclause + ) = None class OnConflictDoNothing(OnConflictClause): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 17b3f81186..b3577ecca2 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -209,9 +209,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): @property def _schema_translate_map(self) -> Optional[SchemaTranslateMapType]: - schema_translate_map: Optional[ - SchemaTranslateMapType - ] = self._execution_options.get("schema_translate_map", None) + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) return schema_translate_map @@ -222,9 +222,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): """ name = obj.schema - schema_translate_map: Optional[ - SchemaTranslateMapType - ] = self._execution_options.get("schema_translate_map", None) + schema_translate_map: Optional[SchemaTranslateMapType] = ( + self._execution_options.get("schema_translate_map", None) + ) if ( schema_translate_map @@ -255,12 +255,10 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> Connection: - ... + ) -> Connection: ... @overload - def execution_options(self, **opt: Any) -> Connection: - ... + def execution_options(self, **opt: Any) -> Connection: ... def execution_options(self, **opt: Any) -> Connection: r"""Set non-SQL options for the connection which take effect @@ -1266,8 +1264,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -1276,8 +1273,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -1315,8 +1311,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -1325,8 +1320,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, @@ -1360,8 +1354,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Unpack[_Ts]]: - ... + ) -> CursorResult[Unpack[_Ts]]: ... @overload def execute( @@ -1370,8 +1363,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... def execute( self, @@ -2021,9 +2013,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): engine_events = self._has_events or self.engine._has_events if self.dialect._has_events: - do_execute_dispatch: Iterable[ - Any - ] = self.dialect.dispatch.do_execute + do_execute_dispatch: Iterable[Any] = ( + self.dialect.dispatch.do_execute + ) else: do_execute_dispatch = () @@ -2384,9 +2376,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): None, cast(Exception, e), dialect.loaded_dbapi.Error, - hide_parameters=engine.hide_parameters - if engine is not None - else False, + hide_parameters=( + engine.hide_parameters if engine is not None else False + ), connection_invalidated=is_disconnect, dialect=dialect, ) @@ -2423,9 +2415,9 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): break if sqlalchemy_exception and is_disconnect != ctx.is_disconnect: - sqlalchemy_exception.connection_invalidated = ( - is_disconnect - ) = ctx.is_disconnect + sqlalchemy_exception.connection_invalidated = is_disconnect = ( + ctx.is_disconnect + ) if newraise: raise newraise.with_traceback(exc_info[2]) from e @@ -3033,12 +3025,10 @@ class Engine( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> OptionEngine: - ... + ) -> OptionEngine: ... @overload - def execution_options(self, **opt: Any) -> OptionEngine: - ... + def execution_options(self, **opt: Any) -> OptionEngine: ... def execution_options(self, **opt: Any) -> OptionEngine: """Return a new :class:`_engine.Engine` that will provide diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index c30db98c09..e04057d44c 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -82,13 +82,11 @@ def create_engine( query_cache_size: int = ..., use_insertmanyvalues: bool = ..., **kwargs: Any, -) -> Engine: - ... +) -> Engine: ... @overload -def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: - ... +def create_engine(url: Union[str, URL], **kwargs: Any) -> Engine: ... @util.deprecated_params( @@ -816,13 +814,11 @@ def create_pool_from_url( timeout: float = ..., use_lifo: bool = ..., **kwargs: Any, -) -> Pool: - ... +) -> Pool: ... @overload -def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: - ... +def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: ... def create_pool_from_url(url: Union[str, URL], **kwargs: Any) -> Pool: diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index c56065bfe6..6798beadb9 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -153,7 +153,7 @@ class CursorResultMetaData(ResultMetaData): "_translated_indexes", "_safe_for_cache", "_unpickled", - "_key_to_index" + "_key_to_index", # don't need _unique_filters support here for now. Can be added # if a need arises. ) @@ -227,9 +227,11 @@ class CursorResultMetaData(ResultMetaData): { key: ( # int index should be None for ambiguous key - value[0] + offset - if value[0] is not None and key not in keymap - else None, + ( + value[0] + offset + if value[0] is not None and key not in keymap + else None + ), value[1] + offset, *value[2:], ) @@ -364,13 +366,11 @@ class CursorResultMetaData(ResultMetaData): ) = context.result_column_struct num_ctx_cols = len(result_columns) else: - result_columns = ( # type: ignore - cols_are_ordered - ) = ( + result_columns = cols_are_ordered = ( # type: ignore num_ctx_cols - ) = ( - ad_hoc_textual - ) = loose_column_name_matching = textual_ordered = False + ) = ad_hoc_textual = loose_column_name_matching = ( + textual_ordered + ) = False # merge cursor.description with the column info # present in the compiled structure, if any diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 4e4561df38..7eb7d0eb8b 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -853,9 +853,11 @@ class DefaultDialect(Dialect): ordered_rows = [ rows_by_sentinel[ tuple( - _resolver(parameters[_spk]) # type: ignore # noqa: E501 - if _resolver - else parameters[_spk] # type: ignore # noqa: E501 + ( + _resolver(parameters[_spk]) # type: ignore # noqa: E501 + if _resolver + else parameters[_spk] # type: ignore # noqa: E501 + ) for _resolver, _spk in zip( sentinel_value_resolvers, imv.sentinel_param_keys, @@ -1462,9 +1464,11 @@ class DefaultExecutionContext(ExecutionContext): assert positiontup is not None for compiled_params in self.compiled_parameters: l_param: List[Any] = [ - flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] + ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) for key in positiontup ] core_positional_parameters.append( @@ -1485,18 +1489,20 @@ class DefaultExecutionContext(ExecutionContext): for compiled_params in self.compiled_parameters: if escaped_names: d_param = { - escaped_names.get(key, key): flattened_processors[key]( - compiled_params[key] + escaped_names.get(key, key): ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] ) - if key in flattened_processors - else compiled_params[key] for key in compiled_params } else: d_param = { - key: flattened_processors[key](compiled_params[key]) - if key in flattened_processors - else compiled_params[key] + key: ( + flattened_processors[key](compiled_params[key]) + if key in flattened_processors + else compiled_params[key] + ) for key in compiled_params } @@ -2158,17 +2164,21 @@ class DefaultExecutionContext(ExecutionContext): if compiled.positional: parameters = self.dialect.execute_sequence_format( [ - processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] + ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) for key in compiled.positiontup or () ] ) else: parameters = { - key: processors[key](compiled_params[key]) # type: ignore - if key in processors - else compiled_params[key] + key: ( + processors[key](compiled_params[key]) # type: ignore + if key in processors + else compiled_params[key] + ) for key in compiled_params } return self._execute_scalar( diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index 5953b86ca3..62476696e8 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -118,17 +118,13 @@ class DBAPIConnection(Protocol): """ # noqa: E501 - def close(self) -> None: - ... + def close(self) -> None: ... - def commit(self) -> None: - ... + def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: - ... + def cursor(self) -> DBAPICursor: ... - def rollback(self) -> None: - ... + def rollback(self) -> None: ... autocommit: bool @@ -174,53 +170,43 @@ class DBAPICursor(Protocol): ... @property - def rowcount(self) -> int: - ... + def rowcount(self) -> int: ... arraysize: int lastrowid: int - def close(self) -> None: - ... + def close(self) -> None: ... def execute( self, operation: Any, parameters: Optional[_DBAPISingleExecuteParams] = None, - ) -> Any: - ... + ) -> Any: ... def executemany( self, operation: Any, parameters: _DBAPIMultiExecuteParams, - ) -> Any: - ... + ) -> Any: ... - def fetchone(self) -> Optional[Any]: - ... + def fetchone(self) -> Optional[Any]: ... - def fetchmany(self, size: int = ...) -> Sequence[Any]: - ... + def fetchmany(self, size: int = ...) -> Sequence[Any]: ... - def fetchall(self) -> Sequence[Any]: - ... + def fetchall(self) -> Sequence[Any]: ... - def setinputsizes(self, sizes: Sequence[Any]) -> None: - ... + def setinputsizes(self, sizes: Sequence[Any]) -> None: ... - def setoutputsize(self, size: Any, column: Any) -> None: - ... + def setoutputsize(self, size: Any, column: Any) -> None: ... - def callproc(self, procname: str, parameters: Sequence[Any] = ...) -> Any: - ... + def callproc( + self, procname: str, parameters: Sequence[Any] = ... + ) -> Any: ... - def nextset(self) -> Optional[bool]: - ... + def nextset(self) -> Optional[bool]: ... - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... _CoreSingleExecuteParams = Mapping[str, Any] @@ -1303,8 +1289,7 @@ class Dialect(EventTarget): if TYPE_CHECKING: - def _overrides_default(self, method_name: str) -> bool: - ... + def _overrides_default(self, method_name: str) -> bool: ... def get_columns( self, diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index b74b9d343b..e353dff9d7 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -118,8 +118,7 @@ class ResultMetaData: @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... - ) -> NoReturn: - ... + ) -> NoReturn: ... @overload def _key_fallback( @@ -127,14 +126,12 @@ class ResultMetaData: key: Any, err: Optional[Exception], raiseerr: Literal[False] = ..., - ) -> None: - ... + ) -> None: ... @overload def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = ... - ) -> Optional[NoReturn]: - ... + ) -> Optional[NoReturn]: ... def _key_fallback( self, key: Any, err: Optional[Exception], raiseerr: bool = True @@ -737,8 +734,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): raise_for_second_row: bool, raise_for_none: Literal[True], scalar: bool, - ) -> _R: - ... + ) -> _R: ... @overload def _only_one_row( @@ -746,8 +742,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): raise_for_second_row: bool, raise_for_none: bool, scalar: bool, - ) -> Optional[_R]: - ... + ) -> Optional[_R]: ... def _only_one_row( self, @@ -1137,18 +1132,15 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): return self._column_slices(col_expressions) @overload - def scalars(self: Result[_T, Unpack[TupleAny]]) -> ScalarResult[_T]: - ... + def scalars(self: Result[_T, Unpack[TupleAny]]) -> ScalarResult[_T]: ... @overload def scalars( self: Result[_T, Unpack[TupleAny]], index: Literal[0] - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload - def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: - ... + def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: ... def scalars(self, index: _KeyIndexType = 0) -> ScalarResult[Any]: """Return a :class:`_engine.ScalarResult` filtering object which @@ -1479,12 +1471,10 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): ) @overload - def scalar_one(self: Result[_T]) -> _T: - ... + def scalar_one(self: Result[_T]) -> _T: ... @overload - def scalar_one(self) -> Any: - ... + def scalar_one(self) -> Any: ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -1504,12 +1494,10 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): ) @overload - def scalar_one_or_none(self: Result[_T]) -> Optional[_T]: - ... + def scalar_one_or_none(self: Result[_T]) -> Optional[_T]: ... @overload - def scalar_one_or_none(self) -> Optional[Any]: - ... + def scalar_one_or_none(self) -> Optional[Any]: ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. @@ -1562,12 +1550,10 @@ class Result(_WithKeys, ResultInternal[Row[Unpack[_Ts]]]): ) @overload - def scalar(self: Result[_T]) -> Optional[_T]: - ... + def scalar(self: Result[_T]) -> Optional[_T]: ... @overload - def scalar(self) -> Any: - ... + def scalar(self) -> Any: ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -1922,11 +1908,9 @@ class TupleResult(FilterResult[_R], util.TypingOnly): """ ... - def __iter__(self) -> Iterator[_R]: - ... + def __iter__(self) -> Iterator[_R]: ... - def __next__(self) -> _R: - ... + def __next__(self) -> _R: ... def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -1960,12 +1944,10 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: - ... + def scalar_one(self: TupleResult[Tuple[_T]]) -> _T: ... @overload - def scalar_one(self) -> Any: - ... + def scalar_one(self) -> Any: ... def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -1983,12 +1965,12 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar_one_or_none(self: TupleResult[Tuple[_T]]) -> Optional[_T]: - ... + def scalar_one_or_none( + self: TupleResult[Tuple[_T]], + ) -> Optional[_T]: ... @overload - def scalar_one_or_none(self) -> Optional[Any]: - ... + def scalar_one_or_none(self) -> Optional[Any]: ... def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -2006,12 +1988,10 @@ class TupleResult(FilterResult[_R], util.TypingOnly): ... @overload - def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: - ... + def scalar(self: TupleResult[Tuple[_T]]) -> Optional[_T]: ... @overload - def scalar(self) -> Any: - ... + def scalar(self) -> Any: ... def scalar(self) -> Any: """Fetch the first column of the first row, and close the result diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index 5e6db0599e..79d8026c62 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -377,8 +377,7 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): if TYPE_CHECKING: - def __getitem__(self, key: _KeyType) -> Any: - ... + def __getitem__(self, key: _KeyType) -> Any: ... else: __getitem__ = BaseRow._get_by_key_impl_mapping diff --git a/lib/sqlalchemy/engine/url.py b/lib/sqlalchemy/engine/url.py index db4f2879c7..1eeb73a236 100644 --- a/lib/sqlalchemy/engine/url.py +++ b/lib/sqlalchemy/engine/url.py @@ -253,14 +253,12 @@ class URL(NamedTuple): @overload def _assert_value( val: str, - ) -> str: - ... + ) -> str: ... @overload def _assert_value( val: Sequence[str], - ) -> Union[str, Tuple[str, ...]]: - ... + ) -> Union[str, Tuple[str, ...]]: ... def _assert_value( val: Union[str, Sequence[str]], diff --git a/lib/sqlalchemy/event/attr.py b/lib/sqlalchemy/event/attr.py index 509b674c8f..3af9fa52b8 100644 --- a/lib/sqlalchemy/event/attr.py +++ b/lib/sqlalchemy/event/attr.py @@ -391,16 +391,14 @@ class _EmptyListener(_InstanceLevelDispatch[_ET]): class _MutexProtocol(Protocol): - def __enter__(self) -> bool: - ... + def __enter__(self) -> bool: ... def __exit__( self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> Optional[bool]: - ... + ) -> Optional[bool]: ... class _CompoundListener(_InstanceLevelDispatch[_ET]): diff --git a/lib/sqlalchemy/event/base.py b/lib/sqlalchemy/event/base.py index 18a3462478..1f52e2eb79 100644 --- a/lib/sqlalchemy/event/base.py +++ b/lib/sqlalchemy/event/base.py @@ -42,9 +42,9 @@ from .registry import _EventKey from .. import util from ..util.typing import Literal -_registrars: MutableMapping[ - str, List[Type[_HasEventsDispatch[Any]]] -] = util.defaultdict(list) +_registrars: MutableMapping[str, List[Type[_HasEventsDispatch[Any]]]] = ( + util.defaultdict(list) +) def _is_event_name(name: str) -> bool: @@ -240,8 +240,7 @@ class _HasEventsDispatch(Generic[_ET]): if typing.TYPE_CHECKING: - def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: - ... + def __getattr__(self, name: str) -> _InstanceLevelDispatch[_ET]: ... def __init_subclass__(cls) -> None: """Intercept new Event subclasses and create associated _Dispatch @@ -430,12 +429,10 @@ class dispatcher(Generic[_ET]): @overload def __get__( self, obj: Literal[None], cls: Type[Any] - ) -> Type[_Dispatch[_ET]]: - ... + ) -> Type[_Dispatch[_ET]]: ... @overload - def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: - ... + def __get__(self, obj: Any, cls: Type[Any]) -> _DispatchCommon[_ET]: ... def __get__(self, obj: Any, cls: Type[Any]) -> Any: if obj is None: diff --git a/lib/sqlalchemy/event/legacy.py b/lib/sqlalchemy/event/legacy.py index 067b720584..57e561c390 100644 --- a/lib/sqlalchemy/event/legacy.py +++ b/lib/sqlalchemy/event/legacy.py @@ -147,9 +147,9 @@ def _standard_listen_example( ) text %= { - "current_since": " (arguments as of %s)" % current_since - if current_since - else "", + "current_since": ( + " (arguments as of %s)" % current_since if current_since else "" + ), "event_name": fn.__name__, "has_kw_arguments": ", **kw" if dispatch_collection.has_kw else "", "named_event_arguments": ", ".join(dispatch_collection.arg_names), @@ -177,9 +177,9 @@ def _legacy_listen_examples( % { "since": since, "event_name": fn.__name__, - "has_kw_arguments": " **kw" - if dispatch_collection.has_kw - else "", + "has_kw_arguments": ( + " **kw" if dispatch_collection.has_kw else "" + ), "named_event_arguments": ", ".join(args), "sample_target": sample_target, } diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index c048735e21..773620f8bb 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -66,9 +66,9 @@ _RefCollectionToListenerType = Dict[ "weakref.ref[_ListenerFnType]", ] -_key_to_collection: Dict[ - _EventKeyTupleType, _RefCollectionToListenerType -] = collections.defaultdict(dict) +_key_to_collection: Dict[_EventKeyTupleType, _RefCollectionToListenerType] = ( + collections.defaultdict(dict) +) """ Given an original listen() argument, can locate all listener collections and the listener fn contained diff --git a/lib/sqlalchemy/exc.py b/lib/sqlalchemy/exc.py index c4025a2b8c..7d7eff3606 100644 --- a/lib/sqlalchemy/exc.py +++ b/lib/sqlalchemy/exc.py @@ -571,8 +571,7 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> StatementError: - ... + ) -> StatementError: ... @overload @classmethod @@ -586,8 +585,7 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> DontWrapMixin: - ... + ) -> DontWrapMixin: ... @overload @classmethod @@ -601,8 +599,7 @@ class DBAPIError(StatementError): connection_invalidated: bool = False, dialect: Optional[Dialect] = None, ismulti: Optional[bool] = None, - ) -> BaseException: - ... + ) -> BaseException: ... @classmethod def instance( diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index b6c4d41ff7..b1720205b6 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -254,45 +254,39 @@ class AssociationProxyExtensionType(InspectionAttrExtensionType): class _GetterProtocol(Protocol[_T_co]): - def __call__(self, instance: Any) -> _T_co: - ... + def __call__(self, instance: Any) -> _T_co: ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _SetterProtocol(Protocol): - ... +class _SetterProtocol(Protocol): ... class _PlainSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, value: _T_con) -> None: - ... + def __call__(self, instance: Any, value: _T_con) -> None: ... class _DictSetterProtocol(_SetterProtocol, Protocol[_T_con]): - def __call__(self, instance: Any, key: Any, value: _T_con) -> None: - ... + def __call__(self, instance: Any, key: Any, value: _T_con) -> None: ... # mypy 0.990 we are no longer allowed to make this Protocol[_T_con] -class _CreatorProtocol(Protocol): - ... +class _CreatorProtocol(Protocol): ... class _PlainCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, value: _T_con) -> Any: - ... + def __call__(self, value: _T_con) -> Any: ... class _KeyCreatorProtocol(_CreatorProtocol, Protocol[_T_con]): - def __call__(self, key: Any, value: Optional[_T_con]) -> Any: - ... + def __call__(self, key: Any, value: Optional[_T_con]) -> Any: ... class _LazyCollectionProtocol(Protocol[_T]): def __call__( self, - ) -> Union[MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T]]: - ... + ) -> Union[ + MutableSet[_T], MutableMapping[Any, _T], MutableSequence[_T] + ]: ... class _GetSetFactoryProtocol(Protocol): @@ -300,8 +294,7 @@ class _GetSetFactoryProtocol(Protocol): self, collection_class: Optional[Type[Any]], assoc_instance: AssociationProxyInstance[Any], - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: - ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... class _ProxyFactoryProtocol(Protocol): @@ -311,15 +304,13 @@ class _ProxyFactoryProtocol(Protocol): creator: _CreatorProtocol, value_attr: str, parent: AssociationProxyInstance[Any], - ) -> Any: - ... + ) -> Any: ... class _ProxyBulkSetProtocol(Protocol): def __call__( self, proxy: _AssociationCollection[Any], collection: Iterable[Any] - ) -> None: - ... + ) -> None: ... class _AssociationProxyProtocol(Protocol[_T]): @@ -337,18 +328,15 @@ class _AssociationProxyProtocol(Protocol[_T]): proxy_bulk_set: Optional[_ProxyBulkSetProtocol] @util.ro_memoized_property - def info(self) -> _InfoType: - ... + def info(self) -> _InfoType: ... def for_class( self, class_: Type[Any], obj: Optional[object] = None - ) -> AssociationProxyInstance[_T]: - ... + ) -> AssociationProxyInstance[_T]: ... def _default_getset( self, collection_class: Any - ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: - ... + ) -> Tuple[_GetterProtocol[Any], _SetterProtocol]: ... class AssociationProxy( @@ -419,18 +407,17 @@ class AssociationProxy( self._attribute_options = _DEFAULT_ATTRIBUTE_OPTIONS @overload - def __get__(self, instance: Literal[None], owner: Literal[None]) -> Self: - ... + def __get__( + self, instance: Literal[None], owner: Literal[None] + ) -> Self: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> AssociationProxyInstance[_T]: - ... + ) -> AssociationProxyInstance[_T]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( self, instance: object, owner: Any @@ -861,12 +848,10 @@ class AssociationProxyInstance(SQLORMOperations[_T]): return self.parent.info @overload - def get(self: _Self, obj: Literal[None]) -> _Self: - ... + def get(self: _Self, obj: Literal[None]) -> _Self: ... @overload - def get(self, obj: Any) -> _T: - ... + def get(self, obj: Any) -> _T: ... def get( self, obj: Any @@ -1432,12 +1417,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): self.setter(object_, value) @overload - def __getitem__(self, index: int) -> _T: - ... + def __getitem__(self, index: int) -> _T: ... @overload - def __getitem__(self, index: slice) -> MutableSequence[_T]: - ... + def __getitem__(self, index: slice) -> MutableSequence[_T]: ... def __getitem__( self, index: Union[int, slice] @@ -1448,12 +1431,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): return [self._get(member) for member in self.col[index]] @overload - def __setitem__(self, index: int, value: _T) -> None: - ... + def __setitem__(self, index: int, value: _T) -> None: ... @overload - def __setitem__(self, index: slice, value: Iterable[_T]) -> None: - ... + def __setitem__(self, index: slice, value: Iterable[_T]) -> None: ... def __setitem__( self, index: Union[int, slice], value: Union[_T, Iterable[_T]] @@ -1492,12 +1473,10 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): self._set(self.col[i], item) @overload - def __delitem__(self, index: int) -> None: - ... + def __delitem__(self, index: int) -> None: ... @overload - def __delitem__(self, index: slice) -> None: - ... + def __delitem__(self, index: slice) -> None: ... def __delitem__(self, index: Union[slice, int]) -> None: del self.col[index] @@ -1624,8 +1603,9 @@ class _AssociationList(_AssociationSingleItem[_T], MutableSequence[_T]): if typing.TYPE_CHECKING: # TODO: no idea how to do this without separate "stub" - def index(self, value: Any, start: int = ..., stop: int = ...) -> int: - ... + def index( + self, value: Any, start: int = ..., stop: int = ... + ) -> int: ... else: @@ -1701,12 +1681,12 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): return repr(dict(self)) @overload - def get(self, __key: _KT, /) -> Optional[_VT]: - ... + def get(self, __key: _KT, /) -> Optional[_VT]: ... @overload - def get(self, __key: _KT, /, default: Union[_VT, _T]) -> Union[_VT, _T]: - ... + def get( + self, __key: _KT, /, default: Union[_VT, _T] + ) -> Union[_VT, _T]: ... def get( self, __key: _KT, /, default: Optional[Union[_VT, _T]] = None @@ -1738,14 +1718,12 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): return ValuesView(self) @overload - def pop(self, __key: _KT, /) -> _VT: - ... + def pop(self, __key: _KT, /) -> _VT: ... @overload def pop( self, __key: _KT, /, default: Union[_VT, _T] = ... - ) -> Union[_VT, _T]: - ... + ) -> Union[_VT, _T]: ... def pop(self, __key: _KT, /, *arg: Any, **kw: Any) -> Union[_VT, _T]: member = self.col.pop(__key, *arg, **kw) @@ -1758,16 +1736,15 @@ class _AssociationDict(_AssociationCollection[_VT], MutableMapping[_KT, _VT]): @overload def update( self, __m: SupportsKeysAndGetItem[_KT, _VT], **kwargs: _VT - ) -> None: - ... + ) -> None: ... @overload - def update(self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT) -> None: - ... + def update( + self, __m: Iterable[tuple[_KT, _VT]], **kwargs: _VT + ) -> None: ... @overload - def update(self, **kwargs: _VT) -> None: - ... + def update(self, **kwargs: _VT) -> None: ... def update(self, *a: Any, **kw: Any) -> None: up: Dict[_KT, _VT] = {} diff --git a/lib/sqlalchemy/ext/asyncio/base.py b/lib/sqlalchemy/ext/asyncio/base.py index 330651b074..9899364d1f 100644 --- a/lib/sqlalchemy/ext/asyncio/base.py +++ b/lib/sqlalchemy/ext/asyncio/base.py @@ -44,12 +44,10 @@ class ReversibleProxy(Generic[_PT]): __slots__ = ("__weakref__",) @overload - def _assign_proxied(self, target: _PT) -> _PT: - ... + def _assign_proxied(self, target: _PT) -> _PT: ... @overload - def _assign_proxied(self, target: None) -> None: - ... + def _assign_proxied(self, target: None) -> None: ... def _assign_proxied(self, target: Optional[_PT]) -> Optional[_PT]: if target is not None: @@ -82,15 +80,13 @@ class ReversibleProxy(Generic[_PT]): cls, target: _PT, regenerate: Literal[True] = ..., - ) -> Self: - ... + ) -> Self: ... @overload @classmethod def _retrieve_proxy_for_target( cls, target: _PT, regenerate: bool = True - ) -> Optional[Self]: - ... + ) -> Optional[Self]: ... @classmethod def _retrieve_proxy_for_target( diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index aabd4b961a..2b3a85465d 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -418,12 +418,10 @@ class AsyncConnection( insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> AsyncConnection: - ... + ) -> AsyncConnection: ... @overload - async def execution_options(self, **opt: Any) -> AsyncConnection: - ... + async def execution_options(self, **opt: Any) -> AsyncConnection: ... async def execution_options(self, **opt: Any) -> AsyncConnection: r"""Set non-SQL options for the connection which take effect @@ -521,8 +519,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[Unpack[_Ts]]]: - ... + ) -> GeneratorStartableContext[AsyncResult[Unpack[_Ts]]]: ... @overload def stream( @@ -531,8 +528,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncResult[Unpack[TupleAny]]]: - ... + ) -> GeneratorStartableContext[AsyncResult[Unpack[TupleAny]]]: ... @asyncstartablecontext async def stream( @@ -608,8 +604,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Unpack[_Ts]]: - ... + ) -> CursorResult[Unpack[_Ts]]: ... @overload async def execute( @@ -618,8 +613,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... async def execute( self, @@ -675,8 +669,7 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -685,8 +678,7 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -717,8 +709,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -727,8 +718,7 @@ class AsyncConnection( parameters: Optional[_CoreAnyExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -760,8 +750,7 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: - ... + ) -> GeneratorStartableContext[AsyncScalarResult[_T]]: ... @overload def stream_scalars( @@ -770,8 +759,7 @@ class AsyncConnection( parameters: Optional[_CoreSingleExecuteParams] = None, *, execution_options: Optional[CoreExecuteOptionsParameter] = None, - ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: - ... + ) -> GeneratorStartableContext[AsyncScalarResult[Any]]: ... @asyncstartablecontext async def stream_scalars( @@ -1108,12 +1096,10 @@ class AsyncEngine(ProxyComparable[Engine], AsyncConnectable): insertmanyvalues_page_size: int = ..., schema_translate_map: Optional[SchemaTranslateMapType] = ..., **opt: Any, - ) -> AsyncEngine: - ... + ) -> AsyncEngine: ... @overload - def execution_options(self, **opt: Any) -> AsyncEngine: - ... + def execution_options(self, **opt: Any) -> AsyncEngine: ... def execution_options(self, **opt: Any) -> AsyncEngine: """Return a new :class:`_asyncio.AsyncEngine` that will provide @@ -1426,15 +1412,13 @@ class AsyncTransaction( @overload -def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: - ... +def _get_sync_engine_or_connection(async_engine: AsyncEngine) -> Engine: ... @overload def _get_sync_engine_or_connection( async_engine: AsyncConnection, -) -> Connection: - ... +) -> Connection: ... def _get_sync_engine_or_connection( diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index 14c0840d95..c02c64706b 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -347,12 +347,10 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): return await greenlet_spawn(self._only_one_row, True, False, False) @overload - async def scalar_one(self: AsyncResult[_T]) -> _T: - ... + async def scalar_one(self: AsyncResult[_T]) -> _T: ... @overload - async def scalar_one(self) -> Any: - ... + async def scalar_one(self) -> Any: ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -372,12 +370,10 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): @overload async def scalar_one_or_none( self: AsyncResult[_T], - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: - ... + async def scalar_one_or_none(self) -> Optional[Any]: ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one scalar result or ``None``. @@ -426,12 +422,10 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): return await greenlet_spawn(self._only_one_row, True, True, False) @overload - async def scalar(self: AsyncResult[_T]) -> Optional[_T]: - ... + async def scalar(self: AsyncResult[_T]) -> Optional[_T]: ... @overload - async def scalar(self) -> Any: - ... + async def scalar(self) -> Any: ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result set. @@ -475,18 +469,15 @@ class AsyncResult(_WithKeys, AsyncCommon[Row[Unpack[_Ts]]]): @overload def scalars( self: AsyncResult[_T, Unpack[TupleAny]], index: Literal[0] - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload def scalars( self: AsyncResult[_T, Unpack[TupleAny]], - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload - def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: - ... + def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: ... def scalars(self, index: _KeyIndexType = 0) -> AsyncScalarResult[Any]: """Return an :class:`_asyncio.AsyncScalarResult` filtering object which @@ -862,11 +853,9 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): """ ... - async def __aiter__(self) -> AsyncIterator[_R]: - ... + async def __aiter__(self) -> AsyncIterator[_R]: ... - async def __anext__(self) -> _R: - ... + async def __anext__(self) -> _R: ... async def first(self) -> Optional[_R]: """Fetch the first object or ``None`` if no object is present. @@ -900,12 +889,10 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): ... @overload - async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: - ... + async def scalar_one(self: AsyncTupleResult[Tuple[_T]]) -> _T: ... @overload - async def scalar_one(self) -> Any: - ... + async def scalar_one(self) -> Any: ... async def scalar_one(self) -> Any: """Return exactly one scalar result or raise an exception. @@ -925,12 +912,10 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): @overload async def scalar_one_or_none( self: AsyncTupleResult[Tuple[_T]], - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - async def scalar_one_or_none(self) -> Optional[Any]: - ... + async def scalar_one_or_none(self) -> Optional[Any]: ... async def scalar_one_or_none(self) -> Optional[Any]: """Return exactly one or no scalar result. @@ -948,12 +933,12 @@ class AsyncTupleResult(AsyncCommon[_R], util.TypingOnly): ... @overload - async def scalar(self: AsyncTupleResult[Tuple[_T]]) -> Optional[_T]: - ... + async def scalar( + self: AsyncTupleResult[Tuple[_T]], + ) -> Optional[_T]: ... @overload - async def scalar(self) -> Any: - ... + async def scalar(self) -> Any: ... async def scalar(self) -> Any: """Fetch the first column of the first row, and close the result diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 850b4b750f..8fdb5a7c6d 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -540,8 +540,7 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[_Ts]]: - ... + ) -> Result[Unpack[_Ts]]: ... @overload async def execute( @@ -553,8 +552,7 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... @overload async def execute( @@ -566,8 +564,7 @@ class async_scoped_session(Generic[_AS]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[TupleAny]]: - ... + ) -> Result[Unpack[TupleAny]]: ... async def execute( self, @@ -1019,8 +1016,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -1031,8 +1027,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -1074,8 +1069,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -1086,8 +1080,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -1217,8 +1210,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Unpack[_Ts]]: - ... + ) -> AsyncResult[Unpack[_Ts]]: ... @overload async def stream( @@ -1229,8 +1221,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Unpack[TupleAny]]: - ... + ) -> AsyncResult[Unpack[TupleAny]]: ... async def stream( self, @@ -1269,8 +1260,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload async def stream_scalars( @@ -1281,8 +1271,7 @@ class async_scoped_session(Generic[_AS]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: - ... + ) -> AsyncScalarResult[Any]: ... async def stream_scalars( self, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index f7a2469868..f8c823cff0 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -402,8 +402,7 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[_Ts]]: - ... + ) -> Result[Unpack[_Ts]]: ... @overload async def execute( @@ -415,8 +414,7 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... @overload async def execute( @@ -428,8 +426,7 @@ class AsyncSession(ReversibleProxy[Session]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[TupleAny]]: - ... + ) -> Result[Unpack[TupleAny]]: ... async def execute( self, @@ -475,8 +472,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload async def scalar( @@ -487,8 +483,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... async def scalar( self, @@ -532,8 +527,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload async def scalars( @@ -544,8 +538,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... async def scalars( self, @@ -659,8 +652,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Unpack[_Ts]]: - ... + ) -> AsyncResult[Unpack[_Ts]]: ... @overload async def stream( @@ -671,8 +663,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncResult[Unpack[TupleAny]]: - ... + ) -> AsyncResult[Unpack[TupleAny]]: ... async def stream( self, @@ -714,8 +705,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[_T]: - ... + ) -> AsyncScalarResult[_T]: ... @overload async def stream_scalars( @@ -726,8 +716,7 @@ class AsyncSession(ReversibleProxy[Session]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> AsyncScalarResult[Any]: - ... + ) -> AsyncScalarResult[Any]: ... async def stream_scalars( self, @@ -1690,8 +1679,7 @@ class async_sessionmaker(Generic[_AS]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... @overload def __init__( @@ -1702,8 +1690,7 @@ class async_sessionmaker(Generic[_AS]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... def __init__( self, diff --git a/lib/sqlalchemy/ext/automap.py b/lib/sqlalchemy/ext/automap.py index 37be38ec68..3efb4ddf9c 100644 --- a/lib/sqlalchemy/ext/automap.py +++ b/lib/sqlalchemy/ext/automap.py @@ -715,8 +715,9 @@ _VT = TypeVar("_VT", bound=Any) class PythonNameForTableType(Protocol): - def __call__(self, base: Type[Any], tablename: str, table: Table) -> str: - ... + def __call__( + self, base: Type[Any], tablename: str, table: Table + ) -> str: ... def classname_for_table( @@ -763,8 +764,7 @@ class NameForScalarRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: - ... + ) -> str: ... def name_for_scalar_relationship( @@ -804,8 +804,7 @@ class NameForCollectionRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], constraint: ForeignKeyConstraint, - ) -> str: - ... + ) -> str: ... def name_for_collection_relationship( @@ -850,8 +849,7 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Relationship[Any]: - ... + ) -> Relationship[Any]: ... @overload def __call__( @@ -863,8 +861,7 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> ORMBackrefArgument: - ... + ) -> ORMBackrefArgument: ... def __call__( self, @@ -877,8 +874,7 @@ class GenerateRelationshipType(Protocol): local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, - ) -> Union[ORMBackrefArgument, Relationship[Any]]: - ... + ) -> Union[ORMBackrefArgument, Relationship[Any]]: ... @overload @@ -890,8 +886,7 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> Relationship[Any]: - ... +) -> Relationship[Any]: ... @overload @@ -903,8 +898,7 @@ def generate_relationship( local_cls: Type[Any], referred_cls: Type[Any], **kw: Any, -) -> ORMBackrefArgument: - ... +) -> ORMBackrefArgument: ... def generate_relationship( diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index ad8b3444ad..71fda2fb39 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -86,8 +86,7 @@ class ShardChooser(Protocol): mapper: Optional[Mapper[_T]], instance: Any, clause: Optional[ClauseElement], - ) -> Any: - ... + ) -> Any: ... class IdentityChooser(Protocol): @@ -100,8 +99,7 @@ class IdentityChooser(Protocol): execution_options: OrmExecuteOptionsParameter, bind_arguments: _BindArguments, **kw: Any, - ) -> Any: - ... + ) -> Any: ... class ShardedQuery(Query[_T]): diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index ddb5d4d9f2..de8cec8fdb 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -904,13 +904,11 @@ class HybridExtensionType(InspectionAttrExtensionType): class _HybridGetterType(Protocol[_T_co]): - def __call__(s, self: Any) -> _T_co: - ... + def __call__(s, self: Any) -> _T_co: ... class _HybridSetterType(Protocol[_T_con]): - def __call__(s, self: Any, value: _T_con) -> None: - ... + def __call__(s, self: Any, value: _T_con) -> None: ... class _HybridUpdaterType(Protocol[_T_con]): @@ -918,25 +916,21 @@ class _HybridUpdaterType(Protocol[_T_con]): s, cls: Any, value: Union[_T_con, _ColumnExpressionArgument[_T_con]], - ) -> List[Tuple[_DMLColumnArgument, Any]]: - ... + ) -> List[Tuple[_DMLColumnArgument, Any]]: ... class _HybridDeleterType(Protocol[_T_co]): - def __call__(s, self: Any) -> None: - ... + def __call__(s, self: Any) -> None: ... class _HybridExprCallableType(Protocol[_T_co]): def __call__( s, cls: Any - ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: - ... + ) -> Union[_HasClauseElement[_T_co], SQLColumnExpression[_T_co]]: ... class _HybridComparatorCallableType(Protocol[_T]): - def __call__(self, cls: Any) -> Comparator[_T]: - ... + def __call__(self, cls: Any) -> Comparator[_T]: ... class _HybridClassLevelAccessor(QueryableAttribute[_T]): @@ -947,23 +941,24 @@ class _HybridClassLevelAccessor(QueryableAttribute[_T]): if TYPE_CHECKING: - def getter(self, fget: _HybridGetterType[_T]) -> hybrid_property[_T]: - ... + def getter( + self, fget: _HybridGetterType[_T] + ) -> hybrid_property[_T]: ... - def setter(self, fset: _HybridSetterType[_T]) -> hybrid_property[_T]: - ... + def setter( + self, fset: _HybridSetterType[_T] + ) -> hybrid_property[_T]: ... - def deleter(self, fdel: _HybridDeleterType[_T]) -> hybrid_property[_T]: - ... + def deleter( + self, fdel: _HybridDeleterType[_T] + ) -> hybrid_property[_T]: ... @property - def overrides(self) -> hybrid_property[_T]: - ... + def overrides(self) -> hybrid_property[_T]: ... def update_expression( self, meth: _HybridUpdaterType[_T] - ) -> hybrid_property[_T]: - ... + ) -> hybrid_property[_T]: ... class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): @@ -1025,14 +1020,12 @@ class hybrid_method(interfaces.InspectionAttrInfo, Generic[_P, _R]): @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> Callable[_P, SQLCoreOperations[_R]]: - ... + ) -> Callable[_P, SQLCoreOperations[_R]]: ... @overload def __get__( self, instance: object, owner: Type[object] - ) -> Callable[_P, _R]: - ... + ) -> Callable[_P, _R]: ... def __get__( self, instance: Optional[object], owner: Type[object] @@ -1106,18 +1099,15 @@ class hybrid_property(interfaces.InspectionAttrInfo, ORMDescriptor[_T]): util.update_wrapper(self, fget) @overload - def __get__(self, instance: Any, owner: Literal[None]) -> Self: - ... + def __get__(self, instance: Any, owner: Literal[None]) -> Self: ... @overload def __get__( self, instance: Literal[None], owner: Type[object] - ) -> _HybridClassLevelAccessor[_T]: - ... + ) -> _HybridClassLevelAccessor[_T]: ... @overload - def __get__(self, instance: object, owner: Type[object]) -> _T: - ... + def __get__(self, instance: object, owner: Type[object]) -> _T: ... def __get__( self, instance: Optional[object], owner: Optional[Type[object]] diff --git a/lib/sqlalchemy/ext/instrumentation.py b/lib/sqlalchemy/ext/instrumentation.py index e84dde2687..5f3c71282b 100644 --- a/lib/sqlalchemy/ext/instrumentation.py +++ b/lib/sqlalchemy/ext/instrumentation.py @@ -214,9 +214,9 @@ class ExtendedInstrumentationRegistry(InstrumentationFactory): )(instance) -orm_instrumentation._instrumentation_factory = ( - _instrumentation_factory -) = ExtendedInstrumentationRegistry() +orm_instrumentation._instrumentation_factory = _instrumentation_factory = ( + ExtendedInstrumentationRegistry() +) orm_instrumentation.instrumentation_finders = instrumentation_finders @@ -436,17 +436,15 @@ def _install_lookups(lookups): instance_dict = lookups["instance_dict"] manager_of_class = lookups["manager_of_class"] opt_manager_of_class = lookups["opt_manager_of_class"] - orm_base.instance_state = ( - attributes.instance_state - ) = orm_instrumentation.instance_state = instance_state - orm_base.instance_dict = ( - attributes.instance_dict - ) = orm_instrumentation.instance_dict = instance_dict - orm_base.manager_of_class = ( - attributes.manager_of_class - ) = orm_instrumentation.manager_of_class = manager_of_class - orm_base.opt_manager_of_class = ( - orm_util.opt_manager_of_class - ) = ( + orm_base.instance_state = attributes.instance_state = ( + orm_instrumentation.instance_state + ) = instance_state + orm_base.instance_dict = attributes.instance_dict = ( + orm_instrumentation.instance_dict + ) = instance_dict + orm_base.manager_of_class = attributes.manager_of_class = ( + orm_instrumentation.manager_of_class + ) = manager_of_class + orm_base.opt_manager_of_class = orm_util.opt_manager_of_class = ( attributes.opt_manager_of_class ) = orm_instrumentation.opt_manager_of_class = opt_manager_of_class diff --git a/lib/sqlalchemy/ext/mutable.py b/lib/sqlalchemy/ext/mutable.py index 6f9a7b4503..fc53981c1b 100644 --- a/lib/sqlalchemy/ext/mutable.py +++ b/lib/sqlalchemy/ext/mutable.py @@ -800,15 +800,12 @@ class MutableDict(Mutable, Dict[_KT, _VT]): @overload def setdefault( self: MutableDict[_KT, Optional[_T]], key: _KT, value: None = None - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload - def setdefault(self, key: _KT, value: _VT) -> _VT: - ... + def setdefault(self, key: _KT, value: _VT) -> _VT: ... - def setdefault(self, key: _KT, value: object = None) -> object: - ... + def setdefault(self, key: _KT, value: object = None) -> object: ... else: @@ -829,17 +826,14 @@ class MutableDict(Mutable, Dict[_KT, _VT]): if TYPE_CHECKING: @overload - def pop(self, __key: _KT, /) -> _VT: - ... + def pop(self, __key: _KT, /) -> _VT: ... @overload - def pop(self, __key: _KT, default: _VT | _T, /) -> _VT | _T: - ... + def pop(self, __key: _KT, default: _VT | _T, /) -> _VT | _T: ... def pop( self, __key: _KT, __default: _VT | _T | None = None, / - ) -> _VT | _T: - ... + ) -> _VT | _T: ... else: diff --git a/lib/sqlalchemy/ext/mypy/apply.py b/lib/sqlalchemy/ext/mypy/apply.py index 4185d29b94..eb9019453d 100644 --- a/lib/sqlalchemy/ext/mypy/apply.py +++ b/lib/sqlalchemy/ext/mypy/apply.py @@ -161,9 +161,9 @@ def re_apply_declarative_assignments( # update the SQLAlchemyAttribute with the better # information - mapped_attr_lookup[ - stmt.lvalues[0].name - ].type = python_type_for_type + mapped_attr_lookup[stmt.lvalues[0].name].type = ( + python_type_for_type + ) update_cls_metadata = True @@ -223,9 +223,11 @@ def apply_type_to_mapped_statement( lvalue.is_inferred_def = False left_node.type = api.named_type( NAMED_TYPE_SQLA_MAPPED, - [AnyType(TypeOfAny.special_form)] - if python_type_for_type is None - else [python_type_for_type], + ( + [AnyType(TypeOfAny.special_form)] + if python_type_for_type is None + else [python_type_for_type] + ), ) # so to have it skip the right side totally, we can do this: diff --git a/lib/sqlalchemy/ext/mypy/decl_class.py b/lib/sqlalchemy/ext/mypy/decl_class.py index d7dff91cbd..3d578b346e 100644 --- a/lib/sqlalchemy/ext/mypy/decl_class.py +++ b/lib/sqlalchemy/ext/mypy/decl_class.py @@ -58,9 +58,9 @@ def scan_declarative_assignments_and_apply_types( elif cls.fullname.startswith("builtins"): return None - mapped_attributes: Optional[ - List[util.SQLAlchemyAttribute] - ] = util.get_mapped_attributes(info, api) + mapped_attributes: Optional[List[util.SQLAlchemyAttribute]] = ( + util.get_mapped_attributes(info, api) + ) # used by assign.add_additional_orm_attributes among others util.establish_as_sqlalchemy(info) diff --git a/lib/sqlalchemy/ext/mypy/util.py b/lib/sqlalchemy/ext/mypy/util.py index 10cdb56b05..7f04c481d3 100644 --- a/lib/sqlalchemy/ext/mypy/util.py +++ b/lib/sqlalchemy/ext/mypy/util.py @@ -212,8 +212,7 @@ def add_global( @overload def get_callexpr_kwarg( callexpr: CallExpr, name: str, *, expr_types: None = ... -) -> Optional[Union[CallExpr, NameExpr]]: - ... +) -> Optional[Union[CallExpr, NameExpr]]: ... @overload @@ -222,8 +221,7 @@ def get_callexpr_kwarg( name: str, *, expr_types: Tuple[TypingType[_TArgType], ...], -) -> Optional[_TArgType]: - ... +) -> Optional[_TArgType]: ... def get_callexpr_kwarg( @@ -315,9 +313,11 @@ def unbound_to_instance( return Instance( bound_type, [ - unbound_to_instance(api, arg) - if isinstance(arg, UnboundType) - else arg + ( + unbound_to_instance(api, arg) + if isinstance(arg, UnboundType) + else arg + ) for arg in typ.args ], ) diff --git a/lib/sqlalchemy/inspection.py b/lib/sqlalchemy/inspection.py index 9b499d0387..4842c89ab7 100644 --- a/lib/sqlalchemy/inspection.py +++ b/lib/sqlalchemy/inspection.py @@ -74,8 +74,7 @@ class _InspectableTypeProtocol(Protocol[_TCov]): """ - def _sa_inspect_type(self) -> _TCov: - ... + def _sa_inspect_type(self) -> _TCov: ... class _InspectableProtocol(Protocol[_TCov]): @@ -84,35 +83,31 @@ class _InspectableProtocol(Protocol[_TCov]): """ - def _sa_inspect_instance(self) -> _TCov: - ... + def _sa_inspect_instance(self) -> _TCov: ... @overload def inspect( subject: Type[_InspectableTypeProtocol[_IN]], raiseerr: bool = True -) -> _IN: - ... +) -> _IN: ... @overload -def inspect(subject: _InspectableProtocol[_IN], raiseerr: bool = True) -> _IN: - ... +def inspect( + subject: _InspectableProtocol[_IN], raiseerr: bool = True +) -> _IN: ... @overload -def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: - ... +def inspect(subject: Inspectable[_IN], raiseerr: bool = True) -> _IN: ... @overload -def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: - ... +def inspect(subject: Any, raiseerr: Literal[False] = ...) -> Optional[Any]: ... @overload -def inspect(subject: Any, raiseerr: bool = True) -> Any: - ... +def inspect(subject: Any, raiseerr: bool = True) -> Any: ... def inspect(subject: Any, raiseerr: bool = True) -> Any: diff --git a/lib/sqlalchemy/log.py b/lib/sqlalchemy/log.py index 3f40b562b4..e6bfbadfed 100644 --- a/lib/sqlalchemy/log.py +++ b/lib/sqlalchemy/log.py @@ -264,14 +264,12 @@ class echo_property: @overload def __get__( self, instance: Literal[None], owner: Type[Identified] - ) -> echo_property: - ... + ) -> echo_property: ... @overload def __get__( self, instance: Identified, owner: Type[Identified] - ) -> _EchoFlagType: - ... + ) -> _EchoFlagType: ... def __get__( self, instance: Optional[Identified], owner: Type[Identified] diff --git a/lib/sqlalchemy/orm/_orm_constructors.py b/lib/sqlalchemy/orm/_orm_constructors.py index 3a7f826e1d..f2c4f8ef42 100644 --- a/lib/sqlalchemy/orm/_orm_constructors.py +++ b/lib/sqlalchemy/orm/_orm_constructors.py @@ -387,9 +387,9 @@ def orm_insert_sentinel( return mapped_column( name=name, - default=default - if default is not None - else _InsertSentinelColumnDefault(), + default=( + default if default is not None else _InsertSentinelColumnDefault() + ), _omit_from_statements=omit_from_statements, insert_sentinel=True, use_existing_column=True, @@ -562,8 +562,7 @@ def composite( info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[Any]: - ... +) -> Composite[Any]: ... @overload @@ -585,8 +584,7 @@ def composite( info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: - ... +) -> Composite[_CC]: ... @overload @@ -608,8 +606,7 @@ def composite( info: Optional[_InfoType] = None, doc: Optional[str] = None, **__kw: Any, -) -> Composite[_CC]: - ... +) -> Composite[_CC]: ... def composite( @@ -2183,8 +2180,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedType[_O]: - ... +) -> AliasedType[_O]: ... @overload @@ -2194,8 +2190,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> AliasedClass[_O]: - ... +) -> AliasedClass[_O]: ... @overload @@ -2205,8 +2200,7 @@ def aliased( name: Optional[str] = None, flat: bool = False, adapt_on_names: bool = False, -) -> FromClause: - ... +) -> FromClause: ... def aliased( diff --git a/lib/sqlalchemy/orm/_typing.py b/lib/sqlalchemy/orm/_typing.py index 6c815169c5..95fbd9e7e2 100644 --- a/lib/sqlalchemy/orm/_typing.py +++ b/lib/sqlalchemy/orm/_typing.py @@ -108,13 +108,13 @@ class _ORMAdapterProto(Protocol): """ - def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: - ... + def __call__(self, obj: _CE, key: Optional[str] = None) -> _CE: ... class _LoaderCallable(Protocol): - def __call__(self, state: InstanceState[Any], passive: PassiveFlag) -> Any: - ... + def __call__( + self, state: InstanceState[Any], passive: PassiveFlag + ) -> Any: ... def is_orm_option( @@ -138,39 +138,33 @@ def is_composite_class(obj: Any) -> bool: if TYPE_CHECKING: - def insp_is_mapper_property(obj: Any) -> TypeGuard[MapperProperty[Any]]: - ... + def insp_is_mapper_property( + obj: Any, + ) -> TypeGuard[MapperProperty[Any]]: ... - def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: - ... + def insp_is_mapper(obj: Any) -> TypeGuard[Mapper[Any]]: ... - def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: - ... + def insp_is_aliased_class(obj: Any) -> TypeGuard[AliasedInsp[Any]]: ... def insp_is_attribute( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: - ... + ) -> TypeGuard[QueryableAttribute[Any]]: ... def attr_is_internal_proxy( obj: InspectionAttr, - ) -> TypeGuard[QueryableAttribute[Any]]: - ... + ) -> TypeGuard[QueryableAttribute[Any]]: ... def prop_is_relationship( prop: MapperProperty[Any], - ) -> TypeGuard[RelationshipProperty[Any]]: - ... + ) -> TypeGuard[RelationshipProperty[Any]]: ... def is_collection_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: - ... + ) -> TypeGuard[CollectionAttributeImpl]: ... def is_has_collection_adapter( impl: AttributeImpl, - ) -> TypeGuard[HasCollectionAdapter]: - ... + ) -> TypeGuard[HasCollectionAdapter]: ... else: insp_is_mapper_property = operator.attrgetter("is_property") diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index dc9743b8b3..d9b2d8213d 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -542,12 +542,12 @@ class InstrumentedAttribute(QueryableAttribute[_T]): self.impl.delete(instance_state(instance), instance_dict(instance)) @overload - def __get__(self, instance: None, owner: Any) -> InstrumentedAttribute[_T]: - ... + def __get__( + self, instance: None, owner: Any + ) -> InstrumentedAttribute[_T]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( self, instance: Optional[object], owner: Any @@ -1538,8 +1538,7 @@ class HasCollectionAdapter: dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -1548,8 +1547,7 @@ class HasCollectionAdapter: dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -1560,8 +1558,7 @@ class HasCollectionAdapter: passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, @@ -1592,8 +1589,7 @@ if TYPE_CHECKING: def _is_collection_attribute_impl( impl: AttributeImpl, - ) -> TypeGuard[CollectionAttributeImpl]: - ... + ) -> TypeGuard[CollectionAttributeImpl]: ... else: _is_collection_attribute_impl = operator.attrgetter("collection") @@ -2049,8 +2045,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -2059,8 +2054,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -2071,8 +2065,7 @@ class CollectionAttributeImpl(HasCollectionAdapter, AttributeImpl): passive: PassiveFlag = PASSIVE_OFF, ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index 50f6703b5e..86af81cd6e 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -308,29 +308,23 @@ def _assertions( if TYPE_CHECKING: - def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: - ... + def manager_of_class(cls: Type[_O]) -> ClassManager[_O]: ... @overload - def opt_manager_of_class(cls: AliasedClass[Any]) -> None: - ... + def opt_manager_of_class(cls: AliasedClass[Any]) -> None: ... @overload def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: - ... + ) -> Optional[ClassManager[_O]]: ... def opt_manager_of_class( cls: _ExternalEntityType[_O], - ) -> Optional[ClassManager[_O]]: - ... + ) -> Optional[ClassManager[_O]]: ... - def instance_state(instance: _O) -> InstanceState[_O]: - ... + def instance_state(instance: _O) -> InstanceState[_O]: ... - def instance_dict(instance: object) -> Dict[str, Any]: - ... + def instance_dict(instance: object) -> Dict[str, Any]: ... else: # these can be replaced by sqlalchemy.ext.instrumentation @@ -512,8 +506,7 @@ def _entity_descriptor(entity: _EntityType[Any], key: str) -> Any: if TYPE_CHECKING: - def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: - ... + def _state_mapper(state: InstanceState[_O]) -> Mapper[_O]: ... else: _state_mapper = util.dottedgetter("manager.mapper") @@ -684,27 +677,25 @@ class SQLORMOperations(SQLCoreOperations[_T_co], TypingOnly): if typing.TYPE_CHECKING: - def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: - ... + def of_type( + self, class_: _EntityType[Any] + ) -> PropComparator[_T_co]: ... def and_( self, *criteria: _ColumnExpressionArgument[bool] - ) -> PropComparator[bool]: - ... + ) -> PropComparator[bool]: ... def any( # noqa: A001 self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def has( self, criterion: Optional[_ColumnExpressionArgument[bool]] = None, **kwargs: Any, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... class ORMDescriptor(Generic[_T_co], TypingOnly): @@ -718,23 +709,19 @@ class ORMDescriptor(Generic[_T_co], TypingOnly): @overload def __get__( self, instance: Any, owner: Literal[None] - ) -> ORMDescriptor[_T_co]: - ... + ) -> ORMDescriptor[_T_co]: ... @overload def __get__( self, instance: Literal[None], owner: Any - ) -> SQLCoreOperations[_T_co]: - ... + ) -> SQLCoreOperations[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: - ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: object, owner: Any - ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: - ... + ) -> Union[ORMDescriptor[_T_co], SQLCoreOperations[_T_co], _T_co]: ... class _MappedAnnotationBase(Generic[_T_co], TypingOnly): @@ -820,29 +807,23 @@ class Mapped( @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T_co: - ... + def __get__(self, instance: object, owner: Any) -> _T_co: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], _T_co]: - ... + ) -> Union[InstrumentedAttribute[_T_co], _T_co]: ... @classmethod - def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: - ... + def _empty_constructor(cls, arg1: Any) -> Mapped[_T_co]: ... def __set__( self, instance: Any, value: Union[SQLCoreOperations[_T_co], _T_co] - ) -> None: - ... + ) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... class _MappedAttribute(Generic[_T_co], TypingOnly): @@ -919,24 +900,20 @@ class DynamicMapped(_MappedAnnotationBase[_T_co]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> AppenderQuery[_T_co]: - ... + ) -> AppenderQuery[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: - ... + ) -> Union[InstrumentedAttribute[_T_co], AppenderQuery[_T_co]]: ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: - ... + ) -> None: ... class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): @@ -975,21 +952,19 @@ class WriteOnlyMapped(_MappedAnnotationBase[_T_co]): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T_co]: - ... + ) -> InstrumentedAttribute[_T_co]: ... @overload def __get__( self, instance: object, owner: Any - ) -> WriteOnlyCollection[_T_co]: - ... + ) -> WriteOnlyCollection[_T_co]: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co]]: - ... + ) -> Union[ + InstrumentedAttribute[_T_co], WriteOnlyCollection[_T_co] + ]: ... def __set__( self, instance: Any, value: typing.Collection[_T_co] - ) -> None: - ... + ) -> None: ... diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index c2ef0980e6..d59570bc20 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -83,8 +83,7 @@ def _bulk_insert( render_nulls: bool, use_orm_insert_stmt: Literal[None] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> None: - ... +) -> None: ... @overload @@ -97,8 +96,7 @@ def _bulk_insert( render_nulls: bool, use_orm_insert_stmt: Optional[dml.Insert] = ..., execution_options: Optional[OrmExecuteOptionsParameter] = ..., -) -> cursor.CursorResult[Any]: - ... +) -> cursor.CursorResult[Any]: ... def _bulk_insert( @@ -238,8 +236,7 @@ def _bulk_update( update_changed_only: bool, use_orm_update_stmt: Literal[None] = ..., enable_check_rowcount: bool = True, -) -> None: - ... +) -> None: ... @overload @@ -251,8 +248,7 @@ def _bulk_update( update_changed_only: bool, use_orm_update_stmt: Optional[dml.Update] = ..., enable_check_rowcount: bool = True, -) -> _result.Result[Unpack[TupleAny]]: - ... +) -> _result.Result[Unpack[TupleAny]]: ... def _bulk_update( @@ -379,14 +375,16 @@ class ORMDMLState(AbstractORMCompileState): if desc is NO_VALUE: yield ( coercions.expect(roles.DMLColumnRole, k), - coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, - ) - if needs_to_be_cacheable - else v, + ( + coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) + if needs_to_be_cacheable + else v + ), ) else: yield from core_get_crud_kv_pairs( @@ -407,13 +405,15 @@ class ORMDMLState(AbstractORMCompileState): else: yield ( k, - v - if not needs_to_be_cacheable - else coercions.expect( - roles.ExpressionElementRole, - v, - type_=sqltypes.NullType(), - is_crud=True, + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=sqltypes.NullType(), + is_crud=True, + ) ), ) @@ -530,9 +530,9 @@ class ORMDMLState(AbstractORMCompileState): fs = fs.execution_options(**orm_level_statement._execution_options) fs = fs.options(*orm_level_statement._with_options) self.select_statement = fs - self.from_statement_ctx = ( - fsc - ) = ORMFromStatementCompileState.create_for_statement(fs, compiler) + self.from_statement_ctx = fsc = ( + ORMFromStatementCompileState.create_for_statement(fs, compiler) + ) fsc.setup_dml_returning_compile_state(dml_mapper) dml_level_statement = dml_level_statement._generate() diff --git a/lib/sqlalchemy/orm/clsregistry.py b/lib/sqlalchemy/orm/clsregistry.py index 2cce129cbf..26113d8b24 100644 --- a/lib/sqlalchemy/orm/clsregistry.py +++ b/lib/sqlalchemy/orm/clsregistry.py @@ -83,9 +83,9 @@ def add_class( _ModuleMarker, decl_class_registry["_sa_module_registry"] ) except KeyError: - decl_class_registry[ - "_sa_module_registry" - ] = root_module = _ModuleMarker("_sa_module_registry", None) + decl_class_registry["_sa_module_registry"] = root_module = ( + _ModuleMarker("_sa_module_registry", None) + ) tokens = cls.__module__.split(".") @@ -542,9 +542,7 @@ class _class_resolver: _fallback_dict: Mapping[str, Any] = None # type: ignore -def _resolver( - cls: Type[Any], prop: RelationshipProperty[Any] -) -> Tuple[ +def _resolver(cls: Type[Any], prop: RelationshipProperty[Any]) -> Tuple[ Callable[[str], Callable[[], Union[Type[Any], Table, _ModNS]]], Callable[[str, bool], _class_resolver], ]: diff --git a/lib/sqlalchemy/orm/collections.py b/lib/sqlalchemy/orm/collections.py index 6e5ded17af..eeef7241c8 100644 --- a/lib/sqlalchemy/orm/collections.py +++ b/lib/sqlalchemy/orm/collections.py @@ -167,8 +167,7 @@ _FN = TypeVar("_FN", bound="Callable[..., Any]") class _CollectionConverterProtocol(Protocol): - def __call__(self, collection: _COL) -> _COL: - ... + def __call__(self, collection: _COL) -> _COL: ... class _AdaptedCollectionProtocol(Protocol): @@ -548,9 +547,9 @@ class CollectionAdapter: self.empty ), "This collection adapter is not in the 'empty' state" self.empty = False - self.owner_state.dict[ - self._key - ] = self.owner_state._empty_collections.pop(self._key) + self.owner_state.dict[self._key] = ( + self.owner_state._empty_collections.pop(self._key) + ) def _refuse_empty(self) -> NoReturn: raise sa_exc.InvalidRequestError( @@ -1554,14 +1553,14 @@ class InstrumentedDict(Dict[_KT, _VT]): """An instrumented version of the built-in dict.""" -__canned_instrumentation: util.immutabledict[ - Any, _CollectionFactoryType -] = util.immutabledict( - { - list: InstrumentedList, - set: InstrumentedSet, - dict: InstrumentedDict, - } +__canned_instrumentation: util.immutabledict[Any, _CollectionFactoryType] = ( + util.immutabledict( + { + list: InstrumentedList, + set: InstrumentedSet, + dict: InstrumentedDict, + } + ) ) __interfaces: util.immutabledict[ diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index b51f2b9613..dba3435a26 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -441,8 +441,7 @@ class ORMCompileState(AbstractORMCompileState): statement: Union[Select, FromStatement], compiler: Optional[SQLCompiler], **kw: Any, - ) -> ORMCompileState: - ... + ) -> ORMCompileState: ... def _append_dedupe_col_collection(self, obj, col_collection): dedupe = self.dedupe_columns @@ -526,14 +525,14 @@ class ORMCompileState(AbstractORMCompileState): and len(statement._compile_options._current_path) > 10 and execution_options.get("compiled_cache", True) is not None ): - execution_options: util.immutabledict[ - str, Any - ] = execution_options.union( - { - "compiled_cache": None, - "_cache_disable_reason": "excess depth for " - "ORM loader options", - } + execution_options: util.immutabledict[str, Any] = ( + execution_options.union( + { + "compiled_cache": None, + "_cache_disable_reason": "excess depth for " + "ORM loader options", + } + ) ) bind_arguments["clause"] = statement @@ -759,9 +758,11 @@ class ORMFromStatementCompileState(ORMCompileState): self.statement = statement self._label_convention = self._column_naming_convention( - statement._label_style - if not statement._is_textual and not statement.is_dml - else LABEL_STYLE_NONE, + ( + statement._label_style + if not statement._is_textual and not statement.is_dml + else LABEL_STYLE_NONE + ), self.use_legacy_query_style, ) @@ -807,9 +808,9 @@ class ORMFromStatementCompileState(ORMCompileState): for entity in self._entities: entity.setup_compile_state(self) - compiler._ordered_columns = ( - compiler._textual_ordered_columns - ) = False + compiler._ordered_columns = compiler._textual_ordered_columns = ( + False + ) # enable looser result column matching. this is shown to be # needed by test_query.py::TextTest @@ -1376,11 +1377,15 @@ class ORMSelectCompileState(ORMCompileState, SelectState): def get_columns_clause_froms(cls, statement): return cls._normalize_froms( itertools.chain.from_iterable( - element._from_objects - if "parententity" not in element._annotations - else [ - element._annotations["parententity"].__clause_element__() - ] + ( + element._from_objects + if "parententity" not in element._annotations + else [ + element._annotations[ + "parententity" + ].__clause_element__() + ] + ) for element in statement._raw_columns ) ) @@ -1509,9 +1514,11 @@ class ORMSelectCompileState(ORMCompileState, SelectState): # the original expressions outside of the label references # in order to have them render. unwrapped_order_by = [ - elem.element - if isinstance(elem, sql.elements._label_reference) - else elem + ( + elem.element + if isinstance(elem, sql.elements._label_reference) + else elem + ) for elem in self.order_by ] @@ -2430,9 +2437,12 @@ def _column_descriptions( "type": ent.type, "aliased": getattr(insp_ent, "is_aliased_class", False), "expr": ent.expr, - "entity": getattr(insp_ent, "entity", None) - if ent.entity_zero is not None and not insp_ent.is_clause_element - else None, + "entity": ( + getattr(insp_ent, "entity", None) + if ent.entity_zero is not None + and not insp_ent.is_clause_element + else None + ), } for ent, insp_ent in [ (_ent, _ent.entity_zero) for _ent in ctx._entities diff --git a/lib/sqlalchemy/orm/decl_api.py b/lib/sqlalchemy/orm/decl_api.py index 8aa1edc46b..72dded0e09 100644 --- a/lib/sqlalchemy/orm/decl_api.py +++ b/lib/sqlalchemy/orm/decl_api.py @@ -313,17 +313,13 @@ class _declared_directive(_declared_attr_common, Generic[_T]): self, fn: Callable[..., _T], cascading: bool = False, - ): - ... + ): ... - def __get__(self, instance: Optional[object], owner: Any) -> _T: - ... + def __get__(self, instance: Optional[object], owner: Any) -> _T: ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... def __call__(self, fn: Callable[..., _TT]) -> _declared_directive[_TT]: # extensive fooling of mypy underway... @@ -428,14 +424,11 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): self, fn: _DeclaredAttrDecorated[_T], cascading: bool = False, - ): - ... + ): ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... # this is the Mapped[] API where at class descriptor get time we want # the type checker to see InstrumentedAttribute[_T]. However the @@ -444,17 +437,14 @@ class declared_attr(interfaces._MappedAttribute[_T], _declared_attr_common): @overload def __get__( self, instance: None, owner: Any - ) -> InstrumentedAttribute[_T]: - ... + ) -> InstrumentedAttribute[_T]: ... @overload - def __get__(self, instance: object, owner: Any) -> _T: - ... + def __get__(self, instance: object, owner: Any) -> _T: ... def __get__( self, instance: Optional[object], owner: Any - ) -> Union[InstrumentedAttribute[_T], _T]: - ... + ) -> Union[InstrumentedAttribute[_T], _T]: ... @hybridmethod def _stateful(cls, **kw: Any) -> _stateful_declared_attr[_T]: @@ -620,9 +610,9 @@ class MappedAsDataclass(metaclass=DCTransformDeclarative): for k, v in apply_dc_transforms.items() } else: - cls._sa_apply_dc_transforms = ( - current_transforms - ) = apply_dc_transforms + cls._sa_apply_dc_transforms = current_transforms = ( + apply_dc_transforms + ) super().__init_subclass__(**kw) @@ -753,11 +743,9 @@ class DeclarativeBase( if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: - ... + def _sa_inspect_type(self) -> Mapper[Self]: ... - def _sa_inspect_instance(self) -> InstanceState[Self]: - ... + def _sa_inspect_instance(self) -> InstanceState[Self]: ... _sa_registry: ClassVar[_RegistryType] @@ -838,8 +826,7 @@ class DeclarativeBase( """ - def __init__(self, **kw: Any): - ... + def __init__(self, **kw: Any): ... def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBase in cls.__bases__: @@ -924,11 +911,9 @@ class DeclarativeBaseNoMeta( if typing.TYPE_CHECKING: - def _sa_inspect_type(self) -> Mapper[Self]: - ... + def _sa_inspect_type(self) -> Mapper[Self]: ... - def _sa_inspect_instance(self) -> InstanceState[Self]: - ... + def _sa_inspect_instance(self) -> InstanceState[Self]: ... __tablename__: Any """String name to assign to the generated @@ -963,8 +948,7 @@ class DeclarativeBaseNoMeta( """ - def __init__(self, **kw: Any): - ... + def __init__(self, **kw: Any): ... def __init_subclass__(cls, **kw: Any) -> None: if DeclarativeBaseNoMeta in cls.__bases__: @@ -1585,8 +1569,7 @@ class registry: ), ) @overload - def mapped_as_dataclass(self, __cls: Type[_O], /) -> Type[_O]: - ... + def mapped_as_dataclass(self, __cls: Type[_O], /) -> Type[_O]: ... @overload def mapped_as_dataclass( @@ -1602,8 +1585,7 @@ class registry: match_args: Union[_NoArg, bool] = ..., kw_only: Union[_NoArg, bool] = ..., dataclass_callable: Union[_NoArg, Callable[..., Type[Any]]] = ..., - ) -> Callable[[Type[_O]], Type[_O]]: - ... + ) -> Callable[[Type[_O]], Type[_O]]: ... def mapped_as_dataclass( self, diff --git a/lib/sqlalchemy/orm/decl_base.py b/lib/sqlalchemy/orm/decl_base.py index 6acdb58d46..0513eac66a 100644 --- a/lib/sqlalchemy/orm/decl_base.py +++ b/lib/sqlalchemy/orm/decl_base.py @@ -98,8 +98,7 @@ class MappedClassProtocol(Protocol[_O]): __mapper__: Mapper[_O] __table__: FromClause - def __call__(self, **kw: Any) -> _O: - ... + def __call__(self, **kw: Any) -> _O: ... class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): @@ -111,11 +110,9 @@ class _DeclMappedClassProtocol(MappedClassProtocol[_O], Protocol): _sa_apply_dc_transforms: Optional[_DataclassArguments] - def __declare_first__(self) -> None: - ... + def __declare_first__(self) -> None: ... - def __declare_last__(self) -> None: - ... + def __declare_last__(self) -> None: ... class _DataclassArguments(TypedDict): @@ -908,9 +905,9 @@ class _ClassScanMapperConfig(_MapperConfig): "@declared_attr.cascading; " "skipping" % (name, cls) ) - collected_attributes[name] = column_copies[ - obj - ] = ret = obj.__get__(obj, cls) + collected_attributes[name] = column_copies[obj] = ( + ret + ) = obj.__get__(obj, cls) setattr(cls, name, ret) else: if is_dataclass_field: @@ -947,9 +944,9 @@ class _ClassScanMapperConfig(_MapperConfig): ): ret = ret.descriptor - collected_attributes[name] = column_copies[ - obj - ] = ret + collected_attributes[name] = column_copies[obj] = ( + ret + ) if ( isinstance(ret, (Column, MapperProperty)) diff --git a/lib/sqlalchemy/orm/dependency.py b/lib/sqlalchemy/orm/dependency.py index 9bdd92428e..71c06fbeb1 100644 --- a/lib/sqlalchemy/orm/dependency.py +++ b/lib/sqlalchemy/orm/dependency.py @@ -167,9 +167,11 @@ class DependencyProcessor: sum_ = state.manager[self.key].impl.get_all_pending( state, state.dict, - self._passive_delete_flag - if isdelete - else attributes.PASSIVE_NO_INITIALIZE, + ( + self._passive_delete_flag + if isdelete + else attributes.PASSIVE_NO_INITIALIZE + ), ) if not sum_: diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 4d5775ee2d..d82a33d0a3 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -422,13 +422,13 @@ class CompositeProperty( and self.composite_class not in _composite_getters ): if self._generated_composite_accessor is not None: - _composite_getters[ - self.composite_class - ] = self._generated_composite_accessor + _composite_getters[self.composite_class] = ( + self._generated_composite_accessor + ) elif hasattr(self.composite_class, "__composite_values__"): - _composite_getters[ - self.composite_class - ] = lambda obj: obj.__composite_values__() + _composite_getters[self.composite_class] = ( + lambda obj: obj.__composite_values__() + ) @util.preload_module("sqlalchemy.orm.properties") @util.preload_module("sqlalchemy.orm.decl_base") diff --git a/lib/sqlalchemy/orm/dynamic.py b/lib/sqlalchemy/orm/dynamic.py index d5db03a19d..7496e5c30d 100644 --- a/lib/sqlalchemy/orm/dynamic.py +++ b/lib/sqlalchemy/orm/dynamic.py @@ -172,8 +172,7 @@ class AppenderMixin(AbstractCollectionWriter[_T]): if TYPE_CHECKING: - def __iter__(self) -> Iterator[_T]: - ... + def __iter__(self) -> Iterator[_T]: ... def __getitem__(self, index: Any) -> Union[_T, List[_T]]: sess = self.session diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index 828dad2b6f..0dbb62c167 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -729,9 +729,9 @@ class _EventsHold(event.RefCollection[_ET]): class _InstanceEventsHold(_EventsHold[_ET]): - all_holds: weakref.WeakKeyDictionary[ - Any, Any - ] = weakref.WeakKeyDictionary() + all_holds: weakref.WeakKeyDictionary[Any, Any] = ( + weakref.WeakKeyDictionary() + ) def resolve(self, class_: Type[_O]) -> Optional[ClassManager[_O]]: return instrumentation.opt_manager_of_class(class_) diff --git a/lib/sqlalchemy/orm/instrumentation.py b/lib/sqlalchemy/orm/instrumentation.py index 97d92c00ba..1452596beb 100644 --- a/lib/sqlalchemy/orm/instrumentation.py +++ b/lib/sqlalchemy/orm/instrumentation.py @@ -85,13 +85,11 @@ class _ExpiredAttributeLoaderProto(Protocol): state: state.InstanceState[Any], toload: Set[str], passive: base.PassiveFlag, - ) -> None: - ... + ) -> None: ... class _ManagerFactory(Protocol): - def __call__(self, class_: Type[_O]) -> ClassManager[_O]: - ... + def __call__(self, class_: Type[_O]) -> ClassManager[_O]: ... class ClassManager( diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index dd9e558cd3..64de1f4027 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -738,6 +738,7 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): :attr:`.TypeEngine.comparator_factory` """ + __slots__ = "prop", "_parententity", "_adapt_to_entity" __visit_name__ = "orm_prop_comparator" @@ -841,13 +842,11 @@ class PropComparator(SQLORMOperations[_T_co], Generic[_T_co], ColumnOperators): def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def of_type(self, class_: _EntityType[Any]) -> PropComparator[_T_co]: r"""Redefine this object in terms of a polymorphic subclass, diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index b430cbf424..50258149af 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -185,20 +185,22 @@ def instances( return go unique_filters = [ - _no_unique - if context.yield_per - else _not_hashable( - ent.column.type, # type: ignore - legacy=context.load_options._legacy_uniquing, - uncertain=ent._null_column_type, - ) - if ( - not ent.use_id_for_hash - and (ent._non_hashable_value or ent._null_column_type) + ( + _no_unique + if context.yield_per + else ( + _not_hashable( + ent.column.type, # type: ignore + legacy=context.load_options._legacy_uniquing, + uncertain=ent._null_column_type, + ) + if ( + not ent.use_id_for_hash + and (ent._non_hashable_value or ent._null_column_type) + ) + else id if ent.use_id_for_hash else None + ) ) - else id - if ent.use_id_for_hash - else None for ent in context.compile_state._entities ] diff --git a/lib/sqlalchemy/orm/mapped_collection.py b/lib/sqlalchemy/orm/mapped_collection.py index 24ac0cc1b9..13c6b689e1 100644 --- a/lib/sqlalchemy/orm/mapped_collection.py +++ b/lib/sqlalchemy/orm/mapped_collection.py @@ -117,9 +117,7 @@ class _SerializableColumnGetterV2(_PlainColumnGetter[_KT]): return self.__class__, (self.colkeys,) @classmethod - def _reduce_from_cols( - cls, cols: Sequence[ColumnElement[_KT]] - ) -> Tuple[ + def _reduce_from_cols(cls, cols: Sequence[ColumnElement[_KT]]) -> Tuple[ Type[_SerializableColumnGetterV2[_KT]], Tuple[Sequence[Tuple[Optional[str], Optional[str]]]], ]: diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index e91b1a6bd0..e51ff7df4e 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -134,9 +134,9 @@ _WithPolymorphicArg = Union[ ] -_mapper_registries: weakref.WeakKeyDictionary[ - _RegistryType, bool -] = weakref.WeakKeyDictionary() +_mapper_registries: weakref.WeakKeyDictionary[_RegistryType, bool] = ( + weakref.WeakKeyDictionary() +) def _all_registries() -> Set[registry]: @@ -1608,9 +1608,11 @@ class Mapper( if self._primary_key_argument: coerced_pk_arg = [ - self._str_arg_to_mapped_col("primary_key", c) - if isinstance(c, str) - else c + ( + self._str_arg_to_mapped_col("primary_key", c) + if isinstance(c, str) + else c + ) for c in ( coercions.expect( roles.DDLConstraintColumnRole, @@ -2467,9 +2469,11 @@ class Mapper( return "Mapper[%s%s(%s)]" % ( self.class_.__name__, self.non_primary and " (non-primary)" or "", - self.local_table.description - if self.local_table is not None - else self.persist_selectable.description, + ( + self.local_table.description + if self.local_table is not None + else self.persist_selectable.description + ), ) def _is_orphan(self, state: InstanceState[_O]) -> bool: diff --git a/lib/sqlalchemy/orm/path_registry.py b/lib/sqlalchemy/orm/path_registry.py index c97afe7e61..76484b3e68 100644 --- a/lib/sqlalchemy/orm/path_registry.py +++ b/lib/sqlalchemy/orm/path_registry.py @@ -45,11 +45,9 @@ if TYPE_CHECKING: from ..util.typing import _LiteralStar from ..util.typing import TypeGuard - def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: - ... + def is_root(path: PathRegistry) -> TypeGuard[RootRegistry]: ... - def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: - ... + def is_entity(path: PathRegistry) -> TypeGuard[AbstractEntityRegistry]: ... else: is_root = operator.attrgetter("is_root") @@ -185,26 +183,21 @@ class PathRegistry(HasCacheKey): return id(self) @overload - def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: - ... + def __getitem__(self, entity: _StrPathToken) -> TokenRegistry: ... @overload - def __getitem__(self, entity: int) -> _PathElementType: - ... + def __getitem__(self, entity: int) -> _PathElementType: ... @overload - def __getitem__(self, entity: slice) -> _PathRepresentation: - ... + def __getitem__(self, entity: slice) -> _PathRepresentation: ... @overload def __getitem__( self, entity: _InternalEntityType[Any] - ) -> AbstractEntityRegistry: - ... + ) -> AbstractEntityRegistry: ... @overload - def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: - ... + def __getitem__(self, entity: MapperProperty[Any]) -> PropRegistry: ... def __getitem__( self, @@ -320,13 +313,11 @@ class PathRegistry(HasCacheKey): @overload @classmethod - def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: - ... + def per_mapper(cls, mapper: Mapper[Any]) -> CachingEntityRegistry: ... @overload @classmethod - def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: - ... + def per_mapper(cls, mapper: AliasedInsp[Any]) -> SlotsEntityRegistry: ... @classmethod def per_mapper( @@ -808,11 +799,9 @@ if TYPE_CHECKING: def path_is_entity( path: PathRegistry, - ) -> TypeGuard[AbstractEntityRegistry]: - ... + ) -> TypeGuard[AbstractEntityRegistry]: ... - def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: - ... + def path_is_property(path: PathRegistry) -> TypeGuard[PropRegistry]: ... else: path_is_entity = operator.attrgetter("is_entity") diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 0c2529d5d1..abe69bf468 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -140,11 +140,13 @@ def post_update(base_mapper, states, uowtransaction, post_update_cols): state_dict, sub_mapper, connection, - mapper._get_committed_state_attr_by_column( - state, state_dict, mapper.version_id_col - ) - if mapper.version_id_col is not None - else None, + ( + mapper._get_committed_state_attr_by_column( + state, state_dict, mapper.version_id_col + ) + if mapper.version_id_col is not None + else None + ), ) for state, state_dict, sub_mapper, connection in states_to_update if table in sub_mapper._pks_by_table @@ -703,10 +705,10 @@ def _collect_delete_commands( params = {} for col in mapper._pks_by_table[table]: - params[ - col.key - ] = value = mapper._get_committed_state_attr_by_column( - state, state_dict, col + params[col.key] = value = ( + mapper._get_committed_state_attr_by_column( + state, state_dict, col + ) ) if value is None: raise orm_exc.FlushError( @@ -934,9 +936,11 @@ def _emit_update_statements( c.context.compiled_parameters[0], value_params, True, - c.returned_defaults - if not c.context.executemany - else None, + ( + c.returned_defaults + if not c.context.executemany + else None + ), ) if check_rowcount: @@ -1069,9 +1073,11 @@ def _emit_insert_statements( last_inserted_params, value_params, False, - result.returned_defaults - if not result.context.executemany - else None, + ( + result.returned_defaults + if not result.context.executemany + else None + ), ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) @@ -1261,9 +1267,11 @@ def _emit_insert_statements( result.context.compiled_parameters[0], value_params, False, - result.returned_defaults - if not result.context.executemany - else None, + ( + result.returned_defaults + if not result.context.executemany + else None + ), ) else: _postfetch_bulk_save(mapper_rec, state_dict, table) diff --git a/lib/sqlalchemy/orm/properties.py b/lib/sqlalchemy/orm/properties.py index 6e2e73dc46..7a5eb8625b 100644 --- a/lib/sqlalchemy/orm/properties.py +++ b/lib/sqlalchemy/orm/properties.py @@ -429,8 +429,7 @@ class ColumnProperty( if TYPE_CHECKING: - def __clause_element__(self) -> NamedColumn[_PT]: - ... + def __clause_element__(self) -> NamedColumn[_PT]: ... def _memoized_method___clause_element__( self, @@ -636,9 +635,11 @@ class MappedColumn( return [ ( self.column, - self._sort_order - if self._sort_order is not _NoArg.NO_ARG - else 0, + ( + self._sort_order + if self._sort_order is not _NoArg.NO_ARG + else 0 + ), ) ] diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 4aaae3ee4f..b1a01f00a1 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -170,7 +170,6 @@ class Query( Executable, Generic[_T], ): - """ORM-level SQL construction object. .. legacy:: The ORM :class:`.Query` object is a legacy construct @@ -209,9 +208,9 @@ class Query( _memoized_select_entities = () - _compile_options: Union[ - Type[CacheableOptions], CacheableOptions - ] = ORMCompileState.default_compile_options + _compile_options: Union[Type[CacheableOptions], CacheableOptions] = ( + ORMCompileState.default_compile_options + ) _with_options: Tuple[ExecutableOption, ...] load_options = QueryContext.default_load_options + { @@ -748,18 +747,15 @@ class Query( @overload def as_scalar( self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[_MAYBE_ENTITY]: - ... + ) -> ScalarSelect[_MAYBE_ENTITY]: ... @overload def as_scalar( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def as_scalar(self) -> ScalarSelect[Any]: - ... + def as_scalar(self) -> ScalarSelect[Any]: ... @util.deprecated( "1.4", @@ -777,18 +773,15 @@ class Query( @overload def scalar_subquery( self: Query[Tuple[_MAYBE_ENTITY]], - ) -> ScalarSelect[Any]: - ... + ) -> ScalarSelect[Any]: ... @overload def scalar_subquery( self: Query[Tuple[_NOT_ENTITY]], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... def scalar_subquery(self) -> ScalarSelect[Any]: """Return the full SELECT statement represented by this @@ -836,14 +829,12 @@ class Query( @overload def only_return_tuples( self: Query[_O], value: Literal[True] - ) -> RowReturningQuery[_O]: - ... + ) -> RowReturningQuery[_O]: ... @overload def only_return_tuples( self: Query[_O], value: Literal[False] - ) -> Query[_O]: - ... + ) -> Query[_O]: ... @_generative def only_return_tuples(self, value: bool) -> Query[Any]: @@ -1489,15 +1480,13 @@ class Query( return None @overload - def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def with_entities(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def with_entities( self, _colexpr: roles.TypedColumnsClauseRole[_T], - ) -> RowReturningQuery[Tuple[_T]]: - ... + ) -> RowReturningQuery[Tuple[_T]]: ... # START OVERLOADED FUNCTIONS self.with_entities RowReturningQuery 2-8 @@ -1507,14 +1496,12 @@ class Query( @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], / - ) -> RowReturningQuery[_T0, _T1]: - ... + ) -> RowReturningQuery[_T0, _T1]: ... @overload def with_entities( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / - ) -> RowReturningQuery[_T0, _T1, _T2]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload def with_entities( @@ -1524,8 +1511,7 @@ class Query( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload def with_entities( @@ -1536,8 +1522,7 @@ class Query( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload def with_entities( @@ -1549,8 +1534,7 @@ class Query( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def with_entities( @@ -1563,8 +1547,7 @@ class Query( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def with_entities( @@ -1581,16 +1564,14 @@ class Query( *entities: _ColumnsClauseArgument[Any], ) -> RowReturningQuery[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.with_entities @overload def with_entities( self, *entities: _ColumnsClauseArgument[Any] - ) -> Query[Any]: - ... + ) -> Query[Any]: ... @_generative def with_entities( @@ -1752,12 +1733,10 @@ class Query( populate_existing: bool = False, autoflush: bool = False, **opt: Any, - ) -> Self: - ... + ) -> Self: ... @overload - def execution_options(self, **opt: Any) -> Self: - ... + def execution_options(self, **opt: Any) -> Self: ... @_generative def execution_options(self, **kwargs: Any) -> Self: diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index 58b413bed9..a054eb96a6 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -1346,9 +1346,11 @@ class RelationshipProperty( state, dict_, column, - passive=PassiveFlag.PASSIVE_OFF - if state.persistent - else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK, + passive=( + PassiveFlag.PASSIVE_OFF + if state.persistent + else PassiveFlag.PASSIVE_NO_FETCH ^ PassiveFlag.INIT_OK + ), ) if current_value is LoaderCallableStatus.NEVER_SET: @@ -2039,9 +2041,11 @@ class RelationshipProperty( "the single_parent=True flag." % { "rel": self, - "direction": "many-to-one" - if self.direction is MANYTOONE - else "many-to-many", + "direction": ( + "many-to-one" + if self.direction is MANYTOONE + else "many-to-many" + ), "clsname": self.parent.class_.__name__, "relatedcls": self.mapper.class_.__name__, }, @@ -3105,9 +3109,9 @@ class JoinCondition: def _setup_pairs(self) -> None: sync_pairs: _MutableColumnPairs = [] - lrp: util.OrderedSet[ - Tuple[ColumnElement[Any], ColumnElement[Any]] - ] = util.OrderedSet([]) + lrp: util.OrderedSet[Tuple[ColumnElement[Any], ColumnElement[Any]]] = ( + util.OrderedSet([]) + ) secondary_sync_pairs: _MutableColumnPairs = [] def go( @@ -3184,9 +3188,9 @@ class JoinCondition: # level configuration that benefits from this warning. if to_ not in self._track_overlapping_sync_targets: - self._track_overlapping_sync_targets[ - to_ - ] = weakref.WeakKeyDictionary({self.prop: from_}) + self._track_overlapping_sync_targets[to_] = ( + weakref.WeakKeyDictionary({self.prop: from_}) + ) else: other_props = [] prop_to_from = self._track_overlapping_sync_targets[to_] @@ -3419,9 +3423,7 @@ class JoinCondition: dest_selectable, ) - def create_lazy_clause( - self, reverse_direction: bool = False - ) -> Tuple[ + def create_lazy_clause(self, reverse_direction: bool = False) -> Tuple[ ColumnElement[bool], Dict[str, ColumnElement[Any]], Dict[ColumnElement[Any], ColumnElement[Any]], diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 2e87f41879..ca8fdc95e5 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -91,8 +91,7 @@ class QueryPropertyDescriptor(Protocol): """ - def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: - ... + def __get__(self, instance: Any, owner: Type[_T]) -> Query[_T]: ... _O = TypeVar("_O", bound=object) @@ -687,8 +686,7 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[_Ts]]: - ... + ) -> Result[Unpack[_Ts]]: ... @overload def execute( @@ -700,8 +698,7 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... @overload def execute( @@ -713,8 +710,7 @@ class scoped_session(Generic[_S]): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[TupleAny]]: - ... + ) -> Result[Unpack[TupleAny]]: ... def execute( self, @@ -1579,14 +1575,12 @@ class scoped_session(Generic[_S]): return self._proxied.merge(instance, load=load, options=options) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[_T]: - ... + ) -> RowReturningQuery[_T]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -1596,14 +1590,12 @@ class scoped_session(Generic[_S]): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], / - ) -> RowReturningQuery[_T0, _T1]: - ... + ) -> RowReturningQuery[_T0, _T1]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / - ) -> RowReturningQuery[_T0, _T1, _T2]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload def query( @@ -1613,8 +1605,7 @@ class scoped_session(Generic[_S]): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload def query( @@ -1625,8 +1616,7 @@ class scoped_session(Generic[_S]): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload def query( @@ -1638,8 +1628,7 @@ class scoped_session(Generic[_S]): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def query( @@ -1652,8 +1641,7 @@ class scoped_session(Generic[_S]): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def query( @@ -1670,16 +1658,14 @@ class scoped_session(Generic[_S]): *entities: _ColumnsClauseArgument[Any], ) -> RowReturningQuery[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: - ... + ) -> Query[Any]: ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -1831,8 +1817,7 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -1843,8 +1828,7 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -1886,8 +1870,7 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -1898,8 +1881,7 @@ class scoped_session(Generic[_S]): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 4315ac7f30..61006ccf0a 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -151,9 +151,9 @@ __all__ = [ "object_session", ] -_sessions: weakref.WeakValueDictionary[ - int, Session -] = weakref.WeakValueDictionary() +_sessions: weakref.WeakValueDictionary[int, Session] = ( + weakref.WeakValueDictionary() +) """Weak-referencing dictionary of :class:`.Session` objects. """ @@ -193,8 +193,7 @@ class _ConnectionCallableProto(Protocol): mapper: Optional[Mapper[Any]] = None, instance: Optional[object] = None, **kw: Any, - ) -> Connection: - ... + ) -> Connection: ... def _state_session(state: InstanceState[Any]) -> Optional[Session]: @@ -1005,9 +1004,11 @@ class SessionTransaction(_StateChange, TransactionalContext): def _begin(self, nested: bool = False) -> SessionTransaction: return SessionTransaction( self.session, - SessionTransactionOrigin.BEGIN_NESTED - if nested - else SessionTransactionOrigin.SUBTRANSACTION, + ( + SessionTransactionOrigin.BEGIN_NESTED + if nested + else SessionTransactionOrigin.SUBTRANSACTION + ), self, ) @@ -1824,9 +1825,11 @@ class Session(_SessionClassMethods, EventTarget): ) trans = SessionTransaction( self, - SessionTransactionOrigin.BEGIN - if begin - else SessionTransactionOrigin.AUTOBEGIN, + ( + SessionTransactionOrigin.BEGIN + if begin + else SessionTransactionOrigin.AUTOBEGIN + ), ) assert self._transaction is trans return trans @@ -2062,8 +2065,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: Literal[True] = ..., - ) -> Any: - ... + ) -> Any: ... @overload def _execute_internal( @@ -2076,8 +2078,7 @@ class Session(_SessionClassMethods, EventTarget): _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, _scalar_result: bool = ..., - ) -> Result[Unpack[TupleAny]]: - ... + ) -> Result[Unpack[TupleAny]]: ... def _execute_internal( self, @@ -2194,15 +2195,15 @@ class Session(_SessionClassMethods, EventTarget): ) if compile_state_cls: - result: Result[ - Unpack[TupleAny] - ] = compile_state_cls.orm_execute_statement( - self, - statement, - params or {}, - execution_options, - bind_arguments, - conn, + result: Result[Unpack[TupleAny]] = ( + compile_state_cls.orm_execute_statement( + self, + statement, + params or {}, + execution_options, + bind_arguments, + conn, + ) ) else: result = conn.execute( @@ -2224,8 +2225,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[_Ts]]: - ... + ) -> Result[Unpack[_Ts]]: ... @overload def execute( @@ -2237,8 +2237,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> CursorResult[Unpack[TupleAny]]: - ... + ) -> CursorResult[Unpack[TupleAny]]: ... @overload def execute( @@ -2250,8 +2249,7 @@ class Session(_SessionClassMethods, EventTarget): bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, - ) -> Result[Unpack[TupleAny]]: - ... + ) -> Result[Unpack[TupleAny]]: ... def execute( self, @@ -2332,8 +2330,7 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Optional[_T]: - ... + ) -> Optional[_T]: ... @overload def scalar( @@ -2344,8 +2341,7 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> Any: - ... + ) -> Any: ... def scalar( self, @@ -2382,8 +2378,7 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[_T]: - ... + ) -> ScalarResult[_T]: ... @overload def scalars( @@ -2394,8 +2389,7 @@ class Session(_SessionClassMethods, EventTarget): execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, - ) -> ScalarResult[Any]: - ... + ) -> ScalarResult[Any]: ... def scalars( self, @@ -2804,14 +2798,12 @@ class Session(_SessionClassMethods, EventTarget): ) @overload - def query(self, _entity: _EntityType[_O]) -> Query[_O]: - ... + def query(self, _entity: _EntityType[_O]) -> Query[_O]: ... @overload def query( self, _colexpr: TypedColumnsClauseRole[_T] - ) -> RowReturningQuery[_T]: - ... + ) -> RowReturningQuery[_T]: ... # START OVERLOADED FUNCTIONS self.query RowReturningQuery 2-8 @@ -2821,14 +2813,12 @@ class Session(_SessionClassMethods, EventTarget): @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], / - ) -> RowReturningQuery[_T0, _T1]: - ... + ) -> RowReturningQuery[_T0, _T1]: ... @overload def query( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / - ) -> RowReturningQuery[_T0, _T1, _T2]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2]: ... @overload def query( @@ -2838,8 +2828,7 @@ class Session(_SessionClassMethods, EventTarget): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3]: ... @overload def query( @@ -2850,8 +2839,7 @@ class Session(_SessionClassMethods, EventTarget): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4]: ... @overload def query( @@ -2863,8 +2851,7 @@ class Session(_SessionClassMethods, EventTarget): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def query( @@ -2877,8 +2864,7 @@ class Session(_SessionClassMethods, EventTarget): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, - ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> RowReturningQuery[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def query( @@ -2895,16 +2881,14 @@ class Session(_SessionClassMethods, EventTarget): *entities: _ColumnsClauseArgument[Any], ) -> RowReturningQuery[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.query @overload def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any - ) -> Query[Any]: - ... + ) -> Query[Any]: ... def query( self, *entities: _ColumnsClauseArgument[Any], **kwargs: Any @@ -3785,9 +3769,9 @@ class Session(_SessionClassMethods, EventTarget): if correct_keys: primary_key_identity = dict(primary_key_identity) for k in correct_keys: - primary_key_identity[ - pk_synonyms[k] - ] = primary_key_identity[k] + primary_key_identity[pk_synonyms[k]] = ( + primary_key_identity[k] + ) try: primary_key_identity = list( @@ -5005,8 +4989,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... @overload def __init__( @@ -5017,8 +5000,7 @@ class sessionmaker(_SessionClassMethods, Generic[_S]): expire_on_commit: bool = ..., info: Optional[_InfoType] = ..., **kw: Any, - ): - ... + ): ... def __init__( self, diff --git a/lib/sqlalchemy/orm/state.py b/lib/sqlalchemy/orm/state.py index 234a028a15..3c1a28e906 100644 --- a/lib/sqlalchemy/orm/state.py +++ b/lib/sqlalchemy/orm/state.py @@ -80,8 +80,7 @@ if not TYPE_CHECKING: class _InstanceDictProto(Protocol): - def __call__(self) -> Optional[IdentityMap]: - ... + def __call__(self) -> Optional[IdentityMap]: ... class _InstallLoaderCallableProto(Protocol[_O]): @@ -99,8 +98,7 @@ class _InstallLoaderCallableProto(Protocol[_O]): state: InstanceState[_O], dict_: _InstanceDict, row: Row[Unpack[TupleAny]], - ) -> None: - ... + ) -> None: ... @inspection._self_inspects diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index d7671e0794..e38a05f061 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -1195,9 +1195,11 @@ class LazyLoader( key, self, loadopt, - loadopt._generate_extra_criteria(context) - if loadopt._extra_criteria - else None, + ( + loadopt._generate_extra_criteria(context) + if loadopt._extra_criteria + else None + ), ), key, ) @@ -1672,9 +1674,11 @@ class SubqueryLoader(PostLoader): elif ltj > 2: middle = [ ( - orm_util.AliasedClass(item[0]) - if not inspect(item[0]).is_aliased_class - else item[0].entity, + ( + orm_util.AliasedClass(item[0]) + if not inspect(item[0]).is_aliased_class + else item[0].entity + ), item[1], ) for item in to_join[1:-1] @@ -2328,9 +2332,11 @@ class JoinedLoader(AbstractRelationshipLoader): to_adapt = orm_util.AliasedClass( self.mapper, - alias=alt_selectable._anonymous_fromclause(flat=True) - if alt_selectable is not None - else None, + alias=( + alt_selectable._anonymous_fromclause(flat=True) + if alt_selectable is not None + else None + ), flat=True, use_mapper_path=True, ) diff --git a/lib/sqlalchemy/orm/strategy_options.py b/lib/sqlalchemy/orm/strategy_options.py index d6f676e99e..bdf6802f99 100644 --- a/lib/sqlalchemy/orm/strategy_options.py +++ b/lib/sqlalchemy/orm/strategy_options.py @@ -320,9 +320,11 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): loader = self._set_relationship_strategy( attr, {"lazy": "joined"}, - opts={"innerjoin": innerjoin} - if innerjoin is not None - else util.EMPTY_DICT, + opts=( + {"innerjoin": innerjoin} + if innerjoin is not None + else util.EMPTY_DICT + ), ) return loader @@ -777,12 +779,10 @@ class _AbstractLoad(traversals.GenerativeOnTraversal, LoaderOption): return self @overload - def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: - ... + def _coerce_strat(self, strategy: _StrategySpec) -> _StrategyKey: ... @overload - def _coerce_strat(self, strategy: Literal[None]) -> None: - ... + def _coerce_strat(self, strategy: Literal[None]) -> None: ... def _coerce_strat( self, strategy: Optional[_StrategySpec] @@ -2081,9 +2081,9 @@ class _AttributeStrategyLoad(_LoadElement): d["_extra_criteria"] = () if self._path_with_polymorphic_path: - d[ - "_path_with_polymorphic_path" - ] = self._path_with_polymorphic_path.serialize() + d["_path_with_polymorphic_path"] = ( + self._path_with_polymorphic_path.serialize() + ) if self._of_type: if self._of_type.is_aliased_class: diff --git a/lib/sqlalchemy/orm/util.py b/lib/sqlalchemy/orm/util.py index 4309cb119e..370d3cad20 100644 --- a/lib/sqlalchemy/orm/util.py +++ b/lib/sqlalchemy/orm/util.py @@ -165,8 +165,7 @@ class _DeStringifyAnnotation(Protocol): *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, include_generic: bool = False, - ) -> Type[Any]: - ... + ) -> Type[Any]: ... de_stringify_annotation = cast( @@ -182,8 +181,7 @@ class _DeStringifyUnionElements(Protocol): originating_module: str, *, str_cleanup_fn: Optional[Callable[[str, str], str]] = None, - ) -> Type[Any]: - ... + ) -> Type[Any]: ... de_stringify_union_elements = cast( @@ -193,8 +191,7 @@ de_stringify_union_elements = cast( class _EvalNameOnly(Protocol): - def __call__(self, name: str, module_name: str) -> Any: - ... + def __call__(self, name: str, module_name: str) -> Any: ... eval_name_only = cast(_EvalNameOnly, _de_stringify_partial(_eval_name_only)) @@ -757,12 +754,16 @@ class AliasedClass( insp, alias, name, - with_polymorphic_mappers - if with_polymorphic_mappers - else mapper.with_polymorphic_mappers, - with_polymorphic_discriminator - if with_polymorphic_discriminator is not None - else mapper.polymorphic_on, + ( + with_polymorphic_mappers + if with_polymorphic_mappers + else mapper.with_polymorphic_mappers + ), + ( + with_polymorphic_discriminator + if with_polymorphic_discriminator is not None + else mapper.polymorphic_on + ), base_alias, use_mapper_path, adapt_on_names, @@ -973,9 +974,9 @@ class AliasedInsp( self._weak_entity = weakref.ref(entity) self.mapper = mapper - self.selectable = ( - self.persist_selectable - ) = self.local_table = selectable + self.selectable = self.persist_selectable = self.local_table = ( + selectable + ) self.name = name self.polymorphic_on = polymorphic_on self._base_alias = weakref.ref(_base_alias or self) @@ -1231,8 +1232,7 @@ class AliasedInsp( self, obj: _CE, key: Optional[str] = None, - ) -> _CE: - ... + ) -> _CE: ... else: _orm_adapt_element = _adapt_element diff --git a/lib/sqlalchemy/orm/writeonly.py b/lib/sqlalchemy/orm/writeonly.py index 3764a6bb5c..6e5756d42d 100644 --- a/lib/sqlalchemy/orm/writeonly.py +++ b/lib/sqlalchemy/orm/writeonly.py @@ -196,8 +196,7 @@ class WriteOnlyAttributeImpl( dict_: _InstanceDict, user_data: Literal[None] = ..., passive: Literal[PassiveFlag.PASSIVE_OFF] = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -206,8 +205,7 @@ class WriteOnlyAttributeImpl( dict_: _InstanceDict, user_data: _AdaptedCollectionProtocol = ..., passive: PassiveFlag = ..., - ) -> CollectionAdapter: - ... + ) -> CollectionAdapter: ... @overload def get_collection( @@ -218,8 +216,7 @@ class WriteOnlyAttributeImpl( passive: PassiveFlag = ..., ) -> Union[ Literal[LoaderCallableStatus.PASSIVE_NO_RESULT], CollectionAdapter - ]: - ... + ]: ... def get_collection( self, diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index 7818825de3..24bdc25d32 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -147,17 +147,14 @@ class _AsyncConnDialect(_ConnDialect): class _CreatorFnType(Protocol): - def __call__(self) -> DBAPIConnection: - ... + def __call__(self) -> DBAPIConnection: ... class _CreatorWRecFnType(Protocol): - def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: - ... + def __call__(self, rec: ConnectionPoolEntry) -> DBAPIConnection: ... class Pool(log.Identified, event.EventTarget): - """Abstract base class for connection pools.""" dispatch: dispatcher[Pool] @@ -633,7 +630,6 @@ class ConnectionPoolEntry(ManagesConnection): class _ConnectionRecord(ConnectionPoolEntry): - """Maintains a position in a connection pool which references a pooled connection. @@ -729,11 +725,13 @@ class _ConnectionRecord(ConnectionPoolEntry): rec.fairy_ref = ref = weakref.ref( fairy, - lambda ref: _finalize_fairy( - None, rec, pool, ref, echo, transaction_was_reset=False - ) - if _finalize_fairy is not None - else None, + lambda ref: ( + _finalize_fairy( + None, rec, pool, ref, echo, transaction_was_reset=False + ) + if _finalize_fairy is not None + else None + ), ) _strong_ref_connection_records[ref] = rec if echo: @@ -1074,14 +1072,11 @@ class PoolProxiedConnection(ManagesConnection): if typing.TYPE_CHECKING: - def commit(self) -> None: - ... + def commit(self) -> None: ... - def cursor(self) -> DBAPICursor: - ... + def cursor(self) -> DBAPICursor: ... - def rollback(self) -> None: - ... + def rollback(self) -> None: ... @property def is_valid(self) -> bool: @@ -1189,7 +1184,6 @@ class _AdhocProxiedConnection(PoolProxiedConnection): class _ConnectionFairy(PoolProxiedConnection): - """Proxies a DBAPI connection and provides return-on-dereference support. diff --git a/lib/sqlalchemy/pool/impl.py b/lib/sqlalchemy/pool/impl.py index fed0bfc8f0..e2bb81bf0d 100644 --- a/lib/sqlalchemy/pool/impl.py +++ b/lib/sqlalchemy/pool/impl.py @@ -43,7 +43,6 @@ if typing.TYPE_CHECKING: class QueuePool(Pool): - """A :class:`_pool.Pool` that imposes a limit on the number of open connections. @@ -55,9 +54,9 @@ class QueuePool(Pool): _is_asyncio = False # type: ignore[assignment] - _queue_class: Type[ - sqla_queue.QueueCommon[ConnectionPoolEntry] - ] = sqla_queue.Queue + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.Queue + ) _pool: sqla_queue.QueueCommon[ConnectionPoolEntry] @@ -250,15 +249,14 @@ class QueuePool(Pool): class AsyncAdaptedQueuePool(QueuePool): _is_asyncio = True # type: ignore[assignment] - _queue_class: Type[ - sqla_queue.QueueCommon[ConnectionPoolEntry] - ] = sqla_queue.AsyncAdaptedQueue + _queue_class: Type[sqla_queue.QueueCommon[ConnectionPoolEntry]] = ( + sqla_queue.AsyncAdaptedQueue + ) _dialect = _AsyncConnDialect() class NullPool(Pool): - """A Pool which does not pool connections. Instead it literally opens and closes the underlying DB-API connection @@ -298,7 +296,6 @@ class NullPool(Pool): class SingletonThreadPool(Pool): - """A Pool that maintains one connection per thread. Maintains one connection per each thread, never moving a connection to a @@ -418,7 +415,6 @@ class SingletonThreadPool(Pool): class StaticPool(Pool): - """A Pool of exactly one connection, used for all requests. Reconnect-related functions such as ``recycle`` and connection @@ -482,7 +478,6 @@ class StaticPool(Pool): class AssertionPool(Pool): - """A :class:`_pool.Pool` that allows at most one checked out connection at any given time. diff --git a/lib/sqlalchemy/sql/_elements_constructors.py b/lib/sqlalchemy/sql/_elements_constructors.py index 9dd2a58a1b..27bac59e12 100644 --- a/lib/sqlalchemy/sql/_elements_constructors.py +++ b/lib/sqlalchemy/sql/_elements_constructors.py @@ -437,13 +437,11 @@ def outparam( @overload -def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: - ... +def not_(clause: BinaryExpression[_T]) -> BinaryExpression[_T]: ... @overload -def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: - ... +def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: ... def not_(clause: _ColumnExpressionArgument[_T]) -> ColumnElement[_T]: diff --git a/lib/sqlalchemy/sql/_selectable_constructors.py b/lib/sqlalchemy/sql/_selectable_constructors.py index 736b4961ec..1737597172 100644 --- a/lib/sqlalchemy/sql/_selectable_constructors.py +++ b/lib/sqlalchemy/sql/_selectable_constructors.py @@ -332,20 +332,17 @@ def outerjoin( @overload -def select(__ent0: _TCCA[_T0], /) -> Select[_T0]: - ... +def select(__ent0: _TCCA[_T0], /) -> Select[_T0]: ... @overload -def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1], /) -> Select[_T0, _T1]: - ... +def select(__ent0: _TCCA[_T0], __ent1: _TCCA[_T1], /) -> Select[_T0, _T1]: ... @overload def select( __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / -) -> Select[_T0, _T1, _T2]: - ... +) -> Select[_T0, _T1, _T2]: ... @overload @@ -355,8 +352,7 @@ def select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, -) -> Select[_T0, _T1, _T2, _T3]: - ... +) -> Select[_T0, _T1, _T2, _T3]: ... @overload @@ -367,8 +363,7 @@ def select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, -) -> Select[_T0, _T1, _T2, _T3, _T4]: - ... +) -> Select[_T0, _T1, _T2, _T3, _T4]: ... @overload @@ -380,8 +375,7 @@ def select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, -) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: - ... +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload @@ -394,8 +388,7 @@ def select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, -) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload @@ -409,8 +402,7 @@ def select( __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], /, -) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: - ... +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... @overload @@ -425,8 +417,7 @@ def select( __ent7: _TCCA[_T7], __ent8: _TCCA[_T8], /, -) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]: - ... +) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8]: ... @overload @@ -445,8 +436,7 @@ def select( *entities: _ColumnsClauseArgument[Any], ) -> Select[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, _T8, _T9, Unpack[TupleAny] -]: - ... +]: ... # END OVERLOADED FUNCTIONS select @@ -455,8 +445,7 @@ def select( @overload def select( *entities: _ColumnsClauseArgument[Any], **__kw: Any -) -> Select[Unpack[TupleAny]]: - ... +) -> Select[Unpack[TupleAny]]: ... def select( diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 689ed19a9f..2b50f2bdab 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -84,15 +84,13 @@ _CLE = TypeVar("_CLE", bound="ClauseElement") class _HasClauseElement(Protocol, Generic[_T_co]): """indicates a class that has a __clause_element__() method""" - def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: - ... + def __clause_element__(self) -> roles.ExpressionElementRole[_T_co]: ... class _CoreAdapterProto(Protocol): """protocol for the ClauseAdapter/ColumnAdapter.traverse() method.""" - def __call__(self, obj: _CE) -> _CE: - ... + def __call__(self, obj: _CE) -> _CE: ... # match column types that are not ORM entities @@ -289,56 +287,47 @@ _AutoIncrementType = Union[bool, Literal["auto", "ignore_fk"]] if TYPE_CHECKING: - def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: - ... + def is_sql_compiler(c: Compiled) -> TypeGuard[SQLCompiler]: ... - def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: - ... + def is_ddl_compiler(c: Compiled) -> TypeGuard[DDLCompiler]: ... - def is_named_from_clause(t: FromClauseRole) -> TypeGuard[NamedFromClause]: - ... + def is_named_from_clause( + t: FromClauseRole, + ) -> TypeGuard[NamedFromClause]: ... - def is_column_element(c: ClauseElement) -> TypeGuard[ColumnElement[Any]]: - ... + def is_column_element( + c: ClauseElement, + ) -> TypeGuard[ColumnElement[Any]]: ... def is_keyed_column_element( c: ClauseElement, - ) -> TypeGuard[KeyedColumnElement[Any]]: - ... + ) -> TypeGuard[KeyedColumnElement[Any]]: ... - def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: - ... + def is_text_clause(c: ClauseElement) -> TypeGuard[TextClause]: ... - def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: - ... + def is_from_clause(c: ClauseElement) -> TypeGuard[FromClause]: ... - def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: - ... + def is_tuple_type(t: TypeEngine[Any]) -> TypeGuard[TupleType]: ... - def is_table_value_type(t: TypeEngine[Any]) -> TypeGuard[TableValueType]: - ... + def is_table_value_type( + t: TypeEngine[Any], + ) -> TypeGuard[TableValueType]: ... - def is_selectable(t: Any) -> TypeGuard[Selectable]: - ... + def is_selectable(t: Any) -> TypeGuard[Selectable]: ... def is_select_base( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[SelectBase]: - ... + ) -> TypeGuard[SelectBase]: ... def is_select_statement( t: Union[Executable, ReturnsRows] - ) -> TypeGuard[Select[Unpack[TupleAny]]]: - ... + ) -> TypeGuard[Select[Unpack[TupleAny]]]: ... - def is_table(t: FromClause) -> TypeGuard[TableClause]: - ... + def is_table(t: FromClause) -> TypeGuard[TableClause]: ... - def is_subquery(t: FromClause) -> TypeGuard[Subquery]: - ... + def is_subquery(t: FromClause) -> TypeGuard[Subquery]: ... - def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: - ... + def is_dml(c: ClauseElement) -> TypeGuard[UpdateBase]: ... else: is_sql_compiler = operator.attrgetter("is_sql") @@ -389,20 +378,17 @@ def _unexpected_kw(methname: str, kw: Dict[str, Any]) -> NoReturn: @overload def Nullable( val: "SQLCoreOperations[_T]", -) -> "SQLCoreOperations[Optional[_T]]": - ... +) -> "SQLCoreOperations[Optional[_T]]": ... @overload def Nullable( val: roles.ExpressionElementRole[_T], -) -> roles.ExpressionElementRole[Optional[_T]]: - ... +) -> roles.ExpressionElementRole[Optional[_T]]: ... @overload -def Nullable(val: Type[_T]) -> Type[Optional[_T]]: - ... +def Nullable(val: Type[_T]) -> Type[Optional[_T]]: ... def Nullable( @@ -426,25 +412,21 @@ def Nullable( @overload def NotNullable( val: "SQLCoreOperations[Optional[_T]]", -) -> "SQLCoreOperations[_T]": - ... +) -> "SQLCoreOperations[_T]": ... @overload def NotNullable( val: roles.ExpressionElementRole[Optional[_T]], -) -> roles.ExpressionElementRole[_T]: - ... +) -> roles.ExpressionElementRole[_T]: ... @overload -def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: - ... +def NotNullable(val: Type[Optional[_T]]) -> Type[_T]: ... @overload -def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: - ... +def NotNullable(val: Optional[Type[_T]]) -> Type[_T]: ... def NotNullable( diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 14e48bd2b8..db382b874b 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -67,16 +67,14 @@ class SupportsAnnotations(ExternallyTraversible): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -99,9 +97,11 @@ class SupportsAnnotations(ExternallyTraversible): tuple( ( key, - value._gen_cache_key(anon_map, []) - if isinstance(value, HasCacheKey) - else value, + ( + value._gen_cache_key(anon_map, []) + if isinstance(value, HasCacheKey) + else value + ), ) for key, value in [ (key, self._annotations[key]) @@ -119,8 +119,7 @@ class SupportsWrappingAnnotations(SupportsAnnotations): if TYPE_CHECKING: @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... def _annotate(self, values: _AnnotationDict) -> Self: """return a copy of this ClauseElement with annotations @@ -141,16 +140,14 @@ class SupportsWrappingAnnotations(SupportsAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -214,16 +211,14 @@ class SupportsCloneAnnotations(SupportsWrappingAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> SupportsAnnotations: - ... + ) -> SupportsAnnotations: ... def _deannotate( self, @@ -316,16 +311,14 @@ class Annotated(SupportsAnnotations): self, values: Literal[None] = ..., clone: bool = ..., - ) -> Self: - ... + ) -> Self: ... @overload def _deannotate( self, values: Sequence[str] = ..., clone: bool = ..., - ) -> Annotated: - ... + ) -> Annotated: ... def _deannotate( self, @@ -395,9 +388,9 @@ class Annotated(SupportsAnnotations): # so that the resulting objects are pickleable; additionally, other # decisions can be made up front about the type of object being annotated # just once per class rather than per-instance. -annotated_classes: Dict[ - Type[SupportsWrappingAnnotations], Type[Annotated] -] = {} +annotated_classes: Dict[Type[SupportsWrappingAnnotations], Type[Annotated]] = ( + {} +) _SA = TypeVar("_SA", bound="SupportsAnnotations") @@ -487,15 +480,13 @@ def _deep_annotate( @overload def _deep_deannotate( element: Literal[None], values: Optional[Sequence[str]] = None -) -> Literal[None]: - ... +) -> Literal[None]: ... @overload def _deep_deannotate( element: _SA, values: Optional[Sequence[str]] = None -) -> _SA: - ... +) -> _SA: ... def _deep_deannotate( diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index ee5583a74b..798a35eed4 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -154,14 +154,12 @@ _never_select_column = operator.attrgetter("_omit_from_statements") class _EntityNamespace(Protocol): - def __getattr__(self, key: str) -> SQLCoreOperations[Any]: - ... + def __getattr__(self, key: str) -> SQLCoreOperations[Any]: ... class _HasEntityNamespace(Protocol): @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: @@ -260,8 +258,7 @@ _SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType") class _GenerativeType(Protocol): - def _generate(self) -> Self: - ... + def _generate(self) -> Self: ... def _generative(fn: _Fn) -> _Fn: @@ -800,14 +797,11 @@ class _MetaOptions(type): if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... - def __setattr__(self, key: str, value: Any) -> None: - ... + def __setattr__(self, key: str, value: Any) -> None: ... - def __delattr__(self, key: str) -> None: - ... + def __delattr__(self, key: str) -> None: ... class Options(metaclass=_MetaOptions): @@ -965,14 +959,11 @@ class Options(metaclass=_MetaOptions): if TYPE_CHECKING: - def __getattr__(self, key: str) -> Any: - ... + def __getattr__(self, key: str) -> Any: ... - def __setattr__(self, key: str, value: Any) -> None: - ... + def __setattr__(self, key: str, value: Any) -> None: ... - def __delattr__(self, key: str) -> None: - ... + def __delattr__(self, key: str) -> None: ... class CacheableOptions(Options, HasCacheKey): @@ -1057,24 +1048,21 @@ class Executable(roles.StatementRole): **kw: Any, ) -> Tuple[ Compiled, Optional[Sequence[BindParameter[Any]]], CacheStats - ]: - ... + ]: ... def _execute_on_connection( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> CursorResult[Any]: - ... + ) -> CursorResult[Any]: ... def _execute_on_scalar( self, connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: CoreExecuteOptionsParameter, - ) -> Any: - ... + ) -> Any: ... @util.ro_non_memoized_property def _all_selected_columns(self): @@ -1179,12 +1167,10 @@ class Executable(roles.StatementRole): is_delete_using: bool = ..., is_update_from: bool = ..., **opt: Any, - ) -> Self: - ... + ) -> Self: ... @overload - def execution_options(self, **opt: Any) -> Self: - ... + def execution_options(self, **opt: Any) -> Self: ... @_generative def execution_options(self, **kw: Any) -> Self: @@ -1590,20 +1576,17 @@ class ColumnCollection(Generic[_COLKEY, _COL_co]): return iter([col for _, col, _ in self._collection]) @overload - def __getitem__(self, key: Union[str, int]) -> _COL_co: - ... + def __getitem__(self, key: Union[str, int]) -> _COL_co: ... @overload def __getitem__( self, key: Tuple[Union[str, int], ...] - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: - ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... @overload def __getitem__( self, key: slice - ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: - ... + ) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: ... def __getitem__( self, key: Union[str, int, slice, Tuple[Union[str, int], ...]] diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 0435be7462..ba8a5403e7 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -44,8 +44,7 @@ if typing.TYPE_CHECKING: class _CacheKeyTraversalDispatchType(Protocol): def __call__( s, self: HasCacheKey, visitor: _CacheKeyTraversal - ) -> _CacheKeyTraversalDispatchTypeReturn: - ... + ) -> _CacheKeyTraversalDispatchTypeReturn: ... class CacheConst(enum.Enum): @@ -303,11 +302,13 @@ class HasCacheKey: result += ( attrname, obj["compile_state_plugin"], - obj["plugin_subject"]._gen_cache_key( - anon_map, bindparams - ) - if obj["plugin_subject"] - else None, + ( + obj["plugin_subject"]._gen_cache_key( + anon_map, bindparams + ) + if obj["plugin_subject"] + else None + ), ) elif meth is InternalTraversal.dp_annotations_key: # obj is here is the _annotations dict. Table uses @@ -619,9 +620,9 @@ class _CacheKeyTraversal(HasTraversalDispatch): InternalTraversal.dp_memoized_select_entities ) - visit_string = ( - visit_boolean - ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE + visit_string = visit_boolean = visit_operator = visit_plain_obj = ( + CACHE_IN_PLACE + ) visit_statement_hint_list = CACHE_IN_PLACE visit_type = STATIC_CACHE_KEY visit_anon_name = ANON_NAME @@ -668,9 +669,11 @@ class _CacheKeyTraversal(HasTraversalDispatch): ) -> Tuple[Any, ...]: return ( attrname, - obj._gen_cache_key(anon_map, bindparams) - if isinstance(obj, HasCacheKey) - else obj, + ( + obj._gen_cache_key(anon_map, bindparams) + if isinstance(obj, HasCacheKey) + else obj + ), ) def visit_multi_list( @@ -684,9 +687,11 @@ class _CacheKeyTraversal(HasTraversalDispatch): return ( attrname, tuple( - elem._gen_cache_key(anon_map, bindparams) - if isinstance(elem, HasCacheKey) - else elem + ( + elem._gen_cache_key(anon_map, bindparams) + if isinstance(elem, HasCacheKey) + else elem + ) for elem in obj ), ) @@ -847,12 +852,16 @@ class _CacheKeyTraversal(HasTraversalDispatch): return tuple( ( target._gen_cache_key(anon_map, bindparams), - onclause._gen_cache_key(anon_map, bindparams) - if onclause is not None - else None, - from_._gen_cache_key(anon_map, bindparams) - if from_ is not None - else None, + ( + onclause._gen_cache_key(anon_map, bindparams) + if onclause is not None + else None + ), + ( + from_._gen_cache_key(anon_map, bindparams) + if from_ is not None + else None + ), tuple([(key, flags[key]) for key in sorted(flags)]), ) for (target, onclause, from_, flags) in obj @@ -946,9 +955,11 @@ class _CacheKeyTraversal(HasTraversalDispatch): tuple( ( key, - value._gen_cache_key(anon_map, bindparams) - if isinstance(value, HasCacheKey) - else value, + ( + value._gen_cache_key(anon_map, bindparams) + if isinstance(value, HasCacheKey) + else value + ), ) for key, value in [(key, obj[key]) for key in sorted(obj)] ), @@ -994,9 +1005,11 @@ class _CacheKeyTraversal(HasTraversalDispatch): attrname, tuple( ( - key._gen_cache_key(anon_map, bindparams) - if hasattr(key, "__clause_element__") - else key, + ( + key._gen_cache_key(anon_map, bindparams) + if hasattr(key, "__clause_element__") + else key + ), value._gen_cache_key(anon_map, bindparams), ) for key, value in obj @@ -1017,9 +1030,11 @@ class _CacheKeyTraversal(HasTraversalDispatch): attrname, tuple( ( - k._gen_cache_key(anon_map, bindparams) - if hasattr(k, "__clause_element__") - else k, + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k + ), obj[k]._gen_cache_key(anon_map, bindparams), ) for k in obj diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 3d33924d89..22d6091552 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -165,8 +165,7 @@ def expect( role: Type[roles.TruncatedLabelRole], element: Any, **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -176,8 +175,7 @@ def expect( *, as_key: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -185,8 +183,7 @@ def expect( role: Type[roles.LiteralValueRole], element: Any, **kw: Any, -) -> BindParameter[Any]: - ... +) -> BindParameter[Any]: ... @overload @@ -194,8 +191,7 @@ def expect( role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, -) -> Column[Any]: - ... +) -> Column[Any]: ... @overload @@ -203,8 +199,7 @@ def expect( role: Type[roles.DDLConstraintColumnRole], element: Any, **kw: Any, -) -> Union[Column[Any], str]: - ... +) -> Union[Column[Any], str]: ... @overload @@ -212,8 +207,7 @@ def expect( role: Type[roles.StatementOptionRole], element: Any, **kw: Any, -) -> DQLDMLClauseElement: - ... +) -> DQLDMLClauseElement: ... @overload @@ -221,8 +215,7 @@ def expect( role: Type[roles.LabeledColumnExprRole[Any]], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> NamedColumn[_T]: - ... +) -> NamedColumn[_T]: ... @overload @@ -234,8 +227,7 @@ def expect( ], element: _ColumnExpressionArgument[_T], **kw: Any, -) -> ColumnElement[_T]: - ... +) -> ColumnElement[_T]: ... @overload @@ -249,8 +241,7 @@ def expect( ], element: Any, **kw: Any, -) -> ColumnElement[Any]: - ... +) -> ColumnElement[Any]: ... @overload @@ -258,8 +249,7 @@ def expect( role: Type[roles.DMLTableRole], element: _DMLTableArgument, **kw: Any, -) -> _DMLTableElement: - ... +) -> _DMLTableElement: ... @overload @@ -267,8 +257,7 @@ def expect( role: Type[roles.HasCTERole], element: HasCTE, **kw: Any, -) -> HasCTE: - ... +) -> HasCTE: ... @overload @@ -276,8 +265,7 @@ def expect( role: Type[roles.SelectStatementRole], element: SelectBase, **kw: Any, -) -> SelectBase: - ... +) -> SelectBase: ... @overload @@ -285,8 +273,7 @@ def expect( role: Type[roles.FromClauseRole], element: _FromClauseArgument, **kw: Any, -) -> FromClause: - ... +) -> FromClause: ... @overload @@ -296,8 +283,7 @@ def expect( *, explicit_subquery: Literal[True] = ..., **kw: Any, -) -> Subquery: - ... +) -> Subquery: ... @overload @@ -305,8 +291,7 @@ def expect( role: Type[roles.ColumnsClauseRole], element: _ColumnsClauseArgument[Any], **kw: Any, -) -> _ColumnsClauseElement: - ... +) -> _ColumnsClauseElement: ... @overload @@ -314,8 +299,7 @@ def expect( role: Type[roles.JoinTargetRole], element: _JoinTargetProtocol, **kw: Any, -) -> _JoinTargetProtocol: - ... +) -> _JoinTargetProtocol: ... # catchall for not-yet-implemented overloads @@ -324,8 +308,7 @@ def expect( role: Type[_SR], element: Any, **kw: Any, -) -> Any: - ... +) -> Any: ... def expect( @@ -870,9 +853,11 @@ class InElementImpl(RoleImpl): if non_literal_expressions: return elements.ClauseList( *[ - non_literal_expressions[o] - if o in non_literal_expressions - else expr._bind_param(operator, o) + ( + non_literal_expressions[o] + if o in non_literal_expressions + else expr._bind_param(operator, o) + ) for o in element ] ) @@ -1150,9 +1135,9 @@ class ColumnsClauseImpl(_SelectIsNotFrom, _CoerceLiterals, RoleImpl): % { "column": util.ellipses_string(element), "argname": "for argument %s" % (argname,) if argname else "", - "literal_column": "literal_column" - if guess_is_literal - else "column", + "literal_column": ( + "literal_column" if guess_is_literal else "column" + ), } ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index ea19e9a86d..e2bdce3291 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -384,8 +384,7 @@ class _ResultMapAppender(Protocol): name: str, objects: Sequence[Any], type_: TypeEngine[Any], - ) -> None: - ... + ) -> None: ... # integer indexes into ResultColumnsEntry used by cursor.py. @@ -739,7 +738,6 @@ class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): class Compiled: - """Represent a compiled SQL or DDL expression. The ``__str__`` method of the ``Compiled`` object should produce @@ -969,7 +967,6 @@ class TypeCompiler(util.EnsureKWArg): class _CompileLabel( roles.BinaryElementRole[Any], elements.CompilerColumnElement ): - """lightweight label object which acts as an expression.Label.""" __visit_name__ = "label" @@ -1039,19 +1036,19 @@ class SQLCompiler(Compiled): extract_map = EXTRACT_MAP - bindname_escape_characters: ClassVar[ - Mapping[str, str] - ] = util.immutabledict( - { - "%": "P", - "(": "A", - ")": "Z", - ":": "C", - ".": "_", - "[": "_", - "]": "_", - " ": "_", - } + bindname_escape_characters: ClassVar[Mapping[str, str]] = ( + util.immutabledict( + { + "%": "P", + "(": "A", + ")": "Z", + ":": "C", + ".": "_", + "[": "_", + "]": "_", + " ": "_", + } + ) ) """A mapping (e.g. dict or similar) containing a lookup of characters keyed to replacement characters which will be applied to all @@ -1791,11 +1788,15 @@ class SQLCompiler(Compiled): for key, value in ( ( self.bind_names[bindparam], - bindparam.type._cached_bind_processor(self.dialect) - if not bindparam.type._is_tuple_type - else tuple( - elem_type._cached_bind_processor(self.dialect) - for elem_type in cast(TupleType, bindparam.type).types + ( + bindparam.type._cached_bind_processor(self.dialect) + if not bindparam.type._is_tuple_type + else tuple( + elem_type._cached_bind_processor(self.dialect) + for elem_type in cast( + TupleType, bindparam.type + ).types + ) ), ) for bindparam in self.bind_names @@ -2101,11 +2102,11 @@ class SQLCompiler(Compiled): if parameter in self.literal_execute_params: if escaped_name not in replacement_expressions: - replacement_expressions[ - escaped_name - ] = self.render_literal_bindparam( - parameter, - render_literal_value=parameters.pop(escaped_name), + replacement_expressions[escaped_name] = ( + self.render_literal_bindparam( + parameter, + render_literal_value=parameters.pop(escaped_name), + ) ) continue @@ -2314,12 +2315,14 @@ class SQLCompiler(Compiled): else: return row_fn( ( - autoinc_getter(lastrowid, parameters) - if autoinc_getter is not None - else lastrowid + ( + autoinc_getter(lastrowid, parameters) + if autoinc_getter is not None + else lastrowid + ) + if col is autoinc_col + else getter(parameters) ) - if col is autoinc_col - else getter(parameters) for getter, col in getters ) @@ -2349,11 +2352,15 @@ class SQLCompiler(Compiled): getters = cast( "List[Tuple[Callable[[Any], Any], bool]]", [ - (operator.itemgetter(ret[col]), True) - if col in ret - else ( - operator.methodcaller("get", param_key_getter(col), None), - False, + ( + (operator.itemgetter(ret[col]), True) + if col in ret + else ( + operator.methodcaller( + "get", param_key_getter(col), None + ), + False, + ) ) for col in table.primary_key ], @@ -2422,9 +2429,9 @@ class SQLCompiler(Compiled): resolve_dict[order_by_elem.name] ) ): - kwargs[ - "render_label_as_label" - ] = element.element._order_by_label_element + kwargs["render_label_as_label"] = ( + element.element._order_by_label_element + ) return self.process( element.element, within_columns_clause=within_columns_clause, @@ -2670,9 +2677,9 @@ class SQLCompiler(Compiled): ) if populate_result_map: - self._ordered_columns = ( - self._textual_ordered_columns - ) = taf.positional + self._ordered_columns = self._textual_ordered_columns = ( + taf.positional + ) # enable looser result column matching when the SQL text links to # Column objects by name only @@ -2799,24 +2806,44 @@ class SQLCompiler(Compiled): def _format_frame_clause(self, range_, **kw): return "%s AND %s" % ( - "UNBOUNDED PRECEDING" - if range_[0] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" - if range_[0] is elements.RANGE_CURRENT - else "%s PRECEDING" - % (self.process(elements.literal(abs(range_[0])), **kw),) - if range_[0] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[0]), **kw),), - "UNBOUNDED FOLLOWING" - if range_[1] is elements.RANGE_UNBOUNDED - else "CURRENT ROW" - if range_[1] is elements.RANGE_CURRENT - else "%s PRECEDING" - % (self.process(elements.literal(abs(range_[1])), **kw),) - if range_[1] < 0 - else "%s FOLLOWING" - % (self.process(elements.literal(range_[1]), **kw),), + ( + "UNBOUNDED PRECEDING" + if range_[0] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[0] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[0])), **kw + ), + ) + if range_[0] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[0]), **kw),) + ) + ) + ), + ( + "UNBOUNDED FOLLOWING" + if range_[1] is elements.RANGE_UNBOUNDED + else ( + "CURRENT ROW" + if range_[1] is elements.RANGE_CURRENT + else ( + "%s PRECEDING" + % ( + self.process( + elements.literal(abs(range_[1])), **kw + ), + ) + if range_[1] < 0 + else "%s FOLLOWING" + % (self.process(elements.literal(range_[1]), **kw),) + ) + ) + ), ) def visit_over(self, over, **kwargs): @@ -3057,9 +3084,12 @@ class SQLCompiler(Compiled): + self.process( elements.Cast( binary.right, - binary.right.type - if binary.right.type._type_affinity is sqltypes.Numeric - else sqltypes.Numeric(), + ( + binary.right.type + if binary.right.type._type_affinity + is sqltypes.Numeric + else sqltypes.Numeric() + ), ), **kw, ) @@ -4214,12 +4244,14 @@ class SQLCompiler(Compiled): "%s%s" % ( self.preparer.quote(col.name), - " %s" - % self.dialect.type_compiler_instance.process( - col.type, **kwargs - ) - if alias._render_derived_w_types - else "", + ( + " %s" + % self.dialect.type_compiler_instance.process( + col.type, **kwargs + ) + if alias._render_derived_w_types + else "" + ), ) for col in alias.c ) @@ -4611,9 +4643,9 @@ class SQLCompiler(Compiled): compile_state = select_stmt._compile_state_factory( select_stmt, self, **kwargs ) - kwargs[ - "ambiguous_table_name_map" - ] = compile_state._ambiguous_table_name_map + kwargs["ambiguous_table_name_map"] = ( + compile_state._ambiguous_table_name_map + ) select_stmt = compile_state.statement @@ -5856,9 +5888,9 @@ class SQLCompiler(Compiled): insert_stmt._post_values_clause is not None ), sentinel_columns=add_sentinel_cols, - num_sentinel_columns=len(add_sentinel_cols) - if add_sentinel_cols - else 0, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), implicit_sentinel=implicit_sentinel, ) elif compile_state._has_multi_parameters: @@ -5952,9 +5984,9 @@ class SQLCompiler(Compiled): insert_stmt._post_values_clause is not None ), sentinel_columns=add_sentinel_cols, - num_sentinel_columns=len(add_sentinel_cols) - if add_sentinel_cols - else 0, + num_sentinel_columns=( + len(add_sentinel_cols) if add_sentinel_cols else 0 + ), sentinel_param_keys=named_sentinel_params, implicit_sentinel=implicit_sentinel, embed_values_counter=embed_sentinel_value, @@ -6439,8 +6471,7 @@ class DDLCompiler(Compiled): schema_translate_map: Optional[SchemaTranslateMapType] = ..., render_schema_translate: bool = ..., compile_kwargs: Mapping[str, Any] = ..., - ): - ... + ): ... @util.memoized_property def sql_compiler(self): @@ -7168,17 +7199,14 @@ class StrSQLTypeCompiler(GenericTypeCompiler): class _SchemaForObjectCallable(Protocol): - def __call__(self, obj: Any) -> str: - ... + def __call__(self, obj: Any) -> str: ... class _BindNameForColProtocol(Protocol): - def __call__(self, col: ColumnClause[Any]) -> str: - ... + def __call__(self, col: ColumnClause[Any]) -> str: ... class IdentifierPreparer: - """Handle quoting and case-folding of identifiers based on options.""" reserved_words = RESERVED_WORDS diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py index fc6f51de1c..499a19d97c 100644 --- a/lib/sqlalchemy/sql/crud.py +++ b/lib/sqlalchemy/sql/crud.py @@ -394,8 +394,7 @@ def _create_bind_param( required: bool = False, name: Optional[str] = None, **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -404,8 +403,7 @@ def _create_bind_param( col: ColumnElement[Any], value: Any, **kw: Any, -) -> str: - ... +) -> str: ... def _create_bind_param( @@ -859,10 +857,12 @@ def _append_param_parameter( c, value, required=value is REQUIRED, - name=_col_bind_name(c) - if not _compile_state_isinsert(compile_state) - or not compile_state._has_multi_parameters - else "%s_m0" % _col_bind_name(c), + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), accumulate_bind_names=accumulated_bind_names, **kw, ) @@ -884,10 +884,12 @@ def _append_param_parameter( compiler, c, value, - name=_col_bind_name(c) - if not _compile_state_isinsert(compile_state) - or not compile_state._has_multi_parameters - else "%s_m0" % _col_bind_name(c), + name=( + _col_bind_name(c) + if not _compile_state_isinsert(compile_state) + or not compile_state._has_multi_parameters + else "%s_m0" % _col_bind_name(c) + ), accumulate_bind_names=accumulated_bind_names, **kw, ) @@ -1213,8 +1215,7 @@ def _create_insert_prefetch_bind_param( c: ColumnElement[Any], process: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -1223,8 +1224,7 @@ def _create_insert_prefetch_bind_param( c: ColumnElement[Any], process: Literal[False], **kw: Any, -) -> elements.BindParameter[Any]: - ... +) -> elements.BindParameter[Any]: ... def _create_insert_prefetch_bind_param( @@ -1247,8 +1247,7 @@ def _create_update_prefetch_bind_param( c: ColumnElement[Any], process: Literal[True] = ..., **kw: Any, -) -> str: - ... +) -> str: ... @overload @@ -1257,8 +1256,7 @@ def _create_update_prefetch_bind_param( c: ColumnElement[Any], process: Literal[False], **kw: Any, -) -> elements.BindParameter[Any]: - ... +) -> elements.BindParameter[Any]: ... def _create_update_prefetch_bind_param( diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 378de6ea5b..aacfa82645 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -95,8 +95,7 @@ class DDLIfCallable(Protocol): dialect: Dialect, compiler: Optional[DDLCompiler] = ..., checkfirst: bool, - ) -> bool: - ... + ) -> bool: ... class DDLIf(typing.NamedTuple): @@ -1021,10 +1020,12 @@ class SchemaDropper(InvokeDropDDLBase): reversed( sort_tables_and_constraints( unsorted_tables, - filter_fn=lambda constraint: False - if not self.dialect.supports_alter - or constraint.name is None - else None, + filter_fn=lambda constraint: ( + False + if not self.dialect.supports_alter + or constraint.name is None + else None + ), ) ) ) diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 072acafed3..5bf8d582e5 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -296,9 +296,11 @@ def _match_impl( operator=operators.match_op, ), result_type=type_api.MATCHTYPE, - negate_op=operators.not_match_op - if op is operators.match_op - else operators.match_op, + negate_op=( + operators.not_match_op + if op is operators.match_op + else operators.match_op + ), **kw, ) @@ -340,9 +342,11 @@ def _between_impl( group=False, ), op, - negate=operators.not_between_op - if op is operators.between_op - else operators.between_op, + negate=( + operators.not_between_op + if op is operators.between_op + else operators.between_op + ), modifiers=kw, ) diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index f35815ca4f..a0ab097f05 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -94,14 +94,11 @@ if TYPE_CHECKING: from .selectable import Select from .selectable import Selectable - def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: - ... + def isupdate(dml: DMLState) -> TypeGuard[UpdateDMLState]: ... - def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: - ... + def isdelete(dml: DMLState) -> TypeGuard[DeleteDMLState]: ... - def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: - ... + def isinsert(dml: DMLState) -> TypeGuard[InsertDMLState]: ... else: isupdate = operator.attrgetter("isupdate") @@ -141,9 +138,11 @@ class DMLState(CompileState): @classmethod def get_entity_description(cls, statement: UpdateBase) -> Dict[str, Any]: return { - "name": statement.table.name - if is_named_from_clause(statement.table) - else None, + "name": ( + statement.table.name + if is_named_from_clause(statement.table) + else None + ), "table": statement.table, } @@ -167,8 +166,7 @@ class DMLState(CompileState): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: - ... + def get_plugin_class(cls, statement: Executable) -> Type[DMLState]: ... @classmethod def _get_multi_crud_kv_pairs( @@ -194,13 +192,15 @@ class DMLState(CompileState): return [ ( coercions.expect(roles.DMLColumnRole, k), - v - if not needs_to_be_cacheable - else coercions.expect( - roles.ExpressionElementRole, - v, - type_=NullType(), - is_crud=True, + ( + v + if not needs_to_be_cacheable + else coercions.expect( + roles.ExpressionElementRole, + v, + type_=NullType(), + is_crud=True, + ) ), ) for k, v in kv_iterator @@ -310,12 +310,14 @@ class InsertDMLState(DMLState): def _process_multi_values(self, statement: ValuesBase) -> None: for parameters in statement._multi_values: multi_parameters: List[MutableMapping[_DMLColumnElement, Any]] = [ - { - c.key: value - for c, value in zip(statement.table.c, parameter_set) - } - if isinstance(parameter_set, collections_abc.Sequence) - else parameter_set + ( + { + c.key: value + for c, value in zip(statement.table.c, parameter_set) + } + if isinstance(parameter_set, collections_abc.Sequence) + else parameter_set + ) for parameter_set in parameters ] @@ -400,9 +402,9 @@ class UpdateBase( __visit_name__ = "update_base" - _hints: util.immutabledict[ - Tuple[_DMLTableElement, str], str - ] = util.EMPTY_DICT + _hints: util.immutabledict[Tuple[_DMLTableElement, str], str] = ( + util.EMPTY_DICT + ) named_with_column = False _label_style: SelectLabelStyle = ( @@ -411,9 +413,9 @@ class UpdateBase( table: _DMLTableElement _return_defaults = False - _return_defaults_columns: Optional[ - Tuple[_ColumnsClauseElement, ...] - ] = None + _return_defaults_columns: Optional[Tuple[_ColumnsClauseElement, ...]] = ( + None + ) _supplemental_returning: Optional[Tuple[_ColumnsClauseElement, ...]] = None _returning: Tuple[_ColumnsClauseElement, ...] = () @@ -1303,8 +1305,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0]: - ... + ) -> ReturningInsert[_T0]: ... @overload def returning( @@ -1314,8 +1315,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1]: - ... + ) -> ReturningInsert[_T0, _T1]: ... @overload def returning( @@ -1326,8 +1326,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1, _T2]: - ... + ) -> ReturningInsert[_T0, _T1, _T2]: ... @overload def returning( @@ -1339,8 +1338,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1, _T2, _T3]: - ... + ) -> ReturningInsert[_T0, _T1, _T2, _T3]: ... @overload def returning( @@ -1353,8 +1351,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4]: ... @overload def returning( @@ -1368,8 +1365,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def returning( @@ -1384,8 +1380,7 @@ class Insert(ValuesBase): /, *, sort_by_parameter_order: bool = False, - ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> ReturningInsert[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def returning( @@ -1403,8 +1398,7 @@ class Insert(ValuesBase): sort_by_parameter_order: bool = False, ) -> ReturningInsert[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.returning @@ -1414,16 +1408,14 @@ class Insert(ValuesBase): *cols: _ColumnsClauseArgument[Any], sort_by_parameter_order: bool = False, **__kw: Any, - ) -> ReturningInsert[Any]: - ... + ) -> ReturningInsert[Any]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], sort_by_parameter_order: bool = False, **__kw: Any, - ) -> ReturningInsert[Any]: - ... + ) -> ReturningInsert[Any]: ... class ReturningInsert(Insert, TypedReturnsRows[Unpack[_Ts]]): @@ -1613,20 +1605,17 @@ class Update(DMLWhereBase, ValuesBase): # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0], /) -> ReturningUpdate[_T0]: - ... + def returning(self, __ent0: _TCCA[_T0], /) -> ReturningUpdate[_T0]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], / - ) -> ReturningUpdate[_T0, _T1]: - ... + ) -> ReturningUpdate[_T0, _T1]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / - ) -> ReturningUpdate[_T0, _T1, _T2]: - ... + ) -> ReturningUpdate[_T0, _T1, _T2]: ... @overload def returning( @@ -1636,8 +1625,7 @@ class Update(DMLWhereBase, ValuesBase): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, - ) -> ReturningUpdate[_T0, _T1, _T2, _T3]: - ... + ) -> ReturningUpdate[_T0, _T1, _T2, _T3]: ... @overload def returning( @@ -1648,8 +1636,7 @@ class Update(DMLWhereBase, ValuesBase): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, - ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4]: ... @overload def returning( @@ -1661,8 +1648,7 @@ class Update(DMLWhereBase, ValuesBase): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, - ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def returning( @@ -1675,8 +1661,7 @@ class Update(DMLWhereBase, ValuesBase): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, - ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> ReturningUpdate[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def returning( @@ -1693,21 +1678,18 @@ class Update(DMLWhereBase, ValuesBase): *entities: _ColumnsClauseArgument[Any], ) -> ReturningUpdate[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.returning @overload def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningUpdate[Any]: - ... + ) -> ReturningUpdate[Any]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningUpdate[Any]: - ... + ) -> ReturningUpdate[Any]: ... class ReturningUpdate(Update, TypedReturnsRows[Unpack[_Ts]]): @@ -1759,20 +1741,17 @@ class Delete(DMLWhereBase, UpdateBase): # statically generated** by tools/generate_tuple_map_overloads.py @overload - def returning(self, __ent0: _TCCA[_T0], /) -> ReturningDelete[_T0]: - ... + def returning(self, __ent0: _TCCA[_T0], /) -> ReturningDelete[_T0]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], / - ) -> ReturningDelete[_T0, _T1]: - ... + ) -> ReturningDelete[_T0, _T1]: ... @overload def returning( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], / - ) -> ReturningDelete[_T0, _T1, _T2]: - ... + ) -> ReturningDelete[_T0, _T1, _T2]: ... @overload def returning( @@ -1782,8 +1761,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], /, - ) -> ReturningDelete[_T0, _T1, _T2, _T3]: - ... + ) -> ReturningDelete[_T0, _T1, _T2, _T3]: ... @overload def returning( @@ -1794,8 +1772,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], /, - ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4]: ... @overload def returning( @@ -1807,8 +1784,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], /, - ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def returning( @@ -1821,8 +1797,7 @@ class Delete(DMLWhereBase, UpdateBase): __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], /, - ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> ReturningDelete[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def returning( @@ -1839,21 +1814,18 @@ class Delete(DMLWhereBase, UpdateBase): *entities: _ColumnsClauseArgument[Any], ) -> ReturningDelete[ _T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7, Unpack[TupleAny] - ]: - ... + ]: ... # END OVERLOADED FUNCTIONS self.returning @overload def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Unpack[TupleAny]]: - ... + ) -> ReturningDelete[Unpack[TupleAny]]: ... def returning( self, *cols: _ColumnsClauseArgument[Any], **__kw: Any - ) -> ReturningDelete[Unpack[TupleAny]]: - ... + ) -> ReturningDelete[Unpack[TupleAny]]: ... class ReturningDelete(Update, TypedReturnsRows[Unpack[_Ts]]): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 973b332d47..bf7e9438d9 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -132,8 +132,7 @@ def literal( value: Any, type_: _TypeEngineArgument[_T], literal_execute: bool = False, -) -> BindParameter[_T]: - ... +) -> BindParameter[_T]: ... @overload @@ -141,8 +140,7 @@ def literal( value: _T, type_: None = None, literal_execute: bool = False, -) -> BindParameter[_T]: - ... +) -> BindParameter[_T]: ... @overload @@ -150,8 +148,7 @@ def literal( value: Any, type_: Optional[_TypeEngineArgument[Any]] = None, literal_execute: bool = False, -) -> BindParameter[Any]: - ... +) -> BindParameter[Any]: ... def literal( @@ -390,8 +387,7 @@ class ClauseElement( def get_children( self, *, omit_attrs: typing_Tuple[str, ...] = ..., **kw: Any - ) -> Iterable[ClauseElement]: - ... + ) -> Iterable[ClauseElement]: ... @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -786,8 +782,7 @@ class DQLDMLClauseElement(ClauseElement): bind: Optional[Union[Engine, Connection]] = None, dialect: Optional[Dialect] = None, **kw: Any, - ) -> SQLCompiler: - ... + ) -> SQLCompiler: ... class CompilerColumnElement( @@ -820,18 +815,15 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): if typing.TYPE_CHECKING: @util.non_memoized_property - def _propagate_attrs(self) -> _PropagateAttrsType: - ... + def _propagate_attrs(self) -> _PropagateAttrsType: ... def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def op( @@ -842,8 +834,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): *, return_type: _TypeEngineArgument[_OPT], python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[_OPT]]: - ... + ) -> Callable[[Any], BinaryExpression[_OPT]]: ... @overload def op( @@ -853,8 +844,7 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): is_comparison: bool = ..., return_type: Optional[_TypeEngineArgument[Any]] = ..., python_impl: Optional[Callable[..., Any]] = ..., - ) -> Callable[[Any], BinaryExpression[Any]]: - ... + ) -> Callable[[Any], BinaryExpression[Any]]: ... def op( self, @@ -863,38 +853,30 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): is_comparison: bool = False, return_type: Optional[_TypeEngineArgument[Any]] = None, python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[Any]]: - ... + ) -> Callable[[Any], BinaryExpression[Any]]: ... def bool_op( self, opstring: str, precedence: int = 0, python_impl: Optional[Callable[..., Any]] = None, - ) -> Callable[[Any], BinaryExpression[bool]]: - ... + ) -> Callable[[Any], BinaryExpression[bool]]: ... - def __and__(self, other: Any) -> BooleanClauseList: - ... + def __and__(self, other: Any) -> BooleanClauseList: ... - def __or__(self, other: Any) -> BooleanClauseList: - ... + def __or__(self, other: Any) -> BooleanClauseList: ... - def __invert__(self) -> ColumnElement[_T_co]: - ... + def __invert__(self) -> ColumnElement[_T_co]: ... - def __lt__(self, other: Any) -> ColumnElement[bool]: - ... + def __lt__(self, other: Any) -> ColumnElement[bool]: ... - def __le__(self, other: Any) -> ColumnElement[bool]: - ... + def __le__(self, other: Any) -> ColumnElement[bool]: ... # declare also that this class has an hash method otherwise # it may be assumed to be None by type checkers since the # object defines __eq__ and python sets it to None in that case: # https://docs.python.org/3/reference/datamodel.html#object.__hash__ - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... def __eq__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... @@ -902,226 +884,172 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 ... - def is_distinct_from(self, other: Any) -> ColumnElement[bool]: - ... + def is_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: - ... + def is_not_distinct_from(self, other: Any) -> ColumnElement[bool]: ... - def __gt__(self, other: Any) -> ColumnElement[bool]: - ... + def __gt__(self, other: Any) -> ColumnElement[bool]: ... - def __ge__(self, other: Any) -> ColumnElement[bool]: - ... + def __ge__(self, other: Any) -> ColumnElement[bool]: ... - def __neg__(self) -> UnaryExpression[_T_co]: - ... + def __neg__(self) -> UnaryExpression[_T_co]: ... - def __contains__(self, other: Any) -> ColumnElement[bool]: - ... + def __contains__(self, other: Any) -> ColumnElement[bool]: ... - def __getitem__(self, index: Any) -> ColumnElement[Any]: - ... + def __getitem__(self, index: Any) -> ColumnElement[Any]: ... @overload - def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: - ... + def __lshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __lshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... - def __lshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __lshift__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: - ... + def __rshift__(self: _SQO[int], other: Any) -> ColumnElement[int]: ... @overload - def __rshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... - def __rshift__(self, other: Any) -> ColumnElement[Any]: - ... + def __rshift__(self, other: Any) -> ColumnElement[Any]: ... @overload - def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: - ... + def concat(self: _SQO[str], other: Any) -> ColumnElement[str]: ... @overload - def concat(self, other: Any) -> ColumnElement[Any]: - ... + def concat(self, other: Any) -> ColumnElement[Any]: ... - def concat(self, other: Any) -> ColumnElement[Any]: - ... + def concat(self, other: Any) -> ColumnElement[Any]: ... def like( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def ilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def bitwise_xor(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_xor(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_or(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_or(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_and(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_and(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_not(self) -> UnaryExpression[_T_co]: - ... + def bitwise_not(self) -> UnaryExpression[_T_co]: ... - def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_lshift(self, other: Any) -> BinaryExpression[Any]: ... - def bitwise_rshift(self, other: Any) -> BinaryExpression[Any]: - ... + def bitwise_rshift(self, other: Any) -> BinaryExpression[Any]: ... def in_( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_in( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notin_( self, other: Union[ Iterable[Any], BindParameter[Any], roles.InElementRole ], - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_like( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notlike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def not_ilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... def notilike( self, other: Any, escape: Optional[str] = None - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def is_(self, other: Any) -> BinaryExpression[bool]: - ... + def is_(self, other: Any) -> BinaryExpression[bool]: ... - def is_not(self, other: Any) -> BinaryExpression[bool]: - ... + def is_not(self, other: Any) -> BinaryExpression[bool]: ... - def isnot(self, other: Any) -> BinaryExpression[bool]: - ... + def isnot(self, other: Any) -> BinaryExpression[bool]: ... def startswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def istartswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def endswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def iendswith( self, other: Any, escape: Optional[str] = None, autoescape: bool = False, - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... - def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: - ... + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def icontains(self, other: Any, **kw: Any) -> ColumnElement[bool]: - ... + def icontains(self, other: Any, **kw: Any) -> ColumnElement[bool]: ... - def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: - ... + def match(self, other: Any, **kwargs: Any) -> ColumnElement[bool]: ... def regexp_match( self, pattern: Any, flags: Optional[str] = None - ) -> ColumnElement[bool]: - ... + ) -> ColumnElement[bool]: ... def regexp_replace( self, pattern: Any, replacement: Any, flags: Optional[str] = None - ) -> ColumnElement[str]: - ... + ) -> ColumnElement[str]: ... - def desc(self) -> UnaryExpression[_T_co]: - ... + def desc(self) -> UnaryExpression[_T_co]: ... - def asc(self) -> UnaryExpression[_T_co]: - ... + def asc(self) -> UnaryExpression[_T_co]: ... - def nulls_first(self) -> UnaryExpression[_T_co]: - ... + def nulls_first(self) -> UnaryExpression[_T_co]: ... - def nullsfirst(self) -> UnaryExpression[_T_co]: - ... + def nullsfirst(self) -> UnaryExpression[_T_co]: ... - def nulls_last(self) -> UnaryExpression[_T_co]: - ... + def nulls_last(self) -> UnaryExpression[_T_co]: ... - def nullslast(self) -> UnaryExpression[_T_co]: - ... + def nullslast(self) -> UnaryExpression[_T_co]: ... - def collate(self, collation: str) -> CollationClause: - ... + def collate(self, collation: str) -> CollationClause: ... def between( self, cleft: Any, cright: Any, symmetric: bool = False - ) -> BinaryExpression[bool]: - ... + ) -> BinaryExpression[bool]: ... - def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: - ... + def distinct(self: _SQO[_T_co]) -> UnaryExpression[_T_co]: ... - def any_(self) -> CollectionAggregate[Any]: - ... + def any_(self) -> CollectionAggregate[Any]: ... - def all_(self) -> CollectionAggregate[Any]: - ... + def all_(self) -> CollectionAggregate[Any]: ... # numeric overloads. These need more tweaking # in particular they all need to have a variant for Optiona[_T] @@ -1132,159 +1060,126 @@ class SQLCoreOperations(Generic[_T_co], ColumnOperators, TypingOnly): def __add__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload def __add__( self: _SQO[str], other: Any, - ) -> ColumnElement[str]: - ... + ) -> ColumnElement[str]: ... - def __add__(self, other: Any) -> ColumnElement[Any]: - ... + def __add__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __radd__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: - ... + def __radd__(self: _SQO[str], other: Any) -> ColumnElement[str]: ... - def __radd__(self, other: Any) -> ColumnElement[Any]: - ... + def __radd__(self, other: Any) -> ColumnElement[Any]: ... @overload def __sub__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __sub__(self, other: Any) -> ColumnElement[Any]: - ... + def __sub__(self, other: Any) -> ColumnElement[Any]: ... - def __sub__(self, other: Any) -> ColumnElement[Any]: - ... + def __sub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rsub__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __rsub__(self, other: Any) -> ColumnElement[Any]: - ... + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... - def __rsub__(self, other: Any) -> ColumnElement[Any]: - ... + def __rsub__(self, other: Any) -> ColumnElement[Any]: ... @overload def __mul__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __mul__(self, other: Any) -> ColumnElement[Any]: - ... + def __mul__(self, other: Any) -> ColumnElement[Any]: ... - def __mul__(self, other: Any) -> ColumnElement[Any]: - ... + def __mul__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rmul__( self: _SQO[_NMT], other: Any, - ) -> ColumnElement[_NMT]: - ... + ) -> ColumnElement[_NMT]: ... @overload - def __rmul__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... - def __rmul__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmul__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __mod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __mod__(self, other: Any) -> ColumnElement[Any]: - ... + def __mod__(self, other: Any) -> ColumnElement[Any]: ... - def __mod__(self, other: Any) -> ColumnElement[Any]: - ... + def __mod__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __rmod__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: ... @overload - def __rmod__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... - def __rmod__(self, other: Any) -> ColumnElement[Any]: - ... + def __rmod__(self, other: Any) -> ColumnElement[Any]: ... @overload def __truediv__( self: _SQO[int], other: Any - ) -> ColumnElement[_NUMERIC]: - ... + ) -> ColumnElement[_NUMERIC]: ... @overload - def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: - ... + def __truediv__(self: _SQO[_NT], other: Any) -> ColumnElement[_NT]: ... @overload - def __truediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... - def __truediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __truediv__(self, other: Any) -> ColumnElement[Any]: ... @overload def __rtruediv__( self: _SQO[_NMT], other: Any - ) -> ColumnElement[_NUMERIC]: - ... + ) -> ColumnElement[_NUMERIC]: ... @overload - def __rtruediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... - def __rtruediv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rtruediv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __floordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __floordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... @overload - def __floordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __floordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __floordiv__(self, other: Any) -> ColumnElement[Any]: ... @overload - def __rfloordiv__(self: _SQO[_NMT], other: Any) -> ColumnElement[_NMT]: - ... + def __rfloordiv__( + self: _SQO[_NMT], other: Any + ) -> ColumnElement[_NMT]: ... @overload - def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... - def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: - ... + def __rfloordiv__(self, other: Any) -> ColumnElement[Any]: ... class SQLColumnExpression( @@ -1536,14 +1431,12 @@ class ColumnElement( @overload def self_group( self: ColumnElement[_T], against: Optional[OperatorType] = None - ) -> ColumnElement[_T]: - ... + ) -> ColumnElement[_T]: ... @overload def self_group( self: ColumnElement[Any], against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def self_group( self, against: Optional[OperatorType] = None @@ -1559,12 +1452,10 @@ class ColumnElement( return self @overload - def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: - ... + def _negate(self: ColumnElement[bool]) -> ColumnElement[bool]: ... @overload - def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: - ... + def _negate(self: ColumnElement[_T]) -> ColumnElement[_T]: ... def _negate(self) -> ColumnElement[Any]: if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: @@ -1768,9 +1659,11 @@ class ColumnElement( assert key is not None co: ColumnClause[_T] = ColumnClause( - coercions.expect(roles.TruncatedLabelRole, name) - if name_is_truncatable - else name, + ( + coercions.expect(roles.TruncatedLabelRole, name) + if name_is_truncatable + else name + ), type_=getattr(self, "type", None), _selectable=selectable, ) @@ -2082,9 +1975,12 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): if unique: self.key = _anonymous_label.safe_construct( id(self), - key - if key is not None and not isinstance(key, _anonymous_label) - else "param", + ( + key + if key is not None + and not isinstance(key, _anonymous_label) + else "param" + ), sanitize_key=True, ) self._key_is_anon = True @@ -2145,9 +2041,9 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): check_value = value[0] else: check_value = value - cast( - "BindParameter[TupleAny]", self - ).type = type_._resolve_values_to_types(check_value) + cast("BindParameter[TupleAny]", self).type = ( + type_._resolve_values_to_types(check_value) + ) else: cast("BindParameter[TupleAny]", self).type = type_ else: @@ -2653,9 +2549,11 @@ class TextClause( ] positional_input_cols = [ - ColumnClause(col.key, types.pop(col.key)) - if col.key in types - else col + ( + ColumnClause(col.key, types.pop(col.key)) + if col.key in types + else col + ) for col in input_cols ] keyed_input_cols: List[NamedColumn[Any]] = [ @@ -3167,9 +3065,11 @@ class BooleanClauseList(ExpressionClauseList[bool]): # which will link elements against the operator. flattened_clauses = itertools.chain.from_iterable( - (c for c in to_flat._flattened_operator_clauses) - if getattr(to_flat, "operator", None) is operator - else (to_flat,) + ( + (c for c in to_flat._flattened_operator_clauses) + if getattr(to_flat, "operator", None) is operator + else (to_flat,) + ) for to_flat in convert_clauses ) @@ -4027,8 +3927,7 @@ class BinaryExpression(OperatorExpression[_T]): def __invert__( self: BinaryExpression[_T], - ) -> BinaryExpression[_T]: - ... + ) -> BinaryExpression[_T]: ... @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: @@ -4594,9 +4493,11 @@ class NamedColumn(KeyedColumnElement[_T]): **kw: Any, ) -> typing_Tuple[str, ColumnClause[_T]]: c = ColumnClause( - coercions.expect(roles.TruncatedLabelRole, name or self.name) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect(roles.TruncatedLabelRole, name or self.name) + if name_is_truncatable + else (name or self.name) + ), type_=self.type, _selectable=selectable, is_literal=False, @@ -5024,9 +4925,11 @@ class ColumnClause( ) ) c = self._constructor( - coercions.expect(roles.TruncatedLabelRole, name or self.name) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect(roles.TruncatedLabelRole, name or self.name) + if name_is_truncatable + else (name or self.name) + ), type_=self.type, _selectable=selectable, is_literal=is_literal, @@ -5169,13 +5072,11 @@ class quoted_name(util.MemoizedSlots, str): @overload @classmethod - def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: - ... + def construct(cls, value: str, quote: Optional[bool]) -> quoted_name: ... @overload @classmethod - def construct(cls, value: None, quote: Optional[bool]) -> None: - ... + def construct(cls, value: None, quote: Optional[bool]) -> None: ... @classmethod def construct( diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index 19ad313024..088b506c76 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -84,9 +84,9 @@ if TYPE_CHECKING: _T = TypeVar("_T", bound=Any) _S = TypeVar("_S", bound=Any) -_registry: util.defaultdict[ - str, Dict[str, Type[Function[Any]]] -] = util.defaultdict(dict) +_registry: util.defaultdict[str, Dict[str, Type[Function[Any]]]] = ( + util.defaultdict(dict) +) def register_function( @@ -486,16 +486,14 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): return WithinGroup(self, *order_by) @overload - def filter(self) -> Self: - ... + def filter(self) -> Self: ... @overload def filter( self, __criterion0: _ColumnExpressionArgument[bool], *criterion: _ColumnExpressionArgument[bool], - ) -> FunctionFilter[_T]: - ... + ) -> FunctionFilter[_T]: ... def filter( self, *criterion: _ColumnExpressionArgument[bool] @@ -945,12 +943,10 @@ class _FunctionGenerator: @overload def __call__( self, *c: Any, type_: _TypeEngineArgument[_T], **kwargs: Any - ) -> Function[_T]: - ... + ) -> Function[_T]: ... @overload - def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: - ... + def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: ... def __call__(self, *c: Any, **kwargs: Any) -> Function[Any]: o = self.opts.copy() @@ -981,24 +977,19 @@ class _FunctionGenerator: # statically generated** by tools/generate_sql_functions.py @property - def aggregate_strings(self) -> Type[aggregate_strings]: - ... + def aggregate_strings(self) -> Type[aggregate_strings]: ... @property - def ansifunction(self) -> Type[AnsiFunction[Any]]: - ... + def ansifunction(self) -> Type[AnsiFunction[Any]]: ... @property - def array_agg(self) -> Type[array_agg[Any]]: - ... + def array_agg(self) -> Type[array_agg[Any]]: ... @property - def cast(self) -> Type[Cast[Any]]: - ... + def cast(self) -> Type[Cast[Any]]: ... @property - def char_length(self) -> Type[char_length]: - ... + def char_length(self) -> Type[char_length]: ... # set ColumnElement[_T] as a separate overload, to appease mypy # which seems to not want to accept _T from _ColumnExpressionArgument. @@ -1011,8 +1002,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: - ... + ) -> coalesce[_T]: ... @overload def coalesce( @@ -1020,8 +1010,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: - ... + ) -> coalesce[_T]: ... @overload def coalesce( @@ -1029,68 +1018,53 @@ class _FunctionGenerator: col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: - ... + ) -> coalesce[_T]: ... def coalesce( self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> coalesce[_T]: - ... + ) -> coalesce[_T]: ... @property - def concat(self) -> Type[concat]: - ... + def concat(self) -> Type[concat]: ... @property - def count(self) -> Type[count]: - ... + def count(self) -> Type[count]: ... @property - def cube(self) -> Type[cube[Any]]: - ... + def cube(self) -> Type[cube[Any]]: ... @property - def cume_dist(self) -> Type[cume_dist]: - ... + def cume_dist(self) -> Type[cume_dist]: ... @property - def current_date(self) -> Type[current_date]: - ... + def current_date(self) -> Type[current_date]: ... @property - def current_time(self) -> Type[current_time]: - ... + def current_time(self) -> Type[current_time]: ... @property - def current_timestamp(self) -> Type[current_timestamp]: - ... + def current_timestamp(self) -> Type[current_timestamp]: ... @property - def current_user(self) -> Type[current_user]: - ... + def current_user(self) -> Type[current_user]: ... @property - def dense_rank(self) -> Type[dense_rank]: - ... + def dense_rank(self) -> Type[dense_rank]: ... @property - def extract(self) -> Type[Extract]: - ... + def extract(self) -> Type[Extract]: ... @property - def grouping_sets(self) -> Type[grouping_sets[Any]]: - ... + def grouping_sets(self) -> Type[grouping_sets[Any]]: ... @property - def localtime(self) -> Type[localtime]: - ... + def localtime(self) -> Type[localtime]: ... @property - def localtimestamp(self) -> Type[localtimestamp]: - ... + def localtimestamp(self) -> Type[localtimestamp]: ... # set ColumnElement[_T] as a separate overload, to appease mypy # which seems to not want to accept _T from _ColumnExpressionArgument. @@ -1103,8 +1077,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: - ... + ) -> max[_T]: ... @overload def max( # noqa: A001 @@ -1112,8 +1085,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: - ... + ) -> max[_T]: ... @overload def max( # noqa: A001 @@ -1121,16 +1093,14 @@ class _FunctionGenerator: col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: - ... + ) -> max[_T]: ... def max( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> max[_T]: - ... + ) -> max[_T]: ... # set ColumnElement[_T] as a separate overload, to appease mypy # which seems to not want to accept _T from _ColumnExpressionArgument. @@ -1143,8 +1113,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: - ... + ) -> min[_T]: ... @overload def min( # noqa: A001 @@ -1152,8 +1121,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: - ... + ) -> min[_T]: ... @overload def min( # noqa: A001 @@ -1161,60 +1129,47 @@ class _FunctionGenerator: col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: - ... + ) -> min[_T]: ... def min( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> min[_T]: - ... + ) -> min[_T]: ... @property - def mode(self) -> Type[mode[Any]]: - ... + def mode(self) -> Type[mode[Any]]: ... @property - def next_value(self) -> Type[next_value]: - ... + def next_value(self) -> Type[next_value]: ... @property - def now(self) -> Type[now]: - ... + def now(self) -> Type[now]: ... @property - def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: - ... + def orderedsetagg(self) -> Type[OrderedSetAgg[Any]]: ... @property - def percent_rank(self) -> Type[percent_rank]: - ... + def percent_rank(self) -> Type[percent_rank]: ... @property - def percentile_cont(self) -> Type[percentile_cont[Any]]: - ... + def percentile_cont(self) -> Type[percentile_cont[Any]]: ... @property - def percentile_disc(self) -> Type[percentile_disc[Any]]: - ... + def percentile_disc(self) -> Type[percentile_disc[Any]]: ... @property - def random(self) -> Type[random]: - ... + def random(self) -> Type[random]: ... @property - def rank(self) -> Type[rank]: - ... + def rank(self) -> Type[rank]: ... @property - def rollup(self) -> Type[rollup[Any]]: - ... + def rollup(self) -> Type[rollup[Any]]: ... @property - def session_user(self) -> Type[session_user]: - ... + def session_user(self) -> Type[session_user]: ... # set ColumnElement[_T] as a separate overload, to appease mypy # which seems to not want to accept _T from _ColumnExpressionArgument. @@ -1227,8 +1182,7 @@ class _FunctionGenerator: col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: - ... + ) -> sum[_T]: ... @overload def sum( # noqa: A001 @@ -1236,8 +1190,7 @@ class _FunctionGenerator: col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: - ... + ) -> sum[_T]: ... @overload def sum( # noqa: A001 @@ -1245,24 +1198,20 @@ class _FunctionGenerator: col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: - ... + ) -> sum[_T]: ... def sum( # noqa: A001 self, col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ) -> sum[_T]: - ... + ) -> sum[_T]: ... @property - def sysdate(self) -> Type[sysdate]: - ... + def sysdate(self) -> Type[sysdate]: ... @property - def user(self) -> Type[user]: - ... + def user(self) -> Type[user]: ... # END GENERATED FUNCTION ACCESSORS @@ -1342,8 +1291,7 @@ class Function(FunctionElement[_T]): *clauses: _ColumnExpressionOrLiteralArgument[_T], type_: None = ..., packagenames: Optional[Tuple[str, ...]] = ..., - ): - ... + ): ... @overload def __init__( @@ -1352,8 +1300,7 @@ class Function(FunctionElement[_T]): *clauses: _ColumnExpressionOrLiteralArgument[Any], type_: _TypeEngineArgument[_T] = ..., packagenames: Optional[Tuple[str, ...]] = ..., - ): - ... + ): ... def __init__( self, @@ -1632,8 +1579,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): col: ColumnElement[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): - ... + ): ... @overload def __init__( @@ -1641,8 +1587,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): col: _ColumnExpressionArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): - ... + ): ... @overload def __init__( @@ -1650,8 +1595,7 @@ class ReturnTypeFromArgs(GenericFunction[_T]): col: _ColumnExpressionOrLiteralArgument[_T], *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any, - ): - ... + ): ... def __init__( self, *args: _ColumnExpressionOrLiteralArgument[Any], **kwargs: Any @@ -1771,6 +1715,7 @@ class count(GenericFunction[int]): """ + type = sqltypes.Integer() inherit_cache = True @@ -2023,6 +1968,7 @@ class cube(GenericFunction[_T]): .. versionadded:: 1.2 """ + _has_args = True inherit_cache = True @@ -2040,6 +1986,7 @@ class rollup(GenericFunction[_T]): .. versionadded:: 1.2 """ + _has_args = True inherit_cache = True @@ -2073,6 +2020,7 @@ class grouping_sets(GenericFunction[_T]): .. versionadded:: 1.2 """ + _has_args = True inherit_cache = True diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index a53ebae797..726fa2411f 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -407,9 +407,9 @@ class LambdaElement(elements.ClauseElement): while parent is not None: assert parent.closure_cache_key is not CacheConst.NO_CACHE - parent_closure_cache_key: Tuple[ - Any, ... - ] = parent.closure_cache_key + parent_closure_cache_key: Tuple[Any, ...] = ( + parent.closure_cache_key + ) cache_key = ( (parent.fn.__code__,) + parent_closure_cache_key + cache_key @@ -535,8 +535,7 @@ class StatementLambdaElement( role: Type[SQLRole], opts: Union[Type[LambdaOptions], LambdaOptions] = LambdaOptions, apply_propagate_attrs: Optional[ClauseElement] = None, - ): - ... + ): ... def __add__( self, other: _StmtLambdaElementType[Any] @@ -737,9 +736,9 @@ class AnalyzedCode: "closure_trackers", "build_py_wrappers", ) - _fns: weakref.WeakKeyDictionary[ - CodeType, AnalyzedCode - ] = weakref.WeakKeyDictionary() + _fns: weakref.WeakKeyDictionary[CodeType, AnalyzedCode] = ( + weakref.WeakKeyDictionary() + ) _generation_mutex = threading.RLock() @@ -1184,12 +1183,12 @@ class AnalyzedFunction: # rewrite the original fn. things that look like they will # become bound parameters are wrapped in a PyWrapper. - self.tracker_instrumented_fn = ( - tracker_instrumented_fn - ) = self._rewrite_code_obj( - fn, - [new_closure[name] for name in fn.__code__.co_freevars], - new_globals, + self.tracker_instrumented_fn = tracker_instrumented_fn = ( + self._rewrite_code_obj( + fn, + [new_closure[name] for name in fn.__code__.co_freevars], + new_globals, + ) ) # now invoke the function. This will give us a new SQL diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index 53fad3ea21..a5390ad6d0 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -77,8 +77,7 @@ class OperatorType(Protocol): right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def __call__( @@ -87,8 +86,7 @@ class OperatorType(Protocol): right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... def __call__( self, @@ -96,8 +94,7 @@ class OperatorType(Protocol): right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... add = cast(OperatorType, _uncast_add) @@ -466,8 +463,7 @@ class custom_op(OperatorType, Generic[_T]): right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... @overload def __call__( @@ -476,8 +472,7 @@ class custom_op(OperatorType, Generic[_T]): right: Optional[Any] = None, *other: Any, **kwargs: Any, - ) -> Operators: - ... + ) -> Operators: ... def __call__( self, @@ -545,13 +540,11 @@ class ColumnOperators(Operators): def operate( self, op: OperatorType, *other: Any, **kwargs: Any - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... def reverse_operate( self, op: OperatorType, other: Any, **kwargs: Any - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... def __lt__(self, other: Any) -> ColumnOperators: """Implement the ``<`` operator. @@ -574,8 +567,7 @@ class ColumnOperators(Operators): # https://docs.python.org/3/reference/datamodel.html#object.__hash__ if TYPE_CHECKING: - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... else: __hash__ = Operators.__hash__ @@ -623,8 +615,7 @@ class ColumnOperators(Operators): # deprecated 1.4; see #5435 if TYPE_CHECKING: - def isnot_distinct_from(self, other: Any) -> ColumnOperators: - ... + def isnot_distinct_from(self, other: Any) -> ColumnOperators: ... else: isnot_distinct_from = is_not_distinct_from @@ -964,8 +955,7 @@ class ColumnOperators(Operators): # deprecated 1.4; see #5429 if TYPE_CHECKING: - def notin_(self, other: Any) -> ColumnOperators: - ... + def notin_(self, other: Any) -> ColumnOperators: ... else: notin_ = not_in @@ -994,8 +984,7 @@ class ColumnOperators(Operators): def notlike( self, other: Any, escape: Optional[str] = None - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... else: notlike = not_like @@ -1024,8 +1013,7 @@ class ColumnOperators(Operators): def notilike( self, other: Any, escape: Optional[str] = None - ) -> ColumnOperators: - ... + ) -> ColumnOperators: ... else: notilike = not_ilike @@ -1063,8 +1051,7 @@ class ColumnOperators(Operators): # deprecated 1.4; see #5429 if TYPE_CHECKING: - def isnot(self, other: Any) -> ColumnOperators: - ... + def isnot(self, other: Any) -> ColumnOperators: ... else: isnot = is_not @@ -1728,8 +1715,7 @@ class ColumnOperators(Operators): # deprecated 1.4; see #5435 if TYPE_CHECKING: - def nullsfirst(self) -> ColumnOperators: - ... + def nullsfirst(self) -> ColumnOperators: ... else: nullsfirst = nulls_first @@ -1747,8 +1733,7 @@ class ColumnOperators(Operators): # deprecated 1.4; see #5429 if TYPE_CHECKING: - def nullslast(self) -> ColumnOperators: - ... + def nullslast(self) -> ColumnOperators: ... else: nullslast = nulls_last @@ -1968,8 +1953,7 @@ def is_true(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def istrue(a: Any) -> Any: - ... + def istrue(a: Any) -> Any: ... else: istrue = is_true @@ -1984,8 +1968,7 @@ def is_false(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isfalse(a: Any) -> Any: - ... + def isfalse(a: Any) -> Any: ... else: isfalse = is_false @@ -2007,8 +1990,7 @@ def is_not_distinct_from(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isnot_distinct_from(a: Any, b: Any) -> Any: - ... + def isnot_distinct_from(a: Any, b: Any) -> Any: ... else: isnot_distinct_from = is_not_distinct_from @@ -2030,8 +2012,7 @@ def is_not(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def isnot(a: Any, b: Any) -> Any: - ... + def isnot(a: Any, b: Any) -> Any: ... else: isnot = is_not @@ -2063,8 +2044,7 @@ def not_like_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: if TYPE_CHECKING: @_operator_fn - def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: - ... + def notlike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ... else: notlike_op = not_like_op @@ -2086,8 +2066,7 @@ def not_ilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: if TYPE_CHECKING: @_operator_fn - def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: - ... + def notilike_op(a: Any, b: Any, escape: Optional[str] = None) -> Any: ... else: notilike_op = not_ilike_op @@ -2109,8 +2088,9 @@ def not_between_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: if TYPE_CHECKING: @_operator_fn - def notbetween_op(a: Any, b: Any, c: Any, symmetric: bool = False) -> Any: - ... + def notbetween_op( + a: Any, b: Any, c: Any, symmetric: bool = False + ) -> Any: ... else: notbetween_op = not_between_op @@ -2132,8 +2112,7 @@ def not_in_op(a: Any, b: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def notin_op(a: Any, b: Any) -> Any: - ... + def notin_op(a: Any, b: Any) -> Any: ... else: notin_op = not_in_op @@ -2198,8 +2177,7 @@ if TYPE_CHECKING: @_operator_fn def notstartswith_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notstartswith_op = not_startswith_op @@ -2243,8 +2221,7 @@ if TYPE_CHECKING: @_operator_fn def notendswith_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notendswith_op = not_endswith_op @@ -2288,8 +2265,7 @@ if TYPE_CHECKING: @_operator_fn def notcontains_op( a: Any, b: Any, escape: Optional[str] = None, autoescape: bool = False - ) -> Any: - ... + ) -> Any: ... else: notcontains_op = not_contains_op @@ -2346,8 +2322,7 @@ def not_match_op(a: Any, b: Any, **kw: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def notmatch_op(a: Any, b: Any, **kw: Any) -> Any: - ... + def notmatch_op(a: Any, b: Any, **kw: Any) -> Any: ... else: notmatch_op = not_match_op @@ -2392,8 +2367,7 @@ def nulls_first_op(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def nullsfirst_op(a: Any) -> Any: - ... + def nullsfirst_op(a: Any) -> Any: ... else: nullsfirst_op = nulls_first_op @@ -2408,8 +2382,7 @@ def nulls_last_op(a: Any) -> Any: if TYPE_CHECKING: @_operator_fn - def nullslast_op(a: Any) -> Any: - ... + def nullslast_op(a: Any) -> Any: ... else: nullslast_op = nulls_last_op diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 42c561cb4b..ae70ac3a5b 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -227,8 +227,7 @@ class AnonymizedFromClauseRole(StrictFromClauseRole): def _anonymous_fromclause( self, *, name: Optional[str] = None, flat: bool = False - ) -> FromClause: - ... + ) -> FromClause: ... class ReturnsRowsRole(SQLRole): @@ -246,8 +245,7 @@ class StatementRole(SQLRole): if TYPE_CHECKING: @util.memoized_property - def _propagate_attrs(self) -> _PropagateAttrsType: - ... + def _propagate_attrs(self) -> _PropagateAttrsType: ... else: _propagate_attrs = util.EMPTY_DICT diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 7d3d1f521e..5759982d09 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -160,15 +160,15 @@ class SchemaConst(Enum): """ -RETAIN_SCHEMA: Final[ - Literal[SchemaConst.RETAIN_SCHEMA] -] = SchemaConst.RETAIN_SCHEMA -BLANK_SCHEMA: Final[ - Literal[SchemaConst.BLANK_SCHEMA] -] = SchemaConst.BLANK_SCHEMA -NULL_UNSPECIFIED: Final[ - Literal[SchemaConst.NULL_UNSPECIFIED] -] = SchemaConst.NULL_UNSPECIFIED +RETAIN_SCHEMA: Final[Literal[SchemaConst.RETAIN_SCHEMA]] = ( + SchemaConst.RETAIN_SCHEMA +) +BLANK_SCHEMA: Final[Literal[SchemaConst.BLANK_SCHEMA]] = ( + SchemaConst.BLANK_SCHEMA +) +NULL_UNSPECIFIED: Final[Literal[SchemaConst.NULL_UNSPECIFIED]] = ( + SchemaConst.NULL_UNSPECIFIED +) def _get_table_key(name: str, schema: Optional[str]) -> str: @@ -345,12 +345,10 @@ class Table( if TYPE_CHECKING: @util.ro_non_memoized_property - def primary_key(self) -> PrimaryKeyConstraint: - ... + def primary_key(self) -> PrimaryKeyConstraint: ... @util.ro_non_memoized_property - def foreign_keys(self) -> Set[ForeignKey]: - ... + def foreign_keys(self) -> Set[ForeignKey]: ... _columns: DedupeColumnCollection[Column[Any]] @@ -402,18 +400,15 @@ class Table( if TYPE_CHECKING: @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... @util.ro_non_memoized_property def exported_columns( self, - ) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + ) -> ReadOnlyColumnCollection[str, Column[Any]]: ... @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: - ... + def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... def _gen_cache_key( self, anon_map: anon_map, bindparams: List[BindParameter[Any]] @@ -2465,9 +2460,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): dialect_option_key, dialect_option_value, ) in dialect_options.items(): - column_kwargs[ - dialect_name + "_" + dialect_option_key - ] = dialect_option_value + column_kwargs[dialect_name + "_" + dialect_option_key] = ( + dialect_option_value + ) server_default = self.server_default server_onupdate = self.server_onupdate @@ -2638,19 +2633,23 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) try: c = self._constructor( - coercions.expect( - roles.TruncatedLabelRole, name if name else self.name - ) - if name_is_truncatable - else (name or self.name), + ( + coercions.expect( + roles.TruncatedLabelRole, name if name else self.name + ) + if name_is_truncatable + else (name or self.name) + ), self.type, # this may actually be ._proxy_key when the key is incoming key=key if key else name if name else self.key, primary_key=self.primary_key, nullable=self.nullable, - _proxies=list(compound_select_cols) - if compound_select_cols - else [self], + _proxies=( + list(compound_select_cols) + if compound_select_cols + else [self] + ), *fk, ) except TypeError as err: @@ -2715,9 +2714,9 @@ def insert_sentinel( return Column( name=name, type_=type_api.INTEGERTYPE if type_ is None else type_, - default=default - if default is not None - else _InsertSentinelColumnDefault(), + default=( + default if default is not None else _InsertSentinelColumnDefault() + ), _omit_from_statements=omit_from_statements, insert_sentinel=True, ) @@ -2890,7 +2889,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _resolve_colspec_argument( self, - ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]: + ) -> Tuple[ + Union[str, Column[Any]], + Optional[Column[Any]], + ]: argument = self._colspec return self._parse_colspec_argument(argument) @@ -2898,7 +2900,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): def _parse_colspec_argument( self, argument: _DDLColumnArgument, - ) -> Tuple[Union[str, Column[Any]], Optional[Column[Any]],]: + ) -> Tuple[ + Union[str, Column[Any]], + Optional[Column[Any]], + ]: _colspec = coercions.expect(roles.DDLReferredColumnRole, argument) if isinstance(_colspec, str): @@ -3181,14 +3186,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): return self._resolve_column() @overload - def _resolve_column(self, *, raiseerr: Literal[True] = ...) -> Column[Any]: - ... + def _resolve_column( + self, *, raiseerr: Literal[True] = ... + ) -> Column[Any]: ... @overload def _resolve_column( self, *, raiseerr: bool = ... - ) -> Optional[Column[Any]]: - ... + ) -> Optional[Column[Any]]: ... def _resolve_column( self, *, raiseerr: bool = True @@ -3309,18 +3314,15 @@ if TYPE_CHECKING: def default_is_sequence( obj: Optional[DefaultGenerator], - ) -> TypeGuard[Sequence]: - ... + ) -> TypeGuard[Sequence]: ... def default_is_clause_element( obj: Optional[DefaultGenerator], - ) -> TypeGuard[ColumnElementColumnDefault]: - ... + ) -> TypeGuard[ColumnElementColumnDefault]: ... def default_is_scalar( obj: Optional[DefaultGenerator], - ) -> TypeGuard[ScalarElementColumnDefault]: - ... + ) -> TypeGuard[ScalarElementColumnDefault]: ... else: default_is_sequence = operator.attrgetter("is_sequence") @@ -3420,21 +3422,18 @@ class ColumnDefault(DefaultGenerator, ABC): @overload def __new__( cls, arg: Callable[..., Any], for_update: bool = ... - ) -> CallableColumnDefault: - ... + ) -> CallableColumnDefault: ... @overload def __new__( cls, arg: ColumnElement[Any], for_update: bool = ... - ) -> ColumnElementColumnDefault: - ... + ) -> ColumnElementColumnDefault: ... # if I return ScalarElementColumnDefault here, which is what's actually # returned, mypy complains that # overloads overlap w/ incompatible return types. @overload - def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault: - ... + def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault: ... def __new__( cls, arg: Any = None, for_update: bool = False @@ -3576,8 +3575,7 @@ class ColumnElementColumnDefault(ColumnDefault): class _CallableColumnDefaultProtocol(Protocol): - def __call__(self, context: ExecutionContext) -> Any: - ... + def __call__(self, context: ExecutionContext) -> Any: ... class CallableColumnDefault(ColumnDefault): @@ -4247,8 +4245,7 @@ class ColumnCollectionMixin: def _set_parent_with_dispatch( self, parent: SchemaEventTarget, **kw: Any - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -4461,9 +4458,9 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): dialect_option_key, dialect_option_value, ) in dialect_options.items(): - constraint_kwargs[ - dialect_name + "_" + dialect_option_key - ] = dialect_option_value + constraint_kwargs[dialect_name + "_" + dialect_option_key] = ( + dialect_option_value + ) assert isinstance(self.parent, Table) c = self.__class__( @@ -4886,11 +4883,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): [ x._get_colspec( schema=schema, - table_name=target_table.name - if target_table is not None - and x._table_key_within_construction() - == x.parent.table.key - else None, + table_name=( + target_table.name + if target_table is not None + and x._table_key_within_construction() + == x.parent.table.key + else None + ), _is_copy=True, ) for x in self.elements @@ -5554,9 +5553,9 @@ class MetaData(HasSchemaAttr): self.info = info self._schemas: Set[str] = set() self._sequences: Dict[str, Sequence] = {} - self._fk_memos: Dict[ - Tuple[str, Optional[str]], List[ForeignKey] - ] = collections.defaultdict(list) + self._fk_memos: Dict[Tuple[str, Optional[str]], List[ForeignKey]] = ( + collections.defaultdict(list) + ) tables: util.FacadeDict[str, Table] """A dictionary of :class:`_schema.Table` diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ae52e5db45..4ae60b7724 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -159,12 +159,10 @@ _LabelConventionCallable = Callable[ class _JoinTargetProtocol(Protocol): @util.ro_non_memoized_property - def _from_objects(self) -> List[FromClause]: - ... + def _from_objects(self) -> List[FromClause]: ... @util.ro_non_memoized_property - def entity_namespace(self) -> _EntityNamespace: - ... + def entity_namespace(self) -> _EntityNamespace: ... _JoinTargetElement = Union["FromClause", _JoinTargetProtocol] @@ -470,9 +468,9 @@ class HasSuffixes: class HasHints: - _hints: util.immutabledict[ - Tuple[FromClause, str], str - ] = util.immutabledict() + _hints: util.immutabledict[Tuple[FromClause, str], str] = ( + util.immutabledict() + ) _statement_hints: Tuple[Tuple[str, str], ...] = () _has_hints_traverse_internals: _TraverseInternalsType = [ @@ -993,8 +991,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): def self_group( self, against: Optional[OperatorType] = None - ) -> Union[FromGrouping, Self]: - ... + ) -> Union[FromGrouping, Self]: ... class NamedFromClause(FromClause): @@ -2261,9 +2258,9 @@ class SelectsRows(ReturnsRows): repeated = False if not c._render_label_in_columns_clause: - effective_name = ( - required_label_name - ) = fallback_label_name = None + effective_name = required_label_name = fallback_label_name = ( + None + ) elif label_style_none: if TYPE_CHECKING: assert is_column_element(c) @@ -2275,9 +2272,9 @@ class SelectsRows(ReturnsRows): assert is_column_element(c) if table_qualified: - required_label_name = ( - effective_name - ) = fallback_label_name = c._tq_label + required_label_name = effective_name = ( + fallback_label_name + ) = c._tq_label else: effective_name = fallback_label_name = c._non_anon_label required_label_name = None @@ -2308,9 +2305,9 @@ class SelectsRows(ReturnsRows): else: fallback_label_name = c._anon_name_label else: - required_label_name = ( - effective_name - ) = fallback_label_name = expr_label + required_label_name = effective_name = ( + fallback_label_name + ) = expr_label if effective_name is not None: if TYPE_CHECKING: @@ -2324,13 +2321,13 @@ class SelectsRows(ReturnsRows): # different column under the same name. apply # disambiguating label if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._anon_tq_label + required_label_name = fallback_label_name = ( + c._anon_tq_label + ) else: - required_label_name = ( - fallback_label_name - ) = c._anon_name_label + required_label_name = fallback_label_name = ( + c._anon_name_label + ) if anon_for_dupe_key and required_label_name in names: # here, c._anon_tq_label is definitely unique to @@ -2345,14 +2342,14 @@ class SelectsRows(ReturnsRows): # subsequent occurrences of the column so that the # original stays non-ambiguous if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) dedupe_hash += 1 else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_label_idx(dedupe_hash) + ) dedupe_hash += 1 repeated = True else: @@ -2361,14 +2358,14 @@ class SelectsRows(ReturnsRows): # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous if table_qualified: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_tq_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_tq_label_idx(dedupe_hash) + ) dedupe_hash += 1 else: - required_label_name = ( - fallback_label_name - ) = c._dedupe_anon_label_idx(dedupe_hash) + required_label_name = fallback_label_name = ( + c._dedupe_anon_label_idx(dedupe_hash) + ) dedupe_hash += 1 repeated = True else: @@ -2985,12 +2982,12 @@ class TableClause(roles.DMLTableRole, Immutable, NamedFromClause): if TYPE_CHECKING: @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... + def columns( + self, + ) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: - ... + def c(self) -> ReadOnlyColumnCollection[str, ColumnClause[Any]]: ... def __str__(self) -> str: if self.schema is not None: @@ -3697,8 +3694,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase, Generic[_SB]): if TYPE_CHECKING: - def _ungroup(self) -> _SB: - ... + def _ungroup(self) -> _SB: ... # def _generate_columns_plus_names( # self, anon_for_dupe_key: bool @@ -3918,14 +3914,12 @@ class GenerativeSelect(SelectBase, Generative): @overload def _offset_or_limit_clause_asint( self, clause: ColumnElement[Any], attrname: str - ) -> NoReturn: - ... + ) -> NoReturn: ... @overload def _offset_or_limit_clause_asint( self, clause: Optional[_OffsetLimitParam], attrname: str - ) -> Optional[int]: - ... + ) -> Optional[int]: ... def _offset_or_limit_clause_asint( self, clause: Optional[ColumnElement[Any]], attrname: str @@ -4492,8 +4486,9 @@ class SelectState(util.MemoizedSlots, CompileState): if TYPE_CHECKING: @classmethod - def get_plugin_class(cls, statement: Executable) -> Type[SelectState]: - ... + def get_plugin_class( + cls, statement: Executable + ) -> Type[SelectState]: ... def __init__( self, @@ -5192,21 +5187,17 @@ class Select( @overload def scalar_subquery( self: Select[_MAYBE_ENTITY], - ) -> ScalarSelect[Any]: - ... + ) -> ScalarSelect[Any]: ... @overload def scalar_subquery( self: Select[_NOT_ENTITY], - ) -> ScalarSelect[_NOT_ENTITY]: - ... + ) -> ScalarSelect[_NOT_ENTITY]: ... @overload - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... - def scalar_subquery(self) -> ScalarSelect[Any]: - ... + def scalar_subquery(self) -> ScalarSelect[Any]: ... def filter_by(self, **kwargs: Any) -> Self: r"""apply the given filtering criterion as a WHERE clause @@ -5789,20 +5780,17 @@ class Select( # statically generated** by tools/generate_sel_v1_overloads.py @overload - def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[_T0]: - ... + def with_only_columns(self, __ent0: _TCCA[_T0]) -> Select[_T0]: ... @overload def with_only_columns( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1] - ) -> Select[_T0, _T1]: - ... + ) -> Select[_T0, _T1]: ... @overload def with_only_columns( self, __ent0: _TCCA[_T0], __ent1: _TCCA[_T1], __ent2: _TCCA[_T2] - ) -> Select[_T0, _T1, _T2]: - ... + ) -> Select[_T0, _T1, _T2]: ... @overload def with_only_columns( @@ -5811,8 +5799,7 @@ class Select( __ent1: _TCCA[_T1], __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], - ) -> Select[_T0, _T1, _T2, _T3]: - ... + ) -> Select[_T0, _T1, _T2, _T3]: ... @overload def with_only_columns( @@ -5822,8 +5809,7 @@ class Select( __ent2: _TCCA[_T2], __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], - ) -> Select[_T0, _T1, _T2, _T3, _T4]: - ... + ) -> Select[_T0, _T1, _T2, _T3, _T4]: ... @overload def with_only_columns( @@ -5834,8 +5820,7 @@ class Select( __ent3: _TCCA[_T3], __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], - ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: - ... + ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5]: ... @overload def with_only_columns( @@ -5847,8 +5832,7 @@ class Select( __ent4: _TCCA[_T4], __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], - ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: - ... + ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6]: ... @overload def with_only_columns( @@ -5861,8 +5845,7 @@ class Select( __ent5: _TCCA[_T5], __ent6: _TCCA[_T6], __ent7: _TCCA[_T7], - ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: - ... + ) -> Select[_T0, _T1, _T2, _T3, _T4, _T5, _T6, _T7]: ... # END OVERLOADED FUNCTIONS self.with_only_columns @@ -5872,8 +5855,7 @@ class Select( *entities: _ColumnsClauseArgument[Any], maintain_column_froms: bool = False, **__kw: Any, - ) -> Select[Unpack[TupleAny]]: - ... + ) -> Select[Unpack[TupleAny]]: ... @_generative def with_only_columns( @@ -6542,14 +6524,12 @@ class ScalarSelect( @overload def self_group( self: ScalarSelect[Any], against: Optional[OperatorType] = None - ) -> ScalarSelect[Any]: - ... + ) -> ScalarSelect[Any]: ... @overload def self_group( self: ColumnElement[Any], against: Optional[OperatorType] = None - ) -> ColumnElement[Any]: - ... + ) -> ColumnElement[Any]: ... def self_group( self, against: Optional[OperatorType] = None @@ -6558,8 +6538,7 @@ class ScalarSelect( if TYPE_CHECKING: - def _ungroup(self) -> Select[Unpack[TupleAny]]: - ... + def _ungroup(self) -> Select[Unpack[TupleAny]]: ... @_generative def correlate( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index a9e0084995..42bce99a82 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -81,7 +81,6 @@ _TE = TypeVar("_TE", bound="TypeEngine[Any]") class HasExpressionLookup(TypeEngineMixin): - """Mixin expression adaptations based on lookup tables. These rules are currently used by the numeric, integer and date types @@ -120,7 +119,6 @@ class HasExpressionLookup(TypeEngineMixin): class Concatenable(TypeEngineMixin): - """A mixin that marks a type as supporting 'concatenation', typically strings.""" @@ -169,7 +167,6 @@ class Indexable(TypeEngineMixin): class String(Concatenable, TypeEngine[str]): - """The base for all string and character types. In SQL, corresponds to VARCHAR. @@ -256,7 +253,6 @@ class String(Concatenable, TypeEngine[str]): class Text(String): - """A variably sized string type. In SQL, usually corresponds to CLOB or TEXT. In general, TEXT objects @@ -269,7 +265,6 @@ class Text(String): class Unicode(String): - """A variable length Unicode string type. The :class:`.Unicode` type is a :class:`.String` subclass that assumes @@ -323,7 +318,6 @@ class Unicode(String): class UnicodeText(Text): - """An unbounded-length Unicode string type. See :class:`.Unicode` for details on the unicode @@ -348,7 +342,6 @@ class UnicodeText(Text): class Integer(HasExpressionLookup, TypeEngine[int]): - """A type for ``int`` integers.""" __visit_name__ = "integer" @@ -356,8 +349,7 @@ class Integer(HasExpressionLookup, TypeEngine[int]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[Integer]: - ... + def _type_affinity(self) -> Type[Integer]: ... def get_dbapi_type(self, dbapi): return dbapi.NUMBER @@ -398,7 +390,6 @@ class Integer(HasExpressionLookup, TypeEngine[int]): class SmallInteger(Integer): - """A type for smaller ``int`` integers. Typically generates a ``SMALLINT`` in DDL, and otherwise acts like @@ -410,7 +401,6 @@ class SmallInteger(Integer): class BigInteger(Integer): - """A type for bigger ``int`` integers. Typically generates a ``BIGINT`` in DDL, and otherwise acts like @@ -425,7 +415,6 @@ _N = TypeVar("_N", bound=Union[decimal.Decimal, float]) class Numeric(HasExpressionLookup, TypeEngine[_N]): - """Base for non-integer numeric types, such as ``NUMERIC``, ``FLOAT``, ``DECIMAL``, and other variants. @@ -462,8 +451,7 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): if TYPE_CHECKING: @util.ro_memoized_property - def _type_affinity(self) -> Type[Numeric[_N]]: - ... + def _type_affinity(self) -> Type[Numeric[_N]]: ... _default_decimal_return_scale = 10 @@ -474,8 +462,7 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): scale: Optional[int] = ..., decimal_return_scale: Optional[int] = ..., asdecimal: Literal[True] = ..., - ): - ... + ): ... @overload def __init__( @@ -484,8 +471,7 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): scale: Optional[int] = ..., decimal_return_scale: Optional[int] = ..., asdecimal: Literal[False] = ..., - ): - ... + ): ... def __init__( self, @@ -581,9 +567,11 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): # we're a "numeric", DBAPI returns floats, convert. return processors.to_decimal_processor_factory( decimal.Decimal, - self.scale - if self.scale is not None - else self._default_decimal_return_scale, + ( + self.scale + if self.scale is not None + else self._default_decimal_return_scale + ), ) else: if dialect.supports_native_decimal: @@ -636,8 +624,7 @@ class Float(Numeric[_N]): precision: Optional[int] = ..., asdecimal: Literal[False] = ..., decimal_return_scale: Optional[int] = ..., - ): - ... + ): ... @overload def __init__( @@ -645,8 +632,7 @@ class Float(Numeric[_N]): precision: Optional[int] = ..., asdecimal: Literal[True] = ..., decimal_return_scale: Optional[int] = ..., - ): - ... + ): ... def __init__( self: Float[_N], @@ -754,7 +740,6 @@ class _RenderISO8601NoT: class DateTime( _RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.datetime] ): - """A type for ``datetime.datetime()`` objects. Date and time types return objects from the Python ``datetime`` @@ -818,7 +803,6 @@ class DateTime( class Date(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.date]): - """A type for ``datetime.date()`` objects.""" __visit_name__ = "date" @@ -859,7 +843,6 @@ class Date(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.date]): class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): - """A type for ``datetime.time()`` objects.""" __visit_name__ = "time" @@ -896,7 +879,6 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): class _Binary(TypeEngine[bytes]): - """Define base behavior for binary types.""" def __init__(self, length: Optional[int] = None): @@ -960,7 +942,6 @@ class _Binary(TypeEngine[bytes]): class LargeBinary(_Binary): - """A type for large binary byte data. The :class:`.LargeBinary` type corresponds to a large and/or unlengthed @@ -984,7 +965,6 @@ class LargeBinary(_Binary): class SchemaType(SchemaEventTarget, TypeEngineMixin): - """Add capabilities to a type which allow for schema-level DDL to be associated with a type. @@ -1122,12 +1102,12 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): ) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -1887,7 +1867,6 @@ class PickleType(TypeDecorator[object]): class Boolean(SchemaType, Emulated, TypeEngine[bool]): - """A bool datatype. :class:`.Boolean` typically uses BOOLEAN or SMALLINT on the DDL side, @@ -2045,7 +2024,6 @@ class _AbstractInterval(HasExpressionLookup, TypeEngine[dt.timedelta]): class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): - """A type for ``datetime.timedelta()`` objects. The Interval type deals with ``datetime.timedelta`` objects. In @@ -2546,9 +2524,11 @@ class JSON(Indexable, TypeEngine[Any]): index, expr=self.expr, operator=operators.json_getitem_op, - bindparam_type=JSON.JSONIntIndexType - if isinstance(index, int) - else JSON.JSONStrIndexType, + bindparam_type=( + JSON.JSONIntIndexType + if isinstance(index, int) + else JSON.JSONStrIndexType + ), ) operator = operators.json_getitem_op @@ -2870,7 +2850,6 @@ class ARRAY( Indexable.Comparator[Sequence[Any]], Concatenable.Comparator[Sequence[Any]], ): - """Define comparison operations for :class:`_types.ARRAY`. More operators are available on the dialect-specific form @@ -3145,14 +3124,16 @@ class ARRAY( return collection_callable(arr) else: return collection_callable( - self._apply_item_processor( - x, - itemproc, - dim - 1 if dim is not None else None, - collection_callable, + ( + self._apply_item_processor( + x, + itemproc, + dim - 1 if dim is not None else None, + collection_callable, + ) + if x is not None + else None ) - if x is not None - else None for x in arr ) @@ -3203,7 +3184,6 @@ class TupleType(TypeEngine[TupleAny]): class REAL(Float[_N]): - """The SQL REAL type. .. seealso:: @@ -3216,7 +3196,6 @@ class REAL(Float[_N]): class FLOAT(Float[_N]): - """The SQL FLOAT type. .. seealso:: @@ -3257,7 +3236,6 @@ class DOUBLE_PRECISION(Double[_N]): class NUMERIC(Numeric[_N]): - """The SQL NUMERIC type. .. seealso:: @@ -3270,7 +3248,6 @@ class NUMERIC(Numeric[_N]): class DECIMAL(Numeric[_N]): - """The SQL DECIMAL type. .. seealso:: @@ -3283,7 +3260,6 @@ class DECIMAL(Numeric[_N]): class INTEGER(Integer): - """The SQL INT or INTEGER type. .. seealso:: @@ -3299,7 +3275,6 @@ INT = INTEGER class SMALLINT(SmallInteger): - """The SQL SMALLINT type. .. seealso:: @@ -3312,7 +3287,6 @@ class SMALLINT(SmallInteger): class BIGINT(BigInteger): - """The SQL BIGINT type. .. seealso:: @@ -3325,7 +3299,6 @@ class BIGINT(BigInteger): class TIMESTAMP(DateTime): - """The SQL TIMESTAMP type. :class:`_types.TIMESTAMP` datatypes have support for timezone @@ -3355,35 +3328,30 @@ class TIMESTAMP(DateTime): class DATETIME(DateTime): - """The SQL DATETIME type.""" __visit_name__ = "DATETIME" class DATE(Date): - """The SQL DATE type.""" __visit_name__ = "DATE" class TIME(Time): - """The SQL TIME type.""" __visit_name__ = "TIME" class TEXT(Text): - """The SQL TEXT type.""" __visit_name__ = "TEXT" class CLOB(Text): - """The CLOB type. This type is found in Oracle and Informix. @@ -3393,63 +3361,54 @@ class CLOB(Text): class VARCHAR(String): - """The SQL VARCHAR type.""" __visit_name__ = "VARCHAR" class NVARCHAR(Unicode): - """The SQL NVARCHAR type.""" __visit_name__ = "NVARCHAR" class CHAR(String): - """The SQL CHAR type.""" __visit_name__ = "CHAR" class NCHAR(Unicode): - """The SQL NCHAR type.""" __visit_name__ = "NCHAR" class BLOB(LargeBinary): - """The SQL BLOB type.""" __visit_name__ = "BLOB" class BINARY(_Binary): - """The SQL BINARY type.""" __visit_name__ = "BINARY" class VARBINARY(_Binary): - """The SQL VARBINARY type.""" __visit_name__ = "VARBINARY" class BOOLEAN(Boolean): - """The SQL BOOLEAN type.""" __visit_name__ = "BOOLEAN" class NullType(TypeEngine[None]): - """An unknown type. :class:`.NullType` is used as a default type for those cases where @@ -3534,7 +3493,6 @@ _UUID_RETURN = TypeVar("_UUID_RETURN", str, _python_UUID) class Uuid(Emulated, TypeEngine[_UUID_RETURN]): - """Represent a database agnostic UUID datatype. For backends that have no "native" UUID datatype, the value will @@ -3594,16 +3552,14 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): self: Uuid[_python_UUID], as_uuid: Literal[True] = ..., native_uuid: bool = ..., - ): - ... + ): ... @overload def __init__( self: Uuid[str], as_uuid: Literal[False] = ..., native_uuid: bool = ..., - ): - ... + ): ... def __init__(self, as_uuid: bool = True, native_uuid: bool = True): """Construct a :class:`_sqltypes.Uuid` type. @@ -3726,7 +3682,6 @@ class Uuid(Emulated, TypeEngine[_UUID_RETURN]): class UUID(Uuid[_UUID_RETURN], type_api.NativeForEmulated): - """Represent the SQL UUID type. This is the SQL-native form of the :class:`_types.Uuid` database agnostic @@ -3750,12 +3705,10 @@ class UUID(Uuid[_UUID_RETURN], type_api.NativeForEmulated): __visit_name__ = "UUID" @overload - def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): - ... + def __init__(self: UUID[_python_UUID], as_uuid: Literal[True] = ...): ... @overload - def __init__(self: UUID[str], as_uuid: Literal[False] = ...): - ... + def __init__(self: UUID[str], as_uuid: Literal[False] = ...): ... def __init__(self, as_uuid: bool = True): """Construct a :class:`_sqltypes.UUID` type. diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 6c44d52175..3ca3caf9e2 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -80,16 +80,13 @@ class HasShallowCopy(HasTraverseInternals): if typing.TYPE_CHECKING: - def _generated_shallow_copy_traversal(self, other: Self) -> None: - ... + def _generated_shallow_copy_traversal(self, other: Self) -> None: ... def _generated_shallow_from_dict_traversal( self, d: Dict[str, Any] - ) -> None: - ... + ) -> None: ... - def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: - ... + def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ... @classmethod def _generate_shallow_copy( @@ -312,9 +309,11 @@ class _CopyInternalsTraversal(HasTraversalDispatch): # sequence of 2-tuples return [ ( - clone(key, **kw) - if hasattr(key, "__clause_element__") - else key, + ( + clone(key, **kw) + if hasattr(key, "__clause_element__") + else key + ), clone(value, **kw), ) for key, value in element @@ -336,9 +335,11 @@ class _CopyInternalsTraversal(HasTraversalDispatch): def copy(elem): if isinstance(elem, (list, tuple)): return [ - clone(value, **kw) - if hasattr(value, "__clause_element__") - else value + ( + clone(value, **kw) + if hasattr(value, "__clause_element__") + else value + ) for value in elem ] elif isinstance(elem, dict): diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 8b79a2a749..a56911fb9a 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -83,23 +83,19 @@ _NO_VALUE_IN_LIST = _NoValueInList.NO_VALUE_IN_LIST class _LiteralProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> str: - ... + def __call__(self, value: Any) -> str: ... class _BindProcessorType(Protocol[_T_con]): - def __call__(self, value: Optional[_T_con]) -> Any: - ... + def __call__(self, value: Optional[_T_con]) -> Any: ... class _ResultProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> Optional[_T_co]: - ... + def __call__(self, value: Any) -> Optional[_T_co]: ... class _SentinelProcessorType(Protocol[_T_co]): - def __call__(self, value: Any) -> Optional[_T_co]: - ... + def __call__(self, value: Any) -> Optional[_T_co]: ... class _BaseTypeMemoDict(TypedDict): @@ -115,8 +111,9 @@ class _TypeMemoDict(_BaseTypeMemoDict, total=False): class _ComparatorFactory(Protocol[_T]): - def __call__(self, expr: ColumnElement[_T]) -> TypeEngine.Comparator[_T]: - ... + def __call__( + self, expr: ColumnElement[_T] + ) -> TypeEngine.Comparator[_T]: ... class TypeEngine(Visitable, Generic[_T]): @@ -300,9 +297,9 @@ class TypeEngine(Visitable, Generic[_T]): """ - _variant_mapping: util.immutabledict[ - str, TypeEngine[Any] - ] = util.EMPTY_DICT + _variant_mapping: util.immutabledict[str, TypeEngine[Any]] = ( + util.EMPTY_DICT + ) def evaluates_none(self) -> Self: """Return a copy of this type which has the @@ -1002,9 +999,11 @@ class TypeEngine(Visitable, Generic[_T]): return (self.__class__,) + tuple( ( k, - self.__dict__[k]._static_cache_key - if isinstance(self.__dict__[k], TypeEngine) - else self.__dict__[k], + ( + self.__dict__[k]._static_cache_key + if isinstance(self.__dict__[k], TypeEngine) + else self.__dict__[k] + ), ) for k in names if k in self.__dict__ @@ -1013,12 +1012,12 @@ class TypeEngine(Visitable, Generic[_T]): ) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -1111,26 +1110,21 @@ class TypeEngineMixin: @util.memoized_property def _static_cache_key( self, - ) -> Union[CacheConst, Tuple[Any, ...]]: - ... + ) -> Union[CacheConst, Tuple[Any, ...]]: ... @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload def adapt( self, cls: Type[TypeEngineMixin], **kw: Any - ) -> TypeEngine[Any]: - ... + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any - ) -> TypeEngine[Any]: - ... + ) -> TypeEngine[Any]: ... - def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: - ... + def dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ... class ExternalType(TypeEngineMixin): @@ -1432,12 +1426,12 @@ class Emulated(TypeEngineMixin): return super().adapt(impltype, **kw) @overload - def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: - ... + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: ... @overload - def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: - ... + def adapt( + self, cls: Type[TypeEngineMixin], **kw: Any + ) -> TypeEngine[Any]: ... def adapt( self, cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], **kw: Any @@ -2283,13 +2277,13 @@ class Variant(TypeDecorator[_T]): @overload -def to_instance(typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any) -> _TE: - ... +def to_instance( + typeobj: Union[Type[_TE], _TE], *arg: Any, **kw: Any +) -> _TE: ... @overload -def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: - ... +def to_instance(typeobj: None, *arg: Any, **kw: Any) -> TypeEngine[None]: ... def to_instance( diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 53e5726722..737ee6822d 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -351,9 +351,9 @@ def find_tables( ] = _visitors["lateral"] = tables.append if include_crud: - _visitors["insert"] = _visitors["update"] = _visitors[ - "delete" - ] = lambda ent: tables.append(ent.table) + _visitors["insert"] = _visitors["update"] = _visitors["delete"] = ( + lambda ent: tables.append(ent.table) + ) if check_columns: @@ -881,8 +881,7 @@ def reduce_columns( columns: Iterable[ColumnElement[Any]], *clauses: Optional[ClauseElement], **kw: bool, -) -> Sequence[ColumnElement[Any]]: - ... +) -> Sequence[ColumnElement[Any]]: ... @overload @@ -890,8 +889,7 @@ def reduce_columns( columns: _SelectIterable, *clauses: Optional[ClauseElement], **kw: bool, -) -> Sequence[Union[ColumnElement[Any], TextClause]]: - ... +) -> Sequence[Union[ColumnElement[Any], TextClause]]: ... def reduce_columns( @@ -1102,8 +1100,7 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): if TYPE_CHECKING: @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... # note this specializes the ReplacingExternalTraversal.traverse() # method to state @@ -1114,13 +1111,11 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): # FromClause but Mypy is not accepting those as compatible with # the base ReplacingExternalTraversal @overload - def traverse(self, obj: _ET) -> _ET: - ... + def traverse(self, obj: _ET) -> _ET: ... def traverse( self, obj: Optional[ExternallyTraversible] - ) -> Optional[ExternallyTraversible]: - ... + ) -> Optional[ExternallyTraversible]: ... def _corresponding_column( self, col, require_embedded, _seen=util.EMPTY_SET @@ -1222,23 +1217,18 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): class _ColumnLookup(Protocol): @overload - def __getitem__(self, key: None) -> None: - ... + def __getitem__(self, key: None) -> None: ... @overload - def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: - ... + def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ... @overload - def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: - ... + def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ... @overload - def __getitem__(self, key: _ET) -> _ET: - ... + def __getitem__(self, key: _ET) -> _ET: ... - def __getitem__(self, key: Any) -> Any: - ... + def __getitem__(self, key: Any) -> Any: ... class ColumnAdapter(ClauseAdapter): @@ -1336,12 +1326,10 @@ class ColumnAdapter(ClauseAdapter): return ac @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: _ET) -> _ET: - ... + def traverse(self, obj: _ET) -> _ET: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -1356,8 +1344,7 @@ class ColumnAdapter(ClauseAdapter): if TYPE_CHECKING: @property - def visitor_iterator(self) -> Iterator[ColumnAdapter]: - ... + def visitor_iterator(self) -> Iterator[ColumnAdapter]: ... adapt_clause = traverse adapt_list = ClauseAdapter.copy_and_process diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 5d77d51082..05025909a4 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -72,8 +72,7 @@ __all__ = [ class _CompilerDispatchType(Protocol): - def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: - ... + def __call__(_self, self: Visitable, visitor: Any, **kw: Any) -> Any: ... class Visitable: @@ -100,8 +99,7 @@ class Visitable: if typing.TYPE_CHECKING: - def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: - ... + def _compiler_dispatch(self, visitor: Any, **kw: Any) -> str: ... def __init_subclass__(cls) -> None: if "__visit_name__" in cls.__dict__: @@ -493,8 +491,7 @@ class HasTraverseInternals: class _InternalTraversalDispatchType(Protocol): - def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: - ... + def __call__(s, self: object, visitor: HasTraversalDispatch) -> Any: ... class HasTraversalDispatch: @@ -602,13 +599,11 @@ class ExternallyTraversible(HasTraverseInternals, Visitable): if typing.TYPE_CHECKING: - def _annotate(self, values: _AnnotationDict) -> Self: - ... + def _annotate(self, values: _AnnotationDict) -> Self: ... def get_children( self, *, omit_attrs: Tuple[str, ...] = (), **kw: Any - ) -> Iterable[ExternallyTraversible]: - ... + ) -> Iterable[ExternallyTraversible]: ... def _clone(self, **kw: Any) -> Self: """clone this element""" @@ -638,13 +633,11 @@ _TraverseCallableType = Callable[[_ET], None] class _CloneCallableType(Protocol): - def __call__(self, element: _ET, **kw: Any) -> _ET: - ... + def __call__(self, element: _ET, **kw: Any) -> _ET: ... class _TraverseTransformCallableType(Protocol[_ET]): - def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: - ... + def __call__(self, element: _ET, **kw: Any) -> Optional[_ET]: ... _ExtT = TypeVar("_ExtT", bound="ExternalTraversal") @@ -680,12 +673,12 @@ class ExternalTraversal(util.MemoizedSlots): return iterate(obj, self.__traverse_options__) @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -746,12 +739,12 @@ class CloningExternalTraversal(ExternalTraversal): return [self.traverse(x) for x in list_] @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -786,12 +779,12 @@ class ReplacingExternalTraversal(CloningExternalTraversal): return None @overload - def traverse(self, obj: Literal[None]) -> None: - ... + def traverse(self, obj: Literal[None]) -> None: ... @overload - def traverse(self, obj: ExternallyTraversible) -> ExternallyTraversible: - ... + def traverse( + self, obj: ExternallyTraversible + ) -> ExternallyTraversible: ... def traverse( self, obj: Optional[ExternallyTraversible] @@ -866,8 +859,7 @@ def traverse_using( iterator: Iterable[ExternallyTraversible], obj: Literal[None], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... @overload @@ -875,8 +867,7 @@ def traverse_using( iterator: Iterable[ExternallyTraversible], obj: ExternallyTraversible, visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def traverse_using( @@ -920,8 +911,7 @@ def traverse( obj: Literal[None], opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... @overload @@ -929,8 +919,7 @@ def traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def traverse( @@ -975,8 +964,7 @@ def cloned_traverse( obj: Literal[None], opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> None: - ... +) -> None: ... # a bit of controversy here, as the clone of the lead element @@ -988,8 +976,7 @@ def cloned_traverse( obj: _ET, opts: Mapping[str, Any], visitors: Mapping[str, _TraverseCallableType[Any]], -) -> _ET: - ... +) -> _ET: ... def cloned_traverse( @@ -1088,8 +1075,7 @@ def replacement_traverse( obj: Literal[None], opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> None: - ... +) -> None: ... @overload @@ -1097,8 +1083,7 @@ def replacement_traverse( obj: _CE, opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> _CE: - ... +) -> _CE: ... @overload @@ -1106,8 +1091,7 @@ def replacement_traverse( obj: ExternallyTraversible, opts: Mapping[str, Any], replace: _TraverseTransformCallableType[Any], -) -> ExternallyTraversible: - ... +) -> ExternallyTraversible: ... def replacement_traverse( diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e061f269a8..ae4d335a96 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -88,9 +88,9 @@ class CompiledSQL(SQLMatchRule): dialect.supports_default_metavalue = True if self.enable_returning: - dialect.insert_returning = ( - dialect.update_returning - ) = dialect.delete_returning = True + dialect.insert_returning = dialect.update_returning = ( + dialect.delete_returning + ) = True dialect.use_insertmanyvalues = True dialect.supports_multivalues_insert = True dialect.update_returning_multifrom = True diff --git a/lib/sqlalchemy/testing/config.py b/lib/sqlalchemy/testing/config.py index 19e1e4bcc2..f2292224e8 100644 --- a/lib/sqlalchemy/testing/config.py +++ b/lib/sqlalchemy/testing/config.py @@ -176,8 +176,7 @@ class Variation: if typing.TYPE_CHECKING: - def __getattr__(self, key: str) -> bool: - ... + def __getattr__(self, key: str) -> bool: ... @property def name(self): @@ -268,9 +267,11 @@ def variation(argname_or_fn, cases=None): else: argname = argname_or_fn cases_plus_limitations = [ - entry - if (isinstance(entry, tuple) and len(entry) == 2) - else (entry, None) + ( + entry + if (isinstance(entry, tuple) and len(entry) == 2) + else (entry, None) + ) for entry in cases ] @@ -279,9 +280,11 @@ def variation(argname_or_fn, cases=None): ) return combinations( *[ - (variation._name, variation, limitation) - if limitation is not None - else (variation._name, variation) + ( + (variation._name, variation, limitation) + if limitation is not None + else (variation._name, variation) + ) for variation, (case, limitation) in zip( variations, cases_plus_limitations ) diff --git a/lib/sqlalchemy/testing/engines.py b/lib/sqlalchemy/testing/engines.py index 7e06366836..6b3f32c2b7 100644 --- a/lib/sqlalchemy/testing/engines.py +++ b/lib/sqlalchemy/testing/engines.py @@ -289,8 +289,7 @@ def testing_engine( options: Optional[Dict[str, Any]] = None, asyncio: Literal[False] = False, transfer_staticpool: bool = False, -) -> Engine: - ... +) -> Engine: ... @typing.overload @@ -299,8 +298,7 @@ def testing_engine( options: Optional[Dict[str, Any]] = None, asyncio: Literal[True] = True, transfer_staticpool: bool = False, -) -> AsyncEngine: - ... +) -> AsyncEngine: ... def testing_engine( diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 7dca583f8e..addc4b7594 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -205,12 +205,12 @@ class Predicate: if negate: bool_ = not negate return self.description % { - "driver": config.db.url.get_driver_name() - if config - else "", - "database": config.db.url.get_backend_name() - if config - else "", + "driver": ( + config.db.url.get_driver_name() if config else "" + ), + "database": ( + config.db.url.get_backend_name() if config else "" + ), "doesnt_support": "doesn't support" if bool_ else "does support", "does_support": "does support" if bool_ else "doesn't support", } diff --git a/lib/sqlalchemy/testing/fixtures/mypy.py b/lib/sqlalchemy/testing/fixtures/mypy.py index 730c7bdc23..149df9f7d4 100644 --- a/lib/sqlalchemy/testing/fixtures/mypy.py +++ b/lib/sqlalchemy/testing/fixtures/mypy.py @@ -86,9 +86,11 @@ class MypyTest(TestBase): "--config-file", os.path.join( use_cachedir, - "sqla_mypy_config.cfg" - if use_plugin - else "plain_mypy_config.cfg", + ( + "sqla_mypy_config.cfg" + if use_plugin + else "plain_mypy_config.cfg" + ), ), ] @@ -208,9 +210,11 @@ class MypyTest(TestBase): # skip first character which could be capitalized # "List item x not found" type of message expected_msg = expected_msg[0] + re.sub( - r"\b(List|Tuple|Dict|Set)\b" - if is_type - else r"\b(List|Tuple|Dict|Set|Type)\b", + ( + r"\b(List|Tuple|Dict|Set)\b" + if is_type + else r"\b(List|Tuple|Dict|Set|Type)\b" + ), lambda m: m.group(1).lower(), expected_msg[1:], ) diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index a7cb4069d0..1a4d4bb30a 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -675,9 +675,9 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): "i": lambda obj: obj, "r": repr, "s": str, - "n": lambda obj: obj.__name__ - if hasattr(obj, "__name__") - else type(obj).__name__, + "n": lambda obj: ( + obj.__name__ if hasattr(obj, "__name__") else type(obj).__name__ + ), } def combinations(self, *arg_sets, **kw): diff --git a/lib/sqlalchemy/testing/suite/test_insert.py b/lib/sqlalchemy/testing/suite/test_insert.py index cc30945cab..8de60e43dc 100644 --- a/lib/sqlalchemy/testing/suite/test_insert.py +++ b/lib/sqlalchemy/testing/suite/test_insert.py @@ -492,9 +492,11 @@ class ReturningTest(fixtures.TablesTest): t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: @@ -596,9 +598,11 @@ class ReturningTest(fixtures.TablesTest): t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index f0d4dca1c2..f257d2fcbc 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -1090,9 +1090,9 @@ class ComponentReflectionTest(ComparesTables, OneConnectionTablesTest): "referred_columns": ref_col, "name": name, "options": mock.ANY, - "referred_schema": ref_schema - if ref_schema is not None - else tt(), + "referred_schema": ( + ref_schema if ref_schema is not None else tt() + ), "referred_table": ref_table, "comment": comment, } diff --git a/lib/sqlalchemy/testing/suite/test_update_delete.py b/lib/sqlalchemy/testing/suite/test_update_delete.py index a46d8fad87..fd4757f9a4 100644 --- a/lib/sqlalchemy/testing/suite/test_update_delete.py +++ b/lib/sqlalchemy/testing/suite/test_update_delete.py @@ -93,9 +93,11 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (2, "d2_new"), (3, "d3")] - if criteria.rows - else [(1, "d1"), (2, "d2"), (3, "d3")], + ( + [(1, "d1"), (2, "d2_new"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), ) @testing.variation("criteria", ["rows", "norows", "emptyin"]) @@ -126,9 +128,11 @@ class SimpleUpdateDeleteTest(fixtures.TablesTest): eq_( connection.execute(t.select().order_by(t.c.id)).fetchall(), - [(1, "d1"), (3, "d3")] - if criteria.rows - else [(1, "d1"), (2, "d2"), (3, "d3")], + ( + [(1, "d1"), (3, "d3")] + if criteria.rows + else [(1, "d1"), (2, "d2"), (3, "d3")] + ), ) diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index aea6439c25..5dd0179505 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -227,12 +227,10 @@ class Properties(Generic[_T]): self._data.update(value) @overload - def get(self, key: str) -> Optional[_T]: - ... + def get(self, key: str) -> Optional[_T]: ... @overload - def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: - ... + def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]: ... def get( self, key: str, default: Optional[Union[_DT, _T]] = None @@ -520,12 +518,10 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): return self._counter @overload - def get(self, key: _KT) -> Optional[_VT]: - ... + def get(self, key: _KT) -> Optional[_VT]: ... @overload - def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: - ... + def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]: ... def get( self, key: _KT, default: Optional[Union[_VT, _T]] = None @@ -587,13 +583,11 @@ class LRUCache(typing.MutableMapping[_KT, _VT]): class _CreateFuncType(Protocol[_T_co]): - def __call__(self) -> _T_co: - ... + def __call__(self) -> _T_co: ... class _ScopeFuncType(Protocol): - def __call__(self) -> Any: - ... + def __call__(self) -> Any: ... class ScopedRegistry(Generic[_T]): diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 010d90e62e..e05626eaf7 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -59,11 +59,9 @@ class ReadOnlyContainer: class ImmutableDictBase(ReadOnlyContainer, Dict[_KT, _VT]): if TYPE_CHECKING: - def __new__(cls, *args: Any) -> Self: - ... + def __new__(cls, *args: Any) -> Self: ... - def __init__(cls, *args: Any): - ... + def __init__(cls, *args: Any): ... def _readonly(self, *arg: Any, **kw: Any) -> NoReturn: self._immutable() diff --git a/lib/sqlalchemy/util/concurrency.py b/lib/sqlalchemy/util/concurrency.py index 53490f23c8..25ea27ea8c 100644 --- a/lib/sqlalchemy/util/concurrency.py +++ b/lib/sqlalchemy/util/concurrency.py @@ -123,8 +123,7 @@ if TYPE_CHECKING: def iscoroutine( awaitable: Awaitable[_T_co], - ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: - ... + ) -> TypeGuard[Coroutine[Any, Any, _T_co]]: ... else: iscoroutine = asyncio.iscoroutine diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 91d9562aae..6c7aead0a2 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -411,15 +411,13 @@ def get_cls_kwargs( *, _set: Optional[Set[str]] = None, raiseerr: Literal[True] = ..., -) -> Set[str]: - ... +) -> Set[str]: ... @overload def get_cls_kwargs( cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False -) -> Optional[Set[str]]: - ... +) -> Optional[Set[str]]: ... def get_cls_kwargs( @@ -1092,23 +1090,19 @@ class generic_fn_descriptor(Generic[_T_co]): self.__name__ = fget.__name__ @overload - def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: - ... + def __get__(self: _GFD, obj: None, cls: Any) -> _GFD: ... @overload - def __get__(self, obj: object, cls: Any) -> _T_co: - ... + def __get__(self, obj: object, cls: Any) -> _T_co: ... def __get__(self: _GFD, obj: Any, cls: Any) -> Union[_GFD, _T_co]: raise NotImplementedError() if TYPE_CHECKING: - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... def _reset(self, obj: Any) -> None: raise NotImplementedError() @@ -1247,12 +1241,10 @@ class HasMemoized: self.__name__ = fget.__name__ @overload - def __get__(self: _MA, obj: None, cls: Any) -> _MA: - ... + def __get__(self: _MA, obj: None, cls: Any) -> _MA: ... @overload - def __get__(self, obj: Any, cls: Any) -> _T: - ... + def __get__(self, obj: Any, cls: Any) -> _T: ... def __get__(self, obj, cls): if obj is None: diff --git a/lib/sqlalchemy/util/queue.py b/lib/sqlalchemy/util/queue.py index 3545afef38..149629dc2c 100644 --- a/lib/sqlalchemy/util/queue.py +++ b/lib/sqlalchemy/util/queue.py @@ -54,8 +54,7 @@ class QueueCommon(Generic[_T]): maxsize: int use_lifo: bool - def __init__(self, maxsize: int = 0, use_lifo: bool = False): - ... + def __init__(self, maxsize: int = 0, use_lifo: bool = False): ... def empty(self) -> bool: raise NotImplementedError() diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index a3e9397640..3a869752b2 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -114,11 +114,9 @@ class GenericProtocol(Protocol[_T]): # copied from TypeShed, required in order to implement # MutableMapping.update() class SupportsKeysAndGetItem(Protocol[_KT, _VT_co]): - def keys(self) -> Iterable[_KT]: - ... + def keys(self) -> Iterable[_KT]: ... - def __getitem__(self, __k: _KT) -> _VT_co: - ... + def __getitem__(self, __k: _KT) -> _VT_co: ... # work around https://github.com/microsoft/pyright/issues/3025 @@ -344,20 +342,17 @@ def is_fwd_ref( @overload -def de_optionalize_union_types(type_: str) -> str: - ... +def de_optionalize_union_types(type_: str) -> str: ... @overload -def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: - ... +def de_optionalize_union_types(type_: Type[Any]) -> Type[Any]: ... @overload def de_optionalize_union_types( type_: _AnnotationScanType, -) -> _AnnotationScanType: - ... +) -> _AnnotationScanType: ... def de_optionalize_union_types( @@ -501,14 +496,11 @@ def _get_type_name(type_: Type[Any]) -> str: class DescriptorProto(Protocol): - def __get__(self, instance: object, owner: Any) -> Any: - ... + def __get__(self, instance: object, owner: Any) -> Any: ... - def __set__(self, instance: Any, value: Any) -> None: - ... + def __set__(self, instance: Any, value: Any) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... _DESC = TypeVar("_DESC", bound=DescriptorProto) @@ -527,14 +519,11 @@ class DescriptorReference(Generic[_DESC]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _DESC: - ... + def __get__(self, instance: object, owner: Any) -> _DESC: ... - def __set__(self, instance: Any, value: _DESC) -> None: - ... + def __set__(self, instance: Any, value: _DESC) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... _DESC_co = TypeVar("_DESC_co", bound=DescriptorProto, covariant=True) @@ -550,14 +539,11 @@ class RODescriptorReference(Generic[_DESC_co]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _DESC_co: - ... + def __get__(self, instance: object, owner: Any) -> _DESC_co: ... - def __set__(self, instance: Any, value: Any) -> NoReturn: - ... + def __set__(self, instance: Any, value: Any) -> NoReturn: ... - def __delete__(self, instance: Any) -> NoReturn: - ... + def __delete__(self, instance: Any) -> NoReturn: ... _FN = TypeVar("_FN", bound=Optional[Callable[..., Any]]) @@ -574,14 +560,11 @@ class CallableReference(Generic[_FN]): if TYPE_CHECKING: - def __get__(self, instance: object, owner: Any) -> _FN: - ... + def __get__(self, instance: object, owner: Any) -> _FN: ... - def __set__(self, instance: Any, value: _FN) -> None: - ... + def __set__(self, instance: Any, value: _FN) -> None: ... - def __delete__(self, instance: Any) -> None: - ... + def __delete__(self, instance: Any) -> None: ... # $def ro_descriptor_reference(fn: Callable[]) diff --git a/setup.cfg b/setup.cfg index 2ff94822c6..dfeed37721 100644 --- a/setup.cfg +++ b/setup.cfg @@ -104,7 +104,7 @@ enable-extensions = G ignore = A003, D, - E203,E305,E711,E712,E721,E722,E741, + E203,E305,E701,E704,E711,E712,E721,E722,E741, N801,N802,N806, RST304,RST303,RST299,RST399, W503,W504,W601 diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 3a5a200d80..8bf2bfa180 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -142,7 +142,6 @@ class MergeTest(NoCache, fixtures.MappedTest): class LoadManyToOneFromIdentityTest(fixtures.MappedTest): - """test overhead associated with many-to-one fetches. Prior to the refactor of LoadLazyAttribute and diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 74867ccbe2..b5ea40b120 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -702,9 +702,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select(tbl), "SELECT %(name)s.test.id FROM %(name)s.test" % {"name": rendered_schema}, - schema_translate_map={None: schemaname} - if use_schema_translate - else None, + schema_translate_map=( + {None: schemaname} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -777,16 +777,20 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "test", metadata, Column("id", Integer, primary_key=True), - schema=quoted_name("Foo.dbo", True) - if not use_schema_translate - else None, + schema=( + quoted_name("Foo.dbo", True) + if not use_schema_translate + else None + ), ) self.assert_compile( select(tbl), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test", - schema_translate_map={None: quoted_name("Foo.dbo", True)} - if use_schema_translate - else None, + schema_translate_map=( + {None: quoted_name("Foo.dbo", True)} + if use_schema_translate + else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -804,9 +808,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select(tbl), "SELECT [Foo.dbo].test.id FROM [Foo.dbo].test", - schema_translate_map={None: "[Foo.dbo]"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "[Foo.dbo]"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -824,9 +828,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select(tbl), "SELECT foo.dbo.test.id FROM foo.dbo.test", - schema_translate_map={None: "foo.dbo"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "foo.dbo"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) @@ -842,9 +846,9 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select(tbl), "SELECT [Foo].dbo.test.id FROM [Foo].dbo.test", - schema_translate_map={None: "Foo.dbo"} - if use_schema_translate - else None, + schema_translate_map=( + {None: "Foo.dbo"} if use_schema_translate else None + ), render_schema_translate=True if use_schema_translate else False, ) diff --git a/test/dialect/mssql/test_reflection.py b/test/dialect/mssql/test_reflection.py index ae2b7662ef..7222ba47ae 100644 --- a/test/dialect/mssql/test_reflection.py +++ b/test/dialect/mssql/test_reflection.py @@ -1028,10 +1028,13 @@ class ReflectHugeViewTest(fixtures.TablesTest): for i in range(col_num) ], ) - cls.view_str = ( - view_str - ) = "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" % ( - ",".join("long_named_column_number_%d" % i for i in range(col_num)) + cls.view_str = view_str = ( + "CREATE VIEW huge_named_view AS SELECT %s FROM base_table" + % ( + ",".join( + "long_named_column_number_%d" % i for i in range(col_num) + ) + ) ) assert len(view_str) > 4000 diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index b2e05d951d..05b4b68542 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -567,7 +567,6 @@ class CompileTest(ReservedWordFixture, fixtures.TestBase, AssertsCompiledSQL): class SQLTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests MySQL-dialect specific compilation.""" __dialect__ = mysql.dialect() diff --git a/test/dialect/mysql/test_for_update.py b/test/dialect/mysql/test_for_update.py index 5717a32997..0895a098d1 100644 --- a/test/dialect/mysql/test_for_update.py +++ b/test/dialect/mysql/test_for_update.py @@ -3,6 +3,7 @@ See #4246 """ + import contextlib from sqlalchemy import Column diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index f890b7ba9c..005e60eaa1 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -3228,7 +3228,6 @@ class InsertOnConflictTest(fixtures.TablesTest, AssertsCompiledSQL): class DistinctOnTest(fixtures.MappedTest, AssertsCompiledSQL): - """Test 'DISTINCT' with SQL expression language and orm.Query with an emphasis on PG's 'DISTINCT ON' syntax. @@ -3382,7 +3381,6 @@ class DistinctOnTest(fixtures.MappedTest, AssertsCompiledSQL): class FullTextSearchTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests for full text searching""" __dialect__ = postgresql.dialect() diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index db2d5e73dc..919842a49c 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1219,9 +1219,9 @@ class MiscBackendTest( def test_autocommit_pre_ping(self, testing_engine, autocommit): engine = testing_engine( options={ - "isolation_level": "AUTOCOMMIT" - if autocommit - else "SERIALIZABLE", + "isolation_level": ( + "AUTOCOMMIT" if autocommit else "SERIALIZABLE" + ), "pool_pre_ping": True, } ) @@ -1239,9 +1239,9 @@ class MiscBackendTest( engine = testing_engine( options={ - "isolation_level": "AUTOCOMMIT" - if autocommit - else "SERIALIZABLE", + "isolation_level": ( + "AUTOCOMMIT" if autocommit else "SERIALIZABLE" + ), "pool_pre_ping": True, } ) diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py index 8d8d9a7ec9..9822b3e60b 100644 --- a/test/dialect/postgresql/test_query.py +++ b/test/dialect/postgresql/test_query.py @@ -1238,7 +1238,6 @@ class TupleTest(fixtures.TestBase): class ExtractTest(fixtures.TablesTest): - """The rationale behind this test is that for many years we've had a system of embedding type casts into the expressions rendered by visit_extract() on the postgreql platform. The reason for this cast is not clear. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 0a98ef5045..2088436eeb 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1155,7 +1155,7 @@ class NamedTypeTest( "one", "two", "three", - native_enum=True # make sure this is True because + native_enum=True, # make sure this is True because # it should *not* take effect due to # the variant ).with_variant( @@ -3234,7 +3234,6 @@ class SpecialTypesCompileTest(fixtures.TestBase, AssertsCompiledSQL): class SpecialTypesTest(fixtures.TablesTest, ComparesTables): - """test DDL and reflection of PG-specific types""" __only_on__ = ("postgresql >= 8.3.0",) @@ -3325,7 +3324,6 @@ class SpecialTypesTest(fixtures.TablesTest, ComparesTables): class UUIDTest(fixtures.TestBase): - """Test postgresql-specific UUID cases. See also generic UUID tests in testing/suite/test_types @@ -3969,9 +3967,11 @@ class _RangeTypeCompilation( self._test_clause( fn(self.col, self._data_str()), f"data_table.range {op} %(range_1)s", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) @testing.combinations(*_all_fns, id_="as") @@ -3979,9 +3979,11 @@ class _RangeTypeCompilation( self._test_clause( fn(self.col, self._data_obj()), f"data_table.range {op} %(range_1)s::{self._col_str}", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) @testing.combinations(*_comparisons, id_="as") @@ -3989,9 +3991,11 @@ class _RangeTypeCompilation( self._test_clause( fn(self.col, any_(array([self._data_str()]))), f"data_table.range {op} ANY (ARRAY[%(param_1)s])", - self.col.type - if op in self._not_compare_op - else sqltypes.BOOLEANTYPE, + ( + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE + ), ) def test_where_is_null(self): @@ -6279,9 +6283,11 @@ class PGInsertManyValuesTest(fixtures.TestBase): t.c.value, sort_by_parameter_order=bool(sort_by_parameter_order), ), - [{"value": value} for i in range(10)] - if multiple_rows - else {"value": value}, + ( + [{"value": value} for i in range(10)] + if multiple_rows + else {"value": value} + ), ) if multiple_rows: diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index 701635d90d..202e23556c 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -1,4 +1,5 @@ """SQLite-specific tests.""" + import datetime import json import os @@ -912,7 +913,6 @@ class AttachedDBTest(fixtures.TablesTest): class SQLTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests SQLite-dialect specific compilation.""" __dialect__ = sqlite.dialect() @@ -1314,7 +1314,6 @@ class OnConflictDDLTest(fixtures.TestBase, AssertsCompiledSQL): class InsertTest(fixtures.TestBase, AssertsExecutionResults): - """Tests inserts and autoincrement.""" __only_on__ = "sqlite" @@ -2508,7 +2507,6 @@ class ConstraintReflectionTest(fixtures.TestBase): class SavepointTest(fixtures.TablesTest): - """test that savepoints work when we use the correct event setup""" __only_on__ = "sqlite" diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index 6080f3dc6d..4618dfff8d 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -3654,12 +3654,12 @@ class DialectEventTest(fixtures.TestBase): arg[-1].get_result_proxy = Mock(return_value=Mock(context=arg[-1])) return retval - m1.real_do_execute.side_effect = ( - m1.do_execute.side_effect - ) = mock_the_cursor - m1.real_do_executemany.side_effect = ( - m1.do_executemany.side_effect - ) = mock_the_cursor + m1.real_do_execute.side_effect = m1.do_execute.side_effect = ( + mock_the_cursor + ) + m1.real_do_executemany.side_effect = m1.do_executemany.side_effect = ( + mock_the_cursor + ) m1.real_do_execute_no_params.side_effect = ( m1.do_execute_no_params.side_effect ) = mock_the_cursor diff --git a/test/engine/test_reconnect.py b/test/engine/test_reconnect.py index a7883efa2f..e1515a23a8 100644 --- a/test/engine/test_reconnect.py +++ b/test/engine/test_reconnect.py @@ -1581,9 +1581,9 @@ class ReconnectRecipeTest(fixtures.TestBase): connection.rollback() time.sleep(retry_interval) - context.cursor = ( - cursor - ) = connection.connection.cursor() + context.cursor = cursor = ( + connection.connection.cursor() + ) else: raise else: diff --git a/test/ext/declarative/test_inheritance.py b/test/ext/declarative/test_inheritance.py index d6d059cbef..e21881b333 100644 --- a/test/ext/declarative/test_inheritance.py +++ b/test/ext/declarative/test_inheritance.py @@ -934,22 +934,25 @@ class ConcreteExtensionConfigTest( self.assert_compile( session.query(Document), - "SELECT pjoin.id AS pjoin_id, pjoin.doctype AS pjoin_doctype, " - "pjoin.type AS pjoin_type, pjoin.send_method AS pjoin_send_method " - "FROM " - "(SELECT actual_documents.id AS id, " - "actual_documents.send_method AS send_method, " - "actual_documents.doctype AS doctype, " - "'actual' AS type FROM actual_documents) AS pjoin" - if use_strict_attrs - else "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " - "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " - "pjoin.type AS pjoin_type " - "FROM " - "(SELECT actual_documents.id AS id, " - "actual_documents.send_method AS send_method, " - "actual_documents.doctype AS doctype, " - "'actual' AS type FROM actual_documents) AS pjoin", + ( + "SELECT pjoin.id AS pjoin_id, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type, " + "pjoin.send_method AS pjoin_send_method " + "FROM " + "(SELECT actual_documents.id AS id, " + "actual_documents.send_method AS send_method, " + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin" + if use_strict_attrs + else "SELECT pjoin.id AS pjoin_id, pjoin.send_method AS " + "pjoin_send_method, pjoin.doctype AS pjoin_doctype, " + "pjoin.type AS pjoin_type " + "FROM " + "(SELECT actual_documents.id AS id, " + "actual_documents.send_method AS send_method, " + "actual_documents.doctype AS doctype, " + "'actual' AS type FROM actual_documents) AS pjoin" + ), ) @testing.combinations(True, False) diff --git a/test/ext/mypy/plugin_files/mapped_attr_assign.py b/test/ext/mypy/plugin_files/mapped_attr_assign.py index 06bc24d9eb..c7244c27a6 100644 --- a/test/ext/mypy/plugin_files/mapped_attr_assign.py +++ b/test/ext/mypy/plugin_files/mapped_attr_assign.py @@ -3,6 +3,7 @@ after the mapping is complete """ + from typing import Optional from sqlalchemy import Column diff --git a/test/ext/mypy/plugin_files/typing_err3.py b/test/ext/mypy/plugin_files/typing_err3.py index cbdbf009a0..146b96b2a7 100644 --- a/test/ext/mypy/plugin_files/typing_err3.py +++ b/test/ext/mypy/plugin_files/typing_err3.py @@ -2,6 +2,7 @@ type checked. """ + from typing import List from sqlalchemy import Column diff --git a/test/ext/test_associationproxy.py b/test/ext/test_associationproxy.py index 87812c9ac6..7e2b31a9b5 100644 --- a/test/ext/test_associationproxy.py +++ b/test/ext/test_associationproxy.py @@ -3830,11 +3830,11 @@ class DeclOrmForms(fixtures.TestBase): id: Mapped[int] = mapped_column(primary_key=True) - user_keyword_associations: Mapped[ - List[UserKeywordAssociation] - ] = relationship( - back_populates="user", - cascade="all, delete-orphan", + user_keyword_associations: Mapped[List[UserKeywordAssociation]] = ( + relationship( + back_populates="user", + cascade="all, delete-orphan", + ) ) keywords: AssociationProxy[list[str]] = association_proxy( @@ -3886,12 +3886,12 @@ class DeclOrmForms(fixtures.TestBase): primary_key=True, repr=True, init=False ) - user_keyword_associations: Mapped[ - List[UserKeywordAssociation] - ] = relationship( - back_populates="user", - cascade="all, delete-orphan", - init=False, + user_keyword_associations: Mapped[List[UserKeywordAssociation]] = ( + relationship( + back_populates="user", + cascade="all, delete-orphan", + init=False, + ) ) if embed_in_field: diff --git a/test/ext/test_automap.py b/test/ext/test_automap.py index c84bc1c78e..a3ba1189b3 100644 --- a/test/ext/test_automap.py +++ b/test/ext/test_automap.py @@ -667,11 +667,14 @@ class ConcurrentAutomapTest(fixtures.TestBase): m, Column("id", Integer, primary_key=True), Column("data", String(50)), - Column( - "t_%d_id" % (i - 1), ForeignKey("table_%d.id" % (i - 1)) - ) - if i > 4 - else None, + ( + Column( + "t_%d_id" % (i - 1), + ForeignKey("table_%d.id" % (i - 1)), + ) + if i > 4 + else None + ), ) m.drop_all(e) m.create_all(e) diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index aa03dabc90..707e02dac1 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -209,9 +209,11 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( stmt, - "SELECT my_function(t1.q) AS my_function_1 FROM t1" - if named - else "SELECT my_function(t1.q) AS anon_1 FROM t1", + ( + "SELECT my_function(t1.q) AS my_function_1 FROM t1" + if named + else "SELECT my_function(t1.q) AS anon_1 FROM t1" + ), dialect="sqlite", ) diff --git a/test/ext/test_extendedattr.py b/test/ext/test_extendedattr.py index dd5b715829..41637c358e 100644 --- a/test/ext/test_extendedattr.py +++ b/test/ext/test_extendedattr.py @@ -760,7 +760,6 @@ class InstrumentationCollisionTest(_ExtBase, fixtures.ORMTest): class ExtendedEventsTest(_ExtBase, fixtures.ORMTest): - """Allow custom Events implementations.""" @modifies_instrumentation_finders diff --git a/test/orm/declarative/test_abs_import_only.py b/test/orm/declarative/test_abs_import_only.py index e1447364e6..287240575c 100644 --- a/test/orm/declarative/test_abs_import_only.py +++ b/test/orm/declarative/test_abs_import_only.py @@ -64,9 +64,9 @@ class MappedColumnTest( if construct.Mapped: bars: orm.Mapped[typing.List[Bar]] = orm.relationship() elif construct.WriteOnlyMapped: - bars: orm.WriteOnlyMapped[ - typing.List[Bar] - ] = orm.relationship() + bars: orm.WriteOnlyMapped[typing.List[Bar]] = ( + orm.relationship() + ) elif construct.DynamicMapped: bars: orm.DynamicMapped[typing.List[Bar]] = orm.relationship() else: diff --git a/test/orm/declarative/test_dc_transforms.py b/test/orm/declarative/test_dc_transforms.py index cbe08f30e1..8408f69617 100644 --- a/test/orm/declarative/test_dc_transforms.py +++ b/test/orm/declarative/test_dc_transforms.py @@ -179,9 +179,9 @@ class DCTransformsTest(AssertsCompiledSQL, fixtures.TestBase): JSON, init=True, default_factory=lambda: {} ) - new_instance: GenericSetting[ # noqa: F841 - Dict[str, Any] - ] = GenericSetting(key="x", value={"foo": "bar"}) + new_instance: GenericSetting[Dict[str, Any]] = ( # noqa: F841 + GenericSetting(key="x", value={"foo": "bar"}) + ) def test_no_anno_doesnt_go_into_dc( self, dc_decl_base: Type[MappedAsDataclass] diff --git a/test/orm/declarative/test_inheritance.py b/test/orm/declarative/test_inheritance.py index c5b908cd82..1b633d1bcf 100644 --- a/test/orm/declarative/test_inheritance.py +++ b/test/orm/declarative/test_inheritance.py @@ -1067,7 +1067,6 @@ class DeclarativeInheritanceTest( target_id = Column(Integer, primary_key=True) class Engineer(Person): - """single table inheritance""" if decl_type.legacy: @@ -1084,7 +1083,6 @@ class DeclarativeInheritanceTest( ) class Manager(Person): - """single table inheritance""" if decl_type.legacy: @@ -1468,7 +1466,6 @@ class DeclarativeInheritanceTest( class OverlapColPrecedenceTest(DeclarativeTestBase): - """test #1892 cases when declarative does column precedence.""" def _run_test(self, Engineer, e_id, p_id): diff --git a/test/orm/declarative/test_mixin.py b/test/orm/declarative/test_mixin.py index 900133df59..32f737484e 100644 --- a/test/orm/declarative/test_mixin.py +++ b/test/orm/declarative/test_mixin.py @@ -672,11 +672,9 @@ class DeclarativeMixinTest(DeclarativeTestBase): return relationship("Other") class Engineer(Mixin, Person): - """single table inheritance""" class Manager(Mixin, Person): - """single table inheritance""" class Other(Base): diff --git a/test/orm/declarative/test_tm_future_annotations_sync.py b/test/orm/declarative/test_tm_future_annotations_sync.py index d2f2a0261f..33e3223e53 100644 --- a/test/orm/declarative/test_tm_future_annotations_sync.py +++ b/test/orm/declarative/test_tm_future_annotations_sync.py @@ -1517,20 +1517,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): data: Mapped[Union[float, Decimal]] = mapped_column() reverse_data: Mapped[Union[Decimal, float]] = mapped_column() - optional_data: Mapped[ - Optional[Union[float, Decimal]] - ] = mapped_column() + optional_data: Mapped[Optional[Union[float, Decimal]]] = ( + mapped_column() + ) # use Optional directly - reverse_optional_data: Mapped[ - Optional[Union[Decimal, float]] - ] = mapped_column() + reverse_optional_data: Mapped[Optional[Union[Decimal, float]]] = ( + mapped_column() + ) # use Union with None, same as Optional but presents differently # (Optional object with __origin__ Union vs. Union) - reverse_u_optional_data: Mapped[ - Union[Decimal, float, None] - ] = mapped_column() + reverse_u_optional_data: Mapped[Union[Decimal, float, None]] = ( + mapped_column() + ) float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1538,14 +1538,14 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if compat.py310: pep604_data: Mapped[float | Decimal] = mapped_column() pep604_reverse: Mapped[Decimal | float] = mapped_column() - pep604_optional: Mapped[ - Decimal | float | None - ] = mapped_column() + pep604_optional: Mapped[Decimal | float | None] = ( + mapped_column() + ) pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() - pep604_optional_fwd: Mapped[ - "Decimal | float | None" - ] = mapped_column() + pep604_optional_fwd: Mapped["Decimal | float | None"] = ( + mapped_column() + ) is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) @@ -2508,9 +2508,9 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): collection_class=list ) elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship(collection_class=list) + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship(collection_class=list) + ) else: datatype.fail() @@ -2537,15 +2537,15 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): if datatype.typing_sequence: bs: Mapped[typing.Sequence[B]] = relationship() elif datatype.collections_sequence: - bs: Mapped[ - collections.abc.Sequence[B] - ] = relationship() + bs: Mapped[collections.abc.Sequence[B]] = ( + relationship() + ) elif datatype.typing_mutable_sequence: bs: Mapped[typing.MutableSequence[B]] = relationship() elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship() + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship() + ) else: datatype.fail() diff --git a/test/orm/declarative/test_typed_mapping.py b/test/orm/declarative/test_typed_mapping.py index 37aa216d54..95d97382ee 100644 --- a/test/orm/declarative/test_typed_mapping.py +++ b/test/orm/declarative/test_typed_mapping.py @@ -1508,20 +1508,20 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): data: Mapped[Union[float, Decimal]] = mapped_column() reverse_data: Mapped[Union[Decimal, float]] = mapped_column() - optional_data: Mapped[ - Optional[Union[float, Decimal]] - ] = mapped_column() + optional_data: Mapped[Optional[Union[float, Decimal]]] = ( + mapped_column() + ) # use Optional directly - reverse_optional_data: Mapped[ - Optional[Union[Decimal, float]] - ] = mapped_column() + reverse_optional_data: Mapped[Optional[Union[Decimal, float]]] = ( + mapped_column() + ) # use Union with None, same as Optional but presents differently # (Optional object with __origin__ Union vs. Union) - reverse_u_optional_data: Mapped[ - Union[Decimal, float, None] - ] = mapped_column() + reverse_u_optional_data: Mapped[Union[Decimal, float, None]] = ( + mapped_column() + ) float_data: Mapped[float] = mapped_column() decimal_data: Mapped[Decimal] = mapped_column() @@ -1529,14 +1529,14 @@ class MappedColumnTest(fixtures.TestBase, testing.AssertsCompiledSQL): if compat.py310: pep604_data: Mapped[float | Decimal] = mapped_column() pep604_reverse: Mapped[Decimal | float] = mapped_column() - pep604_optional: Mapped[ - Decimal | float | None - ] = mapped_column() + pep604_optional: Mapped[Decimal | float | None] = ( + mapped_column() + ) pep604_data_fwd: Mapped["float | Decimal"] = mapped_column() pep604_reverse_fwd: Mapped["Decimal | float"] = mapped_column() - pep604_optional_fwd: Mapped[ - "Decimal | float | None" - ] = mapped_column() + pep604_optional_fwd: Mapped["Decimal | float | None"] = ( + mapped_column() + ) is_(User.__table__.c.data.type, our_type) is_false(User.__table__.c.data.nullable) @@ -2499,9 +2499,9 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): collection_class=list ) elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship(collection_class=list) + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship(collection_class=list) + ) else: datatype.fail() @@ -2528,15 +2528,15 @@ class RelationshipLHSTest(fixtures.TestBase, testing.AssertsCompiledSQL): if datatype.typing_sequence: bs: Mapped[typing.Sequence[B]] = relationship() elif datatype.collections_sequence: - bs: Mapped[ - collections.abc.Sequence[B] - ] = relationship() + bs: Mapped[collections.abc.Sequence[B]] = ( + relationship() + ) elif datatype.typing_mutable_sequence: bs: Mapped[typing.MutableSequence[B]] = relationship() elif datatype.collections_mutable_sequence: - bs: Mapped[ - collections.abc.MutableSequence[B] - ] = relationship() + bs: Mapped[collections.abc.MutableSequence[B]] = ( + relationship() + ) else: datatype.fail() diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index 0f9a623bda..49d90f6c43 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -2476,9 +2476,9 @@ class Issue8168Test(AssertsCompiledSQL, fixtures.TestBase): __mapper_args__ = { "polymorphic_identity": "retailer", - "polymorphic_load": "inline" - if use_poly_on_retailer - else None, + "polymorphic_load": ( + "inline" if use_poly_on_retailer else None + ), } return Customer, Store, Retailer diff --git a/test/orm/inheritance/test_basic.py b/test/orm/inheritance/test_basic.py index abd6c86b57..a76f563f81 100644 --- a/test/orm/inheritance/test_basic.py +++ b/test/orm/inheritance/test_basic.py @@ -1933,7 +1933,7 @@ class OptimizedGetOnDeferredTest(fixtures.MappedTest): # a.id is not included in the SELECT list "SELECT b.data FROM a JOIN b ON a.id = b.id " "WHERE a.id = :pk_1", - [{"pk_1": pk}] + [{"pk_1": pk}], # if we used load_scalar_attributes(), it would look like # this # "SELECT b.data AS b_data FROM b WHERE :param_1 = b.id", diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index daaf937b91..be42dc6090 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -2896,9 +2896,11 @@ class BetweenSubclassJoinWExtraJoinedLoad( m1 = aliased(Manager, flat=True) q = sess.query(Engineer, m1).join(Engineer.manager.of_type(m1)) - with _aliased_join_warning( - r"Manager\(managers\)" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Manager\(managers\)") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT engineers.id AS " diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index 52f3cf9c9f..f45194f29c 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -1909,9 +1909,11 @@ class SingleFromPolySelectableTest( e1 = aliased(Engineer, flat=True) q = s.query(Boss).join(e1, e1.manager_id == Boss.id) - with _aliased_join_warning( - r"Mapper\[Engineer\(engineer\)\]" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Mapper\[Engineer\(engineer\)\]") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT manager.id AS manager_id, employee.id AS employee_id, " @@ -1974,9 +1976,11 @@ class SingleFromPolySelectableTest( b1 = aliased(Boss, flat=True) q = s.query(Engineer).join(b1, Engineer.manager_id == b1.id) - with _aliased_join_warning( - r"Mapper\[Boss\(manager\)\]" - ) if autoalias else nullcontext(): + with ( + _aliased_join_warning(r"Mapper\[Boss\(manager\)\]") + if autoalias + else nullcontext() + ): self.assert_compile( q, "SELECT engineer.id AS engineer_id, " diff --git a/test/orm/test_assorted_eager.py b/test/orm/test_assorted_eager.py index 677f8f2073..f14cdda5b6 100644 --- a/test/orm/test_assorted_eager.py +++ b/test/orm/test_assorted_eager.py @@ -6,6 +6,7 @@ These are generally very old 0.1-era tests and at some point should be cleaned up and modernized. """ + import datetime import sqlalchemy as sa diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index ded2c25db7..f9a1ba3865 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -411,11 +411,11 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): assert_data = [ { "start": d["start"] if "start" in d else None, - "end": d["end"] - if "end" in d - else Point(d["x2"], d["y2"]) - if "x2" in d - else None, + "end": ( + d["end"] + if "end" in d + else Point(d["x2"], d["y2"]) if "x2" in d else None + ), "graph_id": d["graph_id"], } for d in data @@ -916,9 +916,11 @@ class EventsEtcTest(fixtures.MappedTest): mock.call( e1, Point(5, 6), - LoaderCallableStatus.NO_VALUE - if not active_history - else None, + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else None + ), Edge.start.impl, ) ], @@ -965,9 +967,11 @@ class EventsEtcTest(fixtures.MappedTest): mock.call( e1, Point(7, 8), - LoaderCallableStatus.NO_VALUE - if not active_history - else Point(5, 6), + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else Point(5, 6) + ), Edge.start.impl, ) ], @@ -1019,9 +1023,11 @@ class EventsEtcTest(fixtures.MappedTest): [ mock.call( e1, - LoaderCallableStatus.NO_VALUE - if not active_history - else Point(5, 6), + ( + LoaderCallableStatus.NO_VALUE + if not active_history + else Point(5, 6) + ), Edge.start.impl, ) ], diff --git a/test/orm/test_cycles.py b/test/orm/test_cycles.py index 7f0f504b56..cffde9bdab 100644 --- a/test/orm/test_cycles.py +++ b/test/orm/test_cycles.py @@ -5,6 +5,7 @@ T1<->T2, with o2m or m2o between them, and a third T3 with o2m/m2o to one/both T1/T2. """ + from itertools import count from sqlalchemy import bindparam diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 5d6bc9a686..b748779693 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -2204,11 +2204,13 @@ class BindSensitiveStringifyTest(fixtures.MappedTest): eq_ignore_whitespace( str(q), - "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = ?" - if expect_bound - else "SELECT users.id AS users_id, users.name AS users_name " - "FROM users WHERE users.id = :id_1", + ( + "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = ?" + if expect_bound + else "SELECT users.id AS users_id, users.name AS users_name " + "FROM users WHERE users.id = :id_1" + ), ) def test_query_bound_session(self): @@ -2242,7 +2244,6 @@ class DeprecationScopedSessionTest(fixtures.MappedTest): class RequirementsTest(fixtures.MappedTest): - """Tests the contract for user classes.""" @classmethod diff --git a/test/orm/test_dynamic.py b/test/orm/test_dynamic.py index 83f3101f20..cce3f8c18a 100644 --- a/test/orm/test_dynamic.py +++ b/test/orm/test_dynamic.py @@ -1444,9 +1444,11 @@ class DynamicUOWTest( addresses_args={ "order_by": addresses.c.id, "backref": "user", - "cascade": "save-update" - if not delete_cascade_configured - else "all, delete", + "cascade": ( + "save-update" + if not delete_cascade_configured + else "all, delete" + ), } ) @@ -1519,9 +1521,11 @@ class WriteOnlyUOWTest( data: Mapped[str] bs: WriteOnlyMapped["B"] = relationship( # noqa: F821 passive_deletes=passive_deletes, - cascade="all, delete-orphan" - if cascade_deletes - else "save-update, merge", + cascade=( + "all, delete-orphan" + if cascade_deletes + else "save-update, merge" + ), order_by="B.id", ) @@ -1986,9 +1990,11 @@ class _HistoryTest: attributes.get_history( obj, attrname, - PassiveFlag.PASSIVE_NO_FETCH - if self.lazy == "write_only" - else PassiveFlag.PASSIVE_OFF, + ( + PassiveFlag.PASSIVE_NO_FETCH + if self.lazy == "write_only" + else PassiveFlag.PASSIVE_OFF + ), ), compare, ) diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index b1b6e86b79..2e762c2d3c 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -3697,7 +3697,6 @@ class InnerJoinSplicingWSecondaryTest( class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): - """test #2188""" __dialect__ = "default" @@ -3892,7 +3891,6 @@ class SubqueryAliasingTest(fixtures.MappedTest, testing.AssertsCompiledSQL): class LoadOnExistingTest(_fixtures.FixtureTest): - """test that loaders from a base Query fully populate.""" run_inserts = "once" @@ -5309,7 +5307,6 @@ class SubqueryTest(fixtures.MappedTest): class CorrelatedSubqueryTest(fixtures.MappedTest): - """tests for #946, #947, #948. The "users" table is joined to "stuff", and the relationship @@ -6633,7 +6630,6 @@ class DeepOptionsTest(_fixtures.FixtureTest): class SecondaryOptionsTest(fixtures.MappedTest): - """test that the contains_eager() option doesn't bleed into a secondary load.""" diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 56d16dfcd7..02e00fe947 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -390,9 +390,9 @@ class ORMExecuteTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): is_orm_statement=ctx.is_orm_statement, is_relationship_load=ctx.is_relationship_load, is_column_load=ctx.is_column_load, - lazy_loaded_from=ctx.lazy_loaded_from - if ctx.is_select - else None, + lazy_loaded_from=( + ctx.lazy_loaded_from if ctx.is_select else None + ), ) return canary @@ -1545,9 +1545,11 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): ( lambda session: session, "loaded_as_persistent", - lambda session, instance: instance.unloaded - if instance.__class__.__name__ == "A" - else None, + lambda session, instance: ( + instance.unloaded + if instance.__class__.__name__ == "A" + else None + ), ), argnames="target, event_name, fn", )(fn) @@ -1669,7 +1671,6 @@ class DeclarativeEventListenTest( class DeferredMapperEventsTest(RemoveORMEventsGlobally, _fixtures.FixtureTest): - """ "test event listeners against unmapped classes. This incurs special logic. Note if we ever do the "remove" case, diff --git a/test/orm/test_hasparent.py b/test/orm/test_hasparent.py index 8f61c11970..72c90b6d5c 100644 --- a/test/orm/test_hasparent.py +++ b/test/orm/test_hasparent.py @@ -1,4 +1,5 @@ """test the current state of the hasparent() flag.""" + from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import testing diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index 4ab9617123..64c86853d2 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -993,7 +993,6 @@ class LazyTest(_fixtures.FixtureTest): class GetterStateTest(_fixtures.FixtureTest): - """test lazyloader on non-existent attribute returns expected attribute symbols, maintain expected state""" @@ -1080,11 +1079,13 @@ class GetterStateTest(_fixtures.FixtureTest): properties={ "user": relationship( User, - primaryjoin=and_( - users.c.id == addresses.c.user_id, users.c.id != 27 - ) - if dont_use_get - else None, + primaryjoin=( + and_( + users.c.id == addresses.c.user_id, users.c.id != 27 + ) + if dont_use_get + else None + ), back_populates="addresses", ) }, diff --git a/test/orm/test_mapper.py b/test/orm/test_mapper.py index f90803d6e4..f93c18d216 100644 --- a/test/orm/test_mapper.py +++ b/test/orm/test_mapper.py @@ -2555,7 +2555,6 @@ class MapperTest(_fixtures.FixtureTest, AssertsCompiledSQL): class RequirementsTest(fixtures.MappedTest): - """Tests the contract for user classes.""" @classmethod diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index 0c8e2651cd..c313c4b33d 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -1476,9 +1476,7 @@ class MergeTest(_fixtures.FixtureTest): CountStatements( 0 if load.noload - else 1 - if merge_persistent.merge_persistent - else 2 + else 1 if merge_persistent.merge_persistent else 2 ) ) diff --git a/test/orm/test_options.py b/test/orm/test_options.py index 7c96539583..9362d52470 100644 --- a/test/orm/test_options.py +++ b/test/orm/test_options.py @@ -976,9 +976,11 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Keyword = self.classes.Keyword self._assert_eager_with_entity_exception( [Item], - lambda: (joinedload(Keyword),) - if first_element - else (Load(Item).joinedload(Keyword),), + lambda: ( + (joinedload(Keyword),) + if first_element + else (Load(Item).joinedload(Keyword),) + ), "expected ORM mapped attribute for loader " "strategy argument", ) @@ -990,9 +992,11 @@ class OptionsNoPropTest(_fixtures.FixtureTest): Item = self.classes.Item self._assert_eager_with_entity_exception( [Item], - lambda: (joinedload(rando),) - if first_element - else (Load(Item).joinedload(rando)), + lambda: ( + (joinedload(rando),) + if first_element + else (Load(Item).joinedload(rando)) + ), "expected ORM mapped attribute for loader strategy argument", ) @@ -1002,9 +1006,11 @@ class OptionsNoPropTest(_fixtures.FixtureTest): self._assert_eager_with_entity_exception( [OrderWProp], - lambda: (joinedload(OrderWProp.some_attr),) - if first_element - else (Load(OrderWProp).joinedload(OrderWProp.some_attr),), + lambda: ( + (joinedload(OrderWProp.some_attr),) + if first_element + else (Load(OrderWProp).joinedload(OrderWProp.some_attr),) + ), "expected ORM mapped attribute for loader strategy argument", ) diff --git a/test/orm/test_relationship_criteria.py b/test/orm/test_relationship_criteria.py index aebdf6922a..69279f6004 100644 --- a/test/orm/test_relationship_criteria.py +++ b/test/orm/test_relationship_criteria.py @@ -1908,9 +1908,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -1976,9 +1978,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2033,9 +2037,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2129,9 +2135,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result, - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) @testing.combinations((True,), (False,), argnames="use_compiled_cache") @@ -2237,9 +2245,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( @@ -2309,9 +2319,11 @@ class RelationshipCriteriaTest(_Fixtures, testing.AssertsCompiledSQL): eq_( result.scalars().unique().all(), - self._user_minus_edwood(*user_address_fixture) - if value == "ed@wood.com" - else self._user_minus_edlala(*user_address_fixture), + ( + self._user_minus_edwood(*user_address_fixture) + if value == "ed@wood.com" + else self._user_minus_edlala(*user_address_fixture) + ), ) asserter.assert_( diff --git a/test/orm/test_relationships.py b/test/orm/test_relationships.py index d644d26793..db1e90dad2 100644 --- a/test/orm/test_relationships.py +++ b/test/orm/test_relationships.py @@ -183,7 +183,6 @@ class _RelationshipErrors: class DependencyTwoParentTest(fixtures.MappedTest): - """Test flush() when a mapper is dependent on multiple relationships""" run_setup_mappers = "once" @@ -430,7 +429,6 @@ class M2ODontOverwriteFKTest(fixtures.MappedTest): class DirectSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): - """Tests the ultimate join condition, a single column that points to itself, e.g. within a SQL function or similar. The test is against a materialized path setup. @@ -1022,7 +1020,6 @@ class OverlappingFksSiblingTest(fixtures.MappedTest): class CompositeSelfRefFKTest(fixtures.MappedTest, AssertsCompiledSQL): - """Tests a composite FK where, in the relationship(), one col points to itself in the same table. @@ -1506,7 +1503,6 @@ class CompositeJoinPartialFK(fixtures.MappedTest, AssertsCompiledSQL): class SynonymsAsFKsTest(fixtures.MappedTest): - """Syncrules on foreign keys that are also primary""" @classmethod @@ -1578,7 +1574,6 @@ class SynonymsAsFKsTest(fixtures.MappedTest): class FKsAsPksTest(fixtures.MappedTest): - """Syncrules on foreign keys that are also primary""" @classmethod @@ -1863,7 +1858,6 @@ class FKsAsPksTest(fixtures.MappedTest): class UniqueColReferenceSwitchTest(fixtures.MappedTest): - """test a relationship based on a primary join against a unique non-pk column""" @@ -1928,7 +1922,6 @@ class UniqueColReferenceSwitchTest(fixtures.MappedTest): class RelationshipToSelectableTest(fixtures.MappedTest): - """Test a map to a select that relates to a map to the table.""" @classmethod @@ -2022,7 +2015,6 @@ class RelationshipToSelectableTest(fixtures.MappedTest): class FKEquatedToConstantTest(fixtures.MappedTest): - """test a relationship with a non-column entity in the primary join, is not viewonly, and also has the non-column's clause mentioned in the foreign keys list. @@ -2159,7 +2151,6 @@ class BackrefPropagatesForwardsArgs(fixtures.MappedTest): class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): - """test ambiguous joins due to FKs on both sides treated as self-referential. @@ -2254,7 +2245,6 @@ class AmbiguousJoinInterpretedAsSelfRef(fixtures.MappedTest): class ManualBackrefTest(_fixtures.FixtureTest): - """Test explicit relationships that are backrefs to each other.""" run_inserts = None @@ -2485,7 +2475,6 @@ class ManualBackrefTest(_fixtures.FixtureTest): class NoLoadBackPopulates(_fixtures.FixtureTest): - """test the noload stratgegy which unlike others doesn't use lazyloader to set up instrumentation""" @@ -2732,7 +2721,6 @@ class JoinConditionErrorTest(fixtures.TestBase): class TypeMatchTest(fixtures.MappedTest): - """test errors raised when trying to add items whose type is not handled by a relationship""" @@ -3000,7 +2988,6 @@ class TypedAssociationTable(fixtures.MappedTest): class CustomOperatorTest(fixtures.MappedTest, AssertsCompiledSQL): - """test op() in conjunction with join conditions""" run_create_tables = run_deletes = None @@ -3278,7 +3265,6 @@ class ViewOnlyM2MBackrefTest(fixtures.MappedTest): class ViewOnlyOverlappingNames(fixtures.MappedTest): - """'viewonly' mappings with overlapping PK column names.""" @classmethod @@ -3534,7 +3520,6 @@ class ViewOnlySyncBackref(fixtures.MappedTest): class ViewOnlyUniqueNames(fixtures.MappedTest): - """'viewonly' mappings with unique PK column names.""" @classmethod @@ -3636,7 +3621,6 @@ class ViewOnlyUniqueNames(fixtures.MappedTest): class ViewOnlyLocalRemoteM2M(fixtures.TestBase): - """test that local-remote is correctly determined for m2m""" def test_local_remote(self, registry): @@ -3675,7 +3659,6 @@ class ViewOnlyLocalRemoteM2M(fixtures.TestBase): class ViewOnlyNonEquijoin(fixtures.MappedTest): - """'viewonly' mappings based on non-equijoins.""" @classmethod @@ -3737,7 +3720,6 @@ class ViewOnlyNonEquijoin(fixtures.MappedTest): class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): - """'viewonly' mappings that contain the same 'remote' column twice""" @classmethod @@ -3811,7 +3793,6 @@ class ViewOnlyRepeatedRemoteColumn(fixtures.MappedTest): class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): - """'viewonly' mappings that contain the same 'local' column twice""" @classmethod @@ -3886,7 +3867,6 @@ class ViewOnlyRepeatedLocalColumn(fixtures.MappedTest): class ViewOnlyComplexJoin(_RelationshipErrors, fixtures.MappedTest): - """'viewonly' mappings with a complex join condition.""" @classmethod @@ -4088,7 +4068,6 @@ class FunctionAsPrimaryJoinTest(fixtures.DeclarativeMappedTest): class RemoteForeignBetweenColsTest(fixtures.DeclarativeMappedTest): - """test a complex annotation using between(). Using declarative here as an integration test for the local() @@ -4705,7 +4684,6 @@ class SecondaryArgTest(fixtures.TestBase): class SecondaryNestedJoinTest( fixtures.MappedTest, AssertsCompiledSQL, testing.AssertsExecutionResults ): - """test support for a relationship where the 'secondary' table is a compound join(). @@ -6473,7 +6451,6 @@ class RaiseLoadTest(_fixtures.FixtureTest): class RelationDeprecationTest(fixtures.MappedTest): - """test usage of the old 'relation' function.""" run_inserts = "once" diff --git a/test/orm/test_selectable.py b/test/orm/test_selectable.py index 3a7029110e..d4ea0e2919 100644 --- a/test/orm/test_selectable.py +++ b/test/orm/test_selectable.py @@ -1,4 +1,5 @@ """Generic mapping to Select statements""" + import sqlalchemy as sa from sqlalchemy import column from sqlalchemy import Integer diff --git a/test/orm/test_transaction.py b/test/orm/test_transaction.py index d6f22622ea..e502a88833 100644 --- a/test/orm/test_transaction.py +++ b/test/orm/test_transaction.py @@ -1285,7 +1285,6 @@ class FixtureDataTest(_LocalFixture): class CleanSavepointTest(FixtureTest): - """test the behavior for [ticket:2452] - rollback on begin_nested() only expires objects tracked as being modified in that transaction. @@ -2625,12 +2624,14 @@ class ReallyNewJoinIntoAnExternalTransactionTest( self.session = Session( self.connection, - join_transaction_mode="create_savepoint" - if ( - self.join_mode.create_savepoint - or self.join_mode.create_savepoint_w_savepoint - ) - else "conditional_savepoint", + join_transaction_mode=( + "create_savepoint" + if ( + self.join_mode.create_savepoint + or self.join_mode.create_savepoint_w_savepoint + ) + else "conditional_savepoint" + ), ) def teardown_session(self): diff --git a/test/orm/test_unitofwork.py b/test/orm/test_unitofwork.py index 0937c354f9..3b3175e10e 100644 --- a/test/orm/test_unitofwork.py +++ b/test/orm/test_unitofwork.py @@ -1149,9 +1149,9 @@ class DefaultTest(fixtures.MappedTest): mp = self.mapper_registry.map_imperatively( Hoho, default_t, - eager_defaults="auto" - if eager_defaults.auto - else bool(eager_defaults), + eager_defaults=( + "auto" if eager_defaults.auto else bool(eager_defaults) + ), ) h1 = Hoho(hoho=althohoval) diff --git a/test/orm/test_unitofworkv2.py b/test/orm/test_unitofworkv2.py index 1a5b697b8e..e01220d115 100644 --- a/test/orm/test_unitofworkv2.py +++ b/test/orm/test_unitofworkv2.py @@ -2171,7 +2171,6 @@ class BatchInsertsTest(fixtures.MappedTest, testing.AssertsExecutionResults): class LoadersUsingCommittedTest(UOWTest): - """Test that events which occur within a flush() get the same attribute loading behavior as on the outside of the flush, and that the unit of work itself uses the @@ -2260,7 +2259,6 @@ class LoadersUsingCommittedTest(UOWTest): Address, User = self.classes.Address, self.classes.User class AvoidReferencialError(Exception): - """the test here would require ON UPDATE CASCADE on FKs for the flush to fully succeed; this exception is used to cancel the flush before we get that far. diff --git a/test/perf/many_table_reflection.py b/test/perf/many_table_reflection.py index d65c272430..4fa768a74e 100644 --- a/test/perf/many_table_reflection.py +++ b/test/perf/many_table_reflection.py @@ -41,9 +41,9 @@ def generate_table(meta: sa.MetaData, min_cols, max_cols, dialect_name): f"table_{table_num}_col_{i + 1}", *args, primary_key=i == 0, - comment=f"primary key of table_{table_num}" - if i == 0 - else None, + comment=( + f"primary key of table_{table_num}" if i == 0 else None + ), index=random.random() > 0.97 and i > 0, unique=random.random() > 0.97 and i > 0, ) diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py index d6bc098964..5756bb6927 100644 --- a/test/sql/test_compiler.py +++ b/test/sql/test_compiler.py @@ -7525,7 +7525,6 @@ class CoercionTest(fixtures.TestBase, AssertsCompiledSQL): class ResultMapTest(fixtures.TestBase): - """test the behavior of the 'entry stack' and the determination when the result_map needs to be populated. @@ -7740,9 +7739,9 @@ class ResultMapTest(fixtures.TestBase): with mock.patch.object( dialect.statement_compiler, "translate_select_structure", - lambda self, to_translate, **kw: wrapped_again - if to_translate is stmt - else to_translate, + lambda self, to_translate, **kw: ( + wrapped_again if to_translate is stmt else to_translate + ), ): compiled = stmt.compile(dialect=dialect) @@ -7799,9 +7798,9 @@ class ResultMapTest(fixtures.TestBase): with mock.patch.object( dialect.statement_compiler, "translate_select_structure", - lambda self, to_translate, **kw: wrapped_again - if to_translate is stmt - else to_translate, + lambda self, to_translate, **kw: ( + wrapped_again if to_translate is stmt else to_translate + ), ): compiled = stmt.compile(dialect=dialect) diff --git a/test/sql/test_cte.py b/test/sql/test_cte.py index 23ac87a214..0b665b84da 100644 --- a/test/sql/test_cte.py +++ b/test/sql/test_cte.py @@ -613,7 +613,7 @@ class CTETest(fixtures.TestBase, AssertsCompiledSQL): stmt, "WITH anon_1 AS (SELECT test.a AS b FROM test %s b) " "SELECT (SELECT anon_1.b FROM anon_1) AS c" - % ("ORDER BY" if order_by == "order_by" else "GROUP BY") + % ("ORDER BY" if order_by == "order_by" else "GROUP BY"), # prior to the fix, the use_object version came out as: # "WITH anon_1 AS (SELECT test.a AS b FROM test " # "ORDER BY test.a) " diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index bbfb3b0778..bcfdfcdb9c 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -1234,7 +1234,6 @@ class AutoIncrementTest(fixtures.TestBase): class SpecialTypePKTest(fixtures.TestBase): - """test process_result_value in conjunction with primary key columns. Also tests that "autoincrement" checks are against diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index e474e75d75..0204d6e6fc 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -54,7 +54,6 @@ A = B = t1 = t2 = t3 = table1 = table2 = table3 = table4 = None class TraversalTest( fixtures.TestBase, AssertsExecutionResults, AssertsCompiledSQL ): - """test ClauseVisitor's traversal, particularly its ability to copy and modify a ClauseElement in place.""" @@ -362,7 +361,6 @@ class TraversalTest( class BinaryEndpointTraversalTest(fixtures.TestBase): - """test the special binary product visit""" def _assert_traversal(self, expr, expected): @@ -443,7 +441,6 @@ class BinaryEndpointTraversalTest(fixtures.TestBase): class ClauseTest(fixtures.TestBase, AssertsCompiledSQL): - """test copy-in-place behavior of various ClauseElements.""" __dialect__ = "default" @@ -2716,7 +2713,6 @@ class SpliceJoinsTest(fixtures.TestBase, AssertsCompiledSQL): class SelectTest(fixtures.TestBase, AssertsCompiledSQL): - """tests the generative capability of Select""" __dialect__ = "default" @@ -2811,7 +2807,6 @@ class SelectTest(fixtures.TestBase, AssertsCompiledSQL): class ValuesBaseTest(fixtures.TestBase, AssertsCompiledSQL): - """Tests the generative capability of Insert, Update""" __dialect__ = "default" diff --git a/test/sql/test_insert_exec.py b/test/sql/test_insert_exec.py index e9eda0e5bd..4c6c5407b5 100644 --- a/test/sql/test_insert_exec.py +++ b/test/sql/test_insert_exec.py @@ -472,7 +472,6 @@ class InsertExecTest(fixtures.TablesTest): class TableInsertTest(fixtures.TablesTest): - """test for consistent insert behavior across dialects regarding the inline() method, values() method, lower-case 't' tables. @@ -1766,9 +1765,11 @@ class IMVSentinelTest(fixtures.TestBase): Column( "id", Uuid(), - server_default=func.gen_random_uuid() - if default_type.server_side - else None, + server_default=( + func.gen_random_uuid() + if default_type.server_side + else None + ), default=uuid.uuid4 if default_type.client_side else None, primary_key=True, insert_sentinel=bool(add_insert_sentinel), diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index eed861fe17..627310d8f1 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -413,9 +413,11 @@ class LambdaElementTest( stmt = lambda_stmt(lambda: select(tab)) stmt = stmt.add_criteria( - lambda s: s.where(tab.c.col > parameter) - if add_criteria - else s.where(tab.c.col == parameter), + lambda s: ( + s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter) + ), ) stmt += lambda s: s.order_by(tab.c.id) @@ -437,9 +439,11 @@ class LambdaElementTest( stmt = lambda_stmt(lambda: select(tab)) stmt = stmt.add_criteria( - lambda s: s.where(tab.c.col > parameter) - if add_criteria - else s.where(tab.c.col == parameter), + lambda s: ( + s.where(tab.c.col > parameter) + if add_criteria + else s.where(tab.c.col == parameter) + ), track_on=[add_criteria], ) @@ -1945,9 +1949,9 @@ class DeferredLambdaElementTest( # lambda produces either "t1 IN vv" or "t2 IN qq" based on the # argument. will not produce a consistent cache key elem = lambdas.DeferredLambdaElement( - lambda tab: tab.c.q.in_(vv) - if tab.name == "t1" - else tab.c.q.in_(qq), + lambda tab: ( + tab.c.q.in_(vv) if tab.name == "t1" else tab.c.q.in_(qq) + ), roles.WhereHavingRole, lambda_args=(t1,), opts=lambdas.LambdaOptions(track_closure_variables=False), diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index aa3cec3dad..3592bc6f00 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -4146,7 +4146,6 @@ class ConstraintTest(fixtures.TestBase): class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): - """Test Column() construction.""" __dialect__ = "default" @@ -4562,7 +4561,6 @@ class ColumnDefinitionTest(AssertsCompiledSQL, fixtures.TestBase): class ColumnDefaultsTest(fixtures.TestBase): - """test assignment of default fixures to columns""" def _fixture(self, *arg, **kw): @@ -5792,9 +5790,11 @@ class NamingConventionTest(fixtures.TestBase, AssertsCompiledSQL): "b", metadata, Column("id", Integer, primary_key=True), - Column("aid", ForeignKey("a.id")) - if not col_has_type - else Column("aid", Integer, ForeignKey("a.id")), + ( + Column("aid", ForeignKey("a.id")) + if not col_has_type + else Column("aid", Integer, ForeignKey("a.id")) + ), ) fks = list( c for c in b.constraints if isinstance(c, ForeignKeyConstraint) diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 640e70a0a6..c0b5cb47d6 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -483,19 +483,24 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): if negate: self.assert_compile( select(~expr), - f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) " - "AS anon_1 FROM t" - if not reverse - else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) " - "AS anon_1 FROM t", + ( + f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) " + "AS anon_1 FROM t" + if not reverse + else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) " + "AS anon_1 FROM t" + ), ) else: self.assert_compile( select(expr), - f"SELECT t.q{opstring}t.p{opstring}{exprs} AS anon_1 FROM t" - if not reverse - else f"SELECT {exprs}{opstring}t.q{opstring}t.p " - f"AS anon_1 FROM t", + ( + f"SELECT t.q{opstring}t.p{opstring}{exprs} " + "AS anon_1 FROM t" + if not reverse + else f"SELECT {exprs}{opstring}t.q{opstring}t.p " + "AS anon_1 FROM t" + ), ) @testing.combinations( @@ -565,9 +570,11 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( select(~expr), - f"SELECT {str_expr} AS anon_1 FROM t" - if not reverse - else f"SELECT {str_expr} AS anon_1 FROM t", + ( + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t" + ), ) else: if reverse: @@ -583,9 +590,11 @@ class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile( select(expr), - f"SELECT {str_expr} AS anon_1 FROM t" - if not reverse - else f"SELECT {str_expr} AS anon_1 FROM t", + ( + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t" + ), ) @@ -650,9 +659,11 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): col = column("somecol", modulus()) self.assert_compile( col.modulus(), - "somecol %%" - if paramstyle in ("format", "pyformat") - else "somecol %", + ( + "somecol %%" + if paramstyle in ("format", "pyformat") + else "somecol %" + ), dialect=default.DefaultDialect(paramstyle=paramstyle), ) @@ -667,9 +678,11 @@ class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): col = column("somecol", modulus()) self.assert_compile( col.modulus_prefix(), - "%% somecol" - if paramstyle in ("format", "pyformat") - else "% somecol", + ( + "%% somecol" + if paramstyle in ("format", "pyformat") + else "% somecol" + ), dialect=default.DefaultDialect(paramstyle=paramstyle), ) @@ -1272,7 +1285,6 @@ class ArrayIndexOpTest(fixtures.TestBase, testing.AssertsCompiledSQL): class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): - """test standalone booleans being wrapped in an AsBoolean, as well as true/false compilation.""" @@ -1433,7 +1445,6 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): class ConjunctionTest(fixtures.TestBase, testing.AssertsCompiledSQL): - """test interaction of and_()/or_() with boolean , null constants""" __dialect__ = default.DefaultDialect(supports_native_boolean=True) diff --git a/test/sql/test_query.py b/test/sql/test_query.py index 54943897e1..5d7788fcf1 100644 --- a/test/sql/test_query.py +++ b/test/sql/test_query.py @@ -1076,7 +1076,6 @@ class LimitTest(fixtures.TablesTest): class CompoundTest(fixtures.TablesTest): - """test compound statements like UNION, INTERSECT, particularly their ability to nest on different databases.""" @@ -1463,7 +1462,6 @@ class CompoundTest(fixtures.TablesTest): class JoinTest(fixtures.TablesTest): - """Tests join execution. The compiled SQL emitted by the dialect might be ANSI joins or diff --git a/test/sql/test_quote.py b/test/sql/test_quote.py index 08c9c4207e..51382b19b4 100644 --- a/test/sql/test_quote.py +++ b/test/sql/test_quote.py @@ -858,7 +858,6 @@ class QuoteTest(fixtures.TestBase, AssertsCompiledSQL): class PreparerTest(fixtures.TestBase): - """Test the db-agnostic quoting services of IdentifierPreparer.""" def test_unformat(self): diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 1848f7bdd3..cad58f8b0c 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -1303,11 +1303,15 @@ class CursorResultTest(fixtures.TablesTest): stmt = select( *[ - text("*") - if colname == "*" - else users.c.user_name.label("name_label") - if colname == "name_label" - else users.c[colname] + ( + text("*") + if colname == "*" + else ( + users.c.user_name.label("name_label") + if colname == "name_label" + else users.c[colname] + ) + ) for colname in cols ] ) diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index 4d55c435db..6cccd01d4a 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -690,7 +690,6 @@ class SequenceReturningTest(fixtures.TablesTest): class KeyReturningTest(fixtures.TablesTest, AssertsExecutionResults): - """test returning() works with columns that define 'key'.""" __requires__ = ("insert_returning",) @@ -1561,9 +1560,11 @@ class InsertManyReturningTest(fixtures.TablesTest): config, t1, (t1.c.id, t1.c.insdef, t1.c.data), - set_lambda=(lambda excluded: {"data": excluded.data + " excluded"}) - if update_cols - else None, + set_lambda=( + (lambda excluded: {"data": excluded.data + " excluded"}) + if update_cols + else None + ), ) upserted_rows = connection.execute( diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index d3b7b47841..0c0c23b870 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -1,4 +1,5 @@ """Test various algorithmic properties of selectables.""" + from itertools import zip_longest from sqlalchemy import and_ @@ -1962,7 +1963,6 @@ class RefreshForNewColTest(fixtures.TestBase): class AnonLabelTest(fixtures.TestBase): - """Test behaviors fixed by [ticket:2168].""" def test_anon_labels_named_column(self): diff --git a/test/sql/test_text.py b/test/sql/test_text.py index de40c8f429..301ad9ffdf 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -71,7 +71,6 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL): - """test the usage of text() implicit within the select() construct when strings are passed.""" diff --git a/test/sql/test_types.py b/test/sql/test_types.py index eb91d9c4cd..76249f5617 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -1417,9 +1417,11 @@ class TypeCoerceCastTest(fixtures.TablesTest): # on the way in here eq_( conn.execute(new_stmt).fetchall(), - [("x", "BIND_INxBIND_OUT")] - if coerce_fn is type_coerce - else [("x", "xBIND_OUT")], + ( + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")] + ), ) def test_cast_bind(self, connection): @@ -1441,9 +1443,11 @@ class TypeCoerceCastTest(fixtures.TablesTest): eq_( conn.execute(stmt).fetchall(), - [("x", "BIND_INxBIND_OUT")] - if coerce_fn is type_coerce - else [("x", "xBIND_OUT")], + ( + [("x", "BIND_INxBIND_OUT")] + if coerce_fn is type_coerce + else [("x", "xBIND_OUT")] + ), ) def test_cast_existing_typed(self, connection): @@ -3876,7 +3880,6 @@ class TestKWArgPassThru(AssertsCompiledSQL, fixtures.TestBase): class NumericRawSQLTest(fixtures.TestBase): - """Test what DBAPIs and dialects return without any typing information supplied at the SQLA level. @@ -4007,7 +4010,6 @@ class IntegerTest(fixtures.TestBase): class BooleanTest( fixtures.TablesTest, AssertsExecutionResults, AssertsCompiledSQL ): - """test edge cases for booleans. Note that the main boolean test suite is now in testing/suite/test_types.py diff --git a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py index 664ff0411d..d9997141a1 100644 --- a/test/typing/plain_files/ext/asyncio/async_sessionmaker.py +++ b/test/typing/plain_files/ext/asyncio/async_sessionmaker.py @@ -2,6 +2,7 @@ for asynchronous ORM use. """ + from __future__ import annotations import asyncio diff --git a/test/typing/plain_files/orm/issue_9340.py b/test/typing/plain_files/orm/issue_9340.py index 20bc424ce2..6ccd2eed31 100644 --- a/test/typing/plain_files/orm/issue_9340.py +++ b/test/typing/plain_files/orm/issue_9340.py @@ -10,8 +10,7 @@ from sqlalchemy.orm import Session from sqlalchemy.orm import with_polymorphic -class Base(DeclarativeBase): - ... +class Base(DeclarativeBase): ... class Message(Base): diff --git a/test/typing/plain_files/orm/mapped_covariant.py b/test/typing/plain_files/orm/mapped_covariant.py index 1a17ee3848..9f964021b3 100644 --- a/test/typing/plain_files/orm/mapped_covariant.py +++ b/test/typing/plain_files/orm/mapped_covariant.py @@ -24,8 +24,7 @@ class ChildProtocol(Protocol): # Read-only for simplicity, mutable protocol members are complicated, # see https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected @property - def parent(self) -> Mapped[ParentProtocol]: - ... + def parent(self) -> Mapped[ParentProtocol]: ... def get_parent_name(child: ChildProtocol) -> str: diff --git a/test/typing/plain_files/orm/relationship.py b/test/typing/plain_files/orm/relationship.py index d0ab35249d..6bfe19cc4e 100644 --- a/test/typing/plain_files/orm/relationship.py +++ b/test/typing/plain_files/orm/relationship.py @@ -1,6 +1,7 @@ """this suite experiments with other kinds of relationship syntaxes. """ + from __future__ import annotations import typing diff --git a/test/typing/plain_files/orm/trad_relationship_uselist.py b/test/typing/plain_files/orm/trad_relationship_uselist.py index 8d7d7e71a2..9282181f01 100644 --- a/test/typing/plain_files/orm/trad_relationship_uselist.py +++ b/test/typing/plain_files/orm/trad_relationship_uselist.py @@ -2,6 +2,7 @@ """ + import typing from typing import cast from typing import Dict diff --git a/test/typing/plain_files/orm/traditional_relationship.py b/test/typing/plain_files/orm/traditional_relationship.py index 02afc7c801..bd6bada528 100644 --- a/test/typing/plain_files/orm/traditional_relationship.py +++ b/test/typing/plain_files/orm/traditional_relationship.py @@ -5,6 +5,7 @@ This requires that the return type of relationship is based on Any, if no uselists are present. """ + import typing from typing import List from typing import Set diff --git a/test/typing/plain_files/sql/common_sql_element.py b/test/typing/plain_files/sql/common_sql_element.py index 730d99bc15..89c0c4d2ef 100644 --- a/test/typing/plain_files/sql/common_sql_element.py +++ b/test/typing/plain_files/sql/common_sql_element.py @@ -6,7 +6,6 @@ unions. """ - from __future__ import annotations from sqlalchemy import asc diff --git a/tox.ini b/tox.ini index dbffc9e206..900165fd7e 100644 --- a/tox.ini +++ b/tox.ini @@ -227,7 +227,7 @@ deps= # in case it requires a version pin pydocstyle pygments - black==23.3.0 + black==24.1.1 slotscheck>=0.17.0 # required by generate_tuple_map_overloads