from __future__ import annotations
+from typing import Any
+
from . import util as _util
from .engine import AdaptedConnection as AdaptedConnection
from .engine import BaseRow as BaseRow
from .sql.expression import type_coerce as type_coerce
from .sql.expression import TypeClause as TypeClause
from .sql.expression import TypeCoerce as TypeCoerce
-from .sql.expression import typing as typing
from .sql.expression import UnaryExpression as UnaryExpression
from .sql.expression import union as union
from .sql.expression import union_all as union_all
__version__ = "2.0.0b1"
-def __go(lcls):
+def __go(lcls: Any) -> None:
from . import util as _sa_util
_sa_util.preloaded.import_prefix("sqlalchemy")
cdef list _list
+ @classmethod
+ def __class_getitem__(cls, key):
+ return cls
+
def __init__(self, d=None):
set.__init__(self)
if d is not None:
from ..sql.functions import FunctionElement
from ..sql.schema import ColumnDefault
from ..sql.schema import HasSchemaAttr
+ from ..sql.schema import SchemaItem
"""Defines :class:`_engine.Connection` and :class:`_engine.Engine`.
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
- element: DDLElement,
+ element: SchemaItem,
**kwargs: Any,
) -> None:
"""run a DDL visitor.
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
- element: DDLElement,
+ element: SchemaItem,
**kwargs: Any,
) -> None:
with self.begin() as conn:
from ..sql._typing import is_tuple_type
from ..sql.compiler import DDLCompiler
from ..sql.compiler import SQLCompiler
+from ..sql.elements import ColumnClause
from ..sql.elements import quoted_name
+from ..sql.schema import default_is_scalar
if typing.TYPE_CHECKING:
from types import ModuleType
return ()
@util.memoized_property
- def returning_cols(self) -> Optional[Sequence[Column[Any]]]:
+ def returning_cols(self) -> Optional[Sequence[ColumnClause[Any]]]:
if TYPE_CHECKING:
assert isinstance(self.compiled, SQLCompiler)
return self.compiled.returning
# to avoid many calls of get_insert_default()/
# get_update_default()
for c in insert_prefetch:
- if c.default and not c.default.is_sequence and c.default.is_scalar:
- if TYPE_CHECKING:
- assert isinstance(c.default, ColumnDefault)
+ if c.default and default_is_scalar(c.default):
scalar_defaults[c] = c.default.arg
for c in update_prefetch:
- if c.onupdate and c.onupdate.is_scalar:
- if TYPE_CHECKING:
- assert isinstance(c.onupdate, ColumnDefault)
+ if c.onupdate and default_is_scalar(c.onupdate):
scalar_defaults[c] = c.onupdate.arg
for param in self.compiled_parameters:
) = self.compiled_parameters[0]
for c in compiled.insert_prefetch:
- if c.default and not c.default.is_sequence and c.default.is_scalar:
- if TYPE_CHECKING:
- assert isinstance(c.default, ColumnDefault)
+ if c.default and default_is_scalar(c.default):
val = c.default.arg
else:
val = self.get_insert_default(c)
from ..sql.ddl import SchemaDropper
from ..sql.ddl import SchemaGenerator
from ..sql.schema import HasSchemaAttr
+ from ..sql.schema import SchemaItem
class MockConnection:
def _run_ddl_visitor(
self,
visitorcallable: Type[Union[SchemaGenerator, SchemaDropper]],
- element: DDLElement,
+ element: SchemaItem,
**kwargs: Any,
) -> None:
kwargs["checkfirst"] = False
with an already-given :class:`_schema.MetaData`.
"""
-
with self._operation_context() as conn:
tnames = self.dialect.get_table_names(
conn, schema, info_cache=self.info_cache
return cls(typ=typ, info=info, **data)
-def name_is_dunder(name):
+def name_is_dunder(name: str) -> bool:
return bool(re.match(r"^__.+?__$", name))
else:
name = _qual_logger_name_for_cls(instance.__class__)
- instance._echo = echoflag
+ instance._echo = echoflag # type: ignore
logger: Union[logging.Logger, InstanceLogger]
# levels by calling logger._log()
logger = InstanceLogger(echoflag, name)
- instance.logger = logger
+ instance.logger = logger # type: ignore
class echo_property:
"""
- __hash__ = None
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore
@util.memoized_property
def clauses(self):
"the set of foreign key values."
)
- __hash__ = None
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore
def __eq__(self, other):
"""Implement the ``==`` operator.
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import Any
from .base import Executable as Executable
from .compiler import COLLECT_CARTESIAN_PRODUCTS as COLLECT_CARTESIAN_PRODUCTS
from .visitors import ClauseVisitor as ClauseVisitor
-def __go(lcls):
+def __go(lcls: Any) -> None:
from .. import util as _sa_util
from . import base
from __future__ import annotations
from typing import Any
-from typing import Union
+from typing import Optional
from . import coercions
from . import roles
from .selectable import TableClause
from .selectable import TableSample
from .selectable import Values
-from ..util.typing import _LiteralStar
-from ..util.typing import Literal
def alias(selectable, name=None, flat=False):
return Join(left, right, onclause, isouter=True, full=full)
-def select(
- *entities: Union[_LiteralStar, Literal[1], _ColumnsClauseElement]
-) -> "Select":
+def select(*entities: _ColumnsClauseElement) -> Select:
r"""Construct a new :class:`_expression.Select`.
return Select(*entities)
-def table(name: str, *columns: ColumnClause, **kw: Any) -> "TableClause":
+def table(name: str, *columns: ColumnClause[Any], **kw: Any) -> TableClause:
"""Produce a new :class:`_expression.TableClause`.
The object returned is an instance of
return CompoundSelect._create_union_all(*selects)
-def values(*columns, name=None, literal_binds=False) -> "Values":
+def values(
+ *columns: ColumnClause[Any],
+ name: Optional[str] = None,
+ literal_binds: bool = False,
+) -> Values:
r"""Construct a :class:`_expression.Values` construct.
The column expressions and the actual data for
from . import roles
from .. import util
from ..inspection import Inspectable
+from ..util.typing import Literal
if TYPE_CHECKING:
from .elements import quoted_name
_T = TypeVar("_T", bound=Any)
_ColumnsClauseElement = Union[
+ Literal["*", 1],
roles.ColumnsClauseRole,
- Type,
+ Type[Any],
Inspectable[roles.HasColumnElementClauseElement],
]
_FromClauseElement = Union[
- roles.FromClauseRole, Type, Inspectable[roles.HasFromClauseElement]
+ roles.FromClauseRole, Type[Any], Inspectable[roles.HasFromClauseElement]
]
_ColumnExpression = Union[
from __future__ import annotations
-import collections.abc as collections_abc
from enum import Enum
from functools import reduce
import itertools
from itertools import zip_longest
import operator
import re
-import typing
from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
from typing import Iterable
+from typing import Iterator
from typing import List
+from typing import Mapping
from typing import MutableMapping
+from typing import NoReturn
from typing import Optional
from typing import Sequence
+from typing import Set
from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
+from typing import Union
from . import roles
from . import visitors
from .traversals import HasCopyInternals # noqa
from .visitors import ClauseVisitor
from .visitors import ExtendedInternalTraversal
+from .visitors import ExternallyTraversible
from .visitors import InternalTraversal
+from .. import event
from .. import exc
from .. import util
from ..util import HasMemoized as HasMemoized
from ..util import hybridmethod
from ..util import typing as compat_typing
+from ..util.typing import Protocol
from ..util.typing import Self
+from ..util.typing import TypeGuard
-if typing.TYPE_CHECKING:
+if TYPE_CHECKING:
+ from . import coercions
+ from . import elements
+ from . import type_api
from .elements import BindParameter
+ from .elements import ColumnClause
from .elements import ColumnElement
+ from .elements import SQLCoreOperations
from ..engine import Connection
from ..engine import Result
from ..engine.base import _CompiledCacheType
from ..engine.interfaces import CacheStats
from ..engine.interfaces import Compiled
from ..engine.interfaces import Dialect
+ from ..event import dispatcher
-coercions = None
-elements = None
-type_api = None
+if not TYPE_CHECKING:
+ coercions = None # noqa
+ elements = None # noqa
+ type_api = None # noqa
class _NoArg(Enum):
NO_ARG = _NoArg.NO_ARG
-# if I use sqlalchemy.util.typing, which has the exact same
-# symbols, mypy reports: "error: _Fn? not callable"
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = TypeVar("_Fn", bound=Callable[..., Any])
_AmbiguousTableNameMap = MutableMapping[str, str]
+class _EntityNamespace(Protocol):
+ def __getattr__(self, key: str) -> SQLCoreOperations[Any]:
+ ...
+
+
+class _HasEntityNamespace(Protocol):
+ entity_namespace: _EntityNamespace
+
+
+def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]:
+ return hasattr(element, "entity_namespace")
+
+
class Immutable:
"""mark a ClauseElement as 'immutable' when expressions are cloned."""
def __new__(cls, *arg, **kw):
return cls._singleton
+ @util.non_memoized_property
+ def proxy_set(self) -> FrozenSet[ColumnElement[Any]]:
+ raise NotImplementedError()
+
@classmethod
def _create_singleton(cls):
obj = object.__new__(cls)
- obj.__init__()
+ obj.__init__() # type: ignore
# for a long time this was an empty frozenset, meaning
# a SingletonConstant would never be a "corresponding column" in
)
-_Self = typing.TypeVar("_Self", bound="_GenerativeType")
-_Args = compat_typing.ParamSpec("_Args")
+_SelfGenerativeType = TypeVar("_SelfGenerativeType", bound="_GenerativeType")
class _GenerativeType(compat_typing.Protocol):
- def _generate(self: "_Self") -> "_Self":
+ def _generate(self: _SelfGenerativeType) -> _SelfGenerativeType:
...
@util.decorator
def _generative(
- fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
- ) -> _Self:
+ fn: _Fn, self: _SelfGenerativeType, *args: Any, **kw: Any
+ ) -> _SelfGenerativeType:
"""Mark a method as generative."""
self = self._generate()
assert x is self, "generative methods must return self"
return self
- decorated = _generative(fn)
- decorated.non_generative = fn
- return decorated
+ decorated = _generative(fn) # type: ignore
+ decorated.non_generative = fn # type: ignore
+ return decorated # type: ignore
def _exclusive_against(*names, **kw):
)
-class _DialectArgView(collections_abc.MutableMapping):
+class _DialectArgView(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments in the form
<dialectname>_<argument_name>.
)
-class _DialectArgDict(collections_abc.MutableMapping):
+class _DialectArgDict(MutableMapping[str, Any]):
"""A dictionary view of dialect-level arguments for a specific
dialect.
"""
+ __slots__ = ()
+
_dialect_kwargs_traverse_internals = [
("dialect_options", InternalTraversal.dp_dialect_options)
]
__slots__ = ("statement", "_ambiguous_table_name_map")
- plugins = {}
+ plugins: Dict[Tuple[str, str], Type[CompileState]] = {}
_ambiguous_table_name_map: Optional[_AmbiguousTableNameMap]
class HasCompileState(Generative):
"""A class that has a :class:`.CompileState` associated with it."""
- _compile_state_plugin = None
+ _compile_state_plugin: Optional[Type[CompileState]] = None
- _attributes = util.immutabledict()
+ _attributes: util.immutabledict[str, Any] = util.EMPTY_DICT
_compile_state_factory = CompileState.create_for_statement
"""
+ _cache_attrs: Tuple[str, ...]
+
def __add__(self, other):
o1 = self()
__slots__ = ()
+ _cache_attrs: Tuple[str, ...]
+
def __init_subclass__(cls) -> None:
dict_ = cls.__dict__
cls._cache_attrs = tuple(
return self + {name: getattr(self, name) + value}
@hybridmethod
- def _state_dict(self):
+ def _state_dict_inst(self) -> Mapping[str, Any]:
return self.__dict__
- _state_dict_const = util.immutabledict()
+ _state_dict_const: util.immutabledict[str, Any] = util.EMPTY_DICT
- @_state_dict.classlevel
- def _state_dict(cls):
+ @_state_dict_inst.classlevel
+ def _state_dict(cls) -> Mapping[str, Any]:
return cls._state_dict_const
@classmethod
__slots__ = ()
@hybridmethod
- def _gen_cache_key(self, anon_map, bindparams):
+ def _gen_cache_key_inst(self, anon_map, bindparams):
return HasCacheKey._gen_cache_key(self, anon_map, bindparams)
- @_gen_cache_key.classlevel
+ @_gen_cache_key_inst.classlevel
def _gen_cache_key(cls, anon_map, bindparams):
return (cls, ())
def _clone(self, **kw):
"""Create a shallow copy of this ExecutableOption."""
c = self.__class__.__new__(self.__class__)
- c.__dict__ = dict(self.__dict__)
+ c.__dict__ = dict(self.__dict__) # type: ignore
return c
-SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+SelfExecutable = TypeVar("SelfExecutable", bound="Executable")
class Executable(roles.StatementRole, Generative):
"""
supports_execution: bool = True
- _execution_options: _ImmutableExecuteOptions = util.immutabledict()
- _with_options = ()
- _with_context_options = ()
+ _execution_options: _ImmutableExecuteOptions = util.EMPTY_DICT
+ _with_options: Tuple[ExecutableOption, ...] = ()
+ _with_context_options: Tuple[
+ Tuple[Callable[[CompileState], None], Any], ...
+ ] = ()
+ _compile_options: Optional[CacheableOptions]
_executable_traverse_internals = [
("_with_options", InternalTraversal.dp_executable_options),
is_delete = False
is_dml = False
- if typing.TYPE_CHECKING:
+ if TYPE_CHECKING:
+
+ __visit_name__: str
def _compile_w_cache(
self,
raise NotImplementedError()
@property
- def _effective_plugin_target(self):
+ def _effective_plugin_target(self) -> str:
return self.__visit_name__
@_generative
- def options(self: SelfExecutable, *options) -> SelfExecutable:
+ def options(
+ self: SelfExecutable, *options: ExecutableOption
+ ) -> SelfExecutable:
"""Apply options to this statement.
In the general sense, options are any kind of Python object
@_generative
def _set_compile_options(
- self: SelfExecutable, compile_options
+ self: SelfExecutable, compile_options: CacheableOptions
) -> SelfExecutable:
"""Assign the compile options to a new value.
@_generative
def _update_compile_options(
- self: SelfExecutable, options
+ self: SelfExecutable, options: CacheableOptions
) -> SelfExecutable:
"""update the _compile_options with new keys."""
+ assert self._compile_options is not None
self._compile_options += options
return self
@_generative
def _add_context_option(
- self: SelfExecutable, callable_, cache_args
+ self: SelfExecutable,
+ callable_: Callable[[CompileState], None],
+ cache_args: Any,
) -> SelfExecutable:
"""Add a context option to this statement.
return self
@_generative
- def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
+ def execution_options(self: SelfExecutable, **kw: Any) -> SelfExecutable:
"""Set non-SQL options for the statement which take effect during
execution.
self._execution_options = self._execution_options.union(kw)
return self
- def get_execution_options(self):
+ def get_execution_options(self) -> _ExecuteOptions:
"""Get the non-SQL options which will take effect during execution.
.. versionadded:: 1.3
return self._execution_options
-class SchemaEventTarget:
+class SchemaEventTarget(event.EventTarget):
"""Base class for elements that are the targets of :class:`.DDLEvents`
events.
"""
+ dispatch: dispatcher[SchemaEventTarget]
+
def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
"""Associate with this SchemaEvent's parent object."""
__traverse_options__ = {"schema_visitor": True}
-class ColumnCollection:
+_COL = TypeVar("_COL", bound="ColumnClause[Any]")
+
+
+class ColumnCollection(Generic[_COL]):
"""Collection of :class:`_expression.ColumnElement` instances,
typically for
:class:`_sql.FromClause` objects.
__slots__ = "_collection", "_index", "_colset"
- def __init__(self, columns=None):
+ _collection: List[Tuple[str, _COL]]
+ _index: Dict[Union[str, int], _COL]
+ _colset: Set[_COL]
+
+ def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None):
object.__setattr__(self, "_colset", set())
object.__setattr__(self, "_index", {})
object.__setattr__(self, "_collection", [])
if columns:
self._initial_populate(columns)
- def _initial_populate(self, iter_):
+ def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None:
self._populate_separate_keys(iter_)
@property
- def _all_columns(self):
+ def _all_columns(self) -> List[_COL]:
return [col for (k, col) in self._collection]
- def keys(self):
+ def keys(self) -> List[str]:
"""Return a sequence of string key names for all columns in this
collection."""
return [k for (k, col) in self._collection]
- def values(self):
+ def values(self) -> List[_COL]:
"""Return a sequence of :class:`_sql.ColumnClause` or
:class:`_schema.Column` objects for all columns in this
collection."""
return [col for (k, col) in self._collection]
- def items(self):
+ def items(self) -> List[Tuple[str, _COL]]:
"""Return a sequence of (key, column) tuples for all columns in this
collection each consisting of a string key name and a
:class:`_sql.ColumnClause` or
return list(self._collection)
- def __bool__(self):
+ def __bool__(self) -> bool:
return bool(self._collection)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._collection)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_COL]:
# turn to a list first to maintain over a course of changes
return iter([col for k, col in self._collection])
- def __getitem__(self, key):
+ def __getitem__(self, key: Union[str, int]) -> _COL:
try:
return self._index[key]
except KeyError as err:
else:
raise
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> _COL:
try:
return self._index[key]
except KeyError as err:
raise AttributeError(key) from err
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
if key not in self._index:
if not isinstance(key, str):
raise exc.ArgumentError(
else:
return True
- def compare(self, other):
+ def compare(self, other: ColumnCollection[Any]) -> bool:
"""Compare this :class:`_expression.ColumnCollection` to another
based on the names of the keys"""
else:
return True
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
return self.compare(other)
- def get(self, key, default=None):
+ def get(self, key: str, default: Optional[_COL] = None) -> Optional[_COL]:
"""Get a :class:`_sql.ColumnClause` or :class:`_schema.Column` object
based on a string key name from this
:class:`_expression.ColumnCollection`."""
else:
return default
- def __str__(self):
+ def __str__(self) -> str:
return "%s(%s)" % (
self.__class__.__name__,
", ".join(str(c) for c in self),
)
- def __setitem__(self, key, value):
+ def __setitem__(self, key: str, value: Any) -> NoReturn:
raise NotImplementedError()
- def __delitem__(self, key):
+ def __delitem__(self, key: str) -> NoReturn:
raise NotImplementedError()
- def __setattr__(self, key, obj):
+ def __setattr__(self, key: str, obj: Any) -> NoReturn:
raise NotImplementedError()
- def clear(self):
+ def clear(self) -> NoReturn:
"""Dictionary clear() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- def remove(self, column):
- """Dictionary remove() is not implemented for
- :class:`_sql.ColumnCollection`."""
+ def remove(self, column: Any) -> None:
raise NotImplementedError()
- def update(self, iter_):
+ def update(self, iter_: Any) -> NoReturn:
"""Dictionary update() is not implemented for
:class:`_sql.ColumnCollection`."""
raise NotImplementedError()
- __hash__ = None
+ # https://github.com/python/mypy/issues/4266
+ __hash__ = None # type: ignore
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
self._collection[:] = cols
)
self._index.update({k: col for k, col in reversed(self._collection)})
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
"""Add a column to this :class:`_sql.ColumnCollection`.
.. note::
if key not in self._index:
self._index[key] = column
- def __getstate__(self):
+ def __getstate__(self) -> Dict[str, Any]:
return {"_collection": self._collection, "_index": self._index}
- def __setstate__(self, state):
+ def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "_index", state["_index"])
object.__setattr__(self, "_collection", state["_collection"])
object.__setattr__(
self, "_colset", {col for k, col in self._collection}
)
- def contains_column(self, col):
+ def contains_column(self, col: _COL) -> bool:
"""Checks if a column object exists in this collection"""
if col not in self._colset:
if isinstance(col, str):
else:
return True
- def as_immutable(self):
+ def as_immutable(self) -> ImmutableColumnCollection[_COL]:
"""Return an "immutable" form of this
:class:`_sql.ColumnCollection`."""
return ImmutableColumnCollection(self)
- def corresponding_column(self, column, require_embedded=False):
+ def corresponding_column(
+ self, column: _COL, require_embedded: bool = False
+ ) -> Optional[_COL]:
"""Given a :class:`_expression.ColumnElement`, return the exported
:class:`_expression.ColumnElement` object from this
:class:`_expression.ColumnCollection`
not require_embedded
or embedded(expanded_proxy_set, target_set)
):
- if col is None:
+ if col is None or intersect is None:
# no corresponding column yet, pick this one.
return col
-class DedupeColumnCollection(ColumnCollection):
+class DedupeColumnCollection(ColumnCollection[_COL]):
"""A :class:`_expression.ColumnCollection`
that maintains deduplicating behavior.
"""
- def add(self, column, key=None):
+ def add(self, column: _COL, key: Optional[str] = None) -> None:
if key is not None and column.key != key:
raise exc.ArgumentError(
self._index[l] = column
self._index[key] = column
- def _populate_separate_keys(self, iter_):
+ def _populate_separate_keys(
+ self, iter_: Iterable[Tuple[str, _COL]]
+ ) -> None:
"""populate from an iterator of (key, column)"""
cols = list(iter_)
for col in replace_col:
self.replace(col)
- def extend(self, iter_):
+ def extend(self, iter_: Iterable[_COL]) -> None:
self._populate_separate_keys((col.key, col) for col in iter_)
- def remove(self, column):
+ def remove(self, column: _COL) -> None:
if column not in self._colset:
raise ValueError(
"Can't remove column %r; column is not in this collection"
# delete higher index
del self._index[len(self._collection)]
- def replace(self, column):
+ def replace(self, column: _COL) -> None:
"""add the given column to this collection, removing unaliased
versions of this column as well as existing columns with the
same key.
self._index.update(self._collection)
-class ImmutableColumnCollection(util.ImmutableContainer, ColumnCollection):
+class ImmutableColumnCollection(
+ util.ImmutableContainer, ColumnCollection[_COL]
+):
__slots__ = ("_parent",)
def __init__(self, collection):
def __setstate__(self, state):
parent = state["_parent"]
- self.__init__(parent)
+ self.__init__(parent) # type: ignore
- add = extend = remove = util.ImmutableContainer._immutable
+ def add(self, column: Any, key: Any = ...) -> Any:
+ self._immutable()
+ def extend(self, elements: Any) -> None:
+ self._immutable()
-class ColumnSet(util.ordered_column_set):
+ def remove(self, item: Any) -> None:
+ self._immutable()
+
+
+class ColumnSet(util.OrderedSet["ColumnClause[Any]"]):
def contains_column(self, col):
return col in self
for col in cols:
self.add(col)
- def __add__(self, other):
- return list(self) + list(other)
-
def __eq__(self, other):
l = []
for c in other:
return hash(tuple(x for x in self))
-def _entity_namespace(entity):
+def _entity_namespace(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible]
+) -> _EntityNamespace:
"""Return the nearest .entity_namespace for the given entity.
If not immediately available, does an iterate to find a sub-element
"""
try:
- return entity.entity_namespace
+ return cast(_HasEntityNamespace, entity).entity_namespace
except AttributeError:
- for elem in visitors.iterate(entity):
- if hasattr(elem, "entity_namespace"):
+ for elem in visitors.iterate(cast(ExternallyTraversible, entity)):
+ if _is_has_entity_namespace(elem):
return elem.entity_namespace
else:
raise
-def _entity_namespace_key(entity, key, default=NO_ARG):
+def _entity_namespace_key(
+ entity: Union[_HasEntityNamespace, ExternallyTraversible],
+ key: str,
+ default: Union[SQLCoreOperations[Any], _NoArg] = NO_ARG,
+) -> SQLCoreOperations[Any]:
"""Return an entry from an entity_namespace.
if default is not NO_ARG:
return getattr(ns, key, default)
else:
- return getattr(ns, key)
+ return getattr(ns, key) # type: ignore
except AttributeError as err:
raise exc.InvalidRequestError(
'Entity namespace for "%s" has no property "%s"' % (entity, key)
from .sqltypes import TupleType
from .type_api import TypeEngine
from .visitors import prefix_anon_map
+from .visitors import Visitable
from .. import exc
from .. import util
from ..util.typing import Literal
raise NotImplementedError()
- def process(self, obj, **kwargs):
+ def process(self, obj: Visitable, **kwargs: Any) -> str:
return obj._compiler_dispatch(self, **kwargs)
- def __str__(self):
+ def __str__(self) -> str:
"""Return the string text of the generated SQL or DDL."""
return self.string or ""
"""list of columns for which onupdate default values should be evaluated
before an UPDATE takes place"""
- returning: Optional[List[Column[Any]]]
+ returning: Optional[List[ColumnClause[Any]]]
"""list of columns that will be delivered to cursor.description or
dialect equivalent via the RETURNING clause on an INSERT, UPDATE, or DELETE
self._result_columns
)
- _key_getters_for_crud_column: Tuple[
- Callable[[Union[str, Column[Any]]], str],
- Callable[[Column[Any]], str],
- Callable[[Column[Any]], str],
- ]
+ # assigned by crud.py for insert/update statements
+ _get_bind_name_for_col: _BindNameForColProtocol
@util.memoized_property
def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
- getter = self._key_getters_for_crud_column[2]
+ getter = self._get_bind_name_for_col
if self.escaped_bind_names:
def _get(obj):
def for_update_clause(self, select, **kw):
return " FOR UPDATE"
- def returning_clause(self, stmt, returning_cols):
+ def returning_clause(
+ self, stmt: UpdateBase, returning_cols: List[ColumnClause[Any]]
+ ) -> str:
raise exc.CompileError(
"RETURNING is not supported by this "
"dialect's statement compiler."
}
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, insert_stmt, compile_state, **kw
)
+ crud_params_single = crud_params_struct.single_params
if (
- not crud_params
+ not crud_params_single
and not self.dialect.supports_default_values
and not self.dialect.supports_default_metavalue
and not self.dialect.supports_empty_insert
"version settings does not support "
"in-place multirow inserts." % self.dialect.name
)
- crud_params_single = crud_params[0]
+ crud_params_single = crud_params_struct.single_params
else:
- crud_params_single = crud_params
+ crud_params_single = crud_params_struct.single_params
preparer = self.preparer
supports_default_values = self.dialect.supports_default_values
if crud_params_single or not supports_default_values:
text += " (%s)" % ", ".join(
- [expr for c, expr, value in crud_params_single]
+ [expr for _, expr, _ in crud_params_single]
)
if self.returning or insert_stmt._returning:
)
else:
text += " %s" % select_text
- elif not crud_params and supports_default_values:
+ elif not crud_params_single and supports_default_values:
text += " DEFAULT VALUES"
elif compile_state._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
"(%s)"
- % (", ".join(value for c, expr, value in crud_param_set))
- for crud_param_set in crud_params
+ % (", ".join(value for _, _, value in crud_param_set))
+ for crud_param_set in crud_params_struct.all_multi_params
)
)
else:
insert_single_values_expr = ", ".join(
- [value for c, expr, value in crud_params]
+ [
+ value
+ for _, _, value in cast(
+ "List[Tuple[Any, Any, str]]", crud_params_single
+ )
+ ]
)
text += " VALUES (%s)" % insert_single_values_expr
if toplevel and insert_stmt._post_values_clause is None:
table_text = self.update_tables_clause(
update_stmt, update_stmt.table, render_extra_froms, **kw
)
- crud_params = crud._get_crud_params(
+ crud_params_struct = crud._get_crud_params(
self, update_stmt, compile_state, **kw
)
+ crud_params = crud_params_struct.single_params
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
text += table_text
text += " SET "
- text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
+ text += ", ".join(
+ expr + "=" + value
+ for _, expr, value in cast(
+ "List[Tuple[Any, str, str]]", crud_params
+ )
+ )
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
...
+class _BindNameForColProtocol(Protocol):
+ def __call__(self, col: ColumnClause[Any]) -> str:
+ ...
+
+
class IdentifierPreparer:
"""Handle quoting and case-folding of identifiers based on options."""
import functools
import operator
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import List
+from typing import MutableMapping
+from typing import NamedTuple
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import Union
from . import coercions
from . import dml
from . import elements
from . import roles
+from .schema import default_is_clause_element
+from .schema import default_is_sequence
from .. import exc
from .. import util
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+ from .compiler import _BindNameForColProtocol
+ from .compiler import SQLCompiler
+ from .dml import DMLState
+ from .dml import Insert
+ from .dml import Update
+ from .dml import UpdateDMLState
+ from .dml import ValuesBase
+ from .elements import ClauseElement
+ from .elements import ColumnClause
+ from .elements import ColumnElement
+ from .elements import TextClause
+ from .schema import _SQLExprDefault
+ from .schema import Column
+ from .selectable import TableClause
REQUIRED = util.symbol(
"REQUIRED",
)
-def _get_crud_params(compiler, stmt, compile_state, **kw):
+class _CrudParams(NamedTuple):
+ single_params: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ]
+ all_multi_params: List[
+ List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ]
+ ]
+
+
+def _get_crud_params(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ **kw: Any,
+) -> _CrudParams:
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
_column_as_key,
_getattr_col_key,
_col_bind_name,
- ) = getters = _key_getters_for_crud_column(compiler, stmt, compile_state)
+ ) = _key_getters_for_crud_column(compiler, stmt, compile_state)
- compiler._key_getters_for_crud_column = getters
+ compiler._get_bind_name_for_col = _col_bind_name
# no parameters in the statement, no parameters in the
# compiled params - return binds for all columns
if compiler.column_keys is None and compile_state._no_parameters:
- return [
- (
- c,
- compiler.preparer.format_column(c),
- _create_bind_param(compiler, c, None, required=True),
- )
- for c in stmt.table.columns
- ]
+ return _CrudParams(
+ [
+ (
+ c,
+ compiler.preparer.format_column(c),
+ _create_bind_param(compiler, c, None, required=True),
+ )
+ for c in stmt.table.columns
+ ],
+ [],
+ )
+
+ stmt_parameter_tuples: Optional[List[Any]]
+ spd: Optional[MutableMapping[str, Any]]
if compile_state._has_multi_parameters:
- spd = compile_state._multi_parameters[0]
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ spd = mp[0]
stmt_parameter_tuples = list(spd.items())
elif compile_state._ordered_values:
spd = compile_state._dict_parameters
if compiler.column_keys is None:
parameters = {}
elif stmt_parameter_tuples:
+ assert spd is not None
parameters = dict(
(_column_as_key(key), REQUIRED)
for key in compiler.column_keys
)
# create a list of column assignment clauses as tuples
- values = []
+ values: List[
+ Tuple[ColumnClause[Any], str, Optional[Union[str, _SQLExprDefault]]]
+ ] = []
if stmt_parameter_tuples is not None:
_get_stmt_parameter_tuples_params(
kw,
)
- check_columns = {}
+ check_columns: Dict[str, ColumnClause[Any]] = {}
# special logic that only occurs for multi-table UPDATE
# statements
- if compile_state.isupdate and compile_state.is_multitable:
+ if dml.isupdate(compile_state) and compile_state.is_multitable:
_get_update_multitable_params(
compiler,
stmt,
)
if compile_state.isinsert and stmt._select_names:
+ # is an insert from select, is not a multiparams
+
+ assert not compile_state._has_multi_parameters
+
_scan_insert_from_select_cols(
compiler,
stmt,
)
if compile_state._has_multi_parameters:
- values = _extend_values_for_multiparams(
+ # is a multiparams, is not an insert from a select
+ assert not stmt._select_names
+ multi_extended_values = _extend_values_for_multiparams(
compiler,
stmt,
compile_state,
- values,
- _column_as_key,
+ cast("List[Tuple[ColumnClause[Any], str, str]]", values),
+ cast("Callable[..., str]", _column_as_key),
kw,
)
+ return _CrudParams(values, multi_extended_values)
elif (
not values
and compiler.for_executemany
)
]
- return values
+ return _CrudParams(values, [])
+@overload
def _create_bind_param(
- compiler, col, value, process=True, required=False, name=None, **kw
-):
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: Literal[True] = ...,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ **kw: Any,
+) -> str:
+ ...
+
+
+def _create_bind_param(
+ compiler: SQLCompiler,
+ col: ColumnElement[Any],
+ value: Any,
+ process: bool = True,
+ required: bool = False,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[str, elements.BindParameter[Any]]:
if name is None:
name = col.key
bindparam = elements.BindParameter(
)
bindparam._is_crud = True
if process:
- bindparam = bindparam._compiler_dispatch(compiler, **kw)
- return bindparam
+ return bindparam._compiler_dispatch(compiler, **kw)
+ else:
+ return bindparam
def _handle_values_anonymous_param(compiler, col, value, name, **kw):
return value._compiler_dispatch(compiler, **kw)
-def _key_getters_for_crud_column(compiler, stmt, compile_state):
- if compile_state.isupdate and compile_state._extra_froms:
+def _key_getters_for_crud_column(
+ compiler: SQLCompiler, stmt: ValuesBase, compile_state: DMLState
+) -> Tuple[
+ Callable[[Union[str, Column[Any]]], Union[str, Tuple[str, str]]],
+ Callable[[Column[Any]], Union[str, Tuple[str, str]]],
+ _BindNameForColProtocol,
+]:
+ if dml.isupdate(compile_state) and compile_state._extra_froms:
# when extra tables are present, refer to the columns
# in those extra tables as table-qualified, including in
# dictionaries and when rendering bind param names.
coercions.expect_as_key, roles.DMLColumnRole
)
- def _column_as_key(key):
+ def _column_as_key(
+ key: Union[ColumnClause[Any], str]
+ ) -> Union[str, Tuple[str, str]]:
str_key = c_key_role(key)
- if hasattr(key, "table") and key.table in _et:
- return (key.table.name, str_key)
+ if hasattr(key, "table") and key.table in _et: # type: ignore
+ return (key.table.name, str_key) # type: ignore
else:
- return str_key
+ return str_key # type: ignore
- def _getattr_col_key(col):
+ def _getattr_col_key(
+ col: ColumnClause[Any],
+ ) -> Union[str, Tuple[str, str]]:
if col.table in _et:
- return (col.table.name, col.key)
+ return (col.table.name, col.key) # type: ignore
else:
return col.key
- def _col_bind_name(col):
+ def _col_bind_name(col: ColumnClause[Any]) -> str:
if col.table in _et:
+ if TYPE_CHECKING:
+ assert isinstance(col.table, TableClause)
return "%s_%s" % (col.table.name, col.key)
else:
return col.key
else:
- _column_as_key = functools.partial(
+ _column_as_key = functools.partial( # type: ignore
coercions.expect_as_key, roles.DMLColumnRole
)
- _getattr_col_key = _col_bind_name = operator.attrgetter("key")
+ _getattr_col_key = _col_bind_name = operator.attrgetter("key") # type: ignore # noqa E501
return _column_as_key, _getattr_col_key, _col_bind_name
compiler.stack[-1]["insert_from_select"] = stmt.select
- add_select_cols = []
+ add_select_cols: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]] = []
if stmt.include_insert_from_select_defaults:
col_set = set(cols)
for col in stmt.table.columns:
)
-def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ values: List[Tuple[ColumnClause[Any], str, _SQLExprDefault]],
+ kw: Dict[str, Any],
+) -> None:
- if c.default.is_sequence:
+ if default_is_sequence(c.default):
if compiler.dialect.supports_sequences and (
not c.default.optional or not compiler.dialect.sequences_optional
):
values.append(
(c, compiler.preparer.format_column(c), c.default.next_value())
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
values.append(
(c, compiler.preparer.format_column(c), c.default.arg.self_group())
)
compiler.returning.append(c)
+@overload
def _create_insert_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_insert_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.insert_prefetch.append(c)
+ compiler.insert_prefetch.append(c) # type: ignore
return param
+@overload
def _create_update_prefetch_bind_param(
- compiler, c, process=True, name=None, **kw
-):
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[True] = ...,
+ **kw: Any,
+) -> str:
+ ...
+
+
+@overload
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: Literal[False],
+ **kw: Any,
+) -> elements.BindParameter[Any]:
+ ...
+
+
+def _create_update_prefetch_bind_param(
+ compiler: SQLCompiler,
+ c: ColumnElement[Any],
+ process: bool = True,
+ name: Optional[str] = None,
+ **kw: Any,
+) -> Union[elements.BindParameter[Any], str]:
param = _create_bind_param(
compiler, c, None, process=process, name=name, **kw
)
- compiler.update_prefetch.append(c)
+ compiler.update_prefetch.append(c) # type: ignore
return param
-class _multiparam_column(elements.ColumnElement):
+class _multiparam_column(elements.ColumnElement[Any]):
_is_multiparam_column = True
def __init__(self, original, index):
)
-def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
+def _process_multiparam_default_bind(
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ c: ColumnClause[Any],
+ index: int,
+ kw: Dict[str, Any],
+) -> str:
if not c.default:
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
"a Python-side value or SQL expression is required" % c
)
- elif c.default.is_clause_element:
+ elif default_is_clause_element(c.default):
return compiler.process(c.default.arg.self_group(), **kw)
elif c.default.is_sequence:
# these conditions would have been established
else:
col = _multiparam_column(c, index)
if isinstance(stmt, dml.Insert):
- return _create_insert_prefetch_bind_param(compiler, col, **kw)
+ return _create_insert_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
else:
- return _create_update_prefetch_bind_param(compiler, col, **kw)
+ return _create_update_prefetch_bind_param(
+ compiler, col, process=True, **kw
+ )
def _get_update_multitable_params(
def _extend_values_for_multiparams(
- compiler,
- stmt,
- compile_state,
- values,
- _column_as_key,
- kw,
-):
- values_0 = values
- values = [values]
-
- for i, row in enumerate(compile_state._multi_parameters[1:]):
- extension = []
+ compiler: SQLCompiler,
+ stmt: ValuesBase,
+ compile_state: DMLState,
+ initial_values: List[Tuple[ColumnClause[Any], str, str]],
+ _column_as_key: Callable[..., str],
+ kw: Dict[str, Any],
+) -> List[List[Tuple[ColumnClause[Any], str, str]]]:
+ values_0 = initial_values
+ values = [initial_values]
+
+ mp = compile_state._multi_parameters
+ assert mp is not None
+ for i, row in enumerate(mp[1:]):
+ extension: List[
+ Tuple[
+ ColumnClause[Any],
+ str,
+ str,
+ ]
+ ] = []
row = {_column_as_key(key): v for key, v in row.items()}
from . import type_api
from .elements import and_
from .elements import BinaryExpression
+from .elements import ClauseElement
from .elements import ClauseList
from .elements import CollationClause
from .elements import CollectionAggregate
if typing.TYPE_CHECKING:
from .elements import ColumnElement
from .operators import custom_op
- from .sqltypes import TypeEngine
+ from .type_api import TypeEngine
def _boolean_compare(
*,
negate_op: Optional[OperatorType] = None,
reverse: bool = False,
- _python_is_types=(util.NoneType, bool),
- _any_all_expr=False,
+ _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
+ _any_all_expr: bool = False,
result_type: Optional[
- Union[Type["TypeEngine[bool]"], "TypeEngine[bool]"]
+ Union[Type[TypeEngine[bool]], TypeEngine[bool]]
] = None,
**kwargs: Any,
) -> BinaryExpression[bool]:
def _binary_operate(
expr: ColumnElement[Any],
op: OperatorType,
- obj: roles.BinaryElementRole,
+ obj: roles.BinaryElementRole[Any],
*,
reverse: bool = False,
result_type: Optional[
def _conjunction_operate(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if op is operators.and_:
return and_(expr, other)
def _scalar(
- expr: ColumnElement[Any], op: OperatorType, fn, **kw
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ fn: Callable[[ColumnElement[Any]], ColumnElement[Any]],
+ **kw: Any,
) -> ColumnElement[Any]:
return fn(expr)
def _in_impl(
expr: ColumnElement[Any],
op: OperatorType,
- seq_or_selectable,
+ seq_or_selectable: ClauseElement,
negate_op: OperatorType,
- **kw,
+ **kw: Any,
) -> ColumnElement[Any]:
seq_or_selectable = coercions.expect(
roles.InElementRole, seq_or_selectable, expr=expr, operator=op
def _getitem_impl(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
if isinstance(expr.type, type_api.INDEXABLE):
other = coercions.expect(
def _unsupported_impl(
- expr: ColumnElement[Any], op: OperatorType, *arg, **kw
+ expr: ColumnElement[Any], op: OperatorType, *arg: Any, **kw: Any
) -> NoReturn:
raise NotImplementedError(
"Operator '%s' is not supported on " "this expression" % op.__name__
def _inv_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__inv__`."""
def _neg_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.__neg__`."""
return UnaryExpression(expr, operator=operators.neg, type_=expr.type)
def _match_impl(
- expr: ColumnElement[Any], op: OperatorType, other, **kw
+ expr: ColumnElement[Any], op: OperatorType, other: Any, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.match`."""
def _distinct_impl(
- expr: ColumnElement[Any], op: OperatorType, **kw
+ expr: ColumnElement[Any], op: OperatorType, **kw: Any
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.distinct`."""
return UnaryExpression(
def _between_impl(
- expr: ColumnElement[Any], op: OperatorType, cleft, cright, **kw
+ expr: ColumnElement[Any],
+ op: OperatorType,
+ cleft: Any,
+ cright: Any,
+ **kw: Any,
) -> ColumnElement[Any]:
"""See :meth:`.ColumnOperators.between`."""
return BinaryExpression(
def _collate_impl(
- expr: ColumnElement[Any], op: OperatorType, collation, **kw
-) -> ColumnElement[Any]:
+ expr: ColumnElement[str], op: OperatorType, collation: str, **kw: Any
+) -> ColumnElement[str]:
return CollationClause._create_collation_expression(expr, collation)
def _regexp_match_impl(
- expr: ColumnElement[Any], op: OperatorType, pattern, flags, **kw
+ expr: ColumnElement[str],
+ op: OperatorType,
+ pattern: Any,
+ flags: Optional[str],
+ **kw: Any,
) -> ColumnElement[Any]:
if flags is not None:
- flags = coercions.expect(
+ flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
+ else:
+ flags_expr = None
return _boolean_compare(
expr,
op,
pattern,
- flags=flags,
+ flags=flags_expr,
negate_op=operators.not_regexp_match_op
if op is operators.regexp_match_op
else operators.regexp_match_op,
def _regexp_replace_impl(
expr: ColumnElement[Any],
op: OperatorType,
- pattern,
- replacement,
- flags,
- **kw,
+ pattern: Any,
+ replacement: Any,
+ flags: Optional[str],
+ **kw: Any,
) -> ColumnElement[Any]:
replacement = coercions.expect(
roles.BinaryElementRole,
operator=operators.regexp_replace_op,
)
if flags is not None:
- flags = coercions.expect(
+ flags_expr = coercions.expect(
roles.BinaryElementRole,
flags,
expr=expr,
operator=operators.regexp_replace_op,
)
+ else:
+ flags_expr = None
return _binary_operate(
- expr, op, pattern, replacement=replacement, flags=flags, **kw
+ expr, op, pattern, replacement=replacement, flags=flags_expr, **kw
)
# a mapping of operators with the method they use, along with
# additional keyword arguments to be passed
operator_lookup: Dict[
- str, Tuple[Callable[..., ColumnElement[Any]], util.immutabledict]
+ str,
+ Tuple[
+ Callable[..., ColumnElement[Any]],
+ util.immutabledict[
+ str, Union[OperatorType, Callable[..., ColumnElement[Any]]]
+ ],
+ ],
] = {
"and_": (_conjunction_operate, util.EMPTY_DICT),
"or_": (_conjunction_operate, util.EMPTY_DICT),
from __future__ import annotations
import collections.abc as collections_abc
+import operator
import typing
from typing import Any
from typing import List
from typing import MutableMapping
from typing import Optional
+from typing import TYPE_CHECKING
from . import coercions
from . import roles
from .selectable import HasCTE
from .selectable import HasPrefixes
from .selectable import ReturnsRows
+from .selectable import TableClause
from .sqltypes import NullType
from .visitors import InternalTraversal
from .. import exc
from .. import util
+from ..util.typing import TypeGuard
+
+
+if TYPE_CHECKING:
+
+ def isupdate(dml) -> TypeGuard[UpdateDMLState]:
+ ...
+
+ def isdelete(dml) -> TypeGuard[DeleteDMLState]:
+ ...
+
+ def isinsert(dml) -> TypeGuard[InsertDMLState]:
+ ...
+
+else:
+ isupdate = operator.attrgetter("isupdate")
+ isdelete = operator.attrgetter("isdelete")
+ isinsert = operator.attrgetter("isinsert")
class DMLState(CompileState):
_ordered_values = None
_parameter_ordering = None
_has_multi_parameters = False
+
isupdate = False
isdelete = False
isinsert = False
_hints = util.immutabledict()
named_with_column = False
+ table: TableClause
+
_return_defaults = False
_return_defaults_columns = None
_returning = ()
import operator
import re
import typing
+from typing import AbstractSet
from typing import Any
from typing import Callable
from typing import cast
from .operators import OperatorType
from .schema import Column
from .schema import DefaultGenerator
+ from .schema import FetchedValue
from .schema import ForeignKey
from .selectable import FromClause
from .selectable import NamedFromClause
"""
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> Optional[str]:
return None
_cache_key_traversal = None
- negation_clause: ClauseElement
+ negation_clause: ColumnElement[bool]
if typing.TYPE_CHECKING:
primary_key: bool = False
_is_clone_of: Optional[ColumnElement[_T]]
- @util.memoized_property
- def foreign_keys(self) -> Iterable[ForeignKey]:
- return []
+ foreign_keys: AbstractSet[ForeignKey] = frozenset()
@util.memoized_property
def _proxies(self) -> List[ColumnElement[Any]]:
else:
key = name
+ assert key is not None
+
co: ColumnClause[_T] = ColumnClause(
coercions.expect(roles.TruncatedLabelRole, name)
if name_is_truncatable
co._proxies = [self]
if selectable._is_clone_of is not None:
co._is_clone_of = selectable._is_clone_of.columns.get(key)
- assert key is not None
return key, co
def cast(self, type_: TypeEngine[_T]) -> Cast[_T]:
is_literal = False
table: Optional[FromClause] = None
name: str
+ key: str
def _compare_name_for_result(self, other):
return (hasattr(other, "name") and self.name == other.name) or (
hasattr(other, "_label") and self._label == other._label
)
- @util.memoized_property
+ @util.ro_memoized_property
def description(self) -> str:
return self.name
_selectable=selectable,
is_literal=False,
)
+
c._propagate_attrs = selectable._propagate_attrs
if name is None:
c.key = self.key
onupdate: Optional[DefaultGenerator] = None
default: Optional[DefaultGenerator] = None
- server_default: Optional[DefaultGenerator] = None
- server_onupdate: Optional[DefaultGenerator] = None
+ server_default: Optional[FetchedValue] = None
+ server_onupdate: Optional[FetchedValue] = None
_is_multiparam_column = False
from __future__ import annotations
+from typing import Any
+from typing import TYPE_CHECKING
+
from .base import SchemaEventTarget
from .. import event
+if TYPE_CHECKING:
+ from .schema import Column
+ from .schema import Constraint
+ from .schema import SchemaItem
+ from .schema import Table
+ from ..engine.base import Connection
+ from ..engine.interfaces import ReflectedColumn
+ from ..engine.reflection import Inspector
+
-class DDLEvents(event.Events):
+class DDLEvents(event.Events[SchemaEventTarget]):
"""
Define event listeners for schema objects,
that is, :class:`.SchemaItem` and other :class:`.SchemaEventTarget`
_target_class_doc = "SomeSchemaClassOrObject"
_dispatch_target = SchemaEventTarget
- def before_create(self, target, connection, **kw):
+ def before_create(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called before CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
"""
- def after_create(self, target, connection, **kw):
+ def after_create(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called after CREATE statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
"""
- def before_drop(self, target, connection, **kw):
+ def before_drop(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called before DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
"""
- def after_drop(self, target, connection, **kw):
+ def after_drop(
+ self, target: SchemaEventTarget, connection: Connection, **kw: Any
+ ) -> None:
r"""Called after DROP statements are emitted.
:param target: the :class:`_schema.MetaData` or :class:`_schema.Table`
"""
- def before_parent_attach(self, target, parent):
+ def before_parent_attach(
+ self, target: SchemaEventTarget, parent: SchemaItem
+ ) -> None:
"""Called before a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
"""
- def after_parent_attach(self, target, parent):
+ def after_parent_attach(
+ self, target: SchemaEventTarget, parent: SchemaItem
+ ) -> None:
"""Called after a :class:`.SchemaItem` is associated with
a parent :class:`.SchemaItem`.
"""
- def _sa_event_column_added_to_pk_constraint(self, const, col):
+ def _sa_event_column_added_to_pk_constraint(
+ self, const: Constraint, col: Column[Any]
+ ) -> None:
"""internal event hook used for primary key naming convention
updates.
"""
- def column_reflect(self, inspector, table, column_info):
+ def column_reflect(
+ self, inspector: Inspector, table: Table, column_info: ReflectedColumn
+ ) -> None:
"""Called for each unit of 'column info' retrieved when
a :class:`_schema.Table` is being reflected.
from ._elements_constructors import true as true
from ._elements_constructors import tuple_ as tuple_
from ._elements_constructors import type_coerce as type_coerce
-from ._elements_constructors import typing as typing
from ._elements_constructors import within_group as within_group
from ._selectable_constructors import alias as alias
from ._selectable_constructors import cte as cte
__slots__ = ()
# does not allow text() or select() objects
- c: ColumnCollection
+ c: ColumnCollection[Any]
- @property
+ # this should be ->str , however, working around:
+ # https://github.com/python/mypy/issues/12440
+ @util.ro_non_memoized_property
def description(self) -> str:
raise NotImplementedError()
"""
from __future__ import annotations
+from abc import ABC
import collections
+import operator
import typing
from typing import Any
+from typing import Callable
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import Optional
from typing import overload
from typing import Sequence as _typing_Sequence
+from typing import Set
+from typing import Tuple
from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from . import roles
from . import type_api
from . import visitors
+from .base import ColumnCollection
from .base import DedupeColumnCollection
from .base import DialectKWArgs
from .base import Executable
from .. import inspection
from .. import util
from ..util.typing import Literal
+from ..util.typing import Protocol
+from ..util.typing import TypeGuard
if typing.TYPE_CHECKING:
from .type_api import TypeEngine
from ..engine import Connection
from ..engine import Engine
-
+ from ..engine.interfaces import ExecutionContext
+ from ..engine.mock import MockConnection
_T = TypeVar("_T", bound="Any")
_ServerDefaultType = Union["FetchedValue", str, TextClause, ColumnElement]
_TAB = TypeVar("_TAB", bound="Table")
)
-def _get_table_key(name, schema):
+def _get_table_key(name: str, schema: Optional[str]) -> str:
if schema is None:
return name
else:
__visit_name__ = "table"
- constraints = None
+ constraints: Set[Constraint]
"""A collection of all :class:`_schema.Constraint` objects associated with
this :class:`_schema.Table`.
"""
- indexes = None
+ indexes: Set[Index]
"""A collection of all :class:`_schema.Index` objects associated with this
:class:`_schema.Table`.
("schema", InternalTraversal.dp_string)
]
+ if TYPE_CHECKING:
+
+ @util.non_memoized_property
+ def columns(self) -> ColumnCollection[Column[Any]]:
+ ...
+
+ c: ColumnCollection[Column[Any]]
+
def _gen_cache_key(self, anon_map, bindparams):
if self._annotations:
return (self,) + self._annotations_cache_key
)
@property
- def _sorted_constraints(self):
+ def _sorted_constraints(self) -> List[Constraint]:
"""Return the set of constraints as a list, sorted by creation
order.
"""
+
return sorted(self.constraints, key=lambda c: c._creation_order)
@property
)
self.info = kwargs.pop("info", self.info)
+ exclude_columns: _typing_Sequence[str]
+
if autoload:
if not autoload_replace:
# don't replace columns already present.
return metadata.tables[key]
args = []
- for c in self.columns:
- args.append(c._copy(schema=schema))
+ for col in self.columns:
+ args.append(col._copy(schema=schema))
table = Table(
name,
metadata,
*args,
**self.kwargs,
)
- for c in self.constraints:
- if isinstance(c, ForeignKeyConstraint):
- referred_schema = c._referred_schema
+ for const in self.constraints:
+ if isinstance(const, ForeignKeyConstraint):
+ referred_schema = const._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
- self, schema, c, referred_schema
+ self, schema, const, referred_schema
)
else:
fk_constraint_schema = (
schema if referred_schema == self.schema else None
)
table.append_constraint(
- c._copy(schema=fk_constraint_schema, target_table=table)
+ const._copy(
+ schema=fk_constraint_schema, target_table=table
+ )
)
- elif not c._type_bound:
+ elif not const._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
- if c._column_flag:
+ if const._column_flag:
continue
table.append_constraint(
- c._copy(schema=schema, target_table=table)
+ const._copy(schema=schema, target_table=table)
)
for index in self.indexes:
# skip indexes that would be generated
name = kwargs.pop("name", None)
type_ = kwargs.pop("type_", None)
- args = list(args)
- if args:
- if isinstance(args[0], str):
+ l_args = list(args)
+ del args
+
+ if l_args:
+ if isinstance(l_args[0], str):
if name is not None:
raise exc.ArgumentError(
"May not pass name positionally and as a keyword."
)
- name = args.pop(0)
- if args:
- coltype = args[0]
+ name = l_args.pop(0)
+ if l_args:
+ coltype = l_args[0]
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
"May not pass type_ positionally and as a keyword."
)
- type_ = args.pop(0)
+ type_ = l_args.pop(0)
if name is not None:
name = quoted_name(name, kwargs.pop("quote", None))
else:
self.nullable = not primary_key
- self.default = kwargs.pop("default", None)
+ default = kwargs.pop("default", None)
+ onupdate = kwargs.pop("onupdate", None)
+
self.server_default = kwargs.pop("server_default", None)
self.server_onupdate = kwargs.pop("server_onupdate", None)
self.system = kwargs.pop("system", False)
self.doc = kwargs.pop("doc", None)
- self.onupdate = kwargs.pop("onupdate", None)
self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
if isinstance(impl, SchemaEventTarget):
impl._set_parent_with_dispatch(self)
- if self.default is not None:
- if isinstance(self.default, (ColumnDefault, Sequence)):
- args.append(self.default)
- else:
- args.append(ColumnDefault(self.default))
+ if default is not None:
+ if not isinstance(default, (ColumnDefault, Sequence)):
+ default = ColumnDefault(default)
+
+ self.default = default
+ l_args.append(default)
+ else:
+ self.default = None
+
+ if onupdate is not None:
+ if not isinstance(onupdate, (ColumnDefault, Sequence)):
+ onupdate = ColumnDefault(onupdate, for_update=True)
+
+ self.onupdate = onupdate
+ l_args.append(onupdate)
+ else:
+ self.onpudate = None
if self.server_default is not None:
if isinstance(self.server_default, FetchedValue):
- args.append(self.server_default._as_for_update(False))
+ l_args.append(self.server_default._as_for_update(False))
else:
- args.append(DefaultClause(self.server_default))
-
- if self.onupdate is not None:
- if isinstance(self.onupdate, (ColumnDefault, Sequence)):
- args.append(self.onupdate)
- else:
- args.append(ColumnDefault(self.onupdate, for_update=True))
+ l_args.append(DefaultClause(self.server_default))
if self.server_onupdate is not None:
if isinstance(self.server_onupdate, FetchedValue):
- args.append(self.server_onupdate._as_for_update(True))
+ l_args.append(self.server_onupdate._as_for_update(True))
else:
- args.append(
+ l_args.append(
DefaultClause(self.server_onupdate, for_update=True)
)
- self._init_items(*args)
+ self._init_items(*l_args)
util.set_creation_order(self)
self._extra_kwargs(**kwargs)
- foreign_keys = None
+ table: Table
+
+ constraints: Set[Constraint]
+
+ foreign_keys: Set[ForeignKey]
"""A collection of all :class:`_schema.ForeignKey` marker objects
associated with this :class:`_schema.Column`.
"""
- index = None
+ index: bool
"""The value of the :paramref:`_schema.Column.index` parameter.
Does not indicate if this :class:`_schema.Column` is actually indexed
:attr:`_schema.Table.indexes`
"""
- unique = None
+ unique: bool
"""The value of the :paramref:`_schema.Column.unique` parameter.
Does not indicate if this :class:`_schema.Column` is actually subject to
server_default = self.server_default
server_onupdate = self.server_onupdate
if isinstance(server_default, (Computed, Identity)):
+ args.append(server_default._copy(**kw))
server_default = server_onupdate = None
- args.append(self.server_default._copy(**kw))
type_ = self.type
if isinstance(type_, SchemaEventTarget):
__visit_name__ = "foreign_key"
+ parent: Column[Any]
+
def __init__(
self,
- column: Union[str, Column, SQLCoreOperations],
+ column: Union[str, Column[Any], SQLCoreOperations[Any]],
_constraint: Optional["ForeignKeyConstraint"] = None,
use_alter: bool = False,
name: Optional[str] = None,
self._table_column = self._colspec
if not isinstance(
- self._table_column.table, (util.NoneType, TableClause)
+ self._table_column.table, (type(None), TableClause)
):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
# object passes itself in when creating ForeignKey
# markers.
self.constraint = _constraint
- self.parent = None
+
+ # .parent is not Optional under normal use
+ self.parent = None # type: ignore
+
self.use_alter = use_alter
self.name = name
self.onupdate = onupdate
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
- if not hasattr(self.constraint, "_referred_table"):
- self.constraint._referred_table = table
- else:
- assert self.constraint._referred_table is table
-
_column = None
if colname is None:
# colname is None in the case that ForeignKey argument
# was specified as table name only, in which case we
# match the column name to the same column on the
# parent.
- key = self.parent
- _column = table.c.get(self.parent.key, None)
+ # this use case wasn't working in later 1.x series
+ # as it had no test coverage; fixed in 2.0
+ parent = self.parent
+ assert parent is not None
+ key = parent.key
+ _column = table.c.get(key, None)
elif self.link_to_name:
key = colname
for c in table.c:
key,
)
- self._set_target_column(_column)
+ return _column
def _set_target_column(self, column):
- assert isinstance(self.parent.table, Table)
+ assert self.parent is not None
# propagate TypeEngine to parent if it didn't have one
if self.parent.type._isnull:
If no target column has been established, an exception
is raised.
- .. versionchanged:: 0.9.0
- Foreign key target column resolution now occurs as soon as both
- the ForeignKey object and the remote Column to which it refers
- are both associated with the same MetaData object.
-
"""
if isinstance(self._colspec, str):
"parent MetaData" % parenttable
)
else:
- raise exc.NoReferencedColumnError(
- "Could not initialize target column for "
- "ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'"
- % (self._colspec, parenttable.name, tablekey, colname),
- tablekey,
- colname,
+ table = parenttable.metadata.tables[tablekey]
+ return self._link_to_col_by_colstring(
+ parenttable, table, colname
)
+
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
_column = self._colspec
return _column
- def _set_parent(self, column, **kw):
- if self.parent is not None and self.parent is not column:
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ assert isinstance(parent, Column)
+
+ if self.parent is not None and self.parent is not parent:
raise exc.InvalidRequestError(
"This ForeignKey already has a parent !"
)
- self.parent = column
+ self.parent = parent
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
def _set_remote_table(self, table):
- parenttable, tablekey, colname = self._resolve_col_tokens()
- self._link_to_col_by_colstring(parenttable, table, colname)
+ parenttable, _, colname = self._resolve_col_tokens()
+ _column = self._link_to_col_by_colstring(parenttable, table, colname)
+ self._set_target_column(_column)
+ assert self.constraint is not None
self.constraint._validate_dest_table(table)
def _remove_from_metadata(self, metadata):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
- self._link_to_col_by_colstring(parenttable, table, colname)
+ _column = self._link_to_col_by_colstring(
+ parenttable, table, colname
+ )
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
+ else:
+ self._set_target_column(_column)
+
parenttable.metadata._fk_memos[fk_key].append(self)
elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
self._set_target_column(_column)
+if TYPE_CHECKING:
+
+ def default_is_sequence(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[Sequence]:
+ ...
+
+ def default_is_clause_element(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[ColumnElementColumnDefault]:
+ ...
+
+ def default_is_scalar(
+ obj: Optional[DefaultGenerator],
+ ) -> TypeGuard[ScalarElementColumnDefault]:
+ ...
+
+else:
+ default_is_sequence = operator.attrgetter("is_sequence")
+
+ default_is_clause_element = operator.attrgetter("is_clause_element")
+
+ default_is_scalar = operator.attrgetter("is_scalar")
+
+
class DefaultGenerator(Executable, SchemaItem):
"""Base class for column *default* values."""
is_sequence = False
is_server_default = False
+ is_clause_element = False
+ is_callable = False
is_scalar = False
- column = None
+ column: Optional[Column[Any]]
def __init__(self, for_update=False):
self.for_update = for_update
- @util.memoized_property
- def is_callable(self):
- raise NotImplementedError()
-
- def _set_parent(self, column, **kw):
- self.column = column
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ if TYPE_CHECKING:
+ assert isinstance(parent, Column)
+ self.column = parent
if self.for_update:
self.column.onupdate = self
else:
)
-class ColumnDefault(DefaultGenerator):
+class ColumnDefault(DefaultGenerator, ABC):
"""A plain default value on a column.
This could correspond to a constant, a callable function,
"""
- def __init__(self, arg, **kwargs):
+ arg: Any
+
+ @overload
+ def __new__(
+ cls, arg: Callable[..., Any], for_update: bool = ...
+ ) -> CallableColumnDefault:
+ ...
+
+ @overload
+ def __new__(
+ cls, arg: ColumnElement[Any], for_update: bool = ...
+ ) -> ColumnElementColumnDefault:
+ ...
+
+ # if I return ScalarElementColumnDefault here, which is what's actually
+ # returned, mypy complains that
+ # overloads overlap w/ incompatible return types.
+ @overload
+ def __new__(cls, arg: object, for_update: bool = ...) -> ColumnDefault:
+ ...
+
+ def __new__(
+ cls, arg: Any = None, for_update: bool = False
+ ) -> ColumnDefault:
"""Construct a new :class:`.ColumnDefault`.
statement and parameters.
"""
- super(ColumnDefault, self).__init__(**kwargs)
+
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
"ColumnDefault may not be a server-side default type."
)
- if callable(arg):
- arg = self._maybe_wrap_callable(arg)
+ elif callable(arg):
+ cls = CallableColumnDefault
+ elif isinstance(arg, ClauseElement):
+ cls = ColumnElementColumnDefault
+ elif arg is not None:
+ cls = ScalarElementColumnDefault
+
+ return object.__new__(cls)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}({self.arg!r})"
+
+
+class ScalarElementColumnDefault(ColumnDefault):
+ """default generator for a fixed scalar Python value
+
+ .. versionadded: 2.0
+
+ """
+
+ is_scalar = True
+
+ def __init__(self, arg: Any, for_update: bool = False):
+ self.for_update = for_update
self.arg = arg
- @util.memoized_property
- def is_callable(self):
- return callable(self.arg)
- @util.memoized_property
- def is_clause_element(self):
- return isinstance(self.arg, ClauseElement)
+# _SQLExprDefault = Union["ColumnElement[Any]", "TextClause", "SelectBase"]
+_SQLExprDefault = Union["ColumnElement[Any]", "TextClause"]
- @util.memoized_property
- def is_scalar(self):
- return (
- not self.is_callable
- and not self.is_clause_element
- and not self.is_sequence
- )
+
+class ColumnElementColumnDefault(ColumnDefault):
+ """default generator for a SQL expression
+
+ .. versionadded:: 2.0
+
+ """
+
+ is_clause_element = True
+
+ arg: _SQLExprDefault
+
+ def __init__(
+ self,
+ arg: _SQLExprDefault,
+ for_update: bool = False,
+ ):
+ self.for_update = for_update
+ self.arg = arg
@util.memoized_property
@util.preload_module("sqlalchemy.sql.sqltypes")
def _arg_is_typed(self):
sqltypes = util.preloaded.sql_sqltypes
- if self.is_clause_element:
- return not isinstance(self.arg.type, sqltypes.NullType)
- else:
- return False
+ return not isinstance(self.arg.type, sqltypes.NullType)
+
+
+class _CallableColumnDefaultProtocol(Protocol):
+ def __call__(self, context: ExecutionContext) -> Any:
+ ...
- def _maybe_wrap_callable(self, fn):
+
+class CallableColumnDefault(ColumnDefault):
+ """default generator for a callable Python function
+
+ .. versionadded:: 2.0
+
+ """
+
+ is_callable = True
+ arg: _CallableColumnDefaultProtocol
+
+ def __init__(
+ self,
+ arg: Union[_CallableColumnDefaultProtocol, Callable[[], Any]],
+ for_update: bool = False,
+ ):
+ self.for_update = for_update
+ self.arg = self._maybe_wrap_callable(arg)
+
+ def _maybe_wrap_callable(
+ self, fn: Union[_CallableColumnDefaultProtocol, Callable[[], Any]]
+ ) -> _CallableColumnDefaultProtocol:
"""Wrap callables that don't accept a context.
This is to allow easy compatibility with default callables
that aren't specific to accepting of a context.
"""
+
try:
argspec = util.get_callable_argspec(fn, no_self=True)
except TypeError:
- return util.wrap_callable(lambda ctx: fn(), fn)
+ return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
defaulted = argspec[3] is not None and len(argspec[3]) or 0
positionals = len(argspec[0]) - defaulted
if positionals == 0:
- return util.wrap_callable(lambda ctx: fn(), fn)
+ return util.wrap_callable(lambda ctx: fn(), fn) # type: ignore
elif positionals == 1:
- return fn
+ return fn # type: ignore
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
"positional arguments"
)
- def __repr__(self):
- return "ColumnDefault(%r)" % (self.arg,)
-
class IdentityOptions:
"""Defines options for a named database sequence or an identity column.
is_sequence = True
+ column: Optional[Column[Any]] = None
+
def __init__(
self,
name,
else:
self.data_type = None
- @util.memoized_property
- def is_callable(self):
- return False
-
- @util.memoized_property
- def is_clause_element(self):
- return False
-
@util.preload_module("sqlalchemy.sql.functions")
def next_value(self):
"""Return a :class:`.next_value` function element
__visit_name__ = "constraint"
+ _creation_order: int
+ _column_flag: bool
+
def __init__(
self,
name=None,
class ColumnCollectionMixin:
-
- columns = None
"""A :class:`_expression.ColumnCollection` of :class:`_schema.Column`
objects.
"""
+ columns: ColumnCollection[Column[Any]]
+
_allow_multiple_tables = False
+ if TYPE_CHECKING:
+
+ def _set_parent_with_dispatch(
+ self, parent: SchemaEventTarget, **kw: Any
+ ) -> None:
+ ...
+
def __init__(self, *columns, **kw):
_autoattach = kw.pop("_autoattach", True)
self._column_flag = kw.pop("_column_flag", False)
)
)
- def _col_expressions(self, table):
+ def _col_expressions(self, table: Table) -> List[Column[Any]]:
return [
table.c[col] if isinstance(col, str) else col
for col in self._pending_colargs
]
- def _set_parent(self, table, **kw):
- for col in self._col_expressions(table):
+ def _set_parent(self, parent: SchemaEventTarget, **kw: Any) -> None:
+ if TYPE_CHECKING:
+ assert isinstance(parent, Table)
+ for col in self._col_expressions(parent):
if col is not None:
self.columns.add(col)
self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
)
- columns = None
+ columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
"""
self.sqltext = coercions.expect(roles.DDLExpressionRole, sqltext)
- columns = []
+ columns: List[Column[Any]] = []
visitors.traverse(self.sqltext, {}, {"column": columns.append})
super(CheckConstraint, self).__init__(
assert table is self.parent
self._set_parent_with_dispatch(table)
- def _append_element(self, column, fk):
+ def _append_element(self, column: Column[Any], fk: ForeignKey) -> None:
self.columns.add(column)
self.elements.append(fk)
- columns = None
+ columns: DedupeColumnCollection[Column[Any]]
"""A :class:`_expression.ColumnCollection` representing the set of columns
for this constraint.
"""
- elements = None
+ elements: List[ForeignKey]
"""A sequence of :class:`_schema.ForeignKey` objects.
Each :class:`_schema.ForeignKey`
self._validate_dialect_kwargs(kw)
- self.expressions = []
+ self.expressions: List[ColumnElement[Any]] = []
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
)
if info:
self.info = info
- self._schemas = set()
- self._sequences = {}
- self._fk_memos = collections.defaultdict(list)
+ self._schemas: Set[str] = set()
+ self._sequences: Dict[str, Sequence] = {}
+ self._fk_memos: Dict[
+ Tuple[str, str], List[ForeignKey]
+ ] = collections.defaultdict(list)
- tables: Dict[str, Table]
+ tables: util.FacadeDict[str, Table]
"""A dictionary of :class:`_schema.Table`
objects keyed to their name or "table key".
def _remove_table(self, name, schema):
key = _get_table_key(name, schema)
- removed = dict.pop(self.tables, key, None)
+ removed = dict.pop(self.tables, key, None) # type: ignore
if removed is not None:
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
"""
return ddl.sort_tables(
- sorted(self.tables.values(), key=lambda t: t.key)
+ sorted(self.tables.values(), key=lambda t: t.key) # type: ignore
)
def reflect(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection],
schema: Optional[str] = None,
views: bool = False,
only: Optional[_typing_Sequence[str]] = None,
autoload_replace: bool = True,
resolve_fks: bool = True,
**dialect_kwargs: Any,
- ):
+ ) -> None:
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
if schema is not None:
reflect_opts["schema"] = schema
- available = util.OrderedSet(insp.get_table_names(schema))
+ available: util.OrderedSet[str] = util.OrderedSet(
+ insp.get_table_names(schema)
+ )
if views:
available.update(insp.get_view_names(schema))
if schema is not None:
- available_w_schema = util.OrderedSet(
+ available_w_schema: util.OrderedSet[str] = util.OrderedSet(
["%s.%s" % (schema, name) for name in available]
)
else:
def create_all(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
- ):
+ ) -> None:
"""Create all tables stored in this metadata.
Conditional by default, will not attempt to recreate tables already
def drop_all(
self,
- bind: Union["Engine", "Connection"],
+ bind: Union[Engine, Connection, MockConnection],
tables: Optional[_typing_Sequence[Table]] = None,
checkfirst: bool = True,
- ):
+ ) -> None:
"""Drop all tables stored in this metadata.
Conditional by default, will not attempt to drop tables not present in
_is_clone_of: Optional[FromClause]
- schema = None
+ schema: Optional[str] = None
"""Define the 'schema' attribute for this :class:`_expression.FromClause`.
This is typically ``None`` for most objects except that of
"""
return self._cloned_set.intersection(other._cloned_set)
- @property
+ @util.non_memoized_property
def description(self) -> str:
"""A brief description of this :class:`_expression.FromClause`.
return self.columns
@util.memoized_property
- def columns(self) -> ColumnCollection:
+ def columns(self) -> ColumnCollection[Any]:
"""A named-based collection of :class:`_expression.ColumnElement`
objects maintained by this :class:`_expression.FromClause`.
# this is awkward. maybe there's a better way
if TYPE_CHECKING:
- c: ColumnCollection
+ c: ColumnCollection[Any]
else:
c = property(
attrgetter("columns"),
_is_table = True
+ fullname: str
+
implicit_returning = False
""":class:`_expression.TableClause`
doesn't support having a primary key or column
__visit_name__ = "integer"
+ if TYPE_CHECKING:
+
+ @util.ro_memoized_property
+ def _type_affinity(self) -> Type[Integer]:
+ ...
+
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
operators.truediv: {Numeric: self.__class__},
}
- @util.non_memoized_property
- def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
+ @util.ro_non_memoized_property
+ def _type_affinity(self) -> Type[Interval]:
return Interval
"""
return self
- @util.memoized_property
+ @util.ro_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[_T]]]:
"""Return a rudimental 'affinity' value expressing the general class
of type."""
else:
return self.__class__
- @util.memoized_property
+ @util.ro_memoized_property
def _generic_type_affinity(
self,
) -> Type[TypeEngine[_T]]:
tt.impl = tt.impl_instance = typedesc
return tt
- @util.non_memoized_property
+ @util.ro_non_memoized_property
def _type_affinity(self) -> Optional[Type[TypeEngine[Any]]]:
return self.impl_instance._type_affinity
from .langhelpers import PluginLoader as PluginLoader
from .langhelpers import portable_instancemethod as portable_instancemethod
from .langhelpers import quoted_token_parser as quoted_token_parser
+from .langhelpers import ro_memoized_property as ro_memoized_property
+from .langhelpers import ro_non_memoized_property as ro_non_memoized_property
from .langhelpers import safe_reraise as safe_reraise
from .langhelpers import set_creation_order as set_creation_order
from .langhelpers import string_or_unprintable as string_or_unprintable
EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
-class FacadeDict(ImmutableDictBase[Any, Any]):
+class FacadeDict(ImmutableDictBase[_KT, _VT]):
"""A dictionary that is not publicly mutable."""
def __new__(cls, *args):
_F = TypeVar("_F", bound=Callable[..., Any])
_MP = TypeVar("_MP", bound="memoized_property[Any]")
_MA = TypeVar("_MA", bound="HasMemoized.memoized_attribute[Any]")
-_HP = TypeVar("_HP", bound="hybridproperty")
-_HM = TypeVar("_HM", bound="hybridmethod")
+_HP = TypeVar("_HP", bound="hybridproperty[Any]")
+_HM = TypeVar("_HM", bound="hybridmethod[Any]")
if compat.py310:
# superclass has non-memoized, the class hierarchy of the descriptors
# would need to be reversed; "class non_memoized(memoized)". so there's no
# way to achieve this.
+# additional issues, RO properties:
+# https://github.com/python/mypy/issues/12440
if TYPE_CHECKING:
+
+ # allow memoized and non-memoized to be freely mixed by having them
+ # be the same class
memoized_property = generic_fn_descriptor
non_memoized_property = generic_fn_descriptor
+
+ # for read only situations, mypy only sees @property as read only.
+ # read only is needed when a subtype specializes the return type
+ # of a property, meaning assignment needs to be disallowed
+ ro_memoized_property = property
+ ro_non_memoized_property = property
else:
- memoized_property = _memoized_property
- non_memoized_property = _non_memoized_property
+ memoized_property = ro_memoized_property = _memoized_property
+ non_memoized_property = ro_non_memoized_property = _non_memoized_property
def memoized_instancemethod(fn: _F) -> _F:
return default
-def assert_arg_type(arg: Any, argtype: Type[Any], name: str) -> Any:
+def assert_arg_type(
+ arg: Any, argtype: Union[Tuple[Type[Any], ...], Type[Any]], name: str
+) -> Any:
if isinstance(arg, argtype):
return arg
else:
return self.fget(cls) # type: ignore
-class hybridproperty:
- def __init__(self, func):
+class hybridproperty(Generic[_T]):
+ def __init__(self, func: Callable[..., _T]):
self.func = func
self.clslevel = func
- def __get__(self, instance, owner):
+ def __get__(self, instance: Any, owner: Any) -> _T:
if instance is None:
clsval = self.clslevel(owner)
return clsval
else:
return self.func(instance)
- def classlevel(self, func):
+ def classlevel(self, func: Callable[..., Any]) -> hybridproperty[_T]:
self.clslevel = func
return self
-class hybridmethod:
+class hybridmethod(Generic[_T]):
"""Decorate a function as cls- or instance- level."""
- def __init__(self, func):
+ def __init__(self, func: Callable[..., _T]):
self.func = self.__func__ = func
self.clslevel = func
- def __get__(self, instance, owner):
+ def __get__(self, instance: Any, owner: Any) -> Callable[..., _T]:
if instance is None:
- return self.clslevel.__get__(owner, owner.__class__)
+ return self.clslevel.__get__(owner, owner.__class__) # type:ignore
else:
- return self.func.__get__(instance, owner)
+ return self.func.__get__(instance, owner) # type:ignore
- def classlevel(self, func):
+ def classlevel(self, func: Callable[..., Any]) -> hybridmethod[_T]:
self.clslevel = func
return self
if compat.py310:
from typing import TypeGuard as TypeGuard
+ from typing import TypeAlias as TypeAlias
else:
from typing_extensions import TypeGuard as TypeGuard
+ from typing_extensions import TypeAlias as TypeAlias
if typing.TYPE_CHECKING or compat.py38:
from typing import SupportsIndex as SupportsIndex
[tool.mypy]
mypy_path = "./lib/"
show_error_codes = true
-strict = false
+strict = true
incremental = true
-# disabled checking
[[tool.mypy.overrides]]
-module="sqlalchemy.*"
-ignore_errors = true
-warn_unused_ignores = false
-
-strict = true
-
-# some debate at
-# https://github.com/python/mypy/issues/8754.
-# implicit_reexport = true
-# individual packages or even modules should be listed here
-# with strictness-specificity set up. there's no way we are going to get
-# the whole library 100% strictly typed, so we have to tune this based on
-# the type of module or package we are dealing with
-
-[[tool.mypy.overrides]]
# ad-hoc ignores
module = [
"sqlalchemy.engine.reflection", # interim, should be strict
+
+ # TODO for strict:
+ "sqlalchemy.ext.asyncio.*",
+ "sqlalchemy.ext.automap",
+ "sqlalchemy.ext.compiler",
+ "sqlalchemy.ext.declarative.*",
+ "sqlalchemy.ext.mutable",
+ "sqlalchemy.ext.horizontal_shard",
+
+ "sqlalchemy.sql._selectable_constructors",
+ "sqlalchemy.sql._dml_constructors",
+
+ # TODO for non-strict:
+ "sqlalchemy.ext.baked",
+ "sqlalchemy.ext.instrumentation",
+ "sqlalchemy.ext.indexable",
+ "sqlalchemy.ext.orderinglist",
+ "sqlalchemy.ext.serializer",
+
+ "sqlalchemy.sql.selectable", # would be nice as strict
+ "sqlalchemy.sql.ddl",
+ "sqlalchemy.sql.functions", # would be nice as strict
+ "sqlalchemy.sql.lambdas",
+ "sqlalchemy.sql.dml", # would be nice as strict
+ "sqlalchemy.sql.util",
+
+ # not yet classified:
+ "sqlalchemy.orm.*",
+ "sqlalchemy.dialects.*",
+ "sqlalchemy.cyextension.*",
+ "sqlalchemy.future.*",
+ "sqlalchemy.testing.*",
+
]
+warn_unused_ignores = false
ignore_errors = true
# strict checking
[[tool.mypy.overrides]]
+
module = [
- "sqlalchemy.sql.annotation",
- "sqlalchemy.sql.cache_key",
- "sqlalchemy.sql._elements_constructors",
- "sqlalchemy.sql.operators",
- "sqlalchemy.sql.type_api",
- "sqlalchemy.sql.roles",
- "sqlalchemy.sql.visitors",
- "sqlalchemy.sql._py_util",
+ # packages
"sqlalchemy.connectors.*",
+ "sqlalchemy.event.*",
+ "sqlalchemy.ext.*",
+ "sqlalchemy.sql.*",
"sqlalchemy.engine.*",
- "sqlalchemy.ext.hybrid",
- "sqlalchemy.ext.associationproxy",
"sqlalchemy.pool.*",
- "sqlalchemy.event.*",
+
+ # modules
"sqlalchemy.events",
"sqlalchemy.exc",
"sqlalchemy.inspection",
"sqlalchemy.schema",
"sqlalchemy.types",
]
+
+warn_unused_ignores = false
ignore_errors = false
strict = true
[[tool.mypy.overrides]]
module = [
- #"sqlalchemy.sql.*",
- "sqlalchemy.sql.sqltypes",
- "sqlalchemy.sql.elements",
+ "sqlalchemy.engine.cursor",
+ "sqlalchemy.engine.default",
+
+ "sqlalchemy.sql.base",
"sqlalchemy.sql.coercions",
"sqlalchemy.sql.compiler",
- #"sqlalchemy.sql.default_comparator",
+ "sqlalchemy.sql.crud",
+ "sqlalchemy.sql.elements", # would be nice as strict
"sqlalchemy.sql.naming",
+ "sqlalchemy.sql.schema", # would be nice as strict
+ "sqlalchemy.sql.sqltypes", # would be nice as strict
"sqlalchemy.sql.traversals",
+
"sqlalchemy.util.*",
- "sqlalchemy.engine.cursor",
- "sqlalchemy.engine.default",
]
+warn_unused_ignores = false
ignore_errors = false
# mostly strict without requiring totally untyped things to be
# /home/classic/dev/sqlalchemy/test/profiles.txt
# This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and
+# For each test in aaa_profiling, the corresponding function and
# environment is located within this file. If it doesn't exist,
# the test is skipped.
-# If a callcount does exist, it is compared to what we received.
+# If a callcount does exist, it is compared to what we received.
# assertions are raised if the counts do not match.
-#
-# To add a new callcount test, apply the function_call_count
-# decorator and re-run the tests using the --write-profiles
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
# option - this file will be rewritten including the new count.
-#
+#
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 72
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 70
-test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 70
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_mysqldb_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mariadb_pymysql_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_mssql_pyodbc_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_oracle_cx_oracle_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_postgresql_psycopg2_dbapiunicode_nocextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_cextensions 75
+test.aaa_profiling.test_compiler.CompileTest.test_insert x86_64_linux_cpython_3.10_sqlite_pysqlite_dbapiunicode_nocextensions 75
# TEST: test.aaa_profiling.test_compiler.CompileTest.test_select
assert_raises_message(
sa.exc.ArgumentError,
r"SQL expression for WHERE/HAVING role expected, "
- r"got (?:Sequence|ColumnDefault|DefaultClause)\('y'.*\)",
+ r"got (?:Sequence|(?:ScalarElement)ColumnDefault|"
+ r"DefaultClause)\('y'.*\)",
t.select().where,
const,
)
"%s"
", name='someconstraint')" % repr(ck.sqltext),
),
- (ColumnDefault(("foo", "bar")), "ColumnDefault(('foo', 'bar'))"),
+ (
+ ColumnDefault(("foo", "bar")),
+ "ScalarElementColumnDefault(('foo', 'bar'))",
+ ),
):
eq_(repr(const), exp)
a2 = a.to_metadata(m2)
assert b2.c.y.references(a2.c.x)
+ def test_fk_w_no_colname(self):
+ """test a ForeignKey that refers to table name only. the column
+ name is assumed to be the same col name on parent table.
+
+ this is a little used feature from long ago that nonetheless is
+ still in the code.
+
+ The feature was found to be not working but is repaired for
+ SQLAlchemy 2.0.
+
+ """
+ m1 = MetaData()
+ a = Table("a", m1, Column("x", Integer))
+ b = Table("b", m1, Column("x", Integer, ForeignKey("a")))
+ assert b.c.x.references(a.c.x)
+
+ m2 = MetaData()
+ b2 = b.to_metadata(m2)
+ a2 = a.to_metadata(m2)
+ assert b2.c.x.references(a2.c.x)
+
+ def test_fk_w_no_colname_name_missing(self):
+ """test a ForeignKey that refers to table name only. the column
+ name is assumed to be the same col name on parent table.
+
+ this is a little used feature from long ago that nonetheless is
+ still in the code.
+
+ """
+ m1 = MetaData()
+ a = Table("a", m1, Column("x", Integer))
+ b = Table("b", m1, Column("y", Integer, ForeignKey("a")))
+
+ with expect_raises_message(
+ exc.NoReferencedColumnError,
+ "Could not initialize target column for ForeignKey 'a' on "
+ "table 'b': table 'a' has no column named 'y'",
+ ):
+ assert b.c.y.references(a.c.x)
+
def test_column_collection_constraint_w_ad_hoc_columns(self):
"""Test ColumnCollectionConstraint that has columns that aren't
part of the Table.