From: Mike Bayer Date: Wed, 13 Apr 2022 13:45:29 +0000 (-0400) Subject: pep484: schema API X-Git-Tag: rel_2_0_0b1~347^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c932123bacad9bf047d160b85e3f95d396c513ae;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git pep484: schema API implement strict typing for schema.py this module has lots of public API, lots of old decisions and very hard to follow construction sequences in many cases, and is also where we get a lot of new feature requests, so strict typing should help keep things clean. among improvements here, fixed the pool .info getters and also figured out how to get ColumnCollection and related to be covariant so that we may set them up as returning Column or ColumnClause without any conflicts. DDL was affected, noting that superclasses of DDLElement (_DDLCompiles, added recently) can now be passed into "ddl_if" callables; reorganized ddl into ExecutableDDLElement as a new name for DDLElement and _DDLCompiles renamed to BaseDDLElement. setting up strict also located an API use case that is completely broken, which is connection.execute(some_default) returns a scalar value. This case has been deprecated and new paths have been set up so that connection.scalar() may be used. This likely wasn't possible in previous versions because scalar() would assume a CursorResult. The scalar() change also impacts Session as we have explicit support (since someone had reported it as a regression) for session.execute(Sequence()) to work. They will get the same deprecation message (which omits the word "Connection", just uses ".execute()" and ".scalar()") and they can then use Session.scalar() as well. Getting this to type correctly while still supporting ORM use cases required some refactoring, and I also set up a keyword only delimeter for Session.execute() and related as execution_options / bind_arguments should always be keyword only, applied these changes to AsyncSession as well. Additionally simpify Table __init__ now that we are Python 3 only, we can have positional plus explicit kwargs finally. Simplify Column.__init__ as well again taking advantage of kw only arguments. Fill in most/all __init__ methods in sqltypes.py as the constructor for types is most of the API. should likely do this for dialect-specific types as well. Apply _InfoType for all info attributes as should have been done originally and update descriptor decorators. Change-Id: I3f9f8ff3f1c8858471ff4545ac83d68c88107527 --- diff --git a/doc/build/changelog/unreleased_20/7631.rst b/doc/build/changelog/unreleased_20/7631.rst index d6c69f5d53..d2e0992ab2 100644 --- a/doc/build/changelog/unreleased_20/7631.rst +++ b/doc/build/changelog/unreleased_20/7631.rst @@ -3,7 +3,8 @@ :tickets: 7631 Expanded on the "conditional DDL" system implemented by the - :class:`_schema.DDLElement` class to be directly available on + :class:`_schema.ExecutableDDLElement` class (renamed from + :class:`_schema.DDLElement`) to be directly available on :class:`_schema.SchemaItem` constructs such as :class:`_schema.Index`, :class:`_schema.ForeignKeyConstraint`, etc. such that the conditional logic for generating these elements is included within the default DDL emitting diff --git a/doc/build/changelog/unreleased_20/exec_default.rst b/doc/build/changelog/unreleased_20/exec_default.rst new file mode 100644 index 0000000000..05ff5862b0 --- /dev/null +++ b/doc/build/changelog/unreleased_20/exec_default.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, engine + + Passing a :class:`.DefaultGenerator` object such as a :class:`.Sequence` to + the :meth:`.Connection.execute` method is deprecated, as this method is + typed as returning a :class:`.CursorResult` object, and not a plain scalar + value. The :meth:`.Connection.scalar` method should be used instead, which + has been reworked with new internal codepaths to suit invoking a SELECT for + default generation objects without going through the + :meth:`.Connection.execute` method. \ No newline at end of file diff --git a/doc/build/core/constraints.rst b/doc/build/core/constraints.rst index ea84c15a3e..97a043aabc 100644 --- a/doc/build/core/constraints.rst +++ b/doc/build/core/constraints.rst @@ -791,6 +791,10 @@ Constraints API :members: :inherited-members: +.. autoclass:: HasConditionalDDL + :members: + :inherited-members: + .. autoclass:: PrimaryKeyConstraint :members: :inherited-members: diff --git a/doc/build/core/ddl.rst b/doc/build/core/ddl.rst index 6bbd494246..c34a4e1a33 100644 --- a/doc/build/core/ddl.rst +++ b/doc/build/core/ddl.rst @@ -49,7 +49,7 @@ Controlling DDL Sequences The :class:`_schema.DDL` construct introduced previously also has the ability to be invoked conditionally based on inspection of the -database. This feature is available using the :meth:`.DDLElement.execute_if` +database. This feature is available using the :meth:`.ExecutableDDLElement.execute_if` method. For example, if we wanted to create a trigger but only on the PostgreSQL backend, we could invoke this as:: @@ -85,7 +85,7 @@ the PostgreSQL backend, we could invoke this as:: trigger.execute_if(dialect='postgresql') ) -The :paramref:`.DDLElement.execute_if.dialect` keyword also accepts a tuple +The :paramref:`.ExecutableDDLElement.execute_if.dialect` keyword also accepts a tuple of string dialect names:: event.listen( @@ -99,7 +99,7 @@ of string dialect names:: trigger.execute_if(dialect=('postgresql', 'mysql')) ) -The :meth:`.DDLElement.execute_if` method can also work against a callable +The :meth:`.ExecutableDDLElement.execute_if` method can also work against a callable function that will receive the database connection in use. In the example below, we use this to conditionally create a CHECK constraint, first looking within the PostgreSQL catalogs to see if it exists: @@ -151,7 +151,7 @@ Using the built-in DDLElement Classes The ``sqlalchemy.schema`` package contains SQL expression constructs that provide DDL expressions, all of which extend from the common base -:class:`.DDLElement`. For example, to produce a ``CREATE TABLE`` statement, +:class:`.ExecutableDDLElement`. For example, to produce a ``CREATE TABLE`` statement, one can use the :class:`.CreateTable` construct: .. sourcecode:: python+sql @@ -171,13 +171,13 @@ one can use the :class:`.CreateTable` construct: Above, the :class:`~sqlalchemy.schema.CreateTable` construct works like any other expression construct (such as ``select()``, ``table.insert()``, etc.). All of SQLAlchemy's DDL oriented constructs are subclasses of -the :class:`.DDLElement` base class; this is the base of all the +the :class:`.ExecutableDDLElement` base class; this is the base of all the objects corresponding to CREATE and DROP as well as ALTER, not only in SQLAlchemy but in Alembic Migrations as well. A full reference of available constructs is in :ref:`schema_api_ddl`. User-defined DDL constructs may also be created as subclasses of -:class:`.DDLElement` itself. The documentation in +:class:`.ExecutableDDLElement` itself. The documentation in :ref:`sqlalchemy.ext.compiler_toplevel` has several examples of this. .. _schema_ddl_ddl_if: @@ -187,7 +187,7 @@ Controlling DDL Generation of Constraints and Indexes .. versionadded:: 2.0 -While the previously mentioned :meth:`.DDLElement.execute_if` method is +While the previously mentioned :meth:`.ExecutableDDLElement.execute_if` method is useful for custom :class:`.DDL` classes which need to invoke conditionally, there is also a common need for elements that are typically related to a particular :class:`.Table`, namely constraints and indexes, to also be @@ -196,7 +196,7 @@ that are specific to a particular backend such as PostgreSQL or SQL Server. For this use case, the :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if` methods may be used against constructs such as :class:`.CheckConstraint`, :class:`.UniqueConstraint` and :class:`.Index`, accepting the same -arguments as the :meth:`.DDLElement.execute_if` method in order to control +arguments as the :meth:`.ExecutableDDLElement.execute_if` method in order to control whether or not their DDL will be emitted in terms of their parent :class:`.Table` object. These methods may be used inline when creating the definition for a :class:`.Table` @@ -271,7 +271,7 @@ statement emitted for the index: The :meth:`.Constraint.ddl_if` and :meth:`.Index.ddl_if` methods create an event hook that may be consulted not just at DDL execution time, as is the -behavior with :meth:`.DDLElement.execute_if`, but also within the SQL compilation +behavior with :meth:`.ExecutableDDLElement.execute_if`, but also within the SQL compilation phase of the :class:`.CreateTable` object, which is responsible for rendering the ``CHECK (num > 5)`` DDL inline within the CREATE TABLE statement. As such, the event hook that is received by the :meth:`.Constraint.ddl_if.callable_` @@ -317,67 +317,58 @@ DDL Expression Constructs API .. autofunction:: sort_tables_and_constraints -.. autoclass:: DDLElement +.. autoclass:: BaseDDLElement :members: - :undoc-members: +.. autoattr:: DDLElement + +.. autoclass:: ExecutableDDLElement + :members: .. autoclass:: DDL :members: - :undoc-members: .. autoclass:: _CreateDropBase .. autoclass:: CreateTable :members: - :undoc-members: .. autoclass:: DropTable :members: - :undoc-members: .. autoclass:: CreateColumn :members: - :undoc-members: .. autoclass:: CreateSequence :members: - :undoc-members: .. autoclass:: DropSequence :members: - :undoc-members: .. autoclass:: CreateIndex :members: - :undoc-members: .. autoclass:: DropIndex :members: - :undoc-members: .. autoclass:: AddConstraint :members: - :undoc-members: .. autoclass:: DropConstraint :members: - :undoc-members: .. autoclass:: CreateSchema :members: - :undoc-members: .. autoclass:: DropSchema :members: - :undoc-members: diff --git a/doc/build/core/metadata.rst b/doc/build/core/metadata.rst index e4f06f1a58..551fe918c1 100644 --- a/doc/build/core/metadata.rst +++ b/doc/build/core/metadata.rst @@ -558,27 +558,11 @@ Column, Table, MetaData API .. attribute:: sqlalchemy.schema.BLANK_SCHEMA - Symbol indicating that a :class:`_schema.Table` or :class:`.Sequence` - should have 'None' for its schema, even if the parent - :class:`_schema.MetaData` has specified a schema. - - .. seealso:: - - :paramref:`_schema.MetaData.schema` - - :paramref:`_schema.Table.schema` - - :paramref:`.Sequence.schema` - - .. versionadded:: 1.0.14 + Refers to :attr:`.SchemaConst.BLANK_SCHEMA`. .. attribute:: sqlalchemy.schema.RETAIN_SCHEMA - Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence` - or in some cases a :class:`_schema.ForeignKey` object, in situations - where the object is being copied for a :meth:`.Table.to_metadata` - operation, should retain the schema name that it already has. - + Refers to :attr:`.SchemaConst.RETAIN_SCHEMA` .. autoclass:: Column @@ -589,6 +573,8 @@ Column, Table, MetaData API .. autoclass:: MetaData :members: +.. autoclass:: SchemaConst + :members: .. autoclass:: SchemaItem :members: diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py index 4a6ae08b25..96189c7fde 100644 --- a/lib/sqlalchemy/__init__.py +++ b/lib/sqlalchemy/__init__.py @@ -55,6 +55,7 @@ from .pool import PoolProxiedConnection as PoolProxiedConnection from .pool import QueuePool as QueuePool from .pool import SingletonThreadPool as SingleonThreadPool from .pool import StaticPool as StaticPool +from .schema import BaseDDLElement as BaseDDLElement from .schema import BLANK_SCHEMA as BLANK_SCHEMA from .schema import CheckConstraint as CheckConstraint from .schema import Column as Column @@ -62,7 +63,9 @@ from .schema import ColumnDefault as ColumnDefault from .schema import Computed as Computed from .schema import Constraint as Constraint from .schema import DDL as DDL +from .schema import DDLElement as DDLElement from .schema import DefaultClause as DefaultClause +from .schema import ExecutableDDLElement as ExecutableDDLElement from .schema import FetchedValue as FetchedValue from .schema import ForeignKey as ForeignKey from .schema import ForeignKeyConstraint as ForeignKeyConstraint diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 503bcc03bc..6c3bc4e7cb 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1448,7 +1448,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util -from ...sql.ddl import DDLBase +from ...sql.ddl import InvokeDDLBase from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -2014,7 +2014,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - class EnumGenerator(DDLBase): + class EnumGenerator(InvokeDDLBase): def __init__(self, dialect, connection, checkfirst=False, **kwargs): super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) self.checkfirst = checkfirst @@ -2035,7 +2035,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.connection.execute(CreateEnumType(enum)) - class EnumDropper(DDLBase): + class EnumDropper(InvokeDDLBase): def __init__(self, dialect, connection, checkfirst=False, **kwargs): super(ENUM.EnumDropper, self).__init__(connection, **kwargs) self.checkfirst = checkfirst diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 594a193446..a325da929b 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -12,7 +12,6 @@ import typing from typing import Any from typing import Callable from typing import cast -from typing import Dict from typing import Iterator from typing import List from typing import Mapping @@ -65,13 +64,15 @@ if typing.TYPE_CHECKING: from ..pool import Pool from ..pool import PoolProxiedConnection from ..sql import Executable + from ..sql._typing import _InfoType from ..sql.base import SchemaVisitor from ..sql.compiler import Compiled - from ..sql.ddl import DDLElement + from ..sql.ddl import ExecutableDDLElement from ..sql.ddl import SchemaDropper from ..sql.ddl import SchemaGenerator from ..sql.functions import FunctionElement from ..sql.schema import ColumnDefault + from ..sql.schema import DefaultGenerator from ..sql.schema import HasSchemaAttr from ..sql.schema import SchemaItem @@ -561,7 +562,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): raise exc.ResourceClosedError("This Connection is closed") @property - def info(self) -> Dict[str, Any]: + def info(self) -> _InfoType: """Info dictionary associated with the underlying DBAPI connection referred to by this :class:`_engine.Connection`, allowing user-defined data to be associated with the connection. @@ -1157,7 +1158,17 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): first row returned. """ - return self.execute(statement, parameters, execution_options).scalar() + distilled_parameters = _distill_params_20(parameters) + try: + meth = statement._execute_on_scalar + except AttributeError as err: + raise exc.ObjectNotExecutableError(statement) from err + else: + return meth( + self, + distilled_parameters, + execution_options or NO_OPTIONS, + ) def scalars( self, @@ -1200,7 +1211,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): * :class:`_expression.TextClause` and :class:`_expression.TextualSelect` * :class:`_schema.DDL` and objects which inherit from - :class:`_schema.DDLElement` + :class:`_schema.ExecutableDDLElement` :param parameters: parameters which will be bound into the statement. This may be either a dictionary of parameter names to values, @@ -1244,7 +1255,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_default( self, - default: ColumnDefault, + default: DefaultGenerator, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, ) -> Any: @@ -1303,7 +1314,7 @@ class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]): def _execute_ddl( self, - ddl: DDLElement, + ddl: ExecutableDDLElement, distilled_parameters: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, ) -> CursorResult: diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index b30cf9c088..c6571f68bb 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -829,7 +829,8 @@ class DefaultExecutionContext(ExecutionContext): execution_options: _ExecuteOptions, compiled_ddl: DDLCompiler, ) -> ExecutionContext: - """Initialize execution context for a DDLElement construct.""" + """Initialize execution context for an ExecutableDDLElement + construct.""" self = cls.__new__(cls) self.root_connection = connection diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index c94dd1032e..2d5707b536 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -28,7 +28,7 @@ if typing.TYPE_CHECKING: from .interfaces import Dialect from .url import URL from ..sql.base import Executable - from ..sql.ddl import DDLElement + from ..sql.ddl import ExecutableDDLElement from ..sql.ddl import SchemaDropper from ..sql.ddl import SchemaGenerator from ..sql.schema import HasSchemaAttr @@ -101,8 +101,8 @@ def create_mock_engine(url: URL, executor: Any, **kw: Any) -> MockConnection: :param executor: a callable which receives the arguments ``sql``, ``*multiparams`` and ``**params``. The ``sql`` parameter is typically - an instance of :class:`.DDLElement`, which can then be compiled into a - string using :meth:`.DDLElement.compile`. + an instance of :class:`.ExecutableDDLElement`, which can then be compiled + into a string using :meth:`.ExecutableDDLElement.compile`. .. versionadded:: 1.4 - the :func:`.create_mock_engine` function replaces the previous "mock" engine strategy used with diff --git a/lib/sqlalchemy/ext/associationproxy.py b/lib/sqlalchemy/ext/associationproxy.py index 5ea268a2dd..2e6c6b422a 100644 --- a/lib/sqlalchemy/ext/associationproxy.py +++ b/lib/sqlalchemy/ext/associationproxy.py @@ -65,6 +65,7 @@ if typing.TYPE_CHECKING: from ..orm.interfaces import MapperProperty from ..orm.interfaces import PropComparator from ..orm.mapper import Mapper + from ..sql._typing import _InfoType _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) @@ -221,8 +222,8 @@ class _AssociationProxyProtocol(Protocol[_T]): proxy_factory: Optional[_ProxyFactoryProtocol] proxy_bulk_set: Optional[_ProxyBulkSetProtocol] - @util.memoized_property - def info(self) -> Dict[Any, Any]: + @util.ro_memoized_property + def info(self) -> _InfoType: ... def for_class( @@ -259,7 +260,7 @@ class AssociationProxy( getset_factory: Optional[_GetSetFactoryProtocol] = None, proxy_factory: Optional[_ProxyFactoryProtocol] = None, proxy_bulk_set: Optional[_ProxyBulkSetProtocol] = None, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, cascade_scalar_deletes: bool = False, ): """Construct a new :class:`.AssociationProxy`. @@ -338,7 +339,7 @@ class AssociationProxy( id(self), ) if info: - self.info = info + self.info = info # type: ignore @overload def __get__( @@ -777,8 +778,8 @@ class AssociationProxyInstance(SQLORMOperations[_T]): return getter, plain_setter - @property - def info(self) -> Dict[Any, Any]: + @util.ro_non_memoized_property + def info(self) -> _InfoType: return self.parent.info @overload diff --git a/lib/sqlalchemy/ext/asyncio/engine.py b/lib/sqlalchemy/ext/asyncio/engine.py index bb51a4d225..fb05f512e4 100644 --- a/lib/sqlalchemy/ext/asyncio/engine.py +++ b/lib/sqlalchemy/ext/asyncio/engine.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from ...engine.url import URL from ...pool import Pool from ...pool import PoolProxiedConnection + from ...sql._typing import _InfoType from ...sql.base import Executable @@ -241,8 +242,8 @@ class AsyncConnection( return await greenlet_spawn(getattr, self._proxied, "connection") - @property - def info(self) -> Dict[str, Any]: + @util.ro_non_memoized_property + def info(self) -> _InfoType: """Return the :attr:`_engine.Connection.info` dictionary of the underlying :class:`_engine.Connection`. @@ -464,7 +465,7 @@ class AsyncConnection( * :class:`_expression.TextClause` and :class:`_expression.TextualSelect` * :class:`_schema.DDL` and objects which inherit from - :class:`_schema.DDLElement` + :class:`_schema.ExecutableDDLElement` :param parameters: parameters which will be bound into the statement. This may be either a dictionary of parameter names to values, diff --git a/lib/sqlalchemy/ext/asyncio/result.py b/lib/sqlalchemy/ext/asyncio/result.py index a9db822a6a..d0337554cf 100644 --- a/lib/sqlalchemy/ext/asyncio/result.py +++ b/lib/sqlalchemy/ext/asyncio/result.py @@ -710,7 +710,14 @@ _RT = TypeVar("_RT", bound="Result") async def _ensure_sync_result(result: _RT, calling_method: Any) -> _RT: cursor_result: CursorResult - if not result._is_cursor: + + try: + is_cursor = result._is_cursor + except AttributeError: + # legacy execute(DefaultGenerator) case + return result + + if not is_cursor: cursor_result = getattr(result, "raw", None) # type: ignore else: cursor_result = result # type: ignore diff --git a/lib/sqlalchemy/ext/asyncio/scoping.py b/lib/sqlalchemy/ext/asyncio/scoping.py index 0d6ae92b41..33cf3f745a 100644 --- a/lib/sqlalchemy/ext/asyncio/scoping.py +++ b/lib/sqlalchemy/ext/asyncio/scoping.py @@ -442,6 +442,7 @@ class async_scoped_session: self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -914,6 +915,7 @@ class async_scoped_session: self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -944,6 +946,7 @@ class async_scoped_session: self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -980,6 +983,7 @@ class async_scoped_session: self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -1007,6 +1011,7 @@ class async_scoped_session: self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 7d63b084c2..1422f99a39 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -55,6 +55,7 @@ if TYPE_CHECKING: from ...orm.session import _PKIdentityArgument from ...orm.session import _SessionBind from ...orm.session import _SessionBindKey + from ...sql._typing import _InfoType from ...sql.base import Executable from ...sql.elements import ClauseElement from ...sql.selectable import ForUpdateArg @@ -260,6 +261,7 @@ class AsyncSession(ReversibleProxy[Session]): self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -294,6 +296,7 @@ class AsyncSession(ReversibleProxy[Session]): self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -306,19 +309,28 @@ class AsyncSession(ReversibleProxy[Session]): """ - result = await self.execute( + if execution_options: + execution_options = util.immutabledict(execution_options).union( + _EXECUTE_OPTIONS + ) + else: + execution_options = _EXECUTE_OPTIONS + + result = await greenlet_spawn( + self.sync_session.scalar, statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, **kw, ) - return result.scalar() + return result async def scalars( self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -383,6 +395,7 @@ class AsyncSession(ReversibleProxy[Session]): self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -414,6 +427,7 @@ class AsyncSession(ReversibleProxy[Session]): self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -1277,7 +1291,7 @@ class async_sessionmaker: class_: Type[AsyncSession] = AsyncSession, autoflush: bool = True, expire_on_commit: bool = True, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, **kw: Any, ): r"""Construct a new :class:`.async_sessionmaker`. diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index b8265e88e7..b74761fe4c 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -229,18 +229,19 @@ A synopsis is as follows: raise TypeError("coalesce only supports two arguments on Oracle") return "nvl(%s)" % compiler.process(element.clauses, **kw) -* :class:`.DDLElement` - The root of all DDL expressions, - like CREATE TABLE, ALTER TABLE, etc. Compilation of :class:`.DDLElement` - subclasses is issued by a :class:`.DDLCompiler` instead of a - :class:`.SQLCompiler`. :class:`.DDLElement` can also be used as an event hook - in conjunction with event hooks like :meth:`.DDLEvents.before_create` and +* :class:`.ExecutableDDLElement` - The root of all DDL expressions, + like CREATE TABLE, ALTER TABLE, etc. Compilation of + :class:`.ExecutableDDLElement` subclasses is issued by a + :class:`.DDLCompiler` instead of a :class:`.SQLCompiler`. + :class:`.ExecutableDDLElement` can also be used as an event hook in + conjunction with event hooks like :meth:`.DDLEvents.before_create` and :meth:`.DDLEvents.after_create`, allowing the construct to be invoked automatically during CREATE TABLE and DROP TABLE sequences. .. seealso:: :ref:`metadata_ddl_toplevel` - contains examples of associating - :class:`.DDL` objects (which are themselves :class:`.DDLElement` + :class:`.DDL` objects (which are themselves :class:`.ExecutableDDLElement` instances) with :class:`.DDLEvents` event hooks. * :class:`~sqlalchemy.sql.expression.Executable` - This is a mixin which diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index a0c7905d84..be872804e2 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -808,7 +808,6 @@ from __future__ import annotations from typing import Any from typing import Callable from typing import cast -from typing import Dict from typing import Generic from typing import List from typing import Optional @@ -837,9 +836,11 @@ if TYPE_CHECKING: from ..sql._typing import _ColumnExpressionArgument from ..sql._typing import _DMLColumnArgument from ..sql._typing import _HasClauseElement + from ..sql._typing import _InfoType from ..sql.operators import OperatorType from ..sql.roles import ColumnsClauseRole + _T = TypeVar("_T", bound=Any) _T_co = TypeVar("_T_co", bound=Any, covariant=True) _T_con = TypeVar("_T_con", bound=Any, contravariant=True) @@ -1325,7 +1326,7 @@ class ExprComparator(Comparator[_T]): return getattr(self.expression, key) @util.non_memoized_property - def info(self) -> Dict[Any, Any]: + def info(self) -> _InfoType: return self.hybrid.info def _bulk_update_tuples( diff --git a/lib/sqlalchemy/orm/base.py b/lib/sqlalchemy/orm/base.py index c5348c2373..3fa855a4bd 100644 --- a/lib/sqlalchemy/orm/base.py +++ b/lib/sqlalchemy/orm/base.py @@ -42,6 +42,7 @@ if typing.TYPE_CHECKING: from .attributes import InstrumentedAttribute from .mapper import Mapper from .state import InstanceState + from ..sql._typing import _InfoType _T = TypeVar("_T", bound=Any) @@ -587,8 +588,8 @@ class InspectionAttrInfo(InspectionAttr): __slots__ = () - @util.memoized_property - def info(self) -> Dict[Any, Any]: + @util.ro_memoized_property + def info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 32c69a7446..8beac472e3 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -19,7 +19,6 @@ import operator import typing from typing import Any from typing import Callable -from typing import Dict from typing import List from typing import Optional from typing import Tuple @@ -49,6 +48,7 @@ if typing.TYPE_CHECKING: from .attributes import InstrumentedAttribute from .properties import MappedColumn from ..sql._typing import _ColumnExpressionArgument + from ..sql._typing import _InfoType from ..sql.schema import Column _T = TypeVar("_T", bound=Any) @@ -158,7 +158,7 @@ class Composite( deferred: bool = False, group: Optional[str] = None, comparator_factory: Optional[Type[Comparator]] = None, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, ): super().__init__() diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 7be7ce32b4..abc1300d8d 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -62,6 +62,7 @@ if typing.TYPE_CHECKING: from .decl_api import RegistryType from ..sql._typing import _ColumnsClauseArgument from ..sql._typing import _DMLColumnArgument + from ..sql._typing import _InfoType _T = TypeVar("_T", bound=Any) @@ -192,7 +193,7 @@ class MapperProperty( """ raise NotImplementedError() - def _memoized_attr_info(self): + def _memoized_attr_info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.InspectionAttr`. @@ -522,7 +523,7 @@ class PropComparator(SQLORMOperations[_T]): return self._adapt_to_entity._adapt_element @util.non_memoized_property - def info(self): + def info(self) -> _InfoType: return self.property.info @staticmethod diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 93b49ab254..18aa9945f5 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -2101,8 +2101,8 @@ class BulkUDCompileState(CompileState): result = session.execute( select_stmt, params, - execution_options, - bind_arguments, + execution_options=execution_options, + bind_arguments=bind_arguments, _add_event=skip_for_full_returning, ) matched_rows = result.fetchall() diff --git a/lib/sqlalchemy/orm/scoping.py b/lib/sqlalchemy/orm/scoping.py index 1dd7a69523..93d18b8d79 100644 --- a/lib/sqlalchemy/orm/scoping.py +++ b/lib/sqlalchemy/orm/scoping.py @@ -552,6 +552,7 @@ class scoped_session: self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, @@ -1562,6 +1563,7 @@ class scoped_session: self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -1592,6 +1594,7 @@ class scoped_session: self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index a26c55a248..5b1d0bb087 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -21,6 +21,7 @@ from typing import Iterator from typing import List from typing import NoReturn from typing import Optional +from typing import overload from typing import Sequence from typing import Set from typing import Tuple @@ -98,6 +99,7 @@ if typing.TYPE_CHECKING: from ..engine.result import ScalarResult from ..event import _InstanceLevelDispatch from ..sql._typing import _ColumnsClauseArgument + from ..sql._typing import _InfoType from ..sql.base import Executable from ..sql.elements import ClauseElement from ..sql.schema import Table @@ -380,11 +382,11 @@ class ORMExecuteState(util.MemoizedSlots): if execution_options: _execution_options = _execution_options.union(execution_options) - return self.session.execute( + return self.session._execute_internal( statement, _params, - _execution_options, - _bind_arguments, + execution_options=_execution_options, + bind_arguments=_bind_arguments, _parent_execute_state=self, ) @@ -1232,7 +1234,7 @@ class Session(_SessionClassMethods, EventTarget): twophase: bool = False, binds: Optional[Dict[_SessionBindKey, _SessionBind]] = None, enable_baked_queries: bool = True, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, query_cls: Optional[Type[Query[Any]]] = None, autocommit: Literal[False] = False, ): @@ -1452,7 +1454,7 @@ class Session(_SessionClassMethods, EventTarget): return self._nested_transaction @util.memoized_property - def info(self) -> Dict[Any, Any]: + def info(self) -> _InfoType: """A user-modifiable dictionary. The initial value of this dictionary can be populated using the @@ -1686,66 +1688,45 @@ class Session(_SessionClassMethods, EventTarget): trans = self._autobegin_t() return trans._connection_for_bind(engine, execution_options) - def execute( + @overload + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreSingleExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: Literal[True] = ..., + ) -> Any: + ... + + @overload + def _execute_internal( self, statement: Executable, params: Optional[_CoreAnyExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, _parent_execute_state: Optional[Any] = None, _add_event: Optional[Any] = None, + _scalar_result: bool = ..., ) -> Result: - r"""Execute a SQL expression construct. - - Returns a :class:`_engine.Result` object representing - results of the statement execution. - - E.g.:: - - from sqlalchemy import select - result = session.execute( - select(User).where(User.id == 5) - ) - - The API contract of :meth:`_orm.Session.execute` is similar to that - of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version - of :class:`_engine.Connection`. - - .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is - now the primary point of ORM statement execution when using - :term:`2.0 style` ORM usage. - - :param statement: - An executable statement (i.e. an :class:`.Executable` expression - such as :func:`_expression.select`). - - :param params: - Optional dictionary, or list of dictionaries, containing - bound parameter values. If a single dictionary, single-row - execution occurs; if a list of dictionaries, an - "executemany" will be invoked. The keys in each dictionary - must correspond to parameter names present in the statement. - - :param execution_options: optional dictionary of execution options, - which will be associated with the statement execution. This - dictionary can provide a subset of the options that are accepted - by :meth:`_engine.Connection.execution_options`, and may also - provide additional options understood only in an ORM context. - - .. seealso:: - - :ref:`orm_queryguide_execution_options` - ORM-specific execution - options - - :param bind_arguments: dictionary of additional arguments to determine - the bind. May include "mapper", "bind", or other custom arguments. - Contents of this dictionary are passed to the - :meth:`.Session.get_bind` method. - - :return: a :class:`_engine.Result` object. - + ... - """ + def _execute_internal( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + _scalar_result: bool = False, + ) -> Any: statement = coercions.expect(roles.StatementRole, statement) if not bind_arguments: @@ -1805,7 +1786,10 @@ class Session(_SessionClassMethods, EventTarget): orm_exec_state._starting_event_idx = idx fn_result: Optional[Result] = fn(orm_exec_state) if fn_result: - return fn_result + if _scalar_result: + return fn_result.scalar() + else: + return fn_result statement = orm_exec_state.statement execution_options = orm_exec_state.local_execution_options @@ -1813,6 +1797,12 @@ class Session(_SessionClassMethods, EventTarget): bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind) + + if _scalar_result and not compile_state_cls: + if TYPE_CHECKING: + params = cast(_CoreSingleExecuteParams, params) + return conn.scalar(statement, params or {}, execution_options) + result: Result = conn.execute( statement, params or {}, execution_options ) @@ -1827,12 +1817,86 @@ class Session(_SessionClassMethods, EventTarget): result, ) - return result + if _scalar_result: + return result.scalar() + else: + return result + + def execute( + self, + statement: Executable, + params: Optional[_CoreAnyExecuteParams] = None, + *, + execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, + _parent_execute_state: Optional[Any] = None, + _add_event: Optional[Any] = None, + ) -> Result: + r"""Execute a SQL expression construct. + + Returns a :class:`_engine.Result` object representing + results of the statement execution. + + E.g.:: + + from sqlalchemy import select + result = session.execute( + select(User).where(User.id == 5) + ) + + The API contract of :meth:`_orm.Session.execute` is similar to that + of :meth:`_engine.Connection.execute`, the :term:`2.0 style` version + of :class:`_engine.Connection`. + + .. versionchanged:: 1.4 the :meth:`_orm.Session.execute` method is + now the primary point of ORM statement execution when using + :term:`2.0 style` ORM usage. + + :param statement: + An executable statement (i.e. an :class:`.Executable` expression + such as :func:`_expression.select`). + + :param params: + Optional dictionary, or list of dictionaries, containing + bound parameter values. If a single dictionary, single-row + execution occurs; if a list of dictionaries, an + "executemany" will be invoked. The keys in each dictionary + must correspond to parameter names present in the statement. + + :param execution_options: optional dictionary of execution options, + which will be associated with the statement execution. This + dictionary can provide a subset of the options that are accepted + by :meth:`_engine.Connection.execution_options`, and may also + provide additional options understood only in an ORM context. + + .. seealso:: + + :ref:`orm_queryguide_execution_options` - ORM-specific execution + options + + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + :return: a :class:`_engine.Result` object. + + + """ + return self._execute_internal( + statement, + params, + execution_options=execution_options, + bind_arguments=bind_arguments, + _parent_execute_state=_parent_execute_state, + _add_event=_add_event, + ) def scalar( self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -1845,18 +1909,20 @@ class Session(_SessionClassMethods, EventTarget): """ - return self.execute( + return self._execute_internal( statement, - params=params, + params, execution_options=execution_options, bind_arguments=bind_arguments, + _scalar_result=True, **kw, - ).scalar() + ) def scalars( self, statement: Executable, params: Optional[_CoreSingleExecuteParams] = None, + *, execution_options: _ExecuteOptionsParameter = util.EMPTY_DICT, bind_arguments: Optional[_BindArguments] = None, **kw: Any, @@ -1874,11 +1940,12 @@ class Session(_SessionClassMethods, EventTarget): """ - return self.execute( + return self._execute_internal( statement, params=params, execution_options=execution_options, bind_arguments=bind_arguments, + _scalar_result=False, # mypy appreciates this **kw, ).scalars() @@ -4220,7 +4287,7 @@ class sessionmaker(_SessionClassMethods): class_: Type[Session] = Session, autoflush: bool = True, expire_on_commit: bool = True, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, **kw: Any, ): r"""Construct a new :class:`.sessionmaker`. diff --git a/lib/sqlalchemy/pool/base.py b/lib/sqlalchemy/pool/base.py index e22848fd25..934423e2f6 100644 --- a/lib/sqlalchemy/pool/base.py +++ b/lib/sqlalchemy/pool/base.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from ..event import _DispatchCommon from ..event import _ListenerFnType from ..event import dispatcher + from ..sql._typing import _InfoType class ResetStyle(Enum): @@ -461,51 +462,55 @@ class ManagesConnection: """ - info: Dict[str, Any] - """Info dictionary associated with the underlying DBAPI connection - referred to by this :class:`.ManagesConnection` instance, allowing - user-defined data to be associated with the connection. + @util.ro_memoized_property + def info(self) -> _InfoType: + """Info dictionary associated with the underlying DBAPI connection + referred to by this :class:`.ManagesConnection` instance, allowing + user-defined data to be associated with the connection. - The data in this dictionary is persistent for the lifespan - of the DBAPI connection itself, including across pool checkins - and checkouts. When the connection is invalidated - and replaced with a new one, this dictionary is cleared. + The data in this dictionary is persistent for the lifespan + of the DBAPI connection itself, including across pool checkins + and checkouts. When the connection is invalidated + and replaced with a new one, this dictionary is cleared. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns a dictionary that is local to that - :class:`.ConnectionPoolEntry`. Therefore the - :attr:`.ManagesConnection.info` attribute will always provide a Python - dictionary. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns a dictionary that is local to that + :class:`.ConnectionPoolEntry`. Therefore the + :attr:`.ManagesConnection.info` attribute will always provide a Python + dictionary. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.record_info` + :attr:`.ManagesConnection.record_info` - """ + """ + raise NotImplementedError() - record_info: Optional[Dict[str, Any]] - """Persistent info dictionary associated with this - :class:`.ManagesConnection`. + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: + """Persistent info dictionary associated with this + :class:`.ManagesConnection`. - Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan - of this dictionary is that of the :class:`.ConnectionPoolEntry` - which owns it; therefore this dictionary will persist across - reconnects and connection invalidation for a particular entry - in the connection pool. + Unlike the :attr:`.ManagesConnection.info` dictionary, the lifespan + of this dictionary is that of the :class:`.ConnectionPoolEntry` + which owns it; therefore this dictionary will persist across + reconnects and connection invalidation for a particular entry + in the connection pool. - For a :class:`.PoolProxiedConnection` instance that's not associated - with a :class:`.ConnectionPoolEntry`, such as if it were detached, the - attribute returns None. Contrast to the :attr:`.ManagesConnection.info` - dictionary which is never None. + For a :class:`.PoolProxiedConnection` instance that's not associated + with a :class:`.ConnectionPoolEntry`, such as if it were detached, the + attribute returns None. Contrast to the :attr:`.ManagesConnection.info` + dictionary which is never None. - .. seealso:: + .. seealso:: - :attr:`.ManagesConnection.info` + :attr:`.ManagesConnection.info` - """ + """ + raise NotImplementedError() def invalidate( self, e: Optional[BaseException] = None, soft: bool = False @@ -627,12 +632,12 @@ class _ConnectionRecord(ConnectionPoolEntry): _soft_invalidate_time: float = 0 - @util.memoized_property - def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 + @util.ro_memoized_property + def info(self) -> _InfoType: return {} - @util.memoized_property - def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa: E501 + @util.ro_memoized_property + def record_info(self) -> Optional[_InfoType]: return {} @classmethod @@ -1080,8 +1085,8 @@ class _AdhocProxiedConnection(PoolProxiedConnection): ) -> None: self._is_valid = False - @property - def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa: E501 + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: return self._connection_record.record_info def cursor(self, *args: Any, **kwargs: Any) -> DBAPICursor: @@ -1314,15 +1319,15 @@ class _ConnectionFairy(PoolProxiedConnection): def is_detached(self) -> bool: return self._connection_record is None - @util.memoized_property - def info(self) -> Dict[str, Any]: # type: ignore[override] # mypy#4125 + @util.ro_memoized_property + def info(self) -> _InfoType: if self._connection_record is None: return {} else: return self._connection_record.info - @property - def record_info(self) -> Optional[Dict[str, Any]]: # type: ignore[override] # mypy#4125 # noqa: E501 + @util.ro_non_memoized_property + def record_info(self) -> Optional[_InfoType]: if self._connection_record is None: return None else: diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 70a982ce24..86166f9f6c 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -13,16 +13,15 @@ from __future__ import annotations from .sql.base import SchemaVisitor as SchemaVisitor from .sql.ddl import _CreateDropBase as _CreateDropBase -from .sql.ddl import _DDLCompiles as _DDLCompiles from .sql.ddl import _DropView as _DropView from .sql.ddl import AddConstraint as AddConstraint +from .sql.ddl import BaseDDLElement as BaseDDLElement from .sql.ddl import CreateColumn as CreateColumn from .sql.ddl import CreateIndex as CreateIndex from .sql.ddl import CreateSchema as CreateSchema from .sql.ddl import CreateSequence as CreateSequence from .sql.ddl import CreateTable as CreateTable from .sql.ddl import DDL as DDL -from .sql.ddl import DDLBase as DDLBase from .sql.ddl import DDLElement as DDLElement from .sql.ddl import DropColumnComment as DropColumnComment from .sql.ddl import DropConstraint as DropConstraint @@ -31,6 +30,8 @@ from .sql.ddl import DropSchema as DropSchema from .sql.ddl import DropSequence as DropSequence from .sql.ddl import DropTable as DropTable from .sql.ddl import DropTableComment as DropTableComment +from .sql.ddl import ExecutableDDLElement as ExecutableDDLElement +from .sql.ddl import InvokeDDLBase as InvokeDDLBase from .sql.ddl import SetColumnComment as SetColumnComment from .sql.ddl import SetTableComment as SetTableComment from .sql.ddl import sort_tables as sort_tables diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 2e766f9766..84913225d7 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -11,6 +11,10 @@ from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS from .compiler import FROM_LINTING as FROM_LINTING from .compiler import NO_LINTING as NO_LINTING from .compiler import WARN_LINTING as WARN_LINTING +from .ddl import BaseDDLElement as BaseDDLElement +from .ddl import DDL as DDL +from .ddl import DDLElement as DDLElement +from .ddl import ExecutableDDLElement as ExecutableDDLElement from .expression import Alias as Alias from .expression import alias as alias from .expression import all_ as all_ diff --git a/lib/sqlalchemy/sql/_typing.py b/lib/sqlalchemy/sql/_typing.py index 7e3a1c4e8d..b0a717a1a3 100644 --- a/lib/sqlalchemy/sql/_typing.py +++ b/lib/sqlalchemy/sql/_typing.py @@ -2,6 +2,7 @@ from __future__ import annotations import operator from typing import Any +from typing import Dict from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -28,6 +29,7 @@ if TYPE_CHECKING: from .elements import TextClause from .roles import ColumnsClauseRole from .roles import FromClauseRole + from .schema import Column from .schema import DefaultGenerator from .schema import Sequence from .selectable import Alias @@ -101,6 +103,8 @@ overall which brings in the TextClause object also. """ +_InfoType = Dict[Any, Any] +"""the .info dictionary accepted and used throughout Core /ORM""" _FromClauseArgument = Union[ roles.FromClauseRole, @@ -145,6 +149,13 @@ the DMLColumnRole to be able to accommodate. """ +_DDLColumnArgument = Union[str, "Column[Any]", roles.DDLConstraintColumnRole] +"""DDL column. + +used for :class:`.PrimaryKeyConstraint`, :class:`.UniqueConstraint`, etc. + +""" + _DMLTableArgument = Union[ "TableClause", "Join", diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index bb51693cfe..629e88a326 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -948,6 +948,7 @@ class Executable(roles.StatementRole, Generative): supports_execution: bool = True _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT + _is_default_generator = False _with_options: Tuple[ExecutableOption, ...] = () _with_context_options: Tuple[ Tuple[Callable[[CompileState], None], Any], ... @@ -993,10 +994,17 @@ class Executable(roles.StatementRole, Generative): connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptionsParameter, - _force: bool = False, ) -> CursorResult: ... + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + ) -> Any: + ... + @util.ro_non_memoized_property def _all_selected_columns(self): raise NotImplementedError() @@ -1243,10 +1251,12 @@ class SchemaVisitor(ClauseVisitor): _COLKEY = TypeVar("_COLKEY", Union[None, str], str) + +_COL_co = TypeVar("_COL_co", bound="ColumnElement[Any]", covariant=True) _COL = TypeVar("_COL", bound="ColumnElement[Any]") -class ColumnCollection(Generic[_COLKEY, _COL]): +class ColumnCollection(Generic[_COLKEY, _COL_co]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1357,12 +1367,12 @@ class ColumnCollection(Generic[_COLKEY, _COL]): __slots__ = "_collection", "_index", "_colset" - _collection: List[Tuple[_COLKEY, _COL]] - _index: Dict[Union[None, str, int], _COL] - _colset: Set[_COL] + _collection: List[Tuple[_COLKEY, _COL_co]] + _index: Dict[Union[None, str, int], _COL_co] + _colset: Set[_COL_co] def __init__( - self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL_co]]] = None ): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) @@ -1370,11 +1380,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]): if columns: self._initial_populate(columns) - def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None: + def _initial_populate( + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] + ) -> None: self._populate_separate_keys(iter_) @property - def _all_columns(self) -> List[_COL]: + def _all_columns(self) -> List[_COL_co]: return [col for (k, col) in self._collection] def keys(self) -> List[_COLKEY]: @@ -1382,13 +1394,13 @@ class ColumnCollection(Generic[_COLKEY, _COL]): collection.""" return [k for (k, col) in self._collection] - def values(self) -> List[_COL]: + def values(self) -> List[_COL_co]: """Return a sequence of :class:`_sql.ColumnClause` or :class:`_schema.Column` objects for all columns in this collection.""" return [col for (k, col) in self._collection] - def items(self) -> List[Tuple[_COLKEY, _COL]]: + def items(self) -> List[Tuple[_COLKEY, _COL_co]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1403,11 +1415,11 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def __len__(self) -> int: return len(self._collection) - def __iter__(self) -> Iterator[_COL]: + def __iter__(self) -> Iterator[_COL_co]: # turn to a list first to maintain over a course of changes return iter([col for k, col in self._collection]) - def __getitem__(self, key: Union[str, int]) -> _COL: + def __getitem__(self, key: Union[str, int]) -> _COL_co: try: return self._index[key] except KeyError as err: @@ -1416,7 +1428,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): else: raise - def __getattr__(self, key: str) -> _COL: + def __getattr__(self, key: str) -> _COL_co: try: return self._index[key] except KeyError as err: @@ -1445,7 +1457,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def __eq__(self, other: Any) -> bool: return self.compare(other) - def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]: + def get( + self, key: str, default: Optional[_COL_co] = None + ) -> Optional[_COL_co]: """Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object based on a string key name from this :class:`_expression.ColumnCollection`.""" @@ -1487,7 +1501,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): __hash__ = None # type: ignore def _populate_separate_keys( - self, iter_: Iterable[Tuple[_COLKEY, _COL]] + self, iter_: Iterable[Tuple[_COLKEY, _COL_co]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1498,7 +1512,9 @@ class ColumnCollection(Generic[_COLKEY, _COL]): ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None: + def add( + self, column: ColumnElement[Any], key: Optional[_COLKEY] = None + ) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1518,11 +1534,17 @@ class ColumnCollection(Generic[_COLKEY, _COL]): colkey = key l = len(self._collection) - self._collection.append((colkey, column)) - self._colset.add(column) - self._index[l] = column + + # don't really know how this part is supposed to work w/ the + # covariant thing + + _column = cast(_COL_co, column) + + self._collection.append((colkey, _column)) + self._colset.add(_column) + self._index[l] = _column if colkey not in self._index: - self._index[colkey] = column + self._index[colkey] = _column def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} @@ -1534,7 +1556,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): self, "_colset", {col for k, col in self._collection} ) - def contains_column(self, col: _COL) -> bool: + def contains_column(self, col: ColumnElement[Any]) -> bool: """Checks if a column object exists in this collection""" if col not in self._colset: if isinstance(col, str): @@ -1546,7 +1568,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): else: return True - def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]: + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL_co]: """Return a "read only" form of this :class:`_sql.ColumnCollection`.""" @@ -1554,7 +1576,7 @@ class ColumnCollection(Generic[_COLKEY, _COL]): def corresponding_column( self, column: _COL, require_embedded: bool = False - ) -> Optional[_COL]: + ) -> Optional[Union[_COL, _COL_co]]: """Given a :class:`_expression.ColumnElement`, return the exported :class:`_expression.ColumnElement` object from this :class:`_expression.ColumnCollection` @@ -1670,14 +1692,16 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """ - def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None: - - if key is not None and column.key != key: + def add( + self, column: ColumnElement[Any], key: Optional[str] = None + ) -> None: + named_column = cast(_NAMEDCOL, column) + if key is not None and named_column.key != key: raise exc.ArgumentError( "DedupeColumnCollection requires columns be under " "the same key as their .key" ) - key = column.key + key = named_column.key if key is None: raise exc.ArgumentError( @@ -1688,21 +1712,21 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): existing = self._index[key] - if existing is column: + if existing is named_column: return - self.replace(column) + self.replace(named_column) # pop out memoized proxy_set as this # operation may very well be occurring # in a _make_proxy operation - util.memoized_property.reset(column, "proxy_set") + util.memoized_property.reset(named_column, "proxy_set") else: l = len(self._collection) - self._collection.append((key, column)) - self._colset.add(column) - self._index[l] = column - self._index[key] = column + self._collection.append((key, named_column)) + self._colset.add(named_column) + self._index[l] = named_column + self._index[key] = named_column def _populate_separate_keys( self, iter_: Iterable[Tuple[str, _NAMEDCOL]] @@ -1805,7 +1829,7 @@ class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): class ReadOnlyColumnCollection( - util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL] + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL_co] ): __slots__ = ("_parent",) diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index 1f8b9c19e8..15fbc2afb9 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -102,8 +102,8 @@ class HasCacheKey: """private attribute which may be set to False to prevent the inherit_cache warning from being emitted for a hierarchy of subclasses. - Currently applies to the DDLElement hierarchy which does not implement - caching. + Currently applies to the :class:`.ExecutableDDLElement` hierarchy which + does not implement caching. """ diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 623bb0be2e..4bf45da9cb 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -14,10 +14,13 @@ import typing from typing import Any from typing import Callable from typing import Dict +from typing import Iterable +from typing import Iterator from typing import List from typing import NoReturn from typing import Optional from typing import overload +from typing import Tuple from typing import Type from typing import TYPE_CHECKING from typing import TypeVar @@ -50,6 +53,7 @@ if typing.TYPE_CHECKING: from . import traversals from ._typing import _ColumnExpressionArgument from ._typing import _ColumnsClauseArgument + from ._typing import _DDLColumnArgument from ._typing import _DMLTableArgument from ._typing import _FromClauseArgument from .dml import _DMLTableElement @@ -166,19 +170,28 @@ def expect( @overload def expect( - role: Type[roles.StatementOptionRole], + role: Type[roles.DDLReferredColumnRole], element: Any, **kw: Any, -) -> DQLDMLClauseElement: +) -> Column[Any]: ... @overload def expect( - role: Type[roles.DDLReferredColumnRole], + role: Type[roles.DDLConstraintColumnRole], element: Any, **kw: Any, -) -> Column[Any]: +) -> Union[Column[Any], str]: + ... + + +@overload +def expect( + role: Type[roles.StatementOptionRole], + element: Any, + **kw: Any, +) -> DQLDMLClauseElement: ... @@ -398,21 +411,33 @@ def expect_as_key(role, element, **kw): return expect(role, element, **kw) -def expect_col_expression_collection(role, expressions): +def expect_col_expression_collection( + role: Type[roles.DDLConstraintColumnRole], + expressions: Iterable[_DDLColumnArgument], +) -> Iterator[ + Tuple[ + Union[str, Column[Any]], + Optional[ColumnClause[Any]], + Optional[str], + Optional[Union[Column[Any], str]], + ] +]: for expr in expressions: strname = None column = None - resolved = expect(role, expr) + resolved: Union[Column[Any], str] = expect(role, expr) if isinstance(resolved, str): + assert isinstance(expr, str) strname = resolved = expr else: - cols: List[ColumnClause[Any]] = [] - col_append: _TraverseCallableType[ColumnClause[Any]] = cols.append + cols: List[Column[Any]] = [] + col_append: _TraverseCallableType[Column[Any]] = cols.append visitors.traverse(resolved, {}, {"column": col_append}) if cols: column = cols[0] add_element = column if column is not None else strname + yield resolved, column, strname, add_element diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 522a0bd4a0..1498c83418 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -85,7 +85,7 @@ if typing.TYPE_CHECKING: from .base import _AmbiguousTableNameMap from .base import CompileState from .cache_key import CacheKey - from .ddl import DDLElement + from .ddl import ExecutableDDLElement from .dml import Insert from .dml import UpdateBase from .dml import ValuesBase @@ -4798,7 +4798,7 @@ class DDLCompiler(Compiled): def __init__( self, dialect: Dialect, - statement: DDLElement, + statement: ExecutableDDLElement, schema_translate_map: Optional[_SchemaTranslateMapType] = ..., render_schema_translate: bool = ..., compile_kwargs: Mapping[str, Any] = ..., diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 131ae9ef11..6ac7c24483 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -33,6 +33,7 @@ if typing.TYPE_CHECKING: from .compiler import Compiled from .compiler import DDLCompiler from .elements import BindParameter + from .schema import Constraint from .schema import ForeignKeyConstraint from .schema import SchemaItem from .schema import Table @@ -43,7 +44,14 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import Dialect -class _DDLCompiles(ClauseElement): +class BaseDDLElement(ClauseElement): + """The root of DDL constructs, including those that are sub-elements + within the "create table" and other processes. + + .. versionadded:: 2.0 + + """ + _hierarchy_supports_caching = False """disable cache warnings for all _DDLCompiles subclasses. """ @@ -71,10 +79,10 @@ class _DDLCompiles(ClauseElement): class DDLIfCallable(Protocol): def __call__( self, - ddl: "DDLElement", - target: "SchemaItem", - bind: Optional["Connection"], - tables: Optional[List["Table"]] = None, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + tables: Optional[List[Table]] = None, state: Optional[Any] = None, *, dialect: Dialect, @@ -89,7 +97,14 @@ class DDLIf(typing.NamedTuple): callable_: Optional[DDLIfCallable] state: Optional[Any] - def _should_execute(self, ddl, target, bind, compiler=None, **kw): + def _should_execute( + self, + ddl: BaseDDLElement, + target: SchemaItem, + bind: Optional[Connection], + compiler: Optional[DDLCompiler] = None, + **kw: Any, + ) -> bool: if bind is not None: dialect = bind.dialect elif compiler is not None: @@ -117,18 +132,23 @@ class DDLIf(typing.NamedTuple): return True -SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement") +SelfExecutableDDLElement = typing.TypeVar( + "SelfExecutableDDLElement", bound="ExecutableDDLElement" +) -class DDLElement(roles.DDLRole, Executable, _DDLCompiles): - """Base class for DDL expression constructs. +class ExecutableDDLElement(roles.DDLRole, Executable, BaseDDLElement): + """Base class for standalone executable DDL expression constructs. This class is the base for the general purpose :class:`.DDL` class, as well as the various create/drop clause constructs such as :class:`.CreateTable`, :class:`.DropTable`, :class:`.AddConstraint`, etc. - :class:`.DDLElement` integrates closely with SQLAlchemy events, + .. versionchanged:: 2.0 :class:`.ExecutableDDLElement` is renamed from + :class:`.DDLElement`, which still exists for backwards compatibility. + + :class:`.ExecutableDDLElement` integrates closely with SQLAlchemy events, introduced in :ref:`event_toplevel`. An instance of one is itself an event receiving callable:: @@ -161,29 +181,31 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): ) @_generative - def against(self: SelfDDLElement, target: SchemaItem) -> SelfDDLElement: - """Return a copy of this :class:`_schema.DDLElement` which will include - the given target. - - This essentially applies the given item to the ``.target`` attribute - of the returned :class:`_schema.DDLElement` object. This target + def against( + self: SelfExecutableDDLElement, target: SchemaItem + ) -> SelfExecutableDDLElement: + """Return a copy of this :class:`_schema.ExecutableDDLElement` which + will include the given target. + + This essentially applies the given item to the ``.target`` attribute of + the returned :class:`_schema.ExecutableDDLElement` object. This target is then usable by event handlers and compilation routines in order to provide services such as tokenization of a DDL string in terms of a particular :class:`_schema.Table`. - When a :class:`_schema.DDLElement` object is established as an event - handler for the :meth:`_events.DDLEvents.before_create` or - :meth:`_events.DDLEvents.after_create` events, and the event - then occurs for a given target such as a :class:`_schema.Constraint` - or :class:`_schema.Table`, that target is established with a copy - of the :class:`_schema.DDLElement` object using this method, which - then proceeds to the :meth:`_schema.DDLElement.execute` method - in order to invoke the actual DDL instruction. + When a :class:`_schema.ExecutableDDLElement` object is established as + an event handler for the :meth:`_events.DDLEvents.before_create` or + :meth:`_events.DDLEvents.after_create` events, and the event then + occurs for a given target such as a :class:`_schema.Constraint` or + :class:`_schema.Table`, that target is established with a copy of the + :class:`_schema.ExecutableDDLElement` object using this method, which + then proceeds to the :meth:`_schema.ExecutableDDLElement.execute` + method in order to invoke the actual DDL instruction. :param target: a :class:`_schema.SchemaItem` that will be the subject of a DDL operation. - :return: a copy of this :class:`_schema.DDLElement` with the + :return: a copy of this :class:`_schema.ExecutableDDLElement` with the ``.target`` attribute assigned to the given :class:`_schema.SchemaItem`. @@ -198,13 +220,14 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): @_generative def execute_if( - self: SelfDDLElement, + self: SelfExecutableDDLElement, dialect: Optional[str] = None, callable_: Optional[DDLIfCallable] = None, state: Optional[Any] = None, - ) -> SelfDDLElement: + ) -> SelfExecutableDDLElement: r"""Return a callable that will execute this - :class:`_ddl.DDLElement` conditionally within an event handler. + :class:`_ddl.ExecutableDDLElement` conditionally within an event + handler. Used to provide a wrapper for event listening:: @@ -302,7 +325,11 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles): return s -class DDL(DDLElement): +DDLElement = ExecutableDDLElement +""":class:`.DDLElement` is renamed to :class:`.ExecutableDDLElement`.""" + + +class DDL(ExecutableDDLElement): """A literal DDL statement. Specifies literal SQL DDL to be executed by the database. DDL objects @@ -390,7 +417,7 @@ class DDL(DDLElement): ) -class _CreateDropBase(DDLElement): +class _CreateDropBase(ExecutableDDLElement): """Base class for DDL constructs that represent CREATE and DROP or equivalents. @@ -484,9 +511,11 @@ class CreateTable(_CreateDropBase): def __init__( self, - element, - include_foreign_key_constraints=None, - if_not_exists=False, + element: Table, + include_foreign_key_constraints: Optional[ + typing_Sequence[ForeignKeyConstraint] + ] = None, + if_not_exists: bool = False, ): """Create a :class:`.CreateTable` construct. @@ -522,12 +551,12 @@ class _DropView(_CreateDropBase): __visit_name__ = "drop_view" -class CreateConstraint(_DDLCompiles): - def __init__(self, element): +class CreateConstraint(BaseDDLElement): + def __init__(self, element: Constraint): self.element = element -class CreateColumn(_DDLCompiles): +class CreateColumn(BaseDDLElement): """Represent a :class:`_schema.Column` as rendered in a CREATE TABLE statement, via the :class:`.CreateTable` construct. @@ -641,7 +670,7 @@ class DropTable(_CreateDropBase): __visit_name__ = "drop_table" - def __init__(self, element, if_exists=False): + def __init__(self, element: Table, if_exists: bool = False): """Create a :class:`.DropTable` construct. :param element: a :class:`_schema.Table` that's the subject @@ -761,12 +790,12 @@ class DropColumnComment(_CreateDropBase): __visit_name__ = "drop_column_comment" -class DDLBase(SchemaVisitor): +class InvokeDDLBase(SchemaVisitor): def __init__(self, connection): self.connection = connection -class SchemaGenerator(DDLBase): +class SchemaGenerator(InvokeDDLBase): def __init__( self, dialect, connection, checkfirst=False, tables=None, **kwargs ): @@ -925,7 +954,7 @@ class SchemaGenerator(DDLBase): CreateIndex(index)._invoke_with(self.connection) -class SchemaDropper(DDLBase): +class SchemaDropper(InvokeDDLBase): def __init__( self, dialect, connection, checkfirst=False, tables=None, **kwargs ): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8057582835..20938fd5a1 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -77,6 +77,7 @@ from ..util.typing import Literal if typing.TYPE_CHECKING: from ._typing import _ColumnExpressionArgument + from ._typing import _InfoType from ._typing import _PropagateAttrsType from ._typing import _TypeEngineArgument from .cache_key import _CacheKeyTraversalType @@ -85,6 +86,7 @@ if typing.TYPE_CHECKING: from .compiler import SQLCompiler from .functions import FunctionElement from .operators import OperatorType + from .schema import _ServerDefaultType from .schema import Column from .schema import DefaultGenerator from .schema import FetchedValue @@ -444,9 +446,8 @@ class ClauseElement( connection: Connection, distilled_params: _CoreMultiExecuteParams, execution_options: _ExecuteOptions, - _force: bool = False, ) -> Result: - if _force or self.supports_execution: + if self.supports_execution: if TYPE_CHECKING: assert isinstance(self, Executable) return connection._execute_clauseelement( @@ -455,6 +456,22 @@ class ClauseElement( else: raise exc.ObjectNotExecutableError(self) + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptions, + ) -> Any: + """an additional hook for subclasses to provide a different + implementation for connection.scalar() vs. connection.execute(). + + .. versionadded:: 2.0 + + """ + return self._execute_on_connection( + connection, distilled_params, execution_options + ).scalar() + def unique_params( self: SelfClauseElement, __optionaldict: Optional[Dict[str, Any]] = None, @@ -1481,6 +1498,7 @@ class ColumnElement( def _make_proxy( self, selectable: FromClause, + *, name: Optional[str] = None, key: Optional[str] = None, name_is_truncatable: bool = False, @@ -4032,12 +4050,14 @@ class NamedColumn(ColumnElement[_T]): def _make_proxy( self, - selectable, - name=None, - name_is_truncatable=False, - disallow_is_literal=False, - **kw, - ): + selectable: FromClause, + *, + name: Optional[str] = None, + key: Optional[str] = None, + name_is_truncatable: bool = False, + disallow_is_literal: bool = False, + **kw: Any, + ) -> typing_Tuple[str, ColumnClause[_T]]: c = ColumnClause( coercions.expect(roles.TruncatedLabelRole, name or self.name) if name_is_truncatable @@ -4188,7 +4208,13 @@ class Label(roles.LabeledColumnExprRole[_T], NamedColumn[_T]): def _from_objects(self) -> List[FromClause]: return self.element._from_objects - def _make_proxy(self, selectable, name=None, **kw): + def _make_proxy( + self, + selectable: FromClause, + *, + name: Optional[str] = None, + **kw: Any, + ) -> typing_Tuple[str, ColumnClause[_T]]: name = self.name if not name else name key, e = self.element._make_proxy( @@ -4279,7 +4305,7 @@ class ColumnClause( onupdate: Optional[DefaultGenerator] = None default: Optional[DefaultGenerator] = None - server_default: Optional[FetchedValue] = None + server_default: Optional[_ServerDefaultType] = None server_onupdate: Optional[FetchedValue] = None _is_multiparam_column = False @@ -4422,12 +4448,14 @@ class ColumnClause( def _make_proxy( self, - selectable, - name=None, - name_is_truncatable=False, - disallow_is_literal=False, - **kw, - ): + selectable: FromClause, + *, + name: Optional[str] = None, + key: Optional[str] = None, + name_is_truncatable: bool = False, + disallow_is_literal: bool = False, + **kw: Any, + ) -> typing_Tuple[str, ColumnClause[_T]]: # the "is_literal" flag normally should never be propagated; a proxied # column is always a SQL identifier and never the actual expression # being evaluated. however, there is a case where the "is_literal" flag @@ -4699,7 +4727,9 @@ class AnnotatedColumnElement(Annotated): return self._Annotated__element.key @util.memoized_property - def info(self): + def info(self) -> _InfoType: + if TYPE_CHECKING: + assert isinstance(self._Annotated__element, Column) return self._Annotated__element.info @util.memoized_property diff --git a/lib/sqlalchemy/sql/events.py b/lib/sqlalchemy/sql/events.py index 0d74e2e4c1..651a8673d9 100644 --- a/lib/sqlalchemy/sql/events.py +++ b/lib/sqlalchemy/sql/events.py @@ -57,7 +57,7 @@ class DDLEvents(event.Events[SchemaEventTarget]): event.listen(some_table, "after_create", after_create) DDL events integrate closely with the - :class:`.DDL` class and the :class:`.DDLElement` hierarchy + :class:`.DDL` class and the :class:`.ExecutableDDLElement` hierarchy of DDL clause constructs, which are themselves appropriate as listener callables:: @@ -94,7 +94,7 @@ class DDLEvents(event.Events[SchemaEventTarget]): :ref:`event_toplevel` - :class:`.DDLElement` + :class:`.ExecutableDDLElement` :class:`.DDL` diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py index a66a1eb929..6481682355 100644 --- a/lib/sqlalchemy/sql/functions.py +++ b/lib/sqlalchemy/sql/functions.py @@ -62,6 +62,10 @@ from .. import util if TYPE_CHECKING: from ._typing import _TypeEngineArgument + from ..engine.base import Connection + from ..engine.cursor import CursorResult + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter _T = TypeVar("_T", bound=Any) @@ -167,8 +171,11 @@ class FunctionElement(Executable, ColumnElement[_T], FromClause, Generative): ) def _execute_on_connection( - self, connection, distilled_params, execution_options - ): + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + ) -> CursorResult: return connection._execute_function( self, distilled_params, execution_options ) diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c9b67caca3..92b9cc62c2 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -32,21 +32,22 @@ from __future__ import annotations from abc import ABC import collections +from enum import Enum import operator import typing from typing import Any from typing import Callable from typing import cast from typing import Dict +from typing import Iterable from typing import Iterator from typing import List -from typing import MutableMapping +from typing import NoReturn from typing import Optional from typing import overload from typing import Sequence as _typing_Sequence from typing import Set from typing import Tuple -from typing import Type from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -65,7 +66,6 @@ from .elements import ClauseElement from .elements import ColumnClause from .elements import ColumnElement from .elements import quoted_name -from .elements import SQLCoreOperations from .elements import TextClause from .selectable import TableClause from .type_api import to_instance @@ -75,53 +75,91 @@ from .. import event from .. import exc from .. import inspection from .. import util +from ..util.typing import Final from ..util.typing import Literal from ..util.typing import Protocol +from ..util.typing import Self from ..util.typing import TypeGuard if typing.TYPE_CHECKING: + from ._typing import _DDLColumnArgument + from ._typing import _InfoType + from ._typing import _TextCoercedExpressionArgument + from ._typing import _TypeEngineArgument + from .base import ColumnCollection + from .base import DedupeColumnCollection from .base import ReadOnlyColumnCollection + from .compiler import DDLCompiler + from .elements import BindParameter + from .functions import Function from .type_api import TypeEngine + from .visitors import _TraverseInternalsType + from .visitors import anon_map from ..engine import Connection from ..engine import Engine + from ..engine.cursor import CursorResult + from ..engine.interfaces import _CoreMultiExecuteParams + from ..engine.interfaces import _CoreSingleExecuteParams + from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection + from ..sql.selectable import FromClause + _T = TypeVar("_T", bound="Any") +_SI = TypeVar("_SI", bound="SchemaItem") _ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement] _TAB = TypeVar("_TAB", bound="Table") -RETAIN_SCHEMA = util.symbol( - "retain_schema" + +_CreateDropBind = Union["Engine", "Connection", "MockConnection"] + + +class SchemaConst(Enum): + + RETAIN_SCHEMA = 1 """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence` or in some cases a :class:`_schema.ForeignKey` object, in situations where the object is being copied for a :meth:`.Table.to_metadata` operation, should retain the schema name that it already has. """ -) -BLANK_SCHEMA = util.symbol( - "blank_schema", - """Symbol indicating that a :class:`_schema.Table`, :class:`.Sequence` - or in some cases a :class:`_schema.ForeignKey` object + BLANK_SCHEMA = 2 + """Symbol indicating that a :class:`_schema.Table` or :class:`.Sequence` should have 'None' for its schema, even if the parent :class:`_schema.MetaData` has specified a schema. + .. seealso:: + + :paramref:`_schema.MetaData.schema` + + :paramref:`_schema.Table.schema` + + :paramref:`.Sequence.schema` + .. versionadded:: 1.0.14 - """, -) + """ -NULL_UNSPECIFIED = util.symbol( - "NULL_UNSPECIFIED", + NULL_UNSPECIFIED = 3 """Symbol indicating the "nullable" keyword was not passed to a Column. Normally we would expect None to be acceptable for this but some backends such as that of SQL Server place special signficance on a "nullability" value of None. - """, -) + """ + + +RETAIN_SCHEMA: Final[ + Literal[SchemaConst.RETAIN_SCHEMA] +] = SchemaConst.RETAIN_SCHEMA +BLANK_SCHEMA: Final[ + Literal[SchemaConst.BLANK_SCHEMA] +] = SchemaConst.BLANK_SCHEMA +NULL_UNSPECIFIED: Final[ + Literal[SchemaConst.NULL_UNSPECIFIED] +] = SchemaConst.NULL_UNSPECIFIED def _get_table_key(name: str, schema: Optional[str]) -> str: @@ -170,7 +208,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): create_drop_stringify_dialect = "default" - def _init_items(self, *args, **kw): + def _init_items(self, *args: SchemaItem, **kw: Any) -> None: """Initialize the list of child items for this SchemaItem.""" for item in args: if item is not None: @@ -184,11 +222,11 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): else: spwd(self, **kw) - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr(self, omit_kwarg=["info"]) @util.memoized_property - def info(self): + def info(self) -> _InfoType: """Info dictionary associated with the object, allowing user-defined data to be associated with this :class:`.SchemaItem`. @@ -199,7 +237,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): """ return {} - def _schema_item_copy(self, schema_item): + def _schema_item_copy(self, schema_item: _SI) -> _SI: if "info" in self.__dict__: schema_item.info = self.info.copy() schema_item.dispatch._update(self.dispatch) @@ -235,9 +273,9 @@ class HasConditionalDDL: r"""apply a conditional DDL rule to this schema item. These rules work in a similar manner to the - :meth:`.DDLElement.execute_if` callable, with the added feature that - the criteria may be checked within the DDL compilation phase for a - construct such as :class:`.CreateTable`. + :meth:`.ExecutableDDLElement.execute_if` callable, with the added + feature that the criteria may be checked within the DDL compilation + phase for a construct such as :class:`.CreateTable`. :meth:`.HasConditionalDDL.ddl_if` currently applies towards the :class:`.Index` construct as well as all :class:`.Constraint` constructs. @@ -246,7 +284,8 @@ class HasConditionalDDL: to indicate multiple dialect types. :param callable\_: a callable that is constructed using the same form - as that described in :paramref:`.DDLElement.execute_if.callable_`. + as that described in + :paramref:`.ExecutableDDLElement.execute_if.callable_`. :param state: any arbitrary object that will be passed to the callable, if present. @@ -306,6 +345,8 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): def foreign_keys(self) -> Set[ForeignKey]: ... + _columns: DedupeColumnCollection[Column[Any]] + constraints: Set[Constraint] """A collection of all :class:`_schema.Constraint` objects associated with this :class:`_schema.Table`. @@ -344,25 +385,30 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): """ - _traverse_internals = TableClause._traverse_internals + [ - ("schema", InternalTraversal.dp_string) - ] + _traverse_internals: _TraverseInternalsType = ( + TableClause._traverse_internals + + [("schema", InternalTraversal.dp_string)] + ) if TYPE_CHECKING: - # we are upgrading .c and .columns to return Column, not - # ColumnClause. mypy typically sees this as incompatible because - # the contract of TableClause is that we can put a ColumnClause - # into this collection. does not recognize its immutability - # for the moment. + + @util.ro_non_memoized_property + def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: + ... + @util.ro_non_memoized_property - def columns(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501 + def exported_columns( + self, + ) -> ReadOnlyColumnCollection[str, Column[Any]]: ... @util.ro_non_memoized_property - def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: # type: ignore # noqa: E501 + def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: ... - def _gen_cache_key(self, anon_map, bindparams): + def _gen_cache_key( + self, anon_map: anon_map, bindparams: List[BindParameter[Any]] + ) -> Tuple[Any, ...]: if self._annotations: return (self,) + self._annotations_cache_key else: @@ -382,7 +428,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): return cls._new(*args, **kw) @classmethod - def _new(cls, *args, **kw): + def _new(cls, *args: Any, **kw: Any) -> Any: if not args and not kw: # python3k pickle seems to call this return object.__new__(cls) @@ -429,7 +475,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): table.dispatch.before_parent_attach(table, metadata) metadata._add_table(name, schema, table) try: - table._init(name, metadata, *args, **kw) + table.__init__(name, metadata, *args, _no_init=False, **kw) table.dispatch.after_parent_attach(table, metadata) return table except Exception: @@ -439,10 +485,31 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): def __init__( self, name: str, - metadata: "MetaData", + metadata: MetaData, *args: SchemaItem, + schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, + quote: Optional[bool] = None, + quote_schema: Optional[bool] = None, + autoload_with: Optional[Union[Engine, Connection]] = None, + autoload_replace: bool = True, + keep_existing: bool = False, + extend_existing: bool = False, + resolve_fks: bool = True, + include_columns: Optional[Iterable[str]] = None, + implicit_returning: bool = True, + comment: Optional[str] = None, + info: Optional[Dict[Any, Any]] = None, + listeners: Optional[ + _typing_Sequence[Tuple[str, Callable[..., Any]]] + ] = None, + prefixes: Optional[_typing_Sequence[str]] = None, + # used internally in the metadata.reflect() process + _extend_on: Optional[Set[Table]] = None, + # used by __new__ to bypass __init__ + _no_init: bool = True, + # dialect-specific keyword args **kw: Any, - ): + ) -> None: r"""Constructor for :class:`_schema.Table`. @@ -731,24 +798,22 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): :ref:`dialect_toplevel` for detail on documented arguments. """ # noqa: E501 + if _no_init: + # don't run __init__ from __new__ by default; + # __new__ has a specific place that __init__ is called + return - # __init__ is overridden to prevent __new__ from - # calling the superclass constructor. - - def _init(self, name, metadata, *args, **kwargs): - super(Table, self).__init__( - quoted_name(name, kwargs.pop("quote", None)) - ) + super().__init__(quoted_name(name, quote)) self.metadata = metadata - self.schema = kwargs.pop("schema", None) - if self.schema is None: + if schema is None: self.schema = metadata.schema - elif self.schema is BLANK_SCHEMA: + elif schema is BLANK_SCHEMA: self.schema = None else: - quote_schema = kwargs.pop("quote_schema", None) - self.schema = quoted_name(self.schema, quote_schema) + quote_schema = quote_schema + assert isinstance(schema, str) + self.schema = quoted_name(schema, quote_schema) self.indexes = set() self.constraints = set() @@ -756,42 +821,31 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): _implicit_generated=True )._set_parent_with_dispatch(self) self.foreign_keys = set() # type: ignore - self._extra_dependencies = set() + self._extra_dependencies: Set[Table] = set() if self.schema is not None: self.fullname = "%s.%s" % (self.schema, self.name) else: self.fullname = self.name - autoload_with = kwargs.pop("autoload_with", None) - autoload = autoload_with is not None - # this argument is only used with _init_existing() - kwargs.pop("autoload_replace", True) - keep_existing = kwargs.pop("keep_existing", False) - extend_existing = kwargs.pop("extend_existing", False) - _extend_on = kwargs.pop("_extend_on", None) - - resolve_fks = kwargs.pop("resolve_fks", True) - include_columns = kwargs.pop("include_columns", None) + self.implicit_returning = implicit_returning - self.implicit_returning = kwargs.pop("implicit_returning", True) + self.comment = comment - self.comment = kwargs.pop("comment", None) + if info is not None: + self.info = info - if "info" in kwargs: - self.info = kwargs.pop("info") - if "listeners" in kwargs: - listeners = kwargs.pop("listeners") + if listeners is not None: for evt, fn in listeners: event.listen(self, evt, fn) - self._prefixes = kwargs.pop("prefixes", None) or [] + self._prefixes = prefixes if prefixes else [] - self._extra_kwargs(**kwargs) + self._extra_kwargs(**kw) # load column definitions from the database if 'autoload' is defined # we do it after the table is in the singleton dictionary to support # circular foreign keys - if autoload: + if autoload_with is not None: self._autoload( metadata, autoload_with, @@ -805,18 +859,20 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): self._init_items( *args, - allow_replacements=extend_existing or keep_existing or autoload, + allow_replacements=extend_existing + or keep_existing + or autoload_with, ) def _autoload( self, - metadata, - autoload_with, - include_columns, - exclude_columns=(), - resolve_fks=True, - _extend_on=None, - ): + metadata: MetaData, + autoload_with: Union[Engine, Connection], + include_columns: Optional[Iterable[str]], + exclude_columns: Iterable[str] = (), + resolve_fks: bool = True, + _extend_on: Optional[Set[Table]] = None, + ) -> None: insp = inspection.inspect(autoload_with) with insp._inspection_context() as conn_insp: conn_insp.reflect_table( @@ -837,7 +893,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): return sorted(self.constraints, key=lambda c: c._creation_order) @property - def foreign_key_constraints(self): + def foreign_key_constraints(self) -> Set[ForeignKeyConstraint]: """:class:`_schema.ForeignKeyConstraint` objects referred to by this :class:`_schema.Table`. @@ -855,9 +911,13 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): :attr:`_schema.Table.indexes` """ - return set(fkc.constraint for fkc in self.foreign_keys) + return set( + fkc.constraint + for fkc in self.foreign_keys + if fkc.constraint is not None + ) - def _init_existing(self, *args, **kwargs): + def _init_existing(self, *args: Any, **kwargs: Any) -> None: autoload_with = kwargs.pop("autoload_with", None) autoload = kwargs.pop("autoload", autoload_with is not None) autoload_replace = kwargs.pop("autoload_replace", True) @@ -916,13 +976,13 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): self._extra_kwargs(**kwargs) self._init_items(*args) - def _extra_kwargs(self, **kwargs): + def _extra_kwargs(self, **kwargs: Any) -> None: self._validate_dialect_kwargs(kwargs) - def _init_collections(self): + def _init_collections(self) -> None: pass - def _reset_exported(self): + def _reset_exported(self) -> None: pass @util.ro_non_memoized_property @@ -930,7 +990,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): return self.primary_key._autoincrement_column @property - def key(self): + def key(self) -> str: """Return the 'key' for this :class:`_schema.Table`. This value is used as the dictionary key within the @@ -943,7 +1003,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): """ return _get_table_key(self.name, self.schema) - def __repr__(self): + def __repr__(self) -> str: return "Table(%s)" % ", ".join( [repr(self.name)] + [repr(self.metadata)] @@ -951,10 +1011,10 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]] ) - def __str__(self): + def __str__(self) -> str: return _get_table_key(self.description, self.schema) - def add_is_dependent_on(self, table): + def add_is_dependent_on(self, table: Table) -> None: """Add a 'dependency' for this Table. This is another Table object which must be created @@ -968,7 +1028,9 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): """ self._extra_dependencies.add(table) - def append_column(self, column, replace_existing=False): + def append_column( + self, column: ColumnClause[Any], replace_existing: bool = False + ) -> None: """Append a :class:`_schema.Column` to this :class:`_schema.Table`. The "key" of the newly added :class:`_schema.Column`, i.e. the @@ -998,7 +1060,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): self, allow_replacements=replace_existing ) - def append_constraint(self, constraint): + def append_constraint(self, constraint: Union[Index, Constraint]) -> None: """Append a :class:`_schema.Constraint` to this :class:`_schema.Table`. @@ -1019,11 +1081,13 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): constraint._set_parent_with_dispatch(self) - def _set_parent(self, metadata, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + metadata = parent + assert isinstance(metadata, MetaData) metadata._add_table(self.name, self.schema, self) self.metadata = metadata - def create(self, bind, checkfirst=False): + def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue a ``CREATE`` statement for this :class:`_schema.Table`, using the given :class:`.Connection` or :class:`.Engine` @@ -1037,7 +1101,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) - def drop(self, bind, checkfirst=False): + def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue a ``DROP`` statement for this :class:`_schema.Table`, using the given :class:`.Connection` or :class:`.Engine` for connectivity. @@ -1056,11 +1120,16 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): ) def tometadata( self, - metadata, - schema=RETAIN_SCHEMA, - referred_schema_fn=None, - name=None, - ): + metadata: MetaData, + schema: Union[str, Literal[SchemaConst.RETAIN_SCHEMA]] = RETAIN_SCHEMA, + referred_schema_fn: Optional[ + Callable[ + [Table, Optional[str], ForeignKeyConstraint, Optional[str]], + Optional[str], + ] + ] = None, + name: Optional[str] = None, + ) -> Table: """Return a copy of this :class:`_schema.Table` associated with a different :class:`_schema.MetaData`. @@ -1077,11 +1146,16 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): def to_metadata( self, - metadata, - schema=RETAIN_SCHEMA, - referred_schema_fn=None, - name=None, - ): + metadata: MetaData, + schema: Union[str, Literal[SchemaConst.RETAIN_SCHEMA]] = RETAIN_SCHEMA, + referred_schema_fn: Optional[ + Callable[ + [Table, Optional[str], ForeignKeyConstraint, Optional[str]], + Optional[str], + ] + ] = None, + name: Optional[str] = None, + ) -> Table: """Return a copy of this :class:`_schema.Table` associated with a different :class:`_schema.MetaData`. @@ -1163,11 +1237,16 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): """ if name is None: name = self.name + + actual_schema: Optional[str] + if schema is RETAIN_SCHEMA: - schema = self.schema + actual_schema = self.schema elif schema is None: - schema = metadata.schema - key = _get_table_key(name, schema) + actual_schema = metadata.schema + else: + actual_schema = schema # type: ignore + key = _get_table_key(name, actual_schema) if key in metadata.tables: util.warn( "Table '%s' already exists within the given " @@ -1177,11 +1256,11 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): args = [] for col in self.columns: - args.append(col._copy(schema=schema)) + args.append(col._copy(schema=actual_schema)) table = Table( name, metadata, - schema=schema, + schema=actual_schema, comment=self.comment, *args, **self.kwargs, @@ -1191,11 +1270,13 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): referred_schema = const._referred_schema if referred_schema_fn: fk_constraint_schema = referred_schema_fn( - self, schema, const, referred_schema + self, actual_schema, const, referred_schema ) else: fk_constraint_schema = ( - schema if referred_schema == self.schema else None + actual_schema + if referred_schema == self.schema + else None ) table.append_constraint( const._copy( @@ -1209,7 +1290,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): continue table.append_constraint( - const._copy(schema=schema, target_table=table) + const._copy(schema=actual_schema, target_table=table) ) for index in self.indexes: # skip indexes that would be generated @@ -1221,7 +1302,7 @@ class Table(DialectKWArgs, HasSchemaAttr, TableClause): unique=index.unique, *[ _copy_expression(expr, self, table) - for expr in index.expressions + for expr in index._table_bound_expressions ], _table=table, **index.kwargs, @@ -1239,101 +1320,137 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): @overload def __init__( - self: "Column[None]", - __name: str, + self, *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = ..., - default: Optional[Any] = ..., - doc: Optional[str] = ..., - key: Optional[str] = ..., - index: Optional[bool] = ..., - info: MutableMapping[Any, Any] = ..., - nullable: bool = ..., - onupdate: Optional[Any] = ..., - primary_key: bool = ..., - server_default: Optional[_ServerDefaultType] = ..., - server_onupdate: Optional["FetchedValue"] = ..., - quote: Optional[bool] = ..., - unique: Optional[bool] = ..., - system: bool = ..., - comment: Optional[str] = ..., - **kwargs: Any, - ) -> None: + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = NULL_UNSPECIFIED, + onupdate: Optional[Any] = None, + primary_key: bool = False, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + _proxies: Optional[Any] = None, + **dialect_kwargs: Any, + ): ... @overload def __init__( - self: "Column[None]", + self, + __name: str, *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = ..., - default: Optional[Any] = ..., - doc: Optional[str] = ..., - key: Optional[str] = ..., - index: Optional[bool] = ..., - info: MutableMapping[Any, Any] = ..., - nullable: bool = ..., - onupdate: Optional[Any] = ..., - primary_key: bool = ..., - server_default: Optional[_ServerDefaultType] = ..., - server_onupdate: Optional["FetchedValue"] = ..., - quote: Optional[bool] = ..., - unique: Optional[bool] = ..., - system: bool = ..., - comment: Optional[str] = ..., - **kwargs: Any, - ) -> None: + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = NULL_UNSPECIFIED, + onupdate: Optional[Any] = None, + primary_key: bool = False, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + _proxies: Optional[Any] = None, + **dialect_kwargs: Any, + ): ... @overload def __init__( self, - __name: str, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __type: _TypeEngineArgument[_T], *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = ..., - default: Optional[Any] = ..., - doc: Optional[str] = ..., - key: Optional[str] = ..., - index: Optional[bool] = ..., - info: MutableMapping[Any, Any] = ..., - nullable: bool = ..., - onupdate: Optional[Any] = ..., - primary_key: bool = ..., - server_default: Optional[_ServerDefaultType] = ..., - server_onupdate: Optional["FetchedValue"] = ..., - quote: Optional[bool] = ..., - unique: Optional[bool] = ..., - system: bool = ..., - comment: Optional[str] = ..., - **kwargs: Any, - ) -> None: + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = NULL_UNSPECIFIED, + onupdate: Optional[Any] = None, + primary_key: bool = False, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + _proxies: Optional[Any] = None, + **dialect_kwargs: Any, + ): ... @overload def __init__( self, - __type: Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"], + __name: str, + __type: _TypeEngineArgument[_T], *args: SchemaEventTarget, - autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = ..., - default: Optional[Any] = ..., - doc: Optional[str] = ..., - key: Optional[str] = ..., - index: Optional[bool] = ..., - info: MutableMapping[Any, Any] = ..., - nullable: bool = ..., - onupdate: Optional[Any] = ..., - primary_key: bool = ..., - server_default: Optional[_ServerDefaultType] = ..., - server_onupdate: Optional["FetchedValue"] = ..., - quote: Optional[bool] = ..., - unique: Optional[bool] = ..., - system: bool = ..., - comment: Optional[str] = ..., - **kwargs: Any, - ) -> None: + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = NULL_UNSPECIFIED, + onupdate: Optional[Any] = None, + primary_key: bool = False, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + _proxies: Optional[Any] = None, + **dialect_kwargs: Any, + ): ... - def __init__(self, *args: Any, **kwargs: Any): + def __init__( + self, + *args: Union[str, _TypeEngineArgument[_T], SchemaEventTarget], + name: Optional[str] = None, + type_: Optional[_TypeEngineArgument[_T]] = None, + autoincrement: Union[bool, Literal["auto", "ignore_fk"]] = "auto", + default: Optional[Any] = None, + doc: Optional[str] = None, + key: Optional[str] = None, + index: Optional[bool] = None, + unique: Optional[bool] = None, + info: Optional[_InfoType] = None, + nullable: Optional[ + Union[bool, Literal[SchemaConst.NULL_UNSPECIFIED]] + ] = NULL_UNSPECIFIED, + onupdate: Optional[Any] = None, + primary_key: bool = False, + server_default: Optional[_ServerDefaultType] = None, + server_onupdate: Optional[FetchedValue] = None, + quote: Optional[bool] = None, + system: bool = False, + comment: Optional[str] = None, + _proxies: Optional[Any] = None, + **dialect_kwargs: Any, + ): r""" Construct a new ``Column`` object. @@ -1836,8 +1953,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ # noqa: E501, RST201, RST202 - name = kwargs.pop("name", None) - type_ = kwargs.pop("type_", None) l_args = list(args) del args @@ -1847,7 +1962,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): raise exc.ArgumentError( "May not pass name positionally and as a keyword." ) - name = l_args.pop(0) + name = l_args.pop(0) # type: ignore if l_args: coltype = l_args[0] @@ -1856,52 +1971,49 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): raise exc.ArgumentError( "May not pass type_ positionally and as a keyword." ) - type_ = l_args.pop(0) + type_ = l_args.pop(0) # type: ignore if name is not None: - name = quoted_name(name, kwargs.pop("quote", None)) - elif "quote" in kwargs: + name = quoted_name(name, quote) + elif quote is not None: raise exc.ArgumentError( "Explicit 'name' is required when " "sending 'quote' argument" ) - super(Column, self).__init__(name, type_) - self.key = kwargs.pop("key", name) - self.primary_key = primary_key = kwargs.pop("primary_key", False) + # name = None is expected to be an interim state + # note this use case is legacy now that ORM declarative has a + # dedicated "column" construct local to the ORM + super(Column, self).__init__(name, type_) # type: ignore - self._user_defined_nullable = udn = kwargs.pop( - "nullable", NULL_UNSPECIFIED - ) + self.key = key if key is not None else name # type: ignore + self.primary_key = primary_key + + self._user_defined_nullable = udn = nullable if udn is not NULL_UNSPECIFIED: self.nullable = udn else: self.nullable = not primary_key - default = kwargs.pop("default", None) - onupdate = kwargs.pop("onupdate", None) - - self.server_default = kwargs.pop("server_default", None) - self.server_onupdate = kwargs.pop("server_onupdate", None) - # these default to None because .index and .unique is *not* # an informational flag about Column - there can still be an # Index or UniqueConstraint referring to this Column. - self.index = kwargs.pop("index", None) - self.unique = kwargs.pop("unique", None) + self.index = index + self.unique = unique - self.system = kwargs.pop("system", False) - self.doc = kwargs.pop("doc", None) - self.autoincrement = kwargs.pop("autoincrement", "auto") + self.system = system + self.doc = doc + self.autoincrement = autoincrement self.constraints = set() self.foreign_keys = set() - self.comment = kwargs.pop("comment", None) + self.comment = comment self.computed = None self.identity = None # check if this Column is proxying another column - if "_proxies" in kwargs: - self._proxies = kwargs.pop("_proxies") + + if _proxies is not None: + self._proxies = _proxies else: # otherwise, add DDL-related events if isinstance(self.type, SchemaEventTarget): @@ -1928,6 +2040,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): else: self.onpudate = None + self.server_default = server_default + self.server_onupdate = server_onupdate + if self.server_default is not None: if isinstance(self.server_default, FetchedValue): l_args.append(self.server_default._as_for_update(False)) @@ -1941,14 +2056,14 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): l_args.append( DefaultClause(self.server_onupdate, for_update=True) ) - self._init_items(*l_args) + self._init_items(*cast(_typing_Sequence[SchemaItem], l_args)) util.set_creation_order(self) - if "info" in kwargs: - self.info = kwargs.pop("info") + if info is not None: + self.info = info - self._extra_kwargs(**kwargs) + self._extra_kwargs(**dialect_kwargs) table: Table @@ -1967,7 +2082,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ - index: bool + index: Optional[bool] """The value of the :paramref:`_schema.Column.index` parameter. Does not indicate if this :class:`_schema.Column` is actually indexed @@ -1978,7 +2093,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): :attr:`_schema.Table.indexes` """ - unique: bool + unique: Optional[bool] """The value of the :paramref:`_schema.Column.unique` parameter. Does not indicate if this :class:`_schema.Column` is actually subject to @@ -1993,10 +2108,14 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ - def _extra_kwargs(self, **kwargs): + computed: Optional[Computed] + + identity: Optional[Identity] + + def _extra_kwargs(self, **kwargs: Any) -> None: self._validate_dialect_kwargs(kwargs) - def __str__(self): + def __str__(self) -> str: if self.name is None: return "(no name)" elif self.table is not None: @@ -2007,7 +2126,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): else: return self.description - def references(self, column): + def references(self, column: Column[Any]) -> bool: """Return True if this Column references the given column via foreign key.""" @@ -2017,10 +2136,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): else: return False - def append_foreign_key(self, fk): + def append_foreign_key(self, fk: ForeignKey) -> None: fk._set_parent_with_dispatch(self) - def __repr__(self): + def __repr__(self) -> str: kwarg = [] if self.key != self.name: kwarg.append("key") @@ -2051,7 +2170,14 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg] ) - def _set_parent(self, table, allow_replacements=True): + def _set_parent( + self, + parent: SchemaEventTarget, + allow_replacements: bool = True, + **kw: Any, + ) -> None: + table = parent + assert isinstance(table, Table) if not self.name: raise exc.ArgumentError( "Column must be constructed with a non-blank name or " @@ -2071,7 +2197,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): ) if self.key in table._columns: - col = table._columns.get(self.key) + col = table._columns[self.key] if col is not self: if not allow_replacements: util.warn_deprecated( @@ -2139,7 +2265,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "An column cannot specify both Identity and Sequence." ) - def _setup_on_memoized_fks(self, fn): + def _setup_on_memoized_fks(self, fn: Callable[..., Any]) -> None: fk_keys = [ ((self.table.key, self.key), False), ((self.table.key, self.name), True), @@ -2150,7 +2276,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if fk.link_to_name is link_to_name: fn(fk) - def _on_table_attach(self, fn): + def _on_table_attach(self, fn: Callable[..., Any]) -> None: if self.table is not None: fn(self, self.table) else: @@ -2161,10 +2287,10 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): "The :meth:`_schema.Column.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, **kw): + def copy(self, **kw: Any) -> Column[Any]: return self._copy(**kw) - def _copy(self, **kw): + def _copy(self, **kw: Any) -> Column[Any]: """Create a copy of this ``Column``, uninitialized. This is used in :meth:`_schema.Table.to_metadata`. @@ -2172,9 +2298,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): """ # Constraint objects plus non-constraint-bound ForeignKey objects - args = [ - c._copy(**kw) for c in self.constraints if not c._type_bound - ] + [c._copy(**kw) for c in self.foreign_keys if not c.constraint] + args: List[SchemaItem] = [ + c._copy(**kw) + for c in self.constraints + if not c._type_bound # type: ignore + ] + [ + c._copy(**kw) # type: ignore + for c in self.foreign_keys + if not c.constraint + ] # ticket #5276 column_kwargs = {} @@ -2223,8 +2355,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): return self._schema_item_copy(c) def _make_proxy( - self, selectable, name=None, key=None, name_is_truncatable=False, **kw - ): + self, + selectable: FromClause, + name: Optional[str] = None, + key: Optional[str] = None, + name_is_truncatable: bool = False, + **kw: Any, + ) -> Tuple[str, ColumnClause[_T]]: """Create a *proxy* for this column. This is a copy of this ``Column`` referenced by a different parent @@ -2272,9 +2409,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause[_T]): if selectable._is_clone_of is not None: c._is_clone_of = selectable._is_clone_of.columns.get(c.key) if self.primary_key: - selectable.primary_key.add(c) + selectable.primary_key.add(c) # type: ignore if fk: - selectable.foreign_keys.update(fk) + selectable.foreign_keys.update(fk) # type: ignore return c.key, c @@ -2326,17 +2463,17 @@ class ForeignKey(DialectKWArgs, SchemaItem): def __init__( self, - column: Union[str, Column[Any], SQLCoreOperations[Any]], - _constraint: Optional["ForeignKeyConstraint"] = None, + column: _DDLColumnArgument, + _constraint: Optional[ForeignKeyConstraint] = None, use_alter: bool = False, name: Optional[str] = None, onupdate: Optional[str] = None, ondelete: Optional[str] = None, deferrable: Optional[bool] = None, - initially: Optional[bool] = None, + initially: Optional[str] = None, link_to_name: bool = False, match: Optional[str] = None, - info: Optional[Dict[Any, Any]] = None, + info: Optional[_InfoType] = None, **dialect_kw: Any, ): r""" @@ -2446,7 +2583,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): self.info = info self._unvalidated_dialect_kw = dialect_kw - def __repr__(self): + def __repr__(self) -> str: return "ForeignKey(%r)" % self._get_colspec() @util.deprecated( @@ -2454,10 +2591,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): "The :meth:`_schema.ForeignKey.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, schema=None, **kw): + def copy(self, *, schema: Optional[str] = None, **kw: Any) -> ForeignKey: return self._copy(schema=schema, **kw) - def _copy(self, schema=None, **kw): + def _copy(self, *, schema: Optional[str] = None, **kw: Any) -> ForeignKey: """Produce a copy of this :class:`_schema.ForeignKey` object. The new :class:`_schema.ForeignKey` will not be bound @@ -2487,7 +2624,17 @@ class ForeignKey(DialectKWArgs, SchemaItem): ) return self._schema_item_copy(fk) - def _get_colspec(self, schema=None, table_name=None, _is_copy=False): + def _get_colspec( + self, + schema: Optional[ + Union[ + str, + Literal[SchemaConst.RETAIN_SCHEMA, SchemaConst.BLANK_SCHEMA], + ] + ] = None, + table_name: Optional[str] = None, + _is_copy: bool = False, + ) -> str: """Return a string based 'column specification' for this :class:`_schema.ForeignKey`. @@ -2523,13 +2670,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): self._table_column.key, ) else: + assert isinstance(self._colspec, str) return self._colspec @property - def _referred_schema(self): + def _referred_schema(self) -> Optional[str]: return self._column_tokens[0] - def _table_key(self): + def _table_key(self) -> Any: if self._table_column is not None: if self._table_column.table is None: return None @@ -2541,16 +2689,16 @@ class ForeignKey(DialectKWArgs, SchemaItem): target_fullname = property(_get_colspec) - def references(self, table): + def references(self, table: Table) -> bool: """Return True if the given :class:`_schema.Table` is referenced by this :class:`_schema.ForeignKey`.""" return table.corresponding_column(self.column) is not None - def get_referent(self, table): + def get_referent(self, table: FromClause) -> Optional[Column[Any]]: """Return the :class:`_schema.Column` in the given - :class:`_schema.Table` + :class:`_schema.Table` (or any :class:`.FromClause`) referenced by this :class:`_schema.ForeignKey`. Returns None if this :class:`_schema.ForeignKey` @@ -2559,10 +2707,10 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ - return table.corresponding_column(self.column) + return table.columns.corresponding_column(self.column) @util.memoized_property - def _column_tokens(self): + def _column_tokens(self) -> Tuple[Optional[str], str, Optional[str]]: """parse a string-based _colspec into its component parts.""" m = self._get_colspec().split(".") @@ -2592,7 +2740,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): schema = None return schema, tname, colname - def _resolve_col_tokens(self): + def _resolve_col_tokens(self) -> Tuple[Table, str, Optional[str]]: if self.parent is None: raise exc.InvalidRequestError( "this ForeignKey object does not yet have a " @@ -2627,7 +2775,9 @@ class ForeignKey(DialectKWArgs, SchemaItem): tablekey = _get_table_key(tname, schema) return parenttable, tablekey, colname - def _link_to_col_by_colstring(self, parenttable, table, colname): + def _link_to_col_by_colstring( + self, parenttable: Table, table: Table, colname: Optional[str] + ) -> Column[Any]: _column = None if colname is None: # colname is None in the case that ForeignKey argument @@ -2661,7 +2811,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): return _column - def _set_target_column(self, column): + def _set_target_column(self, column: Column[Any]) -> None: assert self.parent is not None # propagate TypeEngine to parent if it didn't have one @@ -2671,16 +2821,16 @@ class ForeignKey(DialectKWArgs, SchemaItem): # super-edgy case, if other FKs point to our column, # they'd get the type propagated out also. - def set_type(fk): + def set_type(fk: ForeignKey) -> None: if fk.parent.type._isnull: fk.parent.type = column.type self.parent._setup_on_memoized_fks(set_type) - self.column = column + self.column = column # type: ignore - @util.memoized_property - def column(self): + @util.ro_memoized_property + def column(self) -> Column[Any]: """Return the target :class:`_schema.Column` referenced by this :class:`_schema.ForeignKey`. @@ -2689,6 +2839,8 @@ class ForeignKey(DialectKWArgs, SchemaItem): """ + _column: Column[Any] + if isinstance(self._colspec, str): parenttable, tablekey, colname = self._resolve_col_tokens() @@ -2730,14 +2882,14 @@ class ForeignKey(DialectKWArgs, SchemaItem): self.parent.foreign_keys.add(self) self.parent._on_table_attach(self._set_table) - def _set_remote_table(self, table): + def _set_remote_table(self, table: Table) -> None: parenttable, _, colname = self._resolve_col_tokens() _column = self._link_to_col_by_colstring(parenttable, table, colname) self._set_target_column(_column) assert self.constraint is not None self.constraint._validate_dest_table(table) - def _remove_from_metadata(self, metadata): + def _remove_from_metadata(self, metadata: MetaData) -> None: parenttable, table_key, colname = self._resolve_col_tokens() fk_key = (table_key, colname) @@ -2745,7 +2897,7 @@ class ForeignKey(DialectKWArgs, SchemaItem): # TODO: no test coverage for self not in memos metadata._fk_memos[fk_key].remove(self) - def _set_table(self, column, table): + def _set_table(self, column: Column[Any], table: Table) -> None: # standalone ForeignKey - create ForeignKeyConstraint # on the hosting Table when attached to the Table. assert isinstance(table, Table) @@ -2821,6 +2973,7 @@ class DefaultGenerator(Executable, SchemaItem): __visit_name__ = "default_generator" + _is_default_generator = True is_sequence = False is_server_default = False is_clause_element = False @@ -2828,7 +2981,7 @@ class DefaultGenerator(Executable, SchemaItem): is_scalar = False column: Optional[Column[Any]] - def __init__(self, for_update=False): + def __init__(self, for_update: bool = False) -> None: self.for_update = for_update def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: @@ -2841,8 +2994,27 @@ class DefaultGenerator(Executable, SchemaItem): self.column.default = self def _execute_on_connection( - self, connection, distilled_params, execution_options - ): + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + ) -> Any: + util.warn_deprecated( + "Using the .execute() method to invoke a " + "DefaultGenerator object is deprecated; please use " + "the .scalar() method.", + "2.0", + ) + return self._execute_on_scalar( + connection, distilled_params, execution_options + ) + + def _execute_on_scalar( + self, + connection: Connection, + distilled_params: _CoreMultiExecuteParams, + execution_options: _ExecuteOptionsParameter, + ) -> Any: return connection._execute_default( self, distilled_params, execution_options ) @@ -2933,7 +3105,7 @@ class ColumnDefault(DefaultGenerator, ABC): return object.__new__(cls) - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.arg!r})" @@ -2946,7 +3118,7 @@ class ScalarElementColumnDefault(ColumnDefault): is_scalar = True - def __init__(self, arg: Any, for_update: bool = False): + def __init__(self, arg: Any, for_update: bool = False) -> None: self.for_update = for_update self.arg = arg @@ -2970,13 +3142,13 @@ class ColumnElementColumnDefault(ColumnDefault): self, arg: _SQLExprDefault, for_update: bool = False, - ): + ) -> None: self.for_update = for_update self.arg = arg @util.memoized_property @util.preload_module("sqlalchemy.sql.sqltypes") - def _arg_is_typed(self): + def _arg_is_typed(self) -> bool: sqltypes = util.preloaded.sql_sqltypes return not isinstance(self.arg.type, sqltypes.NullType) @@ -3001,7 +3173,7 @@ class CallableColumnDefault(ColumnDefault): self, arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]], for_update: bool = False, - ): + ) -> None: self.for_update = for_update self.arg = self._maybe_wrap_callable(arg) @@ -3048,16 +3220,16 @@ class IdentityOptions: def __init__( self, - start=None, - increment=None, - minvalue=None, - maxvalue=None, - nominvalue=None, - nomaxvalue=None, - cycle=None, - cache=None, - order=None, - ): + start: Optional[int] = None, + increment: Optional[int] = None, + minvalue: Optional[int] = None, + maxvalue: Optional[int] = None, + nominvalue: Optional[bool] = None, + nomaxvalue: Optional[bool] = None, + cycle: Optional[bool] = None, + cache: Optional[bool] = None, + order: Optional[bool] = None, + ) -> None: """Construct a :class:`.IdentityOptions` object. See the :class:`.Sequence` documentation for a complete description @@ -3125,28 +3297,29 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): is_sequence = True - column: Optional[Column[Any]] = None + column: Optional[Column[Any]] + data_type: Optional[TypeEngine[int]] def __init__( self, - name, - start=None, - increment=None, - minvalue=None, - maxvalue=None, - nominvalue=None, - nomaxvalue=None, - cycle=None, - schema=None, - cache=None, - order=None, - data_type=None, - optional=False, - quote=None, - metadata=None, - quote_schema=None, - for_update=False, - ): + name: str, + start: Optional[int] = None, + increment: Optional[int] = None, + minvalue: Optional[int] = None, + maxvalue: Optional[int] = None, + nominvalue: Optional[bool] = None, + nomaxvalue: Optional[bool] = None, + cycle: Optional[bool] = None, + schema: Optional[Union[str, Literal[SchemaConst.BLANK_SCHEMA]]] = None, + cache: Optional[bool] = None, + order: Optional[bool] = None, + data_type: Optional[_TypeEngineArgument[int]] = None, + optional: bool = False, + quote: Optional[bool] = None, + metadata: Optional[MetaData] = None, + quote_schema: Optional[bool] = None, + for_update: bool = False, + ) -> None: """Construct a :class:`.Sequence` object. :param name: the name of the sequence. @@ -3298,6 +3471,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): cache=cache, order=order, ) + self.column = None self.name = quoted_name(name, quote) self.optional = optional if schema is BLANK_SCHEMA: @@ -3316,7 +3490,7 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): self.data_type = None @util.preload_module("sqlalchemy.sql.functions") - def next_value(self): + def next_value(self) -> Function[int]: """Return a :class:`.next_value` function element which will render the appropriate increment function for this :class:`.Sequence` within any SQL expression. @@ -3324,28 +3498,30 @@ class Sequence(HasSchemaAttr, IdentityOptions, DefaultGenerator): """ return util.preloaded.sql_functions.func.next_value(self) - def _set_parent(self, column, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + column = parent + assert isinstance(column, Column) super(Sequence, self)._set_parent(column) column._on_table_attach(self._set_table) - def _set_table(self, column, table): + def _set_table(self, column: Column[Any], table: Table) -> None: self._set_metadata(table.metadata) - def _set_metadata(self, metadata): + def _set_metadata(self, metadata: MetaData) -> None: self.metadata = metadata self.metadata._sequences[self._key] = self - def create(self, bind, checkfirst=True): + def create(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Creates this sequence in the database.""" bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) - def drop(self, bind, checkfirst=True): + def drop(self, bind: _CreateDropBind, checkfirst: bool = True) -> None: """Drops this sequence from the database.""" bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) - def _not_a_column_expr(self): + def _not_a_column_expr(self) -> NoReturn: raise exc.InvalidRequestError( "This %s cannot be used directly " "as a column expression. Use func.next_value(sequence) " @@ -3380,30 +3556,34 @@ class FetchedValue(SchemaEventTarget): has_argument = False is_clause_element = False - def __init__(self, for_update=False): + column: Optional[Column[Any]] + + def __init__(self, for_update: bool = False) -> None: self.for_update = for_update - def _as_for_update(self, for_update): + def _as_for_update(self, for_update: bool) -> FetchedValue: if for_update == self.for_update: return self else: - return self._clone(for_update) + return self._clone(for_update) # type: ignore - def _clone(self, for_update): + def _clone(self, for_update: bool) -> Any: n = self.__class__.__new__(self.__class__) n.__dict__.update(self.__dict__) n.__dict__.pop("column", None) n.for_update = for_update return n - def _set_parent(self, column, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + column = parent + assert isinstance(column, Column) self.column = column if self.for_update: self.column.server_onupdate = self else: self.column.server_default = self - def __repr__(self): + def __repr__(self) -> str: return util.generic_repr(self) @@ -3431,13 +3611,18 @@ class DefaultClause(FetchedValue): has_argument = True - def __init__(self, arg, for_update=False, _reflected=False): + def __init__( + self, + arg: Union[str, ClauseElement, TextClause], + for_update: bool = False, + _reflected: bool = False, + ) -> None: util.assert_arg_type(arg, (str, ClauseElement, TextClause), "arg") super(DefaultClause, self).__init__(for_update) self.arg = arg self.reflected = _reflected - def __repr__(self): + def __repr__(self) -> str: return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update) @@ -3460,14 +3645,14 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): def __init__( self, - name=None, - deferrable=None, - initially=None, - _create_rule=None, - info=None, - _type_bound=False, - **dialect_kw, - ): + name: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + info: Optional[_InfoType] = None, + _create_rule: Optional[Any] = None, + _type_bound: bool = False, + **dialect_kw: Any, + ) -> None: r"""Create a SQL constraint. :param name: @@ -3510,7 +3695,9 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): util.set_creation_order(self) self._validate_dialect_kwargs(dialect_kw) - def _should_create_for_compiler(self, compiler, **kw): + def _should_create_for_compiler( + self, compiler: DDLCompiler, **kw: Any + ) -> bool: if self._create_rule is not None and not self._create_rule(compiler): return False elif self._ddl_if is not None: @@ -3521,7 +3708,7 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): return True @property - def table(self): + def table(self) -> Table: try: if isinstance(self.parent, Table): return self.parent @@ -3532,7 +3719,8 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): "mean to call table.append_constraint(constraint) ?" ) - def _set_parent(self, parent, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + assert isinstance(parent, (Table, Column)) self.parent = parent parent.constraints.add(self) @@ -3541,10 +3729,10 @@ class Constraint(DialectKWArgs, HasConditionalDDL, SchemaItem): "The :meth:`_schema.Constraint.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, **kw): - return self._copy(**kw) + def copy(self: Self, **kw: Any) -> Self: + return self._copy(**kw) # type: ignore - def _copy(self, **kw): + def _copy(self: Self, **kw: Any) -> Self: raise NotImplementedError() @@ -3561,6 +3749,8 @@ class ColumnCollectionMixin: _allow_multiple_tables = False + _pending_colargs: List[Optional[Union[str, Column[Any]]]] + if TYPE_CHECKING: def _set_parent_with_dispatch( @@ -3568,18 +3758,28 @@ class ColumnCollectionMixin: ) -> None: ... - def __init__(self, *columns, **kw): - _autoattach = kw.pop("_autoattach", True) - self._column_flag = kw.pop("_column_flag", False) + def __init__( + self, + *columns: _DDLColumnArgument, + _autoattach: bool = True, + _column_flag: bool = False, + _gather_expressions: Optional[ + List[Union[str, ColumnElement[Any]]] + ] = None, + ) -> None: + self._column_flag = _column_flag self._columns = DedupeColumnCollection() - processed_expressions = kw.pop("_gather_expressions", None) + processed_expressions: Optional[ + List[Union[ColumnElement[Any], str]] + ] = _gather_expressions + if processed_expressions is not None: self._pending_colargs = [] for ( expr, - column, - strname, + _, + _, add_element, ) in coercions.expect_col_expression_collection( roles.DDLConstraintColumnRole, columns @@ -3595,7 +3795,7 @@ class ColumnCollectionMixin: if _autoattach and self._pending_colargs: self._check_attach() - def _check_attach(self, evt=False): + def _check_attach(self, evt: bool = False) -> None: col_objs = [c for c in self._pending_colargs if isinstance(c, Column)] cols_w_table = [c for c in col_objs if isinstance(c.table, Table)] @@ -3613,7 +3813,7 @@ class ColumnCollectionMixin: ).difference(col_objs) if not has_string_cols: - def _col_attached(column, table): + def _col_attached(column: Column[Any], table: Table) -> None: # this isinstance() corresponds with the # isinstance() above; only want to count Table-bound # columns @@ -3652,15 +3852,24 @@ class ColumnCollectionMixin: def c(self) -> ReadOnlyColumnCollection[str, Column[Any]]: return self._columns.as_readonly() - def _col_expressions(self, table: Table) -> List[Column[Any]]: - return [ - table.c[col] if isinstance(col, str) else col - for col in self._pending_colargs - ] + def _col_expressions( + self, parent: Union[Table, Column[Any]] + ) -> List[Optional[Column[Any]]]: + if isinstance(parent, Column): + result: List[Optional[Column[Any]]] = [ + c for c in self._pending_colargs if isinstance(c, Column) + ] + assert len(result) == len(self._pending_colargs) + return result + else: + return [ + parent.c[col] if isinstance(col, str) else col + for col in self._pending_colargs + ] def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: - if TYPE_CHECKING: - assert isinstance(parent, Table) + assert isinstance(parent, (Table, Column)) + for col in self._col_expressions(parent): if col is not None: self._columns.add(col) @@ -3669,7 +3878,18 @@ class ColumnCollectionMixin: class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): """A constraint that proxies a ColumnCollection.""" - def __init__(self, *columns, **kw): + def __init__( + self, + *columns: _DDLColumnArgument, + name: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + info: Optional[_InfoType] = None, + _autoattach: bool = True, + _column_flag: bool = False, + _gather_expressions: Optional[List[_DDLColumnArgument]] = None, + **dialect_kw: Any, + ) -> None: r""" :param \*columns: A sequence of column names or Column objects. @@ -3685,13 +3905,19 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): Optional string. If set, emit INITIALLY when issuing DDL for this constraint. - :param \**kw: other keyword arguments including dialect-specific - arguments are propagated to the :class:`.Constraint` superclass. + :param \**dialect_kw: other keyword arguments including + dialect-specific arguments are propagated to the :class:`.Constraint` + superclass. """ - _autoattach = kw.pop("_autoattach", True) - _column_flag = kw.pop("_column_flag", False) - Constraint.__init__(self, **kw) + Constraint.__init__( + self, + name=name, + deferrable=deferrable, + initially=initially, + info=info, + **dialect_kw, + ) ColumnCollectionMixin.__init__( self, *columns, _autoattach=_autoattach, _column_flag=_column_flag ) @@ -3702,11 +3928,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): """ - def _set_parent(self, table, **kw): - Constraint._set_parent(self, table) - ColumnCollectionMixin._set_parent(self, table) + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + assert isinstance(parent, (Column, Table)) + Constraint._set_parent(self, parent) + ColumnCollectionMixin._set_parent(self, parent) - def __contains__(self, x): + def __contains__(self, x: Any) -> bool: return x in self._columns @util.deprecated( @@ -3714,10 +3941,20 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): "The :meth:`_schema.ColumnCollectionConstraint.copy` method " "is deprecated and will be removed in a future release.", ) - def copy(self, target_table=None, **kw): + def copy( + self, + *, + target_table: Optional[Table] = None, + **kw: Any, + ) -> ColumnCollectionConstraint: return self._copy(target_table=target_table, **kw) - def _copy(self, target_table=None, **kw): + def _copy( + self, + *, + target_table: Optional[Table] = None, + **kw: Any, + ) -> ColumnCollectionConstraint: # ticket #5276 constraint_kwargs = {} for dialect_name in self.dialect_options: @@ -3730,6 +3967,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): dialect_name + "_" + dialect_option_key ] = dialect_option_value + assert isinstance(self.parent, Table) c = self.__class__( name=self.name, deferrable=self.deferrable, @@ -3742,7 +3980,7 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint): ) return self._schema_item_copy(c) - def contains_column(self, col): + def contains_column(self, col: Column[Any]) -> bool: """Return True if this constraint contains the given column. Note that this object also contains an attribute ``.columns`` @@ -3777,17 +4015,17 @@ class CheckConstraint(ColumnCollectionConstraint): ) def __init__( self, - sqltext, - name=None, - deferrable=None, - initially=None, - table=None, - info=None, - _create_rule=None, - _autoattach=True, - _type_bound=False, - **kw, - ): + sqltext: _TextCoercedExpressionArgument[Any], + name: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + table: Optional[Table] = None, + info: Optional[_InfoType] = None, + _create_rule: Optional[Any] = None, + _autoattach: bool = True, + _type_bound: bool = False, + **dialect_kw: Any, + ) -> None: r"""Construct a CHECK constraint. :param sqltext: @@ -3821,7 +4059,7 @@ class CheckConstraint(ColumnCollectionConstraint): columns: List[Column[Any]] = [] visitors.traverse(self.sqltext, {}, {"column": columns.append}) - super(CheckConstraint, self).__init__( + super().__init__( name=name, deferrable=deferrable, initially=initially, @@ -3830,13 +4068,13 @@ class CheckConstraint(ColumnCollectionConstraint): _type_bound=_type_bound, _autoattach=_autoattach, *columns, - **kw, + **dialect_kw, ) if table is not None: self._set_parent_with_dispatch(table) @property - def is_column_level(self): + def is_column_level(self) -> bool: return not isinstance(self.parent, Table) @util.deprecated( @@ -3844,10 +4082,14 @@ class CheckConstraint(ColumnCollectionConstraint): "The :meth:`_schema.CheckConstraint.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, target_table=None, **kw): + def copy( + self, *, target_table: Optional[Table] = None, **kw: Any + ) -> CheckConstraint: return self._copy(target_table=target_table, **kw) - def _copy(self, target_table=None, **kw): + def _copy( + self, *, target_table: Optional[Table] = None, **kw: Any + ) -> CheckConstraint: if target_table is not None: # note that target_table is None for the copy process of # a column-bound CheckConstraint, so this path is not reached @@ -3886,20 +4128,20 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): def __init__( self, - columns, - refcolumns, - name=None, - onupdate=None, - ondelete=None, - deferrable=None, - initially=None, - use_alter=False, - link_to_name=False, - match=None, - table=None, - info=None, - **dialect_kw, - ): + columns: _typing_Sequence[_DDLColumnArgument], + refcolumns: _typing_Sequence[_DDLColumnArgument], + name: Optional[str] = None, + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + use_alter: bool = False, + link_to_name: bool = False, + match: Optional[str] = None, + table: Optional[Table] = None, + info: Optional[_InfoType] = None, + **dialect_kw: Any, + ) -> None: r"""Construct a composite-capable FOREIGN KEY. :param columns: A sequence of local column names. The named columns @@ -4051,19 +4293,19 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): """ @property - def _elements(self): + def _elements(self) -> util.OrderedDict[str, ForeignKey]: # legacy - provide a dictionary view of (column_key, fk) return util.OrderedDict(zip(self.column_keys, self.elements)) @property - def _referred_schema(self): + def _referred_schema(self) -> Optional[str]: for elem in self.elements: return elem._referred_schema else: return None @property - def referred_table(self): + def referred_table(self) -> Table: """The :class:`_schema.Table` object to which this :class:`_schema.ForeignKeyConstraint` references. @@ -4076,7 +4318,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): """ return self.elements[0].column.table - def _validate_dest_table(self, table): + def _validate_dest_table(self, table: Table) -> None: table_keys = set([elem._table_key() for elem in self.elements]) if None not in table_keys and len(table_keys) > 1: elem0, elem1 = sorted(table_keys)[0:2] @@ -4087,7 +4329,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): ) @property - def column_keys(self): + def column_keys(self) -> _typing_Sequence[str]: """Return a list of string keys representing the local columns in this :class:`_schema.ForeignKeyConstraint`. @@ -4108,10 +4350,12 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): ] @property - def _col_description(self): + def _col_description(self) -> str: return ", ".join(self.column_keys) - def _set_parent(self, table, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + table = parent + assert isinstance(table, Table) Constraint._set_parent(self, table) try: @@ -4134,10 +4378,22 @@ class ForeignKeyConstraint(ColumnCollectionConstraint): "The :meth:`_schema.ForeignKeyConstraint.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, schema=None, target_table=None, **kw): + def copy( + self, + *, + schema: Optional[str] = None, + target_table: Optional[Table] = None, + **kw: Any, + ) -> ForeignKeyConstraint: return self._copy(schema=schema, target_table=target_table, **kw) - def _copy(self, schema=None, target_table=None, **kw): + def _copy( + self, + *, + schema: Optional[str] = None, + target_table: Optional[Table] = None, + **kw: Any, + ) -> ForeignKeyConstraint: fkc = ForeignKeyConstraint( [x.parent.key for x in self.elements], [ @@ -4241,16 +4497,34 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): __visit_name__ = "primary_key_constraint" - def __init__(self, *columns, **kw): - self._implicit_generated = kw.pop("_implicit_generated", False) - super(PrimaryKeyConstraint, self).__init__(*columns, **kw) + def __init__( + self, + *columns: _DDLColumnArgument, + name: Optional[str] = None, + deferrable: Optional[bool] = None, + initially: Optional[str] = None, + info: Optional[_InfoType] = None, + _implicit_generated: bool = False, + **dialect_kw: Any, + ) -> None: + self._implicit_generated = _implicit_generated + super(PrimaryKeyConstraint, self).__init__( + *columns, + name=name, + deferrable=deferrable, + initially=initially, + info=info, + **dialect_kw, + ) - def _set_parent(self, table, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + table = parent + assert isinstance(table, Table) super(PrimaryKeyConstraint, self)._set_parent(table) if table.primary_key is not self: table.constraints.discard(table.primary_key) - table.primary_key = self + table.primary_key = self # type: ignore table.constraints.add(self) table_pks = [c for c in table.c if c.primary_key] @@ -4280,7 +4554,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): if table_pks: self._columns.extend(table_pks) - def _reload(self, columns): + def _reload(self, columns: Iterable[Column[Any]]) -> None: """repopulate this :class:`.PrimaryKeyConstraint` given a set of columns. @@ -4309,14 +4583,14 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._set_parent_with_dispatch(self.table) - def _replace(self, col): + def _replace(self, col: Column[Any]) -> None: PrimaryKeyConstraint._autoincrement_column._reset(self) # type: ignore self._columns.replace(col) self.dispatch._sa_event_column_added_to_pk_constraint(self, col) @property - def columns_autoinc_first(self): + def columns_autoinc_first(self) -> List[Column[Any]]: autoinc = self._autoincrement_column if autoinc is not None: @@ -4326,7 +4600,7 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint): @util.ro_memoized_property def _autoincrement_column(self) -> Optional[Column[Any]]: - def _validate_autoinc(col, autoinc_true): + def _validate_autoinc(col: Column[Any], autoinc_true: bool) -> bool: if col.type._type_affinity is None or not issubclass( col.type._type_affinity, type_api.INTEGERTYPE._type_affinity ): @@ -4478,7 +4752,21 @@ class Index( __visit_name__ = "index" - def __init__(self, name, *expressions, **kw): + table: Optional[Table] + expressions: _typing_Sequence[Union[str, ColumnElement[Any]]] + _table_bound_expressions: _typing_Sequence[ColumnElement[Any]] + + def __init__( + self, + name: Optional[str], + *expressions: _DDLColumnArgument, + unique: bool = False, + quote: Optional[bool] = None, + info: Optional[_InfoType] = None, + _table: Optional[Table] = None, + _column_flag: bool = False, + **dialect_kw: Any, + ) -> None: r"""Construct an index object. :param name: @@ -4503,8 +4791,8 @@ class Index( .. versionadded:: 1.0.0 - :param \**kw: Additional keyword arguments not mentioned above are - dialect specific, and passed in the form + :param \**dialect_kw: Additional keyword arguments not mentioned above + are dialect specific, and passed in the form ``_``. See the documentation regarding an individual dialect at :ref:`dialect_toplevel` for detail on documented arguments. @@ -4512,20 +4800,19 @@ class Index( """ self.table = table = None - self.name = quoted_name.construct(name, kw.pop("quote", None)) - self.unique = kw.pop("unique", False) - _column_flag = kw.pop("_column_flag", False) - if "info" in kw: - self.info = kw.pop("info") + self.name = quoted_name.construct(name, quote) + self.unique = unique + if info is not None: + self.info = info # TODO: consider "table" argument being public, but for # the purpose of the fix here, it starts as private. - if "_table" in kw: - table = kw.pop("_table") + if _table is not None: + table = _table - self._validate_dialect_kwargs(kw) + self._validate_dialect_kwargs(dialect_kw) - self.expressions: List[ColumnElement[Any]] = [] + self.expressions = [] # will call _set_parent() if table-bound column # objects are present ColumnCollectionMixin.__init__( @@ -4534,11 +4821,12 @@ class Index( _column_flag=_column_flag, _gather_expressions=self.expressions, ) - if table is not None: self._set_parent(table) - def _set_parent(self, table, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + table = parent + assert isinstance(table, Table) ColumnCollectionMixin._set_parent(self, table) if self.table is not None and table is not self.table: @@ -4553,12 +4841,18 @@ class Index( expressions = self.expressions col_expressions = self._col_expressions(table) assert len(expressions) == len(col_expressions) - self.expressions = [ - expr if isinstance(expr, ClauseElement) else colexpr - for expr, colexpr in zip(expressions, col_expressions) - ] - def create(self, bind, checkfirst=False): + exprs = [] + for expr, colexpr in zip(expressions, col_expressions): + if isinstance(expr, ClauseElement): + exprs.append(expr) + elif colexpr is not None: + exprs.append(colexpr) + else: + assert False + self.expressions = self._table_bound_expressions = exprs + + def create(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue a ``CREATE`` statement for this :class:`.Index`, using the given :class:`.Connection` or :class:`.Engine`` for connectivity. @@ -4569,9 +4863,8 @@ class Index( """ bind._run_ddl_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst) - return self - def drop(self, bind, checkfirst=False): + def drop(self, bind: _CreateDropBind, checkfirst: bool = False) -> None: """Issue a ``DROP`` statement for this :class:`.Index`, using the given :class:`.Connection` or :class:`.Engine` for connectivity. @@ -4583,7 +4876,9 @@ class Index( """ bind._run_ddl_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst) - def __repr__(self): + def __repr__(self) -> str: + exprs: _typing_Sequence[Any] + return "Index(%s)" % ( ", ".join( [repr(self.name)] @@ -4593,7 +4888,9 @@ class Index( ) -DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"}) +DEFAULT_NAMING_CONVENTION: util.immutabledict[str, str] = util.immutabledict( + {"ix": "ix_%(column_0_label)s"} +) class MetaData(HasSchemaAttr): @@ -4628,8 +4925,8 @@ class MetaData(HasSchemaAttr): schema: Optional[str] = None, quote_schema: Optional[bool] = None, naming_convention: Optional[Dict[str, str]] = None, - info: Optional[Dict[Any, Any]] = None, - ): + info: Optional[_InfoType] = None, + ) -> None: """Create a new MetaData object. :param schema: @@ -4758,7 +5055,7 @@ class MetaData(HasSchemaAttr): self._schemas: Set[str] = set() self._sequences: Dict[str, Sequence] = {} self._fk_memos: Dict[ - Tuple[str, str], List[ForeignKey] + Tuple[str, Optional[str]], List[ForeignKey] ] = collections.defaultdict(list) tables: util.FacadeDict[str, Table] @@ -4787,13 +5084,15 @@ class MetaData(HasSchemaAttr): table_or_key = table_or_key.key return table_or_key in self.tables - def _add_table(self, name, schema, table): + def _add_table( + self, name: str, schema: Optional[str], table: Table + ) -> None: key = _get_table_key(name, schema) self.tables._insert_item(key, table) if schema: self._schemas.add(schema) - def _remove_table(self, name, schema): + def _remove_table(self, name: str, schema: Optional[str]) -> None: key = _get_table_key(name, schema) removed = dict.pop(self.tables, key, None) # type: ignore if removed is not None: @@ -4808,7 +5107,7 @@ class MetaData(HasSchemaAttr): ] ) - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: return { "tables": self.tables, "schema": self.schema, @@ -4818,7 +5117,7 @@ class MetaData(HasSchemaAttr): "naming_convention": self.naming_convention, } - def __setstate__(self, state): + def __setstate__(self, state: Dict[str, Any]) -> None: self.tables = state["tables"] self.schema = state["schema"] self.naming_convention = state["naming_convention"] @@ -5054,7 +5353,7 @@ class MetaData(HasSchemaAttr): def create_all( self, - bind: Union[Engine, Connection, MockConnection], + bind: _CreateDropBind, tables: Optional[_typing_Sequence[Table]] = None, checkfirst: bool = True, ) -> None: @@ -5082,7 +5381,7 @@ class MetaData(HasSchemaAttr): def drop_all( self, - bind: Union[Engine, Connection, MockConnection], + bind: _CreateDropBind, tables: Optional[_typing_Sequence[Table]] = None, checkfirst: bool = True, ) -> None: @@ -5134,10 +5433,14 @@ class Computed(FetchedValue, SchemaItem): __visit_name__ = "computed_column" + column: Optional[Column[Any]] + @_document_text_coercion( "sqltext", ":class:`.Computed`", ":paramref:`.Computed.sqltext`" ) - def __init__(self, sqltext, persisted=None): + def __init__( + self, sqltext: _DDLColumnArgument, persisted: Optional[bool] = None + ) -> None: """Construct a GENERATED ALWAYS AS DDL construct to accompany a :class:`_schema.Column`. @@ -5170,7 +5473,9 @@ class Computed(FetchedValue, SchemaItem): self.persisted = persisted self.column = None - def _set_parent(self, parent, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + assert isinstance(parent, Column) + if not isinstance( parent.server_default, (type(None), Computed) ) or not isinstance(parent.server_onupdate, (type(None), Computed)): @@ -5183,7 +5488,7 @@ class Computed(FetchedValue, SchemaItem): self.column.server_onupdate = self self.column.server_default = self - def _as_for_update(self, for_update): + def _as_for_update(self, for_update: bool) -> FetchedValue: return self @util.deprecated( @@ -5191,10 +5496,14 @@ class Computed(FetchedValue, SchemaItem): "The :meth:`_schema.Computed.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, target_table=None, **kw): - return self._copy(target_table, **kw) + def copy( + self, *, target_table: Optional[Table] = None, **kw: Any + ) -> Computed: + return self._copy(target_table=target_table, **kw) - def _copy(self, target_table=None, **kw): + def _copy( + self, *, target_table: Optional[Table] = None, **kw: Any + ) -> Computed: sqltext = _copy_expression( self.sqltext, self.column.table if self.column is not None else None, @@ -5233,18 +5542,18 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): def __init__( self, - always=False, - on_null=None, - start=None, - increment=None, - minvalue=None, - maxvalue=None, - nominvalue=None, - nomaxvalue=None, - cycle=None, - cache=None, - order=None, - ): + always: bool = False, + on_null: Optional[bool] = None, + start: Optional[int] = None, + increment: Optional[int] = None, + minvalue: Optional[int] = None, + maxvalue: Optional[int] = None, + nominvalue: Optional[bool] = None, + nomaxvalue: Optional[bool] = None, + cycle: Optional[bool] = None, + cache: Optional[bool] = None, + order: Optional[bool] = None, + ) -> None: """Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL construct to accompany a :class:`_schema.Column`. @@ -5306,7 +5615,8 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): self.on_null = on_null self.column = None - def _set_parent(self, parent, **kw): + def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None: + assert isinstance(parent, Column) if not isinstance( parent.server_default, (type(None), Identity) ) or not isinstance(parent.server_onupdate, type(None)): @@ -5327,7 +5637,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): parent.server_default = self - def _as_for_update(self, for_update): + def _as_for_update(self, for_update: bool) -> FetchedValue: return self @util.deprecated( @@ -5335,10 +5645,10 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): "The :meth:`_schema.Identity.copy` method is deprecated " "and will be removed in a future release.", ) - def copy(self, **kw): + def copy(self, **kw: Any) -> Identity: return self._copy(**kw) - def _copy(self, **kw): + def _copy(self, **kw: Any) -> Identity: i = Identity( always=self.always, on_null=self.on_null, diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 99a6baa890..aab3c678c5 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -794,7 +794,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): col._make_proxy(fromclause) for col in self.c ) - @property + @util.ro_non_memoized_property def exported_columns(self) -> ReadOnlyColumnCollection[str, Any]: """A :class:`_expression.ColumnCollection` that represents the "exported" @@ -2779,22 +2779,6 @@ class Subquery(AliasedReturnsRows): def as_scalar(self): return self.element.set_label_style(LABEL_STYLE_NONE).scalar_subquery() - def _execute_on_connection( - self, - connection, - distilled_params, - execution_options, - ): - util.warn_deprecated( - "Executing a subquery object is deprecated and will raise " - "ObjectNotExecutableError in an upcoming release. Please " - "execute the underlying select() statement directly.", - "1.4", - ) - return self.element._execute_on_connection( - connection, distilled_params, execution_options, _force=True - ) - class FromGrouping(GroupedElement, FromClause): """Represent a grouping of a FROM clause""" diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 803e856548..72e658db89 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -17,6 +17,7 @@ import enum import json import pickle from typing import Any +from typing import Callable from typing import cast from typing import Dict from typing import List @@ -59,7 +60,10 @@ from ..util import OrderedDict from ..util.typing import Literal if TYPE_CHECKING: + from ._typing import _ColumnExpressionArgument + from ._typing import _TypeEngineArgument from .operators import OperatorType + from .schema import MetaData from .type_api import _BindProcessorType from .type_api import _ComparatorFactory from .type_api import _ResultProcessorType @@ -156,7 +160,7 @@ class Indexable(TypeEngineMixin): adjusted_op, adjusted_right_expr, result_type=result_type ) - comparator_factory = Comparator + comparator_factory: _ComparatorFactory[Any] = Comparator class String(Concatenable, TypeEngine[str]): @@ -178,8 +182,8 @@ class String(Concatenable, TypeEngine[str]): # for the _T type to be correctly recognized when we send the # class as the argument, e.g. `column("somecol", String)` self, - length=None, - collation=None, + length: Optional[int] = None, + collation: Optional[str] = None, ): """ Create a string-holding type. @@ -456,10 +460,10 @@ class Numeric(HasExpressionLookup, TypeEngine[_N]): def __init__( self, - precision=None, - scale=None, - decimal_return_scale=None, - asdecimal=True, + precision: Optional[int] = None, + scale: Optional[int] = None, + decimal_return_scale: Optional[int] = None, + asdecimal: bool = True, ): """ Construct a Numeric. @@ -733,7 +737,7 @@ class DateTime( __visit_name__ = "datetime" - def __init__(self, timezone=False): + def __init__(self, timezone: bool = False): """Construct a new :class:`.DateTime`. :param timezone: boolean. Indicates that the datetime type should @@ -818,7 +822,7 @@ class Time(_RenderISO8601NoT, HasExpressionLookup, TypeEngine[dt.time]): __visit_name__ = "time" - def __init__(self, timezone=False): + def __init__(self, timezone: bool = False): self.timezone = timezone def get_dbapi_type(self, dbapi): @@ -850,7 +854,7 @@ class _Binary(TypeEngine[bytes]): """Define base behavior for binary types.""" - def __init__(self, length=None): + def __init__(self, length: Optional[int] = None): self.length = length def literal_processor(self, dialect): @@ -919,7 +923,7 @@ class LargeBinary(_Binary): __visit_name__ = "large_binary" - def __init__(self, length=None): + def __init__(self, length: Optional[int] = None): """ Construct a LargeBinary type. @@ -961,12 +965,12 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): def __init__( self, - name=None, - schema=None, - metadata=None, - inherit_schema=False, - quote=None, - _create_events=True, + name: Optional[str] = None, + schema: Optional[str] = None, + metadata: Optional[MetaData] = None, + inherit_schema: bool = False, + quote: Optional[bool] = None, + _create_events: bool = True, ): if name is not None: self.name = quoted_name(name, quote) @@ -1144,7 +1148,9 @@ class SchemaType(SchemaEventTarget, TypeEngineMixin): # be integration tested by PG-specific tests def _we_are_the_impl(typ): return ( - typ is self or isinstance(typ, ARRAY) and typ.item_type is self + typ is self + or isinstance(typ, ARRAY) + and typ.item_type is self # type: ignore[comparison-overlap] ) if dialect.name in variant_mapping and _we_are_the_impl( @@ -1233,7 +1239,7 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): __visit_name__ = "enum" - def __init__(self, *enums, **kw): + def __init__(self, *enums: object, **kw: Any): r"""Construct an enum. Keyword arguments which don't apply to a specific backend are ignored @@ -1675,10 +1681,10 @@ class PickleType(TypeDecorator[object]): def __init__( self, - protocol=pickle.HIGHEST_PROTOCOL, - pickler=None, - comparator=None, - impl=None, + protocol: int = pickle.HIGHEST_PROTOCOL, + pickler: Any = None, + comparator: Optional[Callable[[Any, Any], bool]] = None, + impl: Optional[_TypeEngineArgument[Any]] = None, ): """ Construct a PickleType. @@ -1706,7 +1712,9 @@ class PickleType(TypeDecorator[object]): super(PickleType, self).__init__() if impl: - self.impl = to_instance(impl) + # custom impl is not necessarily a LargeBinary subclass. + # make an exception to typing for this + self.impl = to_instance(impl) # type: ignore def __reduce__(self): return PickleType, (self.protocol, None, self.comparator) @@ -1785,9 +1793,9 @@ class Boolean(SchemaType, Emulated, TypeEngine[bool]): def __init__( self, - create_constraint=False, - name=None, - _create_events=True, + create_constraint: bool = False, + name: Optional[str] = None, + _create_events: bool = True, ): """Construct a Boolean. @@ -1937,7 +1945,12 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): epoch = dt.datetime.utcfromtimestamp(0) cache_ok = True - def __init__(self, native=True, second_precision=None, day_precision=None): + def __init__( + self, + native: bool = True, + second_precision: Optional[int] = None, + day_precision: Optional[int] = None, + ): """Construct an Interval object. :param native: when True, use the actual @@ -2277,7 +2290,7 @@ class JSON(Indexable, TypeEngine[Any]): """ - def __init__(self, none_as_null=False): + def __init__(self, none_as_null: bool = False): """Construct a :class:`_types.JSON` type. :param none_as_null=False: if True, persist the value ``None`` as a @@ -2701,7 +2714,10 @@ class ARRAY( """If True, Python zero-based indexes should be interpreted as one-based on the SQL expression side.""" - class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): + class Comparator( + Indexable.Comparator[Sequence[Any]], + Concatenable.Comparator[Sequence[Any]], + ): """Define comparison operations for :class:`_types.ARRAY`. @@ -2714,6 +2730,8 @@ class ARRAY( arr_type = cast(ARRAY, self.type) + return_type: TypeEngine[Any] + if isinstance(index, slice): return_type = arr_type if arr_type.zero_indexes: @@ -2832,7 +2850,11 @@ class ARRAY( comparator_factory = Comparator def __init__( - self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + self, + item_type: _TypeEngineArgument[Any], + as_tuple: bool = False, + dimensions: Optional[int] = None, + zero_indexes: bool = False, ): """Construct an :class:`_types.ARRAY`. @@ -2910,7 +2932,7 @@ class TupleType(TypeEngine[Tuple[Any, ...]]): types: List[TypeEngine[Any]] - def __init__(self, *types): + def __init__(self, *types: _TypeEngineArgument[Any]): self._fully_typed = NULLTYPE not in types self.types = [ item_type() if isinstance(item_type, type) else item_type @@ -3070,7 +3092,7 @@ class TIMESTAMP(DateTime): __visit_name__ = "TIMESTAMP" - def __init__(self, timezone=False): + def __init__(self, timezone: bool = False): """Construct a new :class:`_types.TIMESTAMP`. :param timezone: boolean. Indicates that the TIMESTAMP type should @@ -3245,7 +3267,7 @@ class TableValueType(HasCacheKey, TypeEngine[Any]): ("_elements", InternalTraversal.dp_clauseelement_list), ] - def __init__(self, *elements): + def __init__(self, *elements: Union[str, _ColumnExpressionArgument[Any]]): self._elements = [ coercions.expect(roles.StrAsPlainColumnRole, elem) for elem in elements diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index 2790bf3735..2843431549 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -1037,6 +1037,9 @@ class ClauseAdapter(visitors.ReplacingExternalTraversal): if "adapt_column" in col._annotations: col = col._annotations["adapt_column"] + if TYPE_CHECKING: + assert isinstance(col, ColumnElement) + if self.adapt_from_selectables and col not in self.equivalents: for adp in self.adapt_from_selectables: if adp.c.corresponding_column(col, False) is not None: diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py index e6c00de097..0c539baab9 100644 --- a/lib/sqlalchemy/testing/assertsql.py +++ b/lib/sqlalchemy/testing/assertsql.py @@ -14,7 +14,7 @@ import re from .. import event from ..engine import url from ..engine.default import DefaultDialect -from ..schema import _DDLCompiles +from ..schema import BaseDDLElement class AssertRule: @@ -110,7 +110,7 @@ class CompiledSQL(SQLMatchRule): else: map_ = None - if isinstance(execute_observed.clauseelement, _DDLCompiles): + if isinstance(execute_observed.clauseelement, BaseDDLElement): compiled = execute_observed.clauseelement.compile( dialect=compare_dialect, diff --git a/lib/sqlalchemy/util/_collections.py b/lib/sqlalchemy/util/_collections.py index eea76f60b0..086b008de6 100644 --- a/lib/sqlalchemy/util/_collections.py +++ b/lib/sqlalchemy/util/_collections.py @@ -23,6 +23,7 @@ from typing import Iterable from typing import Iterator from typing import List from typing import Mapping +from typing import NoReturn from typing import Optional from typing import overload from typing import Set @@ -139,24 +140,24 @@ EMPTY_DICT: immutabledict[Any, Any] = immutabledict() class FacadeDict(ImmutableDictBase[_KT, _VT]): """A dictionary that is not publicly mutable.""" - def __new__(cls, *args): + def __new__(cls, *args: Any) -> FacadeDict[Any, Any]: new = dict.__new__(cls) return new - def copy(self): + def copy(self) -> NoReturn: raise NotImplementedError( "an immutabledict shouldn't need to be copied. use dict(d) " "if you need a mutable dictionary." ) - def __reduce__(self): + def __reduce__(self) -> Any: return FacadeDict, (dict(self),) - def _insert_item(self, key, value): + def _insert_item(self, key: _KT, value: _VT) -> None: """insert an item into the dictionary directly.""" dict.__setitem__(self, key, value) - def __repr__(self): + def __repr__(self) -> str: return "FacadeDict(%s)" % dict.__repr__(self) diff --git a/lib/sqlalchemy/util/_py_collections.py b/lib/sqlalchemy/util/_py_collections.py index 725f6930ee..40f5156718 100644 --- a/lib/sqlalchemy/util/_py_collections.py +++ b/lib/sqlalchemy/util/_py_collections.py @@ -136,7 +136,7 @@ class OrderedSet(Set[_T]): _list: List[_T] - def __init__(self, d=None): + def __init__(self, d: Optional[Iterable[_T]] = None) -> None: if d is not None: self._list = unique_list(d) super().update(self._list) diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py index 2cb9c45d6b..4e161c80c2 100644 --- a/lib/sqlalchemy/util/langhelpers.py +++ b/lib/sqlalchemy/util/langhelpers.py @@ -1677,7 +1677,7 @@ def parse_user_argument_for_enum( _creation_order = 1 -def set_creation_order(instance): +def set_creation_order(instance: Any) -> None: """Assign a '_creation_order' sequence to the given instance. This allows multiple instances to be sorted in order of creation diff --git a/lib/sqlalchemy/util/preloaded.py b/lib/sqlalchemy/util/preloaded.py index 907c510649..260250b2cf 100644 --- a/lib/sqlalchemy/util/preloaded.py +++ b/lib/sqlalchemy/util/preloaded.py @@ -26,6 +26,7 @@ if TYPE_CHECKING: from sqlalchemy.orm import session as orm_session from sqlalchemy.orm import util as orm_util from sqlalchemy.sql import dml as sql_dml + from sqlalchemy.sql import functions as sql_functions from sqlalchemy.sql import util as sql_util diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index dd574f3b0f..b3f3b93870 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -55,10 +55,12 @@ if typing.TYPE_CHECKING or compat.py38: from typing import Literal as Literal from typing import Protocol as Protocol from typing import TypedDict as TypedDict + from typing import Final as Final else: from typing_extensions import Literal as Literal # noqa: F401 from typing_extensions import Protocol as Protocol # noqa: F401 from typing_extensions import TypedDict as TypedDict # noqa: F401 + from typing_extensions import Final as Final # noqa: F401 # copied from TypeShed, required in order to implement # MutableMapping.update() diff --git a/pyproject.toml b/pyproject.toml index 46a4eb0d85..e727ee1e49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,7 +131,6 @@ module = [ "sqlalchemy.sql.lambdas", "sqlalchemy.sql.naming", "sqlalchemy.sql.selectable", # would be nice as strict - "sqlalchemy.sql.schema", # would be nice as strict "sqlalchemy.sql.sqltypes", # would be nice as strict "sqlalchemy.sql.traversals", "sqlalchemy.sql.util", diff --git a/test/dialect/oracle/test_dialect.py b/test/dialect/oracle/test_dialect.py index 60cdb2577a..ad361b8794 100644 --- a/test/dialect/oracle/test_dialect.py +++ b/test/dialect/oracle/test_dialect.py @@ -778,7 +778,7 @@ class ExecuteTest(fixtures.TestBase): seq = Sequence("foo_seq") seq.create(connection) try: - val = connection.execute(seq) + val = connection.scalar(seq) eq_(val, 1) assert type(val) is int finally: diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index b561db99e0..bd9ac0e3d6 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -53,7 +53,6 @@ from sqlalchemy.testing import is_false from sqlalchemy.testing import is_not from sqlalchemy.testing import is_true from sqlalchemy.testing import mock -from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table @@ -490,15 +489,6 @@ class ExecuteTest(fixtures.TablesTest): obj, ) - def test_subquery_exec_warning(self): - for obj in (select(1).alias(), select(1).subquery()): - with testing.db.connect() as conn: - with expect_deprecated( - "Executing a subquery object is deprecated and will " - "raise ObjectNotExecutableError" - ): - eq_(conn.execute(obj).scalar(), 1) - def test_stmt_exception_bytestring_raised(self): name = "méil" users = self.tables.users diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index ce38de5114..48c5e60abc 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -6,6 +6,7 @@ from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import select +from sqlalchemy import Sequence from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import update @@ -25,6 +26,7 @@ from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing import mock +from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertions import is_false from .test_engine_py3k import AsyncFixture as _AsyncFixture from ...orm import _fixtures @@ -68,6 +70,34 @@ class AsyncSessionTest(AsyncFixture): ss = AsyncSession(binds=binds) is_(ss.binds, binds) + @async_test + @testing.combinations((True,), (False,), argnames="use_scalar") + @testing.requires.sequences + async def test_sequence_execute( + self, async_session: AsyncSession, metadata, use_scalar + ): + seq = Sequence("some_sequence", metadata=metadata) + + sync_connection = (await async_session.connection()).sync_connection + + await (await async_session.connection()).run_sync(metadata.create_all) + + if use_scalar: + eq_( + await async_session.scalar(seq), + sync_connection.dialect.default_sequence_base, + ) + else: + with expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + eq_( + await async_session.execute(seq), + sync_connection.dialect.default_sequence_base, + ) + class AsyncSessionQueryTest(AsyncFixture): @async_test diff --git a/test/ext/mypy/plain_files/core_ddl.py b/test/ext/mypy/plain_files/core_ddl.py new file mode 100644 index 0000000000..673a90e943 --- /dev/null +++ b/test/ext/mypy/plain_files/core_ddl.py @@ -0,0 +1,43 @@ +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy import DateTime +from sqlalchemy import ForeignKey +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy import String +from sqlalchemy import Table + + +m = MetaData() + + +t1 = Table( + "t1", + m, + Column("id", Integer, primary_key=True), + Column("data", String), + Column("data2", String(50)), + Column("timestamp", DateTime()), + Index(None, "data2"), +) + +t2 = Table( + "t2", + m, + Column("t1id", ForeignKey("t1.id")), + Column("q", Integer, CheckConstraint("q > 5")), +) + +t3 = Table( + "t3", + m, + Column("x", Integer), + Column("y", Integer), + Column("t1id", ForeignKey(t1.c.id)), + PrimaryKeyConstraint("x", "y"), +) + +# cols w/ no name or type, used by declarative +c1: Column[int] = Column(ForeignKey(t3.c.x)) diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index 9967971220..7067d24c16 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -16,7 +16,7 @@ from sqlalchemy.ext.compiler import deregister from sqlalchemy.orm import Session from sqlalchemy.schema import CreateColumn from sqlalchemy.schema import CreateTable -from sqlalchemy.schema import DDLElement +from sqlalchemy.schema import ExecutableDDLElement from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.expression import BindParameter from sqlalchemy.sql.expression import ClauseElement @@ -275,10 +275,10 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): del Select._compiler_dispatcher def test_dialect_specific(self): - class AddThingy(DDLElement): + class AddThingy(ExecutableDDLElement): __visit_name__ = "add_thingy" - class DropThingy(DDLElement): + class DropThingy(ExecutableDDLElement): __visit_name__ = "drop_thingy" @compiles(AddThingy, "sqlite") diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 4cecac0de4..be1919614d 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -3,6 +3,7 @@ from unittest.mock import call from unittest.mock import Mock import sqlalchemy as sa +from sqlalchemy import bindparam from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import exc as sa_exc @@ -239,6 +240,37 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): ), ) + def test_override_parameters_scalar(self): + """test that session.scalar() maintains the 'scalar-ness' of the + result when using re-execute events. + + This got more complicated when the session.scalar(Sequence("my_seq")) + use case needed to keep working and returning a scalar result. + + """ + User = self.classes.User + + sess = Session(testing.db, future=True) + + @event.listens_for(sess, "do_orm_execute") + def one(ctx): + return ctx.invoke_statement(params={"id": 7}) + + orig_params = {"id": 18} + with self.sql_execution_asserter() as asserter: + result = sess.scalar( + select(User).where(User.id == bindparam("id")), orig_params + ) + asserter.assert_( + CompiledSQL( + "SELECT users.id, users.name FROM users WHERE users.id = :id", + [{"id": 7}], + ) + ) + eq_(result, User(id=7)) + # orig params weren't mutated + eq_(orig_params, {"id": 18}) + def test_override_parameters_executesingle(self): User = self.classes.User diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 8e568aef05..49f769f813 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -56,9 +56,10 @@ class ExecutionTest(_fixtures.FixtureTest): @testing.combinations( (True,), (False,), argnames="add_do_orm_execute_event" ) + @testing.combinations((True,), (False,), argnames="use_scalar") @testing.requires.sequences def test_sequence_execute( - self, connection, metadata, add_do_orm_execute_event + self, connection, metadata, add_do_orm_execute_event, use_scalar ): seq = Sequence("some_sequence", metadata=metadata) metadata.create_all(connection) @@ -69,7 +70,18 @@ class ExecutionTest(_fixtures.FixtureTest): event.listen( sess, "do_orm_execute", lambda ctx: evt(ctx.statement) ) - eq_(sess.execute(seq), connection.dialect.default_sequence_base) + + if use_scalar: + eq_(sess.scalar(seq), connection.dialect.default_sequence_base) + else: + with assertions.expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + eq_( + sess.execute(seq), connection.dialect.default_sequence_base + ) if add_do_orm_execute_event: eq_(evt.mock_calls, [mock.call(seq)]) @@ -1994,7 +2006,7 @@ class SessionInterface(fixtures.MappedTest): if name in blocklist: continue spec = inspect_getfullargspec(getattr(Session, name)) - if len(spec[0]) > 1 or spec[1]: + if len(spec.args) > 1 or spec.varargs or spec.kwonlyargs: ok.add(name) return ok @@ -2051,7 +2063,7 @@ class SessionInterface(fixtures.MappedTest): s = fixture_session() s.add(OK()) - x_raises_(s, "flush", (user_arg,)) + x_raises_(s, "flush", objects=(user_arg,)) _() diff --git a/test/profiles.txt b/test/profiles.txt index 7b4f377349..38b7290d67 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -204,7 +204,8 @@ test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results x86_64_linux_ # TEST: test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated -test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 30373,1014,96450 +# wow first time ever decreasing a value, woop. not sure why though +test.aaa_profiling.test_orm.JoinedEagerLoadTest.test_fetch_results_integrated x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 28587,1014,96450 # TEST: test.aaa_profiling.test_orm.LoadManyToOneFromIdentityTest.test_many_to_one_load_identity diff --git a/test/sql/test_constraints.py b/test/sql/test_constraints.py index 2661e6c8fd..c417255af6 100644 --- a/test/sql/test_constraints.py +++ b/test/sql/test_constraints.py @@ -863,6 +863,8 @@ class ConstraintCompilationTest(fixtures.TestBase, AssertsCompiledSQL): x.append_constraint(idx) self.assert_compile(schema.CreateIndex(idx), ddl) + x.to_metadata(MetaData()) + def test_index_against_text_separate(self): metadata = MetaData() idx = Index("y", text("some_function(q)")) diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index 52c7799695..3742aa174c 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -23,6 +23,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock +from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table from sqlalchemy.types import TypeDecorator @@ -503,15 +504,45 @@ class DefaultRoundTripTest(fixtures.TablesTest): def teardown_test(self): self.default_generator["x"] = 50 - def test_standalone(self, connection): + def test_standalone_via_exec_removed(self, connection): t = self.tables.default_test - x = connection.execute(t.c.col1.default) - y = connection.execute(t.c.col2.default) - z = connection.execute(t.c.col3.default) + + with expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + x = connection.execute(t.c.col1.default) + with expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + y = connection.execute(t.c.col2.default) + with expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + z = connection.execute(t.c.col3.default) + + def test_standalone_default_scalar(self, connection): + t = self.tables.default_test + x = connection.scalar(t.c.col1.default) + y = connection.scalar(t.c.col2.default) + z = connection.scalar(t.c.col3.default) assert 50 <= x <= 57 eq_(y, "imthedefault") eq_(z, self.f) + def test_standalone_function_execute(self, connection): + ctexec = connection.execute(self.currenttime) + assert isinstance(ctexec.scalar(), datetime.date) + + def test_standalone_function_scalar(self, connection): + ctexec = connection.scalar(self.currenttime) + assert isinstance(ctexec, datetime.date) + def test_insert(self, connection): t = self.tables.default_test diff --git a/test/sql/test_returning.py b/test/sql/test_returning.py index ffbab32237..bacdbaf3fb 100644 --- a/test/sql/test_returning.py +++ b/test/sql/test_returning.py @@ -543,7 +543,7 @@ class SequenceReturningTest(fixtures.TablesTest): ) eq_(r.first(), tuple([testing.db.dialect.default_sequence_base])) eq_( - connection.execute(self.sequences.tid_seq), + connection.scalar(self.sequences.tid_seq), testing.db.dialect.default_sequence_base + 1, ) diff --git a/test/sql/test_sequences.py b/test/sql/test_sequences.py index d11961862a..be74153cec 100644 --- a/test/sql/test_sequences.py +++ b/test/sql/test_sequences.py @@ -14,6 +14,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true +from sqlalchemy.testing.assertions import expect_deprecated from sqlalchemy.testing.assertsql import AllOf from sqlalchemy.testing.assertsql import CompiledSQL from sqlalchemy.testing.assertsql import EachOf @@ -117,14 +118,25 @@ class SequenceExecTest(fixtures.TestBase): def test_execute(self, connection): s = Sequence("my_sequence") - self._assert_seq_result(connection.execute(s)) + self._assert_seq_result(connection.scalar(s)) - def test_execute_optional(self, connection): + def test_execute_deprecated(self, connection): + + s = Sequence("my_sequence", optional=True) + + with expect_deprecated( + r"Using the .execute\(\) method to invoke a " + r"DefaultGenerator object is deprecated; please use " + r"the .scalar\(\) method." + ): + self._assert_seq_result(connection.execute(s)) + + def test_scalar_optional(self, connection): """test dialect executes a Sequence, returns nextval, whether or not "optional" is set""" s = Sequence("my_sequence", optional=True) - self._assert_seq_result(connection.execute(s)) + self._assert_seq_result(connection.scalar(s)) def test_execute_next_value(self, connection): """test func.next_value().execute()/.scalar() works @@ -341,7 +353,7 @@ class SequenceTest(fixtures.TestBase, testing.AssertsCompiledSQL): seq.create(testing.db) try: with testing.db.connect() as conn: - values = [conn.execute(seq) for i in range(3)] + values = [conn.scalar(seq) for i in range(3)] start = seq.start or testing.db.dialect.default_sequence_base inc = seq.increment or 1 eq_(values, list(range(start, start + inc * 3, inc)))