- name: Run tests
run: tox -e pep8
+
+ run-pep484:
+ name: pep484-${{ matrix.python-version }}
+ runs-on: ${{ matrix.os }}
+ strategy:
+ # run this job using this matrix, excluding some combinations below.
+ matrix:
+ os:
+ - "ubuntu-latest"
+ python-version:
+ - "3.10"
+
+ fail-fast: false
+
+ # steps to run in each job. Some are github actions, others run shell commands
+ steps:
+ - name: Checkout repo
+ uses: actions/checkout@v2
+
+ - name: Set up python
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ architecture: ${{ matrix.architecture }}
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install --upgrade tox setuptools
+ pip list
+
+ - name: Run tests
+ run: tox -e pep484
recursive-include examples *.py *.xml
recursive-include test *.py *.dat *.testpatch
+# for some reason in some environments stale Cython .c files
+# are being pulled in, these should never be in a dist
+exclude lib/sqlalchemy/cyextension/*.c
+exclude lib/sqlalchemy/cyextension/*.so
+
# include the pyx and pxd extensions, which otherwise
# don't come in if --with-cextensions isn't specified.
recursive-include lib *.pyx *.pxd *.txt *.typed
def unique_list(seq, hashfunc=None):
return cunique_list(seq, hashfunc)
-cdef class OrderedSet(set):
+cdef class OrderedSet:
cdef list _list
+ cdef set _set
def __init__(self, d=None):
- set.__init__(self)
if d is not None:
self._list = cunique_list(d)
- set.update(self, self._list)
+ self._set = set(self._list)
else:
self._list = []
+ self._set = set()
cdef OrderedSet _copy(self):
cdef OrderedSet cp = OrderedSet.__new__(OrderedSet)
cp._list = list(self._list)
- set.update(cp, cp._list)
+ cp._set = set(cp._list)
return cp
cdef OrderedSet _from_list(self, list new_list):
cdef OrderedSet new = OrderedSet.__new__(OrderedSet)
new._list = new_list
- set.update(new, new_list)
+ new._set = set(new_list)
return new
def add(self, element):
if element not in self:
self._list.append(element)
- PySet_Add(self, element)
+ PySet_Add(self._set, element)
def remove(self, element):
# set.remove will raise if element is not in self
- set.remove(self, element)
+ self._set.remove(element)
self._list.remove(element)
def insert(self, Py_ssize_t pos, element):
if element not in self:
self._list.insert(pos, element)
- PySet_Add(self, element)
+ PySet_Add(self._set, element)
def discard(self, element):
if element in self:
- set.remove(self, element)
+ self._set.remove(element)
self._list.remove(element)
def clear(self):
- set.clear(self)
+ self._set.clear()
self._list = []
def __getitem__(self, key):
__str__ = __repr__
- def update(self, iterable):
- for e in iterable:
- if e not in self:
- self._list.append(e)
- set.add(self, e)
- return self
+ def update(self, *iterables):
+ for iterable in iterables:
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ self._set.add(e)
def __ior__(self, iterable):
- return self.update(iterable)
+ self.update(iterable)
+ return self
def union(self, other):
result = self._copy()
result.update(other)
return result
+ def __len__(self) -> int:
+ return len(self._set)
+
+ def __eq__(self, other):
+ return self._set == other
+
+ def __ne__(self, other):
+ return self._set != other
+
+ def __contains__(self, element):
+ return element in self._set
+
def __or__(self, other):
return self.union(other)
cdef set other_set = self._to_set(other)
set.intersection_update(self, other_set)
self._list = [a for a in self._list if a in other_set]
- return self
def __iand__(self, other):
- return self.intersection_update(other)
+ self.intersection_update(other)
+ return self
def symmetric_difference_update(self, other):
set.symmetric_difference_update(self, other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- return self
def __ixor__(self, other):
- return self.symmetric_difference_update(other)
+ self.symmetric_difference_update(other)
+ return self
def difference_update(self, other):
set.difference_update(self, other)
self._list = [a for a in self._list if a in self]
- return self
def __isub__(self, other):
- return self.difference_update(other)
+ self.difference_update(other)
+ return self
cdef object cy_id(object item):
__delitem__ = __setitem__ = __setattr__ = _immutable
+class ImmutableDictBase(dict):
+ def _immutable(self, *a,**kw):
+ _immutable_fn(self)
+
+ @classmethod
+ def __class_getitem__(cls, key):
+ return cls
+
+ __delitem__ = __setitem__ = __setattr__ = _immutable
+ clear = pop = popitem = setdefault = update = _immutable
+
cdef class immutabledict(dict):
def __repr__(self):
return f"immutabledict({dict.__repr__(self)})"
+ @classmethod
+ def __class_getitem__(cls, key):
+ return cls
+
def union(self, *args, **kw):
cdef dict to_merge = None
cdef immutabledict result
from .util import _distill_raw_params
from .util import TransactionalContext
from .. import exc
+from .. import inspection
from .. import log
from .. import util
from ..sql import compiler
from ..sql import util as sql_util
+from ..sql._typing import _ExecuteOptions
+from ..sql._typing import _ExecuteParams
if typing.TYPE_CHECKING:
from .interfaces import Dialect
+ from .reflection import Inspector # noqa
from .url import URL
from ..pool import Pool
from ..pool import PoolProxiedConnection
NO_OPTIONS = util.immutabledict()
-class Connection(ConnectionEventsTarget):
+class Connection(ConnectionEventsTarget, inspection.Inspectable["Inspector"]):
"""Provides high-level functionality for a wrapped DB-API connection.
The :class:`_engine.Connection` object is procured by calling
return self.execute(statement, parameters, execution_options).scalars()
- def execute(self, statement, parameters=None, execution_options=None):
+ def execute(
+ self,
+ statement,
+ parameters: Optional[_ExecuteParams] = None,
+ execution_options: Optional[_ExecuteOptions] = None,
+ ):
r"""Executes a SQL statement construct and returns a
:class:`_engine.Result`.
self.connection._commit_twophase_impl(self.xid, self._is_prepared)
-class Engine(ConnectionEventsTarget, log.Identified):
+class Engine(
+ ConnectionEventsTarget, log.Identified, inspection.Inspectable["Inspector"]
+):
"""
Connects a :class:`~sqlalchemy.pool.Pool` and
:class:`~sqlalchemy.engine.interfaces.Dialect` together to provide a
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import Any
from . import base
from . import url as _url
"is deprecated and will be removed in a future release. ",
),
)
-def create_engine(url, **kwargs):
+def create_engine(url: "_url.URL", **kwargs: Any) -> "base.Engine":
"""Create a new :class:`_engine.Engine` instance.
The standard calling form is to send the :ref:`URL <database_urls>` as the
@inspection._self_inspects
-class Inspector:
+class Inspector(inspection.Inspectable["Inspector"]):
"""Performs database schema inspection.
The Inspector acts as a proxy to the reflection methods of the
"""
import collections
from itertools import chain
+import threading
import weakref
from . import legacy
from . import registry
from .. import exc
from .. import util
-from ..util import threading
from ..util.concurrency import AsyncAdaptedLock
:exc:`.DBAPIError`.
"""
+import typing
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import Union
from .util import _preloaded
from .util import compat
+if typing.TYPE_CHECKING:
+ from .engine.interfaces import Dialect
+ from .sql._typing import _ExecuteParams
+ from .sql.compiler import Compiled
+ from .sql.elements import ClauseElement
+
_version_token = None
class HasDescriptionCode:
"""helper which adds 'code' as an attribute and '_code_str' as a method"""
- code = None
+ code: Optional[str] = None
- def __init__(self, *arg, **kw):
+ def __init__(self, *arg: Any, **kw: Any):
code = kw.pop("code", None)
if code is not None:
self.code = code
super(HasDescriptionCode, self).__init__(*arg, **kw)
- def _code_str(self):
+ def _code_str(self) -> str:
if not self.code:
return ""
else:
)
)
- def __str__(self):
+ def __str__(self) -> str:
message = super(HasDescriptionCode, self).__str__()
if self.code:
message = "%s %s" % (message, self._code_str())
class SQLAlchemyError(HasDescriptionCode, Exception):
"""Generic error class."""
- def _message(self):
+ def _message(self) -> str:
# rules:
#
# 1. single arg string will usually be a unicode
# SQLAlchemy though this is happening in at least one known external
# library, call str() which does a repr().
#
+ text: str
+
if len(self.args) == 1:
- text = self.args[0]
+ arg_text = self.args[0]
- if isinstance(text, bytes):
- text = compat.decode_backslashreplace(text, "utf-8")
+ if isinstance(arg_text, bytes):
+ text = compat.decode_backslashreplace(arg_text, "utf-8")
# This is for when the argument is not a string of any sort.
# Otherwise, converting this exception to string would fail for
# non-string arguments.
else:
- text = str(text)
+ text = str(arg_text)
return text
else:
# a repr() of the tuple
return str(self.args)
- def _sql_message(self):
+ def _sql_message(self) -> str:
message = self._message()
if self.code:
return message
- def __str__(self):
+ def __str__(self) -> str:
return self._sql_message()
"""
- def __init__(self, target):
+ def __init__(self, target: Any):
super(ObjectNotExecutableError, self).__init__(
"Not an executable object: %r" % target
)
self.target = target
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return self.__class__, (self.target,)
"""
- def __init__(self, message, cycles, edges, msg=None, code=None):
+ def __init__(
+ self,
+ message: str,
+ cycles: Any,
+ edges: Any,
+ msg: Optional[str] = None,
+ code: Optional[str] = None,
+ ):
if msg is None:
message += " (%s)" % ", ".join(repr(s) for s in cycles)
else:
self.cycles = cycles
self.edges = edges
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return (
self.__class__,
(None, self.cycles, self.edges, self.args[0]),
code = "l7de"
- def __init__(self, compiler, element_type, message=None):
+ def __init__(
+ self,
+ compiler: "Compiled",
+ element_type: Type["ClauseElement"],
+ message: Optional[str] = None,
+ ):
super(UnsupportedCompilationError, self).__init__(
"Compiler %r can't render element of type %s%s"
% (compiler, element_type, ": %s" % message if message else "")
self.element_type = element_type
self.message = message
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return self.__class__, (self.compiler, self.element_type, self.message)
"""
- invalidate_pool = False
+ invalidate_pool: bool = False
class InvalidatePoolError(DisconnectionError):
"""
- invalidate_pool = True
+ invalidate_pool: bool = True
class TimeoutError(SQLAlchemyError): # noqa
"""
- def __init__(self, message, tname):
+ def __init__(self, message: str, tname: str):
NoReferenceError.__init__(self, message)
self.table_name = tname
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return self.__class__, (self.args[0], self.table_name)
"""
- def __init__(self, message, tname, cname):
+ def __init__(self, message: str, tname: str, cname: str):
NoReferenceError.__init__(self, message)
self.table_name = tname
self.column_name = cname
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return (
self.__class__,
(self.args[0], self.table_name, self.column_name),
"""
- statement = None
+ statement: Optional[str] = None
"""The string SQL statement being invoked when this exception occurred."""
- params = None
+ params: Optional["_ExecuteParams"] = None
"""The parameter list being used when this exception occurred."""
- orig = None
- """The DBAPI exception object."""
+ orig: Optional[BaseException] = None
+ """The original exception that was thrown.
+
+ """
- ismulti = None
+ ismulti: Optional[bool] = None
+ """multi parameter passed to repr_params(). None is meaningful."""
def __init__(
self,
- message,
- statement,
- params,
- orig,
- hide_parameters=False,
- code=None,
- ismulti=None,
+ message: str,
+ statement: Optional[str],
+ params: Optional["_ExecuteParams"],
+ orig: Optional[BaseException],
+ hide_parameters: bool = False,
+ code: Optional[str] = None,
+ ismulti: Optional[bool] = None,
):
SQLAlchemyError.__init__(self, message, code=code)
self.statement = statement
self.orig = orig
self.ismulti = ismulti
self.hide_parameters = hide_parameters
- self.detail = []
+ self.detail: List[str] = []
- def add_detail(self, msg):
+ def add_detail(self, msg: str) -> None:
self.detail.append(msg)
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return (
self.__class__,
(
)
@_preloaded.preload_module("sqlalchemy.sql.util")
- def _sql_message(self):
- util = _preloaded.preloaded.sql_util
+ def _sql_message(self) -> str:
+ if typing.TYPE_CHECKING:
+ from .sql import util
+ else:
+ util = _preloaded.preloaded.sql_util
details = [self._message()]
if self.statement:
code = "dbapi"
+ # I dont think I'm going to try to do overloads like this everywhere
+ # in the library, but as this module is early days for me typing everything
+ # I am sort of just practicing
+
+ @overload
@classmethod
def instance(
cls,
- statement,
- params,
- orig,
- dbapi_base_err,
- hide_parameters=False,
- connection_invalidated=False,
- dialect=None,
- ismulti=None,
- ):
+ statement: str,
+ params: "_ExecuteParams",
+ orig: DontWrapMixin,
+ dbapi_base_err: Type[Exception],
+ hide_parameters: bool = False,
+ connection_invalidated: bool = False,
+ dialect: Optional["Dialect"] = None,
+ ismulti: Optional[bool] = None,
+ ) -> DontWrapMixin:
+ ...
+
+ @overload
+ @classmethod
+ def instance(
+ cls,
+ statement: str,
+ params: "_ExecuteParams",
+ orig: Exception,
+ dbapi_base_err: Type[Exception],
+ hide_parameters: bool = False,
+ connection_invalidated: bool = False,
+ dialect: Optional["Dialect"] = None,
+ ismulti: Optional[bool] = None,
+ ) -> StatementError:
+ ...
+
+ @overload
+ @classmethod
+ def instance(
+ cls,
+ statement: str,
+ params: "_ExecuteParams",
+ orig: BaseException,
+ dbapi_base_err: Type[Exception],
+ hide_parameters: bool = False,
+ connection_invalidated: bool = False,
+ dialect: Optional["Dialect"] = None,
+ ismulti: Optional[bool] = None,
+ ) -> BaseException:
+ ...
+
+ @classmethod
+ def instance(
+ cls,
+ statement: str,
+ params: "_ExecuteParams",
+ orig: Union[BaseException, DontWrapMixin],
+ dbapi_base_err: Type[Exception],
+ hide_parameters: bool = False,
+ connection_invalidated: bool = False,
+ dialect: Optional["Dialect"] = None,
+ ismulti: Optional[bool] = None,
+ ) -> Union[BaseException, DontWrapMixin]:
# Don't ever wrap these, just return them directly as if
# DBAPIError didn't exist.
if (
ismulti=ismulti,
)
- def __reduce__(self):
+ def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
return (
self.__class__,
(
def __init__(
self,
- statement,
- params,
- orig,
- hide_parameters=False,
- connection_invalidated=False,
- code=None,
- ismulti=None,
+ statement: str,
+ params: "_ExecuteParams",
+ orig: BaseException,
+ hide_parameters: bool = False,
+ connection_invalidated: bool = False,
+ code: Optional[str] = None,
+ ismulti: Optional[bool] = None,
):
try:
text = str(orig)
class SADeprecationWarning(HasDescriptionCode, DeprecationWarning):
"""Issued for usage of deprecated APIs."""
- deprecated_since = None
+ deprecated_since: Optional[str] = None
"Indicates the version that started raising this deprecation warning"
"""
- deprecated_since = "1.4"
+ deprecated_since: Optional[str] = "1.4"
"Indicates the version that started raising this deprecation warning"
- def __str__(self):
+ def __str__(self) -> str:
return (
super(Base20DeprecationWarning, self).__str__()
+ " (Background on SQLAlchemy 2.0 at: https://sqlalche.me/e/b8d9)"
"""
- deprecated_since = None
+ deprecated_since: Optional[str] = None
"Indicates the version that started raising this deprecation warning"
in a forwards-compatible way.
"""
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import Generic
+from typing import overload
+from typing import Type
+from typing import TypeVar
+from typing import Union
from . import exc
-from . import util
+from .util.typing import Literal
+_T = TypeVar("_T", bound=Any)
-_registrars = util.defaultdict(list)
+_registrars: Dict[type, Union[Literal[True], Callable[[Any], Any]]] = {}
-def inspect(subject, raiseerr=True):
+class Inspectable(Generic[_T]):
+ """define a class as inspectable.
+
+ This allows typing to set up a linkage between an object that
+ can be inspected and the type of inspection it returns.
+
+ """
+
+
+@overload
+def inspect(subject: Inspectable[_T], raiseerr: bool = True) -> _T:
+ ...
+
+
+@overload
+def inspect(subject: Any, raiseerr: bool = True) -> Any:
+ ...
+
+
+def inspect(subject: Any, raiseerr: bool = True) -> Any:
"""Produce an inspection object for the given target.
The returned value in some cases may be the
type_ = type(subject)
for cls in type_.__mro__:
if cls in _registrars:
- reg = _registrars[cls]
- if reg is True:
+ reg = _registrars.get(cls, None)
+ if reg is None:
+ continue
+ elif reg is True:
return subject
ret = reg(subject)
if ret is not None:
- break
+ return ret
else:
reg = ret = None
return ret
-def _inspects(*types):
- def decorate(fn_or_cls):
+def _inspects(
+ *types: type,
+) -> Callable[[Callable[[Any], Any]], Callable[[Any], Any]]:
+ def decorate(fn_or_cls: Callable[[Any], Any]) -> Callable[[Any], Any]:
for type_ in types:
if type_ in _registrars:
raise AssertionError(
return decorate
-def _self_inspects(cls):
- _inspects(cls)(True)
+def _self_inspects(cls: Type[_T]) -> Type[_T]:
+ if cls in _registrars:
+ raise AssertionError("Type %s is already " "registered" % cls)
+ _registrars[cls] = True
return cls
instance only.
"""
-
import logging
import sys
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import Set
+from typing import Type
+from typing import TypeVar
+from typing import Union
+
+from .util.typing import Literal
+
+_IT = TypeVar("_IT", bound="Identified")
+_EchoFlagType = Union[None, bool, Literal["debug"]]
# set initial level to WARN. This so that
# log statements don't occur in the absence of explicit
rootlogger.setLevel(logging.WARN)
-def _add_default_handler(logger):
+def _add_default_handler(logger: logging.Logger) -> None:
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(
logging.Formatter("%(asctime)s %(levelname)s %(name)s %(message)s")
logger.addHandler(handler)
-_logged_classes = set()
+_logged_classes: Set[Type["Identified"]] = set()
-def _qual_logger_name_for_cls(cls):
+def _qual_logger_name_for_cls(cls: Type["Identified"]) -> str:
return (
getattr(cls, "_sqla_logger_namespace", None)
or cls.__module__ + "." + cls.__name__
)
-def class_logger(cls):
+def class_logger(cls: Type[_IT]) -> Type[_IT]:
logger = logging.getLogger(_qual_logger_name_for_cls(cls))
- cls._should_log_debug = lambda self: logger.isEnabledFor(logging.DEBUG)
- cls._should_log_info = lambda self: logger.isEnabledFor(logging.INFO)
+ cls._should_log_debug = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa E501
+ logging.DEBUG
+ )
+ cls._should_log_info = lambda self: logger.isEnabledFor( # type: ignore[assignment] # noqa E501
+ logging.INFO
+ )
cls.logger = logger
_logged_classes.add(cls)
return cls
class Identified:
- logging_name = None
+ logging_name: Optional[str] = None
- def _should_log_debug(self):
+ logger: Union[logging.Logger, "InstanceLogger"]
+
+ _echo: _EchoFlagType
+
+ def _should_log_debug(self) -> bool:
return self.logger.isEnabledFor(logging.DEBUG)
- def _should_log_info(self):
+ def _should_log_info(self) -> bool:
return self.logger.isEnabledFor(logging.INFO)
"debug": logging.DEBUG,
}
- def __init__(self, echo, name):
+ _echo: _EchoFlagType
+
+ def __init__(self, echo: _EchoFlagType, name: str):
self.echo = echo
self.logger = logging.getLogger(name)
#
# Boilerplate convenience methods
#
- def debug(self, msg, *args, **kwargs):
+ def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate a debug call to the underlying logger."""
self.log(logging.DEBUG, msg, *args, **kwargs)
- def info(self, msg, *args, **kwargs):
+ def info(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate an info call to the underlying logger."""
self.log(logging.INFO, msg, *args, **kwargs)
- def warning(self, msg, *args, **kwargs):
+ def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate a warning call to the underlying logger."""
self.log(logging.WARNING, msg, *args, **kwargs)
warn = warning
- def error(self, msg, *args, **kwargs):
+ def error(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""
Delegate an error call to the underlying logger.
"""
self.log(logging.ERROR, msg, *args, **kwargs)
- def exception(self, msg, *args, **kwargs):
+ def exception(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate an exception call to the underlying logger."""
kwargs["exc_info"] = 1
self.log(logging.ERROR, msg, *args, **kwargs)
- def critical(self, msg, *args, **kwargs):
+ def critical(self, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate a critical call to the underlying logger."""
self.log(logging.CRITICAL, msg, *args, **kwargs)
- def log(self, level, msg, *args, **kwargs):
+ def log(self, level: int, msg: str, *args: Any, **kwargs: Any) -> None:
"""Delegate a log call to the underlying logger.
The level here is determined by the echo
if level >= selected_level:
self.logger._log(level, msg, args, **kwargs)
- def isEnabledFor(self, level):
+ def isEnabledFor(self, level: int) -> bool:
"""Is this logger enabled for level 'level'?"""
if self.logger.manager.disable >= level:
return False
return level >= self.getEffectiveLevel()
- def getEffectiveLevel(self):
+ def getEffectiveLevel(self) -> int:
"""What's the effective level for this logger?"""
level = self._echo_map[self.echo]
return level
-def instance_logger(instance, echoflag=None):
+def instance_logger(
+ instance: Identified, echoflag: _EchoFlagType = None
+) -> None:
"""create a logger for an instance that implements :class:`.Identified`."""
if instance.logging_name:
instance._echo = echoflag
+ logger: Union[logging.Logger, InstanceLogger]
+
if echoflag in (False, None):
# if no echo setting or False, return a Logger directly,
# avoiding overhead of filtering
``logging.DEBUG``.
"""
- def __get__(self, instance, owner):
+ @overload
+ def __get__(
+ self, instance: "Literal[None]", owner: "echo_property"
+ ) -> "echo_property":
+ ...
+
+ @overload
+ def __get__(
+ self, instance: Identified, owner: "echo_property"
+ ) -> _EchoFlagType:
+ ...
+
+ def __get__(
+ self, instance: Optional[Identified], owner: "echo_property"
+ ) -> Union["echo_property", _EchoFlagType]:
if instance is None:
return self
else:
return instance._echo
- def __set__(self, instance, value):
+ def __set__(self, instance: Identified, value: _EchoFlagType) -> None:
instance_logger(instance, echoflag=value)
"""
import operator
+import threading
import weakref
from sqlalchemy.util.compat import inspect_getfullargspec
"attribute_mapped_collection",
]
-__instrumentation_mutex = util.threading.Lock()
+__instrumentation_mutex = threading.Lock()
class _PlainColumnGetter:
from ..util import hybridmethod
from ..util import hybridproperty
+if typing.TYPE_CHECKING:
+ from .state import InstanceState # noqa
+
_T = TypeVar("_T", bound=Any)
return False
-class DeclarativeAttributeIntercept(type):
+class DeclarativeAttributeIntercept(
+ type, inspection.Inspectable["Mapper[Any]"]
+):
"""Metaclass that may be used in conjunction with the
:class:`_orm.DeclarativeBase` class to support addition of class
attributes dynamically.
_del_attribute(cls, key)
-class DeclarativeMeta(type):
+class DeclarativeMeta(type, inspection.Inspectable["Mapper[Any]"]):
def __init__(cls, classname, bases, dict_, **kw):
# early-consume registry from the initial declarative base,
# assign privately to not conflict with subclass attributes named
cls.metadata = cls.registry.metadata
-class DeclarativeBaseNoMeta:
+class DeclarativeBaseNoMeta(inspection.Inspectable["Mapper"]):
"""Same as :class:`_orm.DeclarativeBase`, but does not use a metaclass
to intercept new attributes.
cls._sa_registry.map_declaratively(cls)
-class DeclarativeBase(metaclass=DeclarativeAttributeIntercept):
+class DeclarativeBase(
+ inspection.Inspectable["InstanceState"],
+ metaclass=DeclarativeAttributeIntercept,
+):
"""Base class used for declarative class definitions.
The :class:`_orm.DeclarativeBase` allows for the creation of new
from functools import reduce
from itertools import chain
import sys
+import threading
from typing import Generic
from typing import Type
from typing import TypeVar
NO_ATTRIBUTE = util.symbol("NO_ATTRIBUTE")
# lock used to synchronize the "mapper configure" step
-_CONFIGURE_MUTEX = util.threading.RLock()
+_CONFIGURE_MUTEX = threading.RLock()
@inspection._self_inspects
ORMEntityColumnsClauseRole,
sql_base.MemoizedHasCacheKey,
InspectionAttr,
+ log.Identified,
Generic[_MC],
):
"""Defines an association between a Python class and a database table or
yield c
@HasMemoized.memoized_attribute
- def attrs(self):
+ def attrs(self) -> util.ImmutableProperties["MapperProperty"]:
"""A namespace of all :class:`.MapperProperty` objects
associated this mapper.
"""
+import threading
import traceback
import weakref
from .. import util
from ..util import chop_traceback
from ..util import queue as sqla_queue
-from ..util import threading
class QueuePool(Pool):
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+from typing import Dict
-class prefix_anon_map(dict):
+
+class prefix_anon_map(Dict[str, str]):
"""A map that creates new keys for missing key access.
Considers keys of the form "<ident> <name>" to produce
return value
-class cache_anon_map(dict):
+class cache_anon_map(Dict[int, str]):
"""A map that creates new keys for missing key access.
Produces an incrementing sequence given a series of unique keys.
--- /dev/null
+from typing import Any
+from typing import Mapping
+from typing import Sequence
+from typing import Union
+
+_SingleExecuteParams = Mapping[str, Any]
+_MultiExecuteParams = Sequence[_SingleExecuteParams]
+_ExecuteParams = Union[_SingleExecuteParams, _MultiExecuteParams]
+_ExecuteOptions = Mapping[str, Any]
"""High level utilities which build upon other modules here.
"""
-
from collections import deque
from itertools import chain
+import typing
+from typing import Any
+from typing import cast
+from typing import Optional
from . import coercions
from . import operators
from . import roles
from . import visitors
+from ._typing import _ExecuteParams
+from ._typing import _MultiExecuteParams
+from ._typing import _SingleExecuteParams
from .annotation import _deep_annotate # noqa
from .annotation import _deep_deannotate # noqa
from .annotation import _shallow_annotate # noqa
from .. import exc
from .. import util
+if typing.TYPE_CHECKING:
+ from ..engine.row import Row
+
def join_condition(a, b, a_subset=None, consider_as_foreign_keys=None):
"""Create a join condition between two tables or selectables.
class _repr_base:
- _LIST = 0
- _TUPLE = 1
- _DICT = 2
+ _LIST: int = 0
+ _TUPLE: int = 1
+ _DICT: int = 2
__slots__ = ("max_chars",)
- def trunc(self, value):
+ def trunc(self, value: Any) -> str:
rep = repr(value)
lenrep = len(rep)
if lenrep > self.max_chars:
__slots__ = ("row",)
- def __init__(self, row, max_chars=300):
+ def __init__(self, row: "Row", max_chars: int = 300):
self.row = row
self.max_chars = max_chars
- def __repr__(self):
+ def __repr__(self) -> str:
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
__slots__ = "params", "batches", "ismulti"
- def __init__(self, params, batches, max_chars=300, ismulti=None):
- self.params = params
+ def __init__(
+ self,
+ params: _ExecuteParams,
+ batches: int,
+ max_chars: int = 300,
+ ismulti: Optional[bool] = None,
+ ):
+ self.params: _ExecuteParams = params
self.ismulti = ismulti
self.batches = batches
self.max_chars = max_chars
- def __repr__(self):
+ def __repr__(self) -> str:
if self.ismulti is None:
return self.trunc(self.params)
else:
return self.trunc(self.params)
- if self.ismulti and len(self.params) > self.batches:
- msg = " ... displaying %i of %i total bound parameter sets ... "
- return " ".join(
- (
- self._repr_multi(self.params[: self.batches - 2], typ)[
- 0:-1
- ],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:],
+ if self.ismulti:
+ multi_params = cast(_MultiExecuteParams, self.params)
+
+ if len(self.params) > self.batches:
+ msg = (
+ " ... displaying %i of %i total bound parameter sets ... "
)
- )
- elif self.ismulti:
- return self._repr_multi(self.params, typ)
+ return " ".join(
+ (
+ self._repr_multi(
+ multi_params[: self.batches - 2],
+ typ,
+ )[0:-1],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(multi_params[-2:], typ)[1:],
+ )
+ )
+ else:
+ return self._repr_multi(multi_params, typ)
else:
- return self._repr_params(self.params, typ)
+ return self._repr_params(
+ cast(_SingleExecuteParams, self.params), typ
+ )
- def _repr_multi(self, multi_params, typ):
+ def _repr_multi(self, multi_params: _MultiExecuteParams, typ) -> str:
if multi_params:
if isinstance(multi_params[0], list):
elem_type = self._LIST
else:
return "(%s)" % elements
- def _repr_params(self, params, typ):
+ def _repr_params(self, params: _SingleExecuteParams, typ: int) -> str:
trunc = self.trunc
if typ is self._DICT:
return "{%s}" % (
from .compat import py38
from .compat import py39
from .compat import pypy
-from .compat import threading
from .compat import win32
from .concurrency import asyncio
from .concurrency import await_fallback
"""Collection classes and helpers."""
import collections.abc as collections_abc
import operator
+import threading
import types
+import typing
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import Optional
+from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TypeVar
+from typing import Union
+from typing import ValuesView
import weakref
-from .compat import threading
+from ._has_cy import HAS_CYEXTENSION
+from .typing import Literal
-try:
- from sqlalchemy.cyextension.immutabledict import ImmutableContainer
- from sqlalchemy.cyextension.immutabledict import immutabledict
- from sqlalchemy.cyextension.collections import IdentitySet
- from sqlalchemy.cyextension.collections import OrderedSet
- from sqlalchemy.cyextension.collections import unique_list # noqa
-except ImportError:
+if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
from ._py_collections import immutabledict
from ._py_collections import IdentitySet
from ._py_collections import ImmutableContainer
+ from ._py_collections import ImmutableDictBase
from ._py_collections import OrderedSet
from ._py_collections import unique_list # noqa
+else:
+ from sqlalchemy.cyextension.immutabledict import ImmutableContainer
+ from sqlalchemy.cyextension.immutabledict import ImmutableDictBase
+ from sqlalchemy.cyextension.immutabledict import immutabledict
+ from sqlalchemy.cyextension.collections import IdentitySet
+ from sqlalchemy.cyextension.collections import OrderedSet
+ from sqlalchemy.cyextension.collections import unique_list # noqa
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
-EMPTY_SET = frozenset()
+
+EMPTY_SET: FrozenSet[Any] = frozenset()
def coerce_to_immutabledict(d):
return immutabledict(d)
-EMPTY_DICT = immutabledict()
+EMPTY_DICT: immutabledict[Any, Any] = immutabledict()
-class FacadeDict(ImmutableContainer, dict):
+class FacadeDict(ImmutableDictBase[Any, Any]):
"""A dictionary that is not publicly mutable."""
- clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
-
def __new__(cls, *args):
new = dict.__new__(cls)
return new
return "FacadeDict(%s)" % dict.__repr__(self)
-class Properties:
+_DT = TypeVar("_DT", bound=Any)
+
+
+class Properties(Generic[_T]):
"""Provide a __getattr__/__setattr__ interface over a dict."""
__slots__ = ("_data",)
+ _data: Dict[str, _T]
+
def __init__(self, data):
object.__setattr__(self, "_data", data)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._data)
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
return iter(list(self._data.values()))
def __dir__(self):
def __setitem__(self, key, obj):
self._data[key] = obj
- def __getitem__(self, key):
+ def __getitem__(self, key: str) -> _T:
return self._data[key]
def __delitem__(self, key):
def __setstate__(self, state):
object.__setattr__(self, "_data", state["_data"])
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> _T:
try:
return self._data[key]
except KeyError:
raise AttributeError(key)
- def __contains__(self, key):
+ def __contains__(self, key: str) -> bool:
return key in self._data
- def as_immutable(self):
+ def as_immutable(self) -> "ImmutableProperties[_T]":
"""Return an immutable proxy for this :class:`.Properties`."""
return ImmutableProperties(self._data)
def update(self, value):
self._data.update(value)
- def get(self, key, default=None):
+ @overload
+ def get(self, key: str) -> Optional[_T]:
+ ...
+
+ @overload
+ def get(self, key: str, default: Union[_DT, _T]) -> Union[_DT, _T]:
+ ...
+
+ def get(
+ self, key: str, default: Optional[Union[_DT, _T]] = None
+ ) -> Optional[Union[_T, _DT]]:
if key in self:
return self[key]
else:
return default
- def keys(self):
+ def keys(self) -> List[str]:
return list(self._data)
- def values(self):
+ def values(self) -> List[_T]:
return list(self._data.values())
- def items(self):
+ def items(self) -> List[Tuple[str, _T]]:
return list(self._data.items())
- def has_key(self, key):
+ def has_key(self, key: str) -> bool:
return key in self._data
def clear(self):
self._data.clear()
-class OrderedProperties(Properties):
+class OrderedProperties(Properties[_T]):
"""Provide a __getattr__/__setattr__ interface with an OrderedDict
as backing store."""
Properties.__init__(self, OrderedDict())
-class ImmutableProperties(ImmutableContainer, Properties):
+class ImmutableProperties(ImmutableContainer, Properties[_T]):
"""Provide immutable dict/object attribute to an underlying dictionary."""
__slots__ = ()
self.add(o)
-class PopulateDict(dict):
+class PopulateDict(Dict[_KT, _VT]):
"""A dict which populates missing values via a creation function.
Note the creation function takes a key, unlike
"""
- def __init__(self, creator):
+ def __init__(self, creator: Callable[[_KT], _VT]):
self.creator = creator
- def __missing__(self, key):
+ def __missing__(self, key: Any) -> Any:
self[key] = val = self.creator(key)
return val
-class WeakPopulateDict(dict):
+class WeakPopulateDict(Dict[_KT, _VT]):
"""Like PopulateDict, but assumes a self + a method and does not create
a reference cycle.
"""
- def __init__(self, creator_method):
+ def __init__(self, creator_method: types.MethodType):
self.creator = creator_method.__func__
weakself = creator_method.__self__
self.weakself = weakref.ref(weakself)
- def __missing__(self, key):
+ def __missing__(self, key: Any) -> Any:
self[key] = val = self.creator(self.weakself(), key)
return val
ordered_column_set = OrderedSet
-_getters = PopulateDict(operator.itemgetter)
-
-_property_getters = PopulateDict(
- lambda idx: property(operator.itemgetter(idx))
-)
-
-
-class UniqueAppender:
+class UniqueAppender(Generic[_T]):
"""Appends items to a collection ensuring uniqueness.
Additional appends() of the same object are ignored. Membership is
determined by identity (``is a``) not equality (``==``).
"""
- def __init__(self, data, via=None):
+ __slots__ = "data", "_data_appender", "_unique"
+
+ data: Union[Iterable[_T], Set[_T], List[_T]]
+ _data_appender: Callable[[_T], None]
+ _unique: Dict[int, Literal[True]]
+
+ def __init__(
+ self,
+ data: Union[Iterable[_T], Set[_T], List[_T]],
+ via: Optional[str] = None,
+ ):
self.data = data
self._unique = {}
if via:
- self._data_appender = getattr(data, via)
+ self._data_appender = getattr(data, via) # type: ignore[assignment] # noqa E501
elif hasattr(data, "append"):
- self._data_appender = data.append
+ self._data_appender = cast("List[_T]", data).append # type: ignore[assignment] # noqa E501
elif hasattr(data, "add"):
- self._data_appender = data.add
+ self._data_appender = cast("Set[_T]", data).add # type: ignore[assignment] # noqa E501
- def append(self, item):
+ def append(self, item: _T) -> None:
id_ = id(item)
if id_ not in self._unique:
- self._data_appender(item)
+ self._data_appender(item) # type: ignore[call-arg]
self._unique[id_] = True
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
return iter(self.data)
return arg
-def to_list(x, default=None):
+@overload
+def to_list(x: Sequence[_T], default: Optional[List[_T]] = None) -> List[_T]:
+ ...
+
+
+@overload
+def to_list(
+ x: Optional[Sequence[_T]], default: Optional[List[_T]] = None
+) -> Optional[List[_T]]:
+ ...
+
+
+def to_list(
+ x: Optional[Sequence[_T]], default: Optional[List[_T]] = None
+) -> Optional[List[_T]]:
if x is None:
return default
if not isinstance(x, collections_abc.Iterable) or isinstance(
x, (str, bytes)
):
- return [x]
+ return [cast(_T, x)]
elif isinstance(x, list):
return x
else:
yield elem
-class LRUCache(dict):
+class LRUCache(typing.MutableMapping[_KT, _VT]):
"""Dictionary with 'squishy' removal of least
recently used items.
"""
- __slots__ = "capacity", "threshold", "size_alert", "_counter", "_mutex"
+ __slots__ = (
+ "capacity",
+ "threshold",
+ "size_alert",
+ "_data",
+ "_counter",
+ "_mutex",
+ )
+
+ capacity: int
+ threshold: float
+ size_alert: Callable[["LRUCache[_KT, _VT]"], None]
def __init__(self, capacity=100, threshold=0.5, size_alert=None):
self.capacity = capacity
self.size_alert = size_alert
self._counter = 0
self._mutex = threading.Lock()
+ self._data: Dict[_KT, Tuple[_KT, _VT, List[int]]] = {}
def _inc_counter(self):
self._counter += 1
return self._counter
- def get(self, key, default=None):
- item = dict.get(self, key, default)
- if item is not default:
- item[2] = self._inc_counter()
+ @overload
+ def get(self, key: _KT) -> Optional[_VT]:
+ ...
+
+ @overload
+ def get(self, key: _KT, default: Union[_VT, _T]) -> Union[_VT, _T]:
+ ...
+
+ def get(
+ self, key: _KT, default: Optional[Union[_VT, _T]] = None
+ ) -> Optional[Union[_VT, _T]]:
+ item = self._data.get(key, default)
+ if item is not default and item is not None:
+ item[2][0] = self._inc_counter()
return item[1]
else:
return default
- def __getitem__(self, key):
- item = dict.__getitem__(self, key)
- item[2] = self._inc_counter()
+ def __getitem__(self, key: _KT) -> _VT:
+ item = self._data[key]
+ item[2][0] = self._inc_counter()
return item[1]
- def values(self):
- return [i[1] for i in dict.values(self)]
+ def __iter__(self) -> Iterator[_KT]:
+ return iter(self._data)
- def setdefault(self, key, value):
- if key in self:
- return self[key]
- else:
- self[key] = value
- return value
-
- def __setitem__(self, key, value):
- item = dict.get(self, key)
- if item is None:
- item = [key, value, self._inc_counter()]
- dict.__setitem__(self, key, item)
- else:
- item[1] = value
+ def __len__(self) -> int:
+ return len(self._data)
+
+ def values(self) -> ValuesView[_VT]:
+ return typing.ValuesView({k: i[1] for k, i in self._data.items()})
+
+ def __setitem__(self, key: _KT, value: _VT) -> None:
+ self._data[key] = (key, value, [self._inc_counter()])
self._manage_size()
+ def __delitem__(self, __v: _KT) -> None:
+ del self._data[__v]
+
@property
- def size_threshold(self):
+ def size_threshold(self) -> float:
return self.capacity + self.capacity * self.threshold
- def _manage_size(self):
+ def _manage_size(self) -> None:
if not self._mutex.acquire(False):
return
try:
while len(self) > self.capacity + self.capacity * self.threshold:
if size_alert:
size_alert = False
- self.size_alert(self)
+ self.size_alert(self) # type: ignore
by_counter = sorted(
- dict.values(self), key=operator.itemgetter(2), reverse=True
+ self._data.values(),
+ key=operator.itemgetter(2),
+ reverse=True,
)
for item in by_counter[self.capacity :]:
try:
- del self[item[0]]
+ del self._data[item[0]]
except KeyError:
# deleted elsewhere; skip
continue
a callable that will return a key to store/retrieve an object.
"""
+ __slots__ = "createfunc", "scopefunc", "registry"
+
def __init__(self, createfunc, scopefunc):
"""Construct a new :class:`.ScopedRegistry`.
def clear(self):
try:
- del self.registry.value
+ del self.registry.value # type: ignore
except AttributeError:
pass
import asyncio
from contextvars import copy_context as _copy_context
import sys
+import typing
from typing import Any
from typing import Callable
from typing import Coroutine
-import greenlet
+import greenlet # type: ignore # noqa
from .langhelpers import memoized_property
from .. import exc
-try:
+if not typing.TYPE_CHECKING:
+ try:
- # If greenlet.gr_context is present in current version of greenlet,
- # it will be set with a copy of the current context on creation.
- # Refs: https://github.com/python-greenlet/greenlet/pull/198
- getattr(greenlet.greenlet, "gr_context")
-except (ImportError, AttributeError):
- _copy_context = None # noqa
+ # If greenlet.gr_context is present in current version of greenlet,
+ # it will be set with a copy of the current context on creation.
+ # Refs: https://github.com/python-greenlet/greenlet/pull/198
+ getattr(greenlet.greenlet, "gr_context")
+ except (ImportError, AttributeError):
+ _copy_context = None # noqa
def is_exit_exception(e):
# Issue for context: https://github.com/python-greenlet/greenlet/issues/173
-class _AsyncIoGreenlet(greenlet.greenlet):
+class _AsyncIoGreenlet(greenlet.greenlet): # type: ignore
def __init__(self, fn, driver):
greenlet.greenlet.__init__(self, fn, driver)
self.driver = driver
self.gr_context = _copy_context()
-def await_only(awaitable: Coroutine) -> Any:
+def await_only(awaitable: Coroutine[Any, Any, Any]) -> Any:
"""Awaits an async function in a sync method.
The sync method must be inside a :func:`greenlet_spawn` context.
return current.driver.switch(awaitable)
-def await_fallback(awaitable: Coroutine) -> Any:
+def await_fallback(awaitable: Coroutine[Any, Any, Any]) -> Any:
"""Awaits an async function in a sync method.
The sync method must be inside a :func:`greenlet_spawn` context.
async def greenlet_spawn(
- fn: Callable, *args, _require_await=False, **kwargs
+ fn: Callable[..., Any],
+ *args: Any,
+ _require_await: bool = False,
+ **kwargs: Any,
) -> Any:
"""Runs a sync function ``fn`` in a new greenlet.
runtime.
"""
-
import sys
+from types import ModuleType
+import typing
+from typing import Any
+from typing import Callable
+from typing import TypeVar
+
+_FN = TypeVar("_FN", bound=Callable[..., Any])
class _ModuleRegistry:
self.module_registry = set()
self.prefix = prefix
- def preload_module(self, *deps):
+ def preload_module(self, *deps: str) -> Callable[[_FN], _FN]:
"""Adds the specified modules to the list to load.
This method can be used both as a normal function and as a decorator.
self.module_registry.update(deps)
return lambda fn: fn
- def import_prefix(self, path):
+ def import_prefix(self, path: str) -> None:
"""Resolve all the modules in the registry that start with the
specified path.
"""
__import__(module, globals(), locals())
self.__dict__[key] = sys.modules[module]
+ if typing.TYPE_CHECKING:
+
+ def __getattr__(self, key: str) -> ModuleType:
+ ...
+
preloaded = _ModuleRegistry()
preload_module = preloaded.preload_module
from itertools import filterfalse
+from typing import Any
+from typing import Dict
+from typing import Generic
+from typing import Iterable
+from typing import Iterator
+from typing import List
+from typing import NoReturn
+from typing import Optional
+from typing import Set
+from typing import TypeVar
+
+_T = TypeVar("_T", bound=Any)
+_KT = TypeVar("_KT", bound=Any)
+_VT = TypeVar("_VT", bound=Any)
class ImmutableContainer:
- def _immutable(self, *arg, **kw):
+ def _immutable(self, *arg: Any, **kw: Any) -> NoReturn:
raise TypeError("%s object is immutable" % self.__class__.__name__)
- __delitem__ = __setitem__ = __setattr__ = _immutable
+ def __delitem__(self, key: Any) -> NoReturn:
+ self._immutable()
+ def __setitem__(self, key: Any, value: Any) -> NoReturn:
+ self._immutable()
-class immutabledict(ImmutableContainer, dict):
+ def __setattr__(self, key: str, value: Any) -> NoReturn:
+ self._immutable()
- clear = pop = popitem = setdefault = update = ImmutableContainer._immutable
+class ImmutableDictBase(ImmutableContainer, Dict[_KT, _VT]):
+ def clear(self) -> NoReturn:
+ self._immutable()
+
+ def pop(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def popitem(self) -> NoReturn:
+ self._immutable()
+
+ def setdefault(self, key: Any, default: Optional[Any] = None) -> NoReturn:
+ self._immutable()
+
+ def update(self, *arg: Any, **kw: Any) -> NoReturn:
+ self._immutable()
+
+
+class immutabledict(ImmutableDictBase[_KT, _VT]):
def __new__(cls, *args):
new = dict.__new__(cls)
dict.__init__(new, *args)
dict.__init__(new, self)
if __d:
dict.update(new, __d)
- dict.update(new, kw)
+ dict.update(new, kw) # type: ignore
return new
def merge_with(self, *dicts):
return "immutabledict(%s)" % dict.__repr__(self)
-class OrderedSet(set):
+class OrderedSet(Generic[_T]):
+ __slots__ = ("_list", "_set", "__weakref__")
+
+ _list: List[_T]
+ _set: Set[_T]
+
def __init__(self, d=None):
- set.__init__(self)
if d is not None:
self._list = unique_list(d)
- set.update(self, self._list)
+ self._set = set(self._list)
else:
self._list = []
+ self._set = set()
+
+ def __reduce__(self):
+ return (OrderedSet, (self._list,))
- def add(self, element):
+ def add(self, element: _T) -> None:
if element not in self:
self._list.append(element)
- set.add(self, element)
+ self._set.add(element)
- def remove(self, element):
- set.remove(self, element)
+ def remove(self, element: _T) -> None:
+ self._set.remove(element)
self._list.remove(element)
- def insert(self, pos, element):
+ def insert(self, pos: int, element: _T) -> None:
if element not in self:
self._list.insert(pos, element)
- set.add(self, element)
+ self._set.add(element)
- def discard(self, element):
+ def discard(self, element: _T) -> None:
if element in self:
self._list.remove(element)
- set.remove(self, element)
+ self._set.remove(element)
- def clear(self):
- set.clear(self)
+ def clear(self) -> None:
+ self._set.clear()
self._list = []
- def __getitem__(self, key):
+ def __len__(self) -> int:
+ return len(self._set)
+
+ def __eq__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set == other
+ else:
+ return self._set == other._set
+
+ def __ne__(self, other):
+ if not isinstance(other, OrderedSet):
+ return self._set != other
+ else:
+ return self._set != other._set
+
+ def __contains__(self, element: Any) -> bool:
+ return element in self._set
+
+ def __getitem__(self, key: int) -> _T:
return self._list[key]
- def __iter__(self):
+ def __iter__(self) -> Iterator[_T]:
return iter(self._list)
- def __add__(self, other):
+ def __add__(self, other: Iterator[_T]) -> "OrderedSet[_T]":
return self.union(other)
- def __repr__(self):
+ def __repr__(self) -> str:
return "%s(%r)" % (self.__class__.__name__, self._list)
__str__ = __repr__
- def update(self, iterable):
- for e in iterable:
- if e not in self:
- self._list.append(e)
- set.add(self, e)
- return self
+ def update(self, *iterables: Iterable[_T]) -> None:
+ for iterable in iterables:
+ for e in iterable:
+ if e not in self:
+ self._list.append(e)
+ self._set.add(e)
- __ior__ = update
+ def __ior__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.update(other)
+ return self
- def union(self, other):
+ def union(self, other: Iterable[_T]) -> "OrderedSet[_T]":
result = self.__class__(self)
result.update(other)
return result
- __or__ = union
+ def __or__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.union(other)
- def intersection(self, other):
+ def intersection(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a in other)
- __and__ = intersection
+ def __and__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.intersection(other)
- def symmetric_difference(self, other):
+ def symmetric_difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other_set = other if isinstance(other, set) else set(other)
result = self.__class__(a for a in self if a not in other_set)
result.update(a for a in other if a not in self)
return result
- __xor__ = symmetric_difference
+ def __xor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.symmetric_difference(other)
- def difference(self, other):
+ def difference(self, other: Iterable[_T]) -> "OrderedSet[_T]":
other = other if isinstance(other, set) else set(other)
return self.__class__(a for a in self if a not in other)
- __sub__ = difference
+ def __sub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ return self.difference(other)
- def intersection_update(self, other):
+ def intersection_update(self, other: Iterable[_T]) -> None:
other = other if isinstance(other, set) else set(other)
- set.intersection_update(self, other)
+ self._set.intersection_update(other)
self._list = [a for a in self._list if a in other]
- return self
- __iand__ = intersection_update
+ def __iand__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.intersection_update(other)
+ return self
- def symmetric_difference_update(self, other):
- set.symmetric_difference_update(self, other)
+ def symmetric_difference_update(self, other: Iterable[_T]) -> None:
+ self._set.symmetric_difference_update(other)
self._list = [a for a in self._list if a in self]
self._list += [a for a in other if a in self]
- return self
- __ixor__ = symmetric_difference_update
+ def __ixor__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.symmetric_difference_update(other)
+ return self
- def difference_update(self, other):
- set.difference_update(self, other)
+ def difference_update(self, other: Iterable[_T]) -> None:
+ self._set.difference_update(other)
self._list = [a for a in self._list if a in self]
- return self
- __isub__ = difference_update
+ def __isub__(self, other: Iterable[_T]) -> "OrderedSet[_T]":
+ self.difference_update(other)
+ return self
class IdentitySet:
# the MIT License: https://www.opensource.org/licenses/mit-license.php
"""Handle Python version/platform incompatibilities."""
+from __future__ import annotations
+
import base64
-import collections
import dataclasses
import inspect
import operator
import platform
import sys
import typing
+from typing import Any
+from typing import Callable
+from typing import Dict
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
py311 = sys.version_info >= (3, 11)
dottedgetter = operator.attrgetter
next = next # noqa
-FullArgSpec = collections.namedtuple(
- "FullArgSpec",
- [
- "args",
- "varargs",
- "varkw",
- "defaults",
- "kwonlyargs",
- "kwonlydefaults",
- "annotations",
- ],
-)
-
-try:
- import threading
-except ImportError:
- import dummy_threading as threading # noqa
+class FullArgSpec(typing.NamedTuple):
+ args: List[str]
+ varargs: Optional[str]
+ varkw: Optional[str]
+ defaults: Optional[Tuple[Any, ...]]
+ kwonlyargs: List[str]
+ kwonlydefaults: Dict[str, Any]
+ annotations: Dict[str, Any]
-def inspect_getfullargspec(func):
+def inspect_getfullargspec(func: Callable[..., Any]) -> FullArgSpec:
"""Fully vendored version of getfullargspec from Python 3.3."""
if inspect.ismethod(func):
)
-if py38:
+if typing.TYPE_CHECKING or py38:
from importlib import metadata as importlib_metadata
else:
import importlib_metadata # noqa
-if py39:
+if typing.TYPE_CHECKING or py39:
# pep 584 dict union
dict_union = operator.or_ # noqa
else:
def importlib_metadata_get(group):
ep = importlib_metadata.entry_points()
- if hasattr(ep, "select"):
+ if not typing.TYPE_CHECKING and hasattr(ep, "select"):
return ep.select(group=group)
else:
return ep.get(group, ())
return s.encode("latin-1")
-def b64decode(x):
+def b64decode(x: str) -> bytes:
return base64.b64decode(x.encode("ascii"))
-def b64encode(x):
+def b64encode(x: bytes) -> str:
return base64.b64encode(x).decode("ascii")
-def decode_backslashreplace(text, encoding):
+def decode_backslashreplace(text: bytes, encoding: str) -> str:
return text.decode(encoding, errors="backslashreplace")
def inspect_formatargspec(
- args,
- varargs=None,
- varkw=None,
- defaults=None,
- kwonlyargs=(),
- kwonlydefaults={},
- annotations={},
- formatarg=str,
- formatvarargs=lambda name: "*" + name,
- formatvarkw=lambda name: "**" + name,
- formatvalue=lambda value: "=" + repr(value),
- formatreturns=lambda text: " -> " + text,
- formatannotation=_formatannotation,
-):
+ args: List[str],
+ varargs: Optional[str] = None,
+ varkw: Optional[str] = None,
+ defaults: Optional[Sequence[Any]] = None,
+ kwonlyargs: Optional[Sequence[str]] = (),
+ kwonlydefaults: Optional[Mapping[str, Any]] = {},
+ annotations: Mapping[str, Any] = {},
+ formatarg: Callable[[str], str] = str,
+ formatvarargs: Callable[[str], str] = lambda name: "*" + name,
+ formatvarkw: Callable[[str], str] = lambda name: "**" + name,
+ formatvalue: Callable[[Any], str] = lambda value: "=" + repr(value),
+ formatreturns: Callable[[Any], str] = lambda text: " -> " + str(text),
+ formatannotation: Callable[[Any], str] = _formatannotation,
+) -> str:
"""Copy formatargspec from python 3.7 standard library.
Python 3 has deprecated formatargspec and requested that Signature
specs = []
if defaults:
firstdefault = len(args) - len(defaults)
+ else:
+ firstdefault = -1
+
for i, arg in enumerate(args):
spec = formatargandannotation(arg)
if defaults and i >= firstdefault:
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
+import asyncio # noqa
have_greenlet = False
greenlet_error = None
try:
- import greenlet # noqa F401
+ import greenlet # type: ignore # noqa F401
except ImportError as e:
greenlet_error = str(e)
pass
from ._concurrency_py3k import (
_util_async_run_coroutine_function,
) # noqa F401, E501
- from ._concurrency_py3k import asyncio # noqa F401
if not have_greenlet:
- asyncio = None # noqa F811
-
def _not_implemented():
# this conditional is to prevent pylance from considering
# greenlet_spawn() etc as "no return" and dimming out code below it
def is_exit_exception(e): # noqa F811
return not isinstance(e, Exception)
- def await_only(thing): # noqa F811
+ def await_only(thing): # type: ignore # noqa F811
_not_implemented()
- def await_fallback(thing): # noqa F81
+ def await_fallback(thing): # type: ignore # noqa F81
return thing
- def greenlet_spawn(fn, *args, **kw): # noqa F81
+ def greenlet_spawn(fn, *args, **kw): # type: ignore # noqa F81
_not_implemented()
- def AsyncAdaptedLock(*args, **kw): # noqa F81
+ def AsyncAdaptedLock(*args, **kw): # type: ignore # noqa F81
_not_implemented()
- def _util_async_run(fn, *arg, **kw): # noqa F81
+ def _util_async_run(fn, *arg, **kw): # type: ignore # noqa F81
return fn(*arg, **kw)
- def _util_async_run_coroutine_function(fn, *arg, **kw): # noqa F81
+ def _util_async_run_coroutine_function(fn, *arg, **kw): # type: ignore # noqa F81
_not_implemented()
import re
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Optional
from typing import TypeVar
from . import compat
def deprecated_property(
- version,
- message=None,
- add_deprecation_to_docstring=True,
- warning=None,
- enable_warnings=True,
+ version: str,
+ message: Optional[str] = None,
+ add_deprecation_to_docstring: bool = True,
+ warning: Optional[str] = None,
+ enable_warnings: bool = True,
) -> Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]]:
"""the @deprecated decorator with a @property.
great! now it is.
"""
- return lambda fn: property(
- deprecated(
- version,
- message=message,
- add_deprecation_to_docstring=add_deprecation_to_docstring,
- warning=warning,
- enable_warnings=enable_warnings,
- )(fn)
+ return cast(
+ Callable[[Callable[..., _T]], ReadOnlyInstanceDescriptor[_T]],
+ lambda fn: property(
+ deprecated(
+ version,
+ message=message,
+ add_deprecation_to_docstring=add_deprecation_to_docstring,
+ warning=warning,
+ enable_warnings=enable_warnings,
+ )(fn)
+ ),
)
)
doc = inject_docstring_text(doc, docstring_header, 1)
+ constructor_fn = None
if type(cls) is type:
clsdict = dict(cls.__dict__)
clsdict["__doc__"] = doc
clsdict.pop("__dict__", None)
- cls = type(cls.__name__, cls.__bases__, clsdict)
+ cls = type(cls.__name__, cls.__bases__, clsdict) # type: ignore
if constructor is not None:
constructor_fn = clsdict[constructor]
constructor_fn = getattr(cls, constructor)
if constructor is not None:
+ assert constructor_fn is not None
setattr(
cls,
constructor,
import re
import sys
import textwrap
+import threading
import types
import typing
from typing import Any
from typing import Callable
+from typing import cast
+from typing import Dict
+from typing import FrozenSet
from typing import Generic
+from typing import Iterator
+from typing import List
from typing import Optional
from typing import overload
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import Type
from typing import TypeVar
from typing import Union
import warnings
from . import _collections
from . import compat
from . import typing as compat_typing
+from ._has_cy import HAS_CYEXTENSION
from .. import exc
_T = TypeVar("_T")
_HM = TypeVar("_HM", bound="hybridmethod")
-def md5_hex(x):
+def md5_hex(x: Any) -> str:
x = x.encode("utf-8")
m = hashlib.md5()
m.update(x)
__slots__ = ("warn_only", "_exc_info")
- def __init__(self, warn_only=False):
+ _exc_info: Union[
+ None,
+ Tuple[
+ Type[BaseException],
+ BaseException,
+ types.TracebackType,
+ ],
+ Tuple[None, None, None],
+ ]
+
+ def __init__(self, warn_only: bool = False):
self.warn_only = warn_only
- def __enter__(self):
+ def __enter__(self) -> None:
self._exc_info = sys.exc_info()
- def __exit__(self, type_, value, traceback):
+ def __exit__(
+ self,
+ type_: Optional[Type[BaseException]],
+ value: Optional[BaseException],
+ traceback: Optional[types.TracebackType],
+ ) -> None:
+ assert self._exc_info is not None
# see #2703 for notes
if type_ is None:
exc_type, exc_value, exc_tb = self._exc_info
+ assert exc_value is not None
self._exc_info = None # remove potential circular references
if not self.warn_only:
raise exc_value.with_traceback(exc_tb)
else:
self._exc_info = None # remove potential circular references
+ assert value is not None
raise value.with_traceback(traceback)
-def walk_subclasses(cls):
- seen = set()
+def walk_subclasses(cls: type) -> Iterator[type]:
+ seen: Set[Any] = set()
stack = [cls]
while stack:
yield cls
-def string_or_unprintable(element):
+def string_or_unprintable(element: Any) -> str:
if isinstance(element, str):
return element
else:
return "unprintable element %r" % element
-def clsname_as_plain_name(cls):
+def clsname_as_plain_name(cls: Type[Any]) -> str:
return " ".join(
n.lower() for n in re.findall(r"([A-Z][a-z]+)", cls.__name__)
)
-def method_is_overridden(instance_or_cls, against_method):
+def method_is_overridden(
+ instance_or_cls: Union[Type[Any], object], against_method: types.MethodType
+) -> bool:
"""Return True if the two class methods don't match."""
if not isinstance(instance_or_cls, type):
method_name = against_method.__name__
- current_method = getattr(current_cls, method_name)
+ current_method: types.MethodType = getattr(current_cls, method_name)
return current_method != against_method
-def decode_slice(slc):
+def decode_slice(slc: slice) -> Tuple[Any, ...]:
"""decode a slice object as sent to __getitem__.
takes into account the 2.5 __index__() method, basically.
"""
- ret = []
+ ret: List[Any] = []
for x in slc.start, slc.stop, slc.step:
if hasattr(x, "__index__"):
x = x.__index__()
return tuple(ret)
-def _unique_symbols(used, *bases):
- used = set(used)
+def _unique_symbols(used: Sequence[str], *bases: str) -> Iterator[str]:
+ used_set = set(used)
for base in bases:
pool = itertools.chain(
(base,),
map(lambda i: base + str(i), range(1000)),
)
for sym in pool:
- if sym not in used:
- used.add(sym)
+ if sym not in used_set:
+ used_set.add(sym)
yield sym
break
else:
raise NameError("exhausted namespace for symbol base %s" % base)
-def map_bits(fn, n):
+def map_bits(fn: Callable[[int], Any], n: int) -> Iterator[Any]:
"""Call the given function given each nonzero bit from n."""
while n:
n ^= b
-_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable[..., Any])
_Args = compat_typing.ParamSpec("_Args")
def decorator(
- target: typing.Callable[compat_typing.Concatenate[_Fn, _Args], typing.Any]
+ target: typing.Callable[ # type: ignore
+ compat_typing.Concatenate[_Fn, _Args], typing.Any
+ ]
) -> _Fn:
"""A signature-matching decorator factory."""
- def decorate(fn):
+ def decorate(fn: typing.Callable[..., Any]) -> typing.Callable[..., Any]:
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
raise Exception("not a decoratable function")
spec = compat.inspect_getfullargspec(fn)
- env = {}
+ env: Dict[str, Any] = {}
spec = _update_argspec_defaults_into_env(spec, env)
- names = tuple(spec[0]) + spec[1:3] + (fn.__name__,)
+ names = (
+ tuple(cast("Tuple[str, ...]", spec[0]))
+ + cast("Tuple[str, ...]", spec[1:3])
+ + (fn.__name__,)
+ )
targ_name, fn_name = _unique_symbols(names, "target", "fn")
- metadata = dict(target=targ_name, fn=fn_name)
+ metadata: Dict[str, Optional[str]] = dict(target=targ_name, fn=fn_name)
metadata.update(format_argspec_plus(spec, grouped=False))
metadata["name"] = fn.__name__
code = (
)
env.update({targ_name: target, fn_name: fn, "__name__": fn.__module__})
- decorated = _exec_code_in_env(code, env, fn.__name__)
+ decorated = cast(
+ types.FunctionType,
+ _exec_code_in_env(code, env, fn.__name__),
+ )
decorated.__defaults__ = getattr(fn, "__func__", fn).__defaults__
- decorated.__wrapped__ = fn
+
+ # claims to be fixed?
+ # https://github.com/python/mypy/issues/11896
+ decorated.__wrapped__ = fn # type: ignore
return update_wrapper(decorated, fn)
return typing.cast(_Fn, update_wrapper(decorate, target))
)
-def get_cls_kwargs(cls, _set=None):
+@overload
+def get_cls_kwargs(
+ cls: type,
+ *,
+ _set: Optional[Set[str]] = None,
+ raiseerr: compat_typing.Literal[True] = ...,
+) -> Set[str]:
+ ...
+
+
+@overload
+def get_cls_kwargs(
+ cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
+) -> Optional[Set[str]]:
+ ...
+
+
+def get_cls_kwargs(
+ cls: type, *, _set: Optional[Set[str]] = None, raiseerr: bool = False
+) -> Optional[Set[str]]:
r"""Return the full set of inherited kwargs for the given `cls`.
Probes a class's __init__ method, collecting all named arguments. If the
toplevel = _set is None
if toplevel:
_set = set()
+ assert _set is not None
ctr = cls.__dict__.get("__init__", False)
_set.update(names)
if not has_kw and not toplevel:
- return None
+ if raiseerr:
+ raise TypeError(
+ f"given cls {cls} doesn't have an __init__ method"
+ )
+ else:
+ return None
+ else:
+ has_kw = False
if not has_init or has_kw:
for c in cls.__bases__:
- if get_cls_kwargs(c, _set) is None:
+ if get_cls_kwargs(c, _set=_set) is None:
break
_set.discard("self")
raise TypeError("Can't inspect callable: %s" % fn)
-def format_argspec_plus(fn, grouped=True):
+def format_argspec_plus(
+ fn: Union[Callable[..., Any], compat.FullArgSpec], grouped: bool = True
+) -> Dict[str, Optional[str]]:
"""Returns a dictionary of formatted, introspected function arguments.
A enhanced variant of inspect.formatargspec to support code generation.
num_defaults = 0
if spec[3]:
- num_defaults += len(spec[3])
+ num_defaults += len(cast(Tuple[Any], spec[3]))
if spec[4]:
num_defaults += len(spec[4])
+
name_args = spec[0] + spec[4]
+ defaulted_vals: Union[List[str], Tuple[()]]
+
if num_defaults:
defaulted_vals = name_args[0 - num_defaults :]
else:
spec[1],
spec[2],
defaulted_vals,
- formatvalue=lambda x: "=" + x,
+ formatvalue=lambda x: "=" + str(x),
)
if spec[0]:
spec[1],
spec[2],
defaulted_vals,
- formatvalue=lambda x: "=" + x,
+ formatvalue=lambda x: "=" + str(x),
)
else:
apply_kw_proxied = apply_kw
def decorate(cls):
def instrument(name, clslevel=False):
- fn = getattr(target_cls, name)
+ fn = cast(Callable[..., Any], getattr(target_cls, name))
spec = compat.inspect_getfullargspec(fn)
env = {"__name__": fn.__module__}
% metadata
)
- proxy_fn = _exec_code_in_env(code, env, fn.__name__)
+ proxy_fn = cast(
+ Callable[..., Any], _exec_code_in_env(code, env, fn.__name__)
+ )
proxy_fn.__defaults__ = getattr(fn, "__func__", fn).__defaults__
proxy_fn.__doc__ = inject_docstring_text(
fn.__doc__,
except TypeError:
continue
else:
- default_len = spec.defaults and len(spec.defaults) or 0
+ default_len = len(spec.defaults) if spec.defaults else 0
if i == 0:
if spec.varargs:
vargs = spec.varargs
)
if default_len:
+ assert spec.defaults
kw_args.update(
[
(arg, default)
class_hierarchy(class A(object)) returns (A, object), not A plus every
class systemwide that derives from object.
- Old-style classes are discarded and hierarchies rooted on them
- will not be descended.
-
"""
hier = {cls}
if c.__module__ == "builtins" or not hasattr(c, "__subclasses__"):
continue
- for s in [_ for _ in c.__subclasses__() if _ not in hier]:
+ for s in [
+ _
+ for _ in (
+ c.__subclasses__()
+ if not issubclass(c, type)
+ else c.__subclasses__(c)
+ )
+ if _ not in hier
+ ]:
process.append(s)
hier.add(s)
return list(hier)
for method in dunders:
try:
- fn = getattr(from_cls, method)
- if not hasattr(fn, "__call__"):
+ maybe_fn = getattr(from_cls, method)
+ if not hasattr(maybe_fn, "__call__"):
continue
- fn = getattr(fn, "__func__", fn)
+ maybe_fn = getattr(maybe_fn, "__func__", maybe_fn)
+ fn = cast(Callable[..., Any], maybe_fn)
+
except AttributeError:
continue
try:
__name__: str
def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
- self.fget = fget
+ self.fget = fget # type: ignore[assignment]
self.__doc__ = doc or fget.__doc__
self.__name__ = fget.__name__
if obj is None:
return self
obj.__dict__[self.__name__] = result = self.fget(obj)
- return result
+ return result # type: ignore
def _reset(self, obj):
memoized_property.reset(obj, self.__name__)
__slots__ = ()
- _memoized_keys = frozenset()
+ _memoized_keys: FrozenSet[str] = frozenset()
def _reset_memoizations(self):
for elem in self._memoized_keys:
__name__: str
def __init__(self, fget: Callable[..., _T], doc: Optional[str] = None):
- self.fget = fget
+ # https://github.com/python/mypy/issues/708
+ self.fget = fget # type: ignore
self.__doc__ = doc or fget.__doc__
self.__name__ = fget.__name__
def counter():
"""Return a threadsafe counter function."""
- lock = compat.threading.Lock()
+ lock = threading.Lock()
counter = itertools.count(1)
# avoid the 2to3 "next" transformation...
"""
- def __init__(self, fget, *arg, **kw):
+ fget: Callable[[Any], Any]
+
+ def __init__(self, fget: Callable[[Any], Any], *arg: Any, **kw: Any):
super(classproperty, self).__init__(fget, *arg, **kw)
self.__doc__ = fget.__doc__
- def __get__(desc, self, cls):
- return desc.fget(cls)
+ def __get__(self, obj: Any, cls: Optional[type] = None) -> Any:
+ return self.fget(cls) # type: ignore
class hybridproperty:
class _symbol(int):
- def __new__(self, name, doc=None, canonical=None):
+ name: str
+
+ def __new__(cls, name, doc=None, canonical=None):
"""Construct a new named symbol."""
assert isinstance(name, str)
if canonical is None:
"""
- symbols = {}
- _lock = compat.threading.Lock()
+ symbols: Dict[str, "_symbol"] = {}
+ _lock = threading.Lock()
def __new__(cls, name, doc=None, canonical=None):
with cls._lock:
"""
+ _hash: int
+
def __new__(cls, value, num, args):
interpolated = (value % args) + (
" (this warning may be suppressed after %d occurrences)" % num
super().__init_subclass__()
@classmethod
- def _wrap_w_kw(cls, fn):
- def wrap(*arg, **kw):
+ def _wrap_w_kw(cls, fn: Callable[..., Any]) -> Callable[..., Any]:
+ def wrap(*arg: Any, **kw: Any) -> Any:
return fn(*arg)
return update_wrapper(wrap, fn)
def has_compiled_ext(raise_=False):
- try:
- from sqlalchemy.cyextension import collections # noqa F401
- from sqlalchemy.cyextension import immutabledict # noqa F401
- from sqlalchemy.cyextension import processors # noqa F401
- from sqlalchemy.cyextension import resultproxy # noqa F401
- from sqlalchemy.cyextension import util # noqa F401
-
+ if HAS_CYEXTENSION:
return True
- except ImportError:
- if raise_:
- raise
+ elif raise_:
+ raise ImportError(
+ "cython extensions were expected to be installed, "
+ "but are not present"
+ )
+ else:
return False
condition.
"""
-
+import asyncio
from collections import deque
+import threading
from time import time as _time
-from .compat import threading
-from .concurrency import asyncio
from .concurrency import await_fallback
from .concurrency import await_only
from .langhelpers import memoized_property
+import typing
from typing import Any
from typing import Callable # noqa
from typing import Generic
_T = TypeVar("_T", bound=Any)
-if compat.py38:
- from typing import Literal
- from typing import Protocol
- from typing import TypedDict
+if typing.TYPE_CHECKING or not compat.py38:
+ from typing_extensions import Literal # noqa F401
+ from typing_extensions import Protocol # noqa F401
+ from typing_extensions import TypedDict # noqa F401
else:
- from typing_extensions import Literal # noqa
- from typing_extensions import Protocol # noqa
- from typing_extensions import TypedDict # noqa
+ from typing import Literal # noqa F401
+ from typing import Protocol # noqa F401
+ from typing import TypedDict # noqa F401
-if compat.py310:
- from typing import Concatenate
- from typing import ParamSpec
+if typing.TYPE_CHECKING or not compat.py310:
+ from typing_extensions import Concatenate # noqa F401
+ from typing_extensions import ParamSpec # noqa F401
else:
- from typing_extensions import Concatenate # noqa
- from typing_extensions import ParamSpec # noqa
+ from typing import Concatenate # noqa F401
+ from typing import ParamSpec # noqa F401
-_T = TypeVar("_T")
+class _TypeToInstance(Generic[_T]):
+ """describe a variable that moves between a class and an instance of
+ that class.
+ """
-class _TypeToInstance(Generic[_T]):
@overload
def __get__(self, instance: None, owner: Any) -> Type[_T]:
...
def __get__(self, instance: object, owner: Any) -> _T:
...
+ def __get__(self, instance: object, owner: Any) -> Union[Type[_T], _T]:
+ ...
+
@overload
def __set__(self, instance: None, value: Type[_T]) -> None:
...
def __set__(self, instance: object, value: _T) -> None:
...
+ def __set__(self, instance: object, value: Union[Type[_T], _T]) -> None:
+ ...
+
class ReadOnlyInstanceDescriptor(Protocol[_T]):
"""protocol representing an instance-only descriptor"""
"error::DeprecationWarning:test",
"error::DeprecationWarning:sqlalchemy"
]
+
+
+[tool.pyright]
+include = [
+ "lib/sqlalchemy/events.py",
+ "lib/sqlalchemy/exc.py",
+ "lib/sqlalchemy/log.py",
+ "lib/sqlalchemy/inspection.py",
+ "lib/sqlalchemy/schema.py",
+ "lib/sqlalchemy/types.py",
+ "lib/sqlalchemy/util/",
+]
+
+
+
+[tool.mypy]
+mypy_path = "./lib/"
+show_error_codes = true
+strict = false
+incremental = true
+
+# disabled checking
+[[tool.mypy.overrides]]
+module="sqlalchemy.*"
+ignore_errors = true
+warn_unused_ignores = false
+
+strict = true
+
+# https://github.com/python/mypy/issues/8754
+# we are a pep-561 package, so implicit-rexport should be
+# enabled
+implicit_reexport = true
+
+# individual packages or even modules should be listed here
+# with strictness-specificity set up. there's no way we are going to get
+# the whole library 100% strictly typed, so we have to tune this based on
+# the type of module or package we are dealing with
+
+# strict checking
+[[tool.mypy.overrides]]
+module = [
+ "sqlalchemy.events",
+ "sqlalchemy.events",
+ "sqlalchemy.exc",
+ "sqlalchemy.inspection",
+ "sqlalchemy.schema",
+ "sqlalchemy.types",
+]
+ignore_errors = false
+strict = true
+
+# partial checking, internals can be untyped
+[[tool.mypy.overrides]]
+module="sqlalchemy.util.*"
+ignore_errors = false
+
+# util is for internal use so we can get by without everything
+# being typed
+allow_untyped_defs = true
+check_untyped_defs = false
+allow_untyped_calls = true
+
+
lib/sqlalchemy/types.py:F401
lib/sqlalchemy/sql/expression.py:F401
-[mypy]
-mypy_path = ./lib/
-strict = True
-incremental = True
-#plugins = sqlalchemy.ext.mypy.plugin
-
-[mypy-sqlalchemy.*]
-ignore_errors = True
-
-[mypy-sqlalchemy.ext.mypy.*]
-ignore_errors = False
-
[sqla_testing]
requirement_cls = test.requirements:DefaultRequirements
profile_file = test/profiles.txt
eq_(o.difference(iter([3, 4])), util.OrderedSet([2, 5]))
eq_(o.intersection(iter([3, 4, 6])), util.OrderedSet([3, 4]))
- eq_(o.union(iter([3, 4, 6])), util.OrderedSet([2, 3, 4, 5, 6]))
+ eq_(o.union(iter([3, 4, 6])), util.OrderedSet([3, 2, 4, 5, 6]))
+
+ def test_len(self):
+ eq_(len(util.OrderedSet([1, 2, 3])), 3)
+
+ def test_eq_no_insert_order(self):
+ eq_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([2, 3, 4, 5]))
+
+ ne_(util.OrderedSet([3, 2, 4, 5]), util.OrderedSet([3, 2, 4, 5, 6]))
+
+ def test_eq_non_ordered_set(self):
+ eq_(util.OrderedSet([3, 2, 4, 5]), {2, 3, 4, 5})
+
+ ne_(util.OrderedSet([3, 2, 4, 5]), {3, 2, 4, 5, 6})
def test_repr(self):
o = util.OrderedSet([])
lambda: d.update({2: 4}),
)
if hasattr(d, "pop"):
- calls += (d.pop, d.popitem)
+ calls += (lambda: d.pop(2), d.popitem)
for m in calls:
with expect_raises_message(TypeError, "object is immutable"):
m()
cls_ = testing.db.dialect.__class__
class SomeDialect(cls_):
+ supports_statement_cache = True
+
def initialize(self, connection):
super(SomeDialect, self).initialize(connection)
m1.append("initialize")
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
from sqlalchemy.testing.schema import Table
-from sqlalchemy.util import has_compiled_ext
-from sqlalchemy.util import OrderedSet
from test.orm import _fixtures
-if has_compiled_ext():
- # cython ordered set is immutable, subclass it with a python
- # class so that its method can be replaced
- _OrderedSet = OrderedSet
-
- class OrderedSet(_OrderedSet):
- pass
-
-
class MergeTest(_fixtures.FixtureTest):
"""Session.merge() functionality"""
users,
properties={
"addresses": relationship(
- Address, backref="user", collection_class=OrderedSet
+ Address, backref="user", collection_class=set
)
},
)
u = User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=1, email_address="fred1"),
- Address(id=2, email_address="fred2"),
- ]
- ),
+ addresses={
+ Address(id=1, email_address="fred1"),
+ Address(id=2, email_address="fred2"),
+ },
)
eq_(load.called, 0)
User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=1, email_address="fred1"),
- Address(id=2, email_address="fred2"),
- ]
- ),
+ addresses={
+ Address(id=1, email_address="fred1"),
+ Address(id=2, email_address="fred2"),
+ },
),
)
users,
properties={
"addresses": relationship(
- Address, backref="user", collection_class=OrderedSet
+ Address, backref="user", collection_class=set
)
},
)
u = User(
id=None,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=None, email_address="fred1"),
- Address(id=None, email_address="fred2"),
- ]
- ),
+ addresses={
+ Address(id=None, email_address="fred1"),
+ Address(id=None, email_address="fred2"),
+ },
)
eq_(load.called, 0)
sess.query(User).one(),
User(
name="fred",
- addresses=OrderedSet(
- [
- Address(email_address="fred1"),
- Address(email_address="fred2"),
- ]
- ),
+ addresses={
+ Address(email_address="fred1"),
+ Address(email_address="fred2"),
+ },
),
)
"addresses": relationship(
Address,
backref="user",
- collection_class=OrderedSet,
- order_by=addresses.c.id,
+ collection_class=set,
cascade="all, delete-orphan",
)
},
u = User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=1, email_address="fred1"),
- Address(id=2, email_address="fred2"),
- ]
- ),
+ addresses={
+ Address(id=1, email_address="fred1"),
+ Address(id=2, email_address="fred2"),
+ },
)
sess = fixture_session()
sess.add(u)
u = User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=3, email_address="fred3"),
- Address(id=4, email_address="fred4"),
- ]
- ),
+ addresses={
+ Address(id=3, email_address="fred3"),
+ Address(id=4, email_address="fred4"),
+ },
)
u = sess.merge(u)
User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=3, email_address="fred3"),
- Address(id=4, email_address="fred4"),
- ]
- ),
+ addresses={
+ Address(id=3, email_address="fred3"),
+ Address(id=4, email_address="fred4"),
+ },
),
)
sess.flush()
User(
id=7,
name="fred",
- addresses=OrderedSet(
- [
- Address(id=3, email_address="fred3"),
- Address(id=4, email_address="fred4"),
- ]
- ),
+ addresses={
+ Address(id=3, email_address="fred3"),
+ Address(id=4, email_address="fred4"),
+ },
),
)
Address,
backref="user",
order_by=addresses.c.id,
- collection_class=OrderedSet,
+ collection_class=set,
)
},
)
u = User(
id=7,
name="fred",
- addresses=OrderedSet([a, Address(id=2, email_address="fred2")]),
+ addresses={a, Address(id=2, email_address="fred2")},
)
sess = fixture_session()
sess.add(u)
User(
id=7,
name="fred jones",
- addresses=OrderedSet(
- [
- Address(id=2, email_address="fred2"),
- Address(id=3, email_address="fred3"),
- ]
- ),
+ addresses={
+ Address(id=2, email_address="fred2"),
+ Address(id=3, email_address="fred3"),
+ },
),
)
oracle,mssql,sqlite_file: python reap_dbs.py db_idents.txt
+[testenv:pep484]
+deps=
+ greenlet != 0.4.17
+ importlib_metadata; python_version < '3.8'
+ mypy
+ pyright
+commands =
+ mypy ./lib/sqlalchemy
+ pyright
+
[testenv:mypy]
deps=
pytest>=7.0.0rc1,<8
flake8 ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py {posargs}
black --check ./lib/ ./test/ ./examples/ setup.py doc/build/conf.py
+
# command run in the github action when cext are active.
[testenv:github-cext]
deps = {[testenv]deps}