]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Properly type _generative, decorator, public_factory
authorFederico Caselli <cfederico87@gmail.com>
Wed, 22 Dec 2021 20:45:45 +0000 (21:45 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 30 Dec 2021 23:07:26 +0000 (18:07 -0500)
Good new is that pylance likes it and copies over the
singature and everything.
Bad news is that mypy does not support this yet https://github.com/python/mypy/issues/8645
Other minor bad news is that non_generative is not typed. I've tried using a protocol
like the one in the comment but the signature is not ported over by pylance, so it's
probably best to just live without it to have the correct signature.

notes from mike:  these three decorators are at the core of getting
the library to be typed, more good news is that pylance will
do all the things we like re: public_factory, see
https://github.com/microsoft/pyright/issues/2758#issuecomment-1002788656
.

For @_generative, we will likely move to using pep 673 once mypy
supports it which may be soon.  but overall having the explicit
"return self" in the methods, while a little inconvenient, makes
the typing more straightforward and locally present in the files
rather than being decided at a distance.   having "return self"
present, or not, both have problems, so maybe we will be able
to change it again if things change as far as decorator support.
As it is, I feel like we are barely squeaking by with our decorators,
the typing is already pretty out there.

Change-Id: Ic77e13fc861def76a5925331df85c0aa48d77807
References: #6810

21 files changed:
MANIFEST.in
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/mysql/expression.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/dml.py
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/orm/base.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/strategy_options.py
lib/sqlalchemy/py.typed [new file with mode: 0644]
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/ddl.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/util/compat.py
lib/sqlalchemy/util/langhelpers.py
lib/sqlalchemy/util/typing.py
setup.cfg

index 0a2c923f1d7ddac089c22fe8e2f9818e3f8a9db9..0cb613385140a78154404fe49bcc57002cea0a24 100644 (file)
@@ -7,7 +7,7 @@ recursive-include test *.py *.dat *.testpatch
 
 # include the pyx and pxd extensions, which otherwise
 # don't come in if --with-cextensions isn't specified.
-recursive-include lib *.pyx *.pxd *.txt
+recursive-include lib *.pyx *.pxd *.txt *.typed
 
 include README* AUTHORS LICENSE CHANGES* tox.ini
 prune doc/build/output
index 790733cbfdab68b44e4cda0ec3e6b492d9cf4166..af3df09226505bf930fb2d8edff4c8cbc470f406 100644 (file)
@@ -1,3 +1,11 @@
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import typing
+
 from ... import exc
 from ... import util
 from ...sql.base import _exclusive_against
@@ -12,6 +20,9 @@ from ...util.langhelpers import public_factory
 __all__ = ("Insert", "insert")
 
 
+SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
+
+
 class Insert(StandardInsert):
     """MySQL-specific implementation of INSERT.
 
@@ -70,7 +81,7 @@ class Insert(StandardInsert):
             "has an ON DUPLICATE KEY clause present"
         },
     )
-    def on_duplicate_key_update(self, *args, **kw):
+    def on_duplicate_key_update(self: SelfInsert, *args, **kw) -> SelfInsert:
         r"""
         Specifies the ON DUPLICATE KEY UPDATE clause.
 
@@ -131,6 +142,7 @@ class Insert(StandardInsert):
 
         inserted_alias = getattr(self, "inserted_alias", None)
         self._post_values_clause = OnDuplicateClause(inserted_alias, values)
+        return self
 
 
 insert = public_factory(
index 7a66e9b1428c2dd75aa7ceaad0e1d4060bea796e..77b985b6ab2eea9152195590ecfcde57402dec94 100644 (file)
@@ -1,3 +1,11 @@
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+
+import typing
+
 from ... import exc
 from ... import util
 from ...sql import coercions
@@ -8,6 +16,9 @@ from ...sql.base import _generative
 from ...sql.base import Generative
 
 
+Selfmatch = typing.TypeVar("Selfmatch", bound="match")
+
+
 class match(Generative, elements.BinaryExpression):
     """Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
 
@@ -99,7 +110,7 @@ class match(Generative, elements.BinaryExpression):
         )
 
     @_generative
-    def in_boolean_mode(self):
+    def in_boolean_mode(self: Selfmatch) -> Selfmatch:
         """Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
 
         :return: a new :class:`_mysql.match` instance with modifications
@@ -107,9 +118,10 @@ class match(Generative, elements.BinaryExpression):
         """
 
         self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
+        return self
 
     @_generative
-    def in_natural_language_mode(self):
+    def in_natural_language_mode(self: Selfmatch) -> Selfmatch:
         """Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
         expression.
 
@@ -118,9 +130,10 @@ class match(Generative, elements.BinaryExpression):
         """
 
         self.modifiers = self.modifiers.union({"mysql_natural_language": True})
+        return self
 
     @_generative
-    def with_query_expansion(self):
+    def with_query_expansion(self: Selfmatch) -> Selfmatch:
         """Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
 
         :return: a new :class:`_mysql.match` instance with modifications
@@ -128,3 +141,4 @@ class match(Generative, elements.BinaryExpression):
         """
 
         self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
+        return self
index 4451639f383ceec2f378ec092fc99e5778c0e883..aa21bd8c01c39c0c33d5880282ab4bd9d3a72d28 100644 (file)
@@ -4,6 +4,7 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
+import typing
 
 from . import ext
 from ... import util
@@ -21,6 +22,8 @@ from ...util.langhelpers import public_factory
 
 __all__ = ("Insert", "insert")
 
+SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
+
 
 class Insert(StandardInsert):
     """PostgreSQL-specific implementation of INSERT.
@@ -75,13 +78,13 @@ class Insert(StandardInsert):
     @_generative
     @_on_conflict_exclusive
     def on_conflict_do_update(
-        self,
+        self: SelfInsert,
         constraint=None,
         index_elements=None,
         index_where=None,
         set_=None,
         where=None,
-    ):
+    ) -> SelfInsert:
         r"""
         Specifies a DO UPDATE SET action for ON CONFLICT clause.
 
@@ -138,12 +141,16 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoUpdate(
             constraint, index_elements, index_where, set_, where
         )
+        return self
 
     @_generative
     @_on_conflict_exclusive
     def on_conflict_do_nothing(
-        self, constraint=None, index_elements=None, index_where=None
-    ):
+        self: SelfInsert,
+        constraint=None,
+        index_elements=None,
+        index_where=None,
+    ) -> SelfInsert:
         """
         Specifies a DO NOTHING action for ON CONFLICT clause.
 
@@ -173,6 +180,7 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoNothing(
             constraint, index_elements, index_where
         )
+        return self
 
 
 insert = public_factory(
index e4d8bd9434db398ad635becdde0bb6913617e9b3..91f3b7babc07a43347a6e7d93fbb1f91ac45ec18 100644 (file)
@@ -4,6 +4,8 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+import typing
+
 from ... import util
 from ...sql import coercions
 from ...sql import roles
@@ -18,6 +20,8 @@ from ...util.langhelpers import public_factory
 
 __all__ = ("Insert", "insert")
 
+SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
+
 
 class Insert(StandardInsert):
     """SQLite-specific implementation of INSERT.
@@ -71,12 +75,12 @@ class Insert(StandardInsert):
     @_generative
     @_on_conflict_exclusive
     def on_conflict_do_update(
-        self,
+        self: SelfInsert,
         index_elements=None,
         index_where=None,
         set_=None,
         where=None,
-    ):
+    ) -> SelfInsert:
         r"""
         Specifies a DO UPDATE SET action for ON CONFLICT clause.
 
@@ -120,10 +124,13 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoUpdate(
             index_elements, index_where, set_, where
         )
+        return self
 
     @_generative
     @_on_conflict_exclusive
-    def on_conflict_do_nothing(self, index_elements=None, index_where=None):
+    def on_conflict_do_nothing(
+        self: SelfInsert, index_elements=None, index_where=None
+    ) -> SelfInsert:
         """
         Specifies a DO NOTHING action for ON CONFLICT clause.
 
@@ -141,6 +148,7 @@ class Insert(StandardInsert):
         self._post_values_clause = OnConflictDoNothing(
             index_elements, index_where
         )
+        return self
 
 
 insert = public_factory(
index 794b9d2c77b150ce92a8bb34adace9e16c261649..50d854f826db6359adb4fc1a4b7d09b2e9a05876 100644 (file)
@@ -1702,6 +1702,7 @@ class CursorResult(BaseCursorResult, Result):
     def yield_per(self, num):
         self._yield_per = num
         self.cursor_strategy.yield_per(self, self.cursor, num)
+        return self
 
 
 ResultProxy = CursorResult
index 7d496838a6370eff659f6cb49c1a3fee629c660c..f8e255e92de22741ba58624683ffd7855c636bf1 100644 (file)
@@ -7,11 +7,11 @@
 
 """Define generic result set constructs."""
 
-
 import collections.abc as collections_abc
 import functools
 import itertools
 import operator
+import typing
 
 from .row import Row
 from .. import exc
@@ -257,6 +257,10 @@ def result_tuple(fields, extra=None):
 # filter is applied to rows.
 _NO_ROW = util.symbol("NO_ROW")
 
+SelfResultInternal = typing.TypeVar(
+    "SelfResultInternal", bound="ResultInternal"
+)
+
 
 class ResultInternal(InPlaceGenerative):
     _real_result = None
@@ -614,7 +618,9 @@ class ResultInternal(InPlaceGenerative):
             return row
 
     @_generative
-    def _column_slices(self, indexes):
+    def _column_slices(
+        self: SelfResultInternal, indexes
+    ) -> SelfResultInternal:
         real_result = self._real_result if self._real_result else self
 
         if real_result._source_supports_scalars and len(indexes) == 1:
@@ -623,6 +629,8 @@ class ResultInternal(InPlaceGenerative):
             self._generate_rows = True
             self._metadata = self._metadata._reduce(indexes)
 
+        return self
+
     @HasMemoized.memoized_attribute
     def _unique_strategy(self):
         uniques, strategy = self._unique_filter_state
@@ -668,6 +676,9 @@ class _WithKeys:
         return self._metadata.keys
 
 
+SelfResult = typing.TypeVar("SelfResult", bound="Result")
+
+
 class Result(_WithKeys, ResultInternal):
     """Represent a set of database results.
 
@@ -732,7 +743,7 @@ class Result(_WithKeys, ResultInternal):
         self._soft_close(hard=True)
 
     @_generative
-    def yield_per(self, num):
+    def yield_per(self: SelfResult, num) -> SelfResult:
         """Configure the row-fetching strategy to fetch num rows at a time.
 
         This impacts the underlying behavior of the result when iterating over
@@ -766,9 +777,10 @@ class Result(_WithKeys, ResultInternal):
 
         """
         self._yield_per = num
+        return self
 
     @_generative
-    def unique(self, strategy=None):
+    def unique(self: SelfResult, strategy=None) -> SelfResult:
         """Apply unique filtering to the objects returned by this
         :class:`_engine.Result`.
 
@@ -806,8 +818,11 @@ class Result(_WithKeys, ResultInternal):
 
         """
         self._unique_filter_state = (set(), strategy)
+        return self
 
-    def columns(self, *col_expressions):
+    def columns(
+        self: SelfResultInternal, *col_expressions
+    ) -> SelfResultInternal:
         r"""Establish the columns that should be returned in each row.
 
         This method may be used to limit the columns returned as well
@@ -845,7 +860,7 @@ class Result(_WithKeys, ResultInternal):
         """
         return self._column_slices(col_expressions)
 
-    def scalars(self, index=0):
+    def scalars(self, index=0) -> "ScalarResult":
         """Return a :class:`_result.ScalarResult` filtering object which
         will return single elements rather than :class:`_row.Row` objects.
 
@@ -892,7 +907,7 @@ class Result(_WithKeys, ResultInternal):
             )
         return self._metadata._row_as_tuple_getter(keys)
 
-    def mappings(self):
+    def mappings(self) -> "MappingResult":
         """Apply a mappings filter to returned rows, returning an instance of
         :class:`_result.MappingResult`.
 
@@ -1653,6 +1668,11 @@ def null_result():
     return IteratorResult(SimpleResultMetaData([]), iter([]))
 
 
+SelfChunkedIteratorResult = typing.TypeVar(
+    "SelfChunkedIteratorResult", bound="ChunkedIteratorResult"
+)
+
+
 class ChunkedIteratorResult(IteratorResult):
     """An :class:`.IteratorResult` that works from an iterator-producing callable.
 
@@ -1684,7 +1704,9 @@ class ChunkedIteratorResult(IteratorResult):
         self.dynamic_yield_per = dynamic_yield_per
 
     @_generative
-    def yield_per(self, num):
+    def yield_per(
+        self: SelfChunkedIteratorResult, num
+    ) -> SelfChunkedIteratorResult:
         # TODO: this throws away the iterator which may be holding
         # onto a chunk.   the yield_per cannot be changed once any
         # rows have been fetched.   either find a way to enforce this,
@@ -1693,6 +1715,7 @@ class ChunkedIteratorResult(IteratorResult):
 
         self._yield_per = num
         self.iterator = itertools.chain.from_iterable(self.chunks(num))
+        return self
 
     def _soft_close(self, **kw):
         super(ChunkedIteratorResult, self)._soft_close(**kw)
index c79592625a6c0fe811cea9110b176a13e7932095..31dceb0658365eb201109fb6eedfda85ce6cc978 100644 (file)
 """
 
 import operator
+import typing
 
 from . import exc
 from .. import exc as sa_exc
 from .. import inspection
 from .. import util
+from ..util import typing as compat_typing
 
 
 PASSIVE_NO_RESULT = util.symbol(
@@ -221,13 +223,25 @@ _DEFER_FOR_STATE = util.symbol("DEFER_FOR_STATE")
 _RAISE_FOR_STATE = util.symbol("RAISE_FOR_STATE")
 
 
-def _assertions(*assertions):
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Args = compat_typing.ParamSpec("_Args")
+_Self = typing.TypeVar("_Self")
+
+
+def _assertions(
+    *assertions,
+) -> typing.Callable[
+    [typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self]],
+    typing.Callable[compat_typing.Concatenate[_Fn, _Args], _Self],
+]:
     @util.decorator
-    def generate(fn, *args, **kw):
-        self = args[0]
+    def generate(
+        fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
+    ) -> _Self:
         for assertion in assertions:
             assertion(self, fn.__name__)
-        fn(self, *args[1:], **kw)
+        fn(self, *args, **kw)
+        return self
 
     return generate
 
index a2e247f1474ee1b53c3e1e2c4f4150e2896e8908..0e99139b4bd34f855c90aedc4675211a16782001 100644 (file)
@@ -22,6 +22,7 @@ import collections.abc as collections_abc
 import itertools
 import operator
 import types
+import typing
 
 from . import exc as orm_exc
 from . import interfaces
@@ -72,6 +73,8 @@ from ..sql.visitors import InternalTraversal
 
 __all__ = ["Query", "QueryContext", "aliased"]
 
+SelfQuery = typing.TypeVar("SelfQuery", bound="Query")
+
 
 @inspection._self_inspects
 @log.class_logger
@@ -239,8 +242,9 @@ class Query(
         self._from_obj = tuple(fa)
 
     @_generative
-    def _set_lazyload_from(self, state):
+    def _set_lazyload_from(self: SelfQuery, state) -> SelfQuery:
         self.load_options += {"_lazy_loaded_from": state}
+        return self
 
     def _get_condition(self):
         return self._no_criterion_condition(
@@ -617,7 +621,7 @@ class Query(
         )
 
     @_generative
-    def only_return_tuples(self, value):
+    def only_return_tuples(self: SelfQuery, value) -> SelfQuery:
         """When set to True, the query results will always be a tuple.
 
         This is specifically for single element queries. The default is False.
@@ -630,6 +634,7 @@ class Query(
 
         """
         self.load_options += dict(_only_return_tuples=value)
+        return self
 
     @property
     def is_single_entity(self):
@@ -658,7 +663,7 @@ class Query(
         )
 
     @_generative
-    def enable_eagerloads(self, value):
+    def enable_eagerloads(self: SelfQuery, value) -> SelfQuery:
         """Control whether or not eager joins and subqueries are
         rendered.
 
@@ -674,10 +679,12 @@ class Query(
 
         """
         self._compile_options += {"_enable_eagerloads": value}
+        return self
 
     @_generative
-    def _with_compile_options(self, **opt):
+    def _with_compile_options(self: SelfQuery, **opt) -> SelfQuery:
         self._compile_options += opt
+        return self
 
     @util.deprecated_20(
         ":meth:`_orm.Query.with_labels` and :meth:`_orm.Query.apply_labels`",
@@ -735,7 +742,7 @@ class Query(
         return self
 
     @_generative
-    def enable_assertions(self, value):
+    def enable_assertions(self: SelfQuery, value) -> SelfQuery:
         """Control whether assertions are generated.
 
         When set to False, the returned Query will
@@ -755,6 +762,7 @@ class Query(
 
         """
         self._enable_assertions = value
+        return self
 
     @property
     def whereclause(self):
@@ -770,7 +778,7 @@ class Query(
         )
 
     @_generative
-    def _with_current_path(self, path):
+    def _with_current_path(self: SelfQuery, path) -> SelfQuery:
         """indicate that this query applies to objects loaded
         within a certain path.
 
@@ -780,6 +788,7 @@ class Query(
 
         """
         self._compile_options += {"_current_path": path}
+        return self
 
     @_generative
     @_assertions(_no_clauseelement_condition)
@@ -788,8 +797,8 @@ class Query(
         alternative="Use the orm.with_polymorphic() standalone function",
     )
     def with_polymorphic(
-        self, cls_or_mappers, selectable=None, polymorphic_on=None
-    ):
+        self: SelfQuery, cls_or_mappers, selectable=None, polymorphic_on=None
+    ) -> SelfQuery:
         """Load columns for inheriting classes.
 
         This is a legacy method which is replaced by the
@@ -828,9 +837,10 @@ class Query(
         self._compile_options = self._compile_options.add_to_element(
             "_with_polymorphic_adapt_map", ((entity, inspect(wp)),)
         )
+        return self
 
     @_generative
-    def yield_per(self, count):
+    def yield_per(self: SelfQuery, count) -> SelfQuery:
         r"""Yield only ``count`` rows at a time.
 
         The purpose of this method is when fetching very large result sets
@@ -849,6 +859,7 @@ class Query(
 
         """
         self.load_options += {"_yield_per": count}
+        return self
 
     @util.deprecated_20(
         ":meth:`_orm.Query.get`",
@@ -974,7 +985,7 @@ class Query(
         return self._compile_options._current_path
 
     @_generative
-    def correlate(self, *fromclauses):
+    def correlate(self: SelfQuery, *fromclauses) -> SelfQuery:
         """Return a :class:`.Query` construct which will correlate the given
         FROM clauses to that of an enclosing :class:`.Query` or
         :func:`~.expression.select`.
@@ -1002,9 +1013,10 @@ class Query(
             self._correlate = set(self._correlate).union(
                 coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
+        return self
 
     @_generative
-    def autoflush(self, setting):
+    def autoflush(self: SelfQuery, setting) -> SelfQuery:
         """Return a Query with a specific 'autoflush' setting.
 
         As of SQLAlchemy 1.4, the :meth:`_orm.Query.autoflush` method
@@ -1014,9 +1026,10 @@ class Query(
 
         """
         self.load_options += {"_autoflush": setting}
+        return self
 
     @_generative
-    def populate_existing(self):
+    def populate_existing(self: SelfQuery) -> SelfQuery:
         """Return a :class:`_query.Query`
         that will expire and refresh all instances
         as they are loaded, or reused from the current :class:`.Session`.
@@ -1028,9 +1041,10 @@ class Query(
 
         """
         self.load_options += {"_populate_existing": True}
+        return self
 
     @_generative
-    def _with_invoke_all_eagers(self, value):
+    def _with_invoke_all_eagers(self: SelfQuery, value) -> SelfQuery:
         """Set the 'invoke all eagers' flag which causes joined- and
         subquery loaders to traverse into already-loaded related objects
         and collections.
@@ -1039,6 +1053,7 @@ class Query(
 
         """
         self.load_options += {"_invoke_all_eagers": value}
+        return self
 
     @util.deprecated_20(
         ":meth:`_orm.Query.with_parent`",
@@ -1103,7 +1118,7 @@ class Query(
         return self.filter(with_parent(instance, property, entity_zero.entity))
 
     @_generative
-    def add_entity(self, entity, alias=None):
+    def add_entity(self: SelfQuery, entity, alias=None) -> SelfQuery:
         """add a mapped entity to the list of result columns
         to be returned."""
 
@@ -1118,9 +1133,10 @@ class Query(
                 roles.ColumnsClauseRole, entity, apply_propagate_attrs=self
             )
         )
+        return self
 
     @_generative
-    def with_session(self, session):
+    def with_session(self: SelfQuery, session) -> SelfQuery:
         """Return a :class:`_query.Query` that will use the given
         :class:`.Session`.
 
@@ -1144,6 +1160,7 @@ class Query(
         """
 
         self.session = session
+        return self
 
     @util.deprecated_20(
         ":meth:`_query.Query.from_self`",
@@ -1344,11 +1361,14 @@ class Query(
         return q
 
     @_generative
-    def _set_enable_single_crit(self, val):
+    def _set_enable_single_crit(self: SelfQuery, val) -> SelfQuery:
         self._compile_options += {"_enable_single_crit": val}
+        return self
 
     @_generative
-    def _from_selectable(self, fromclause, set_entity_from=True):
+    def _from_selectable(
+        self: SelfQuery, fromclause, set_entity_from=True
+    ) -> SelfQuery:
         for attr in (
             "_where_criteria",
             "_order_by_clauses",
@@ -1376,6 +1396,7 @@ class Query(
         # "oldstyle" tests that rely on this and the corresponding
         # "newtyle" that do not.
         self._compile_options += {"_orm_only_from_obj_alias": False}
+        return self
 
     @util.deprecated(
         "1.4",
@@ -1417,7 +1438,7 @@ class Query(
             return None
 
     @_generative
-    def with_entities(self, *entities):
+    def with_entities(self: SelfQuery, *entities) -> SelfQuery:
         r"""Return a new :class:`_query.Query`
         replacing the SELECT list with the
         given entities.
@@ -1443,9 +1464,10 @@ class Query(
         """
         _MemoizedSelectEntities._generate_for_statement(self)
         self._set_entities(entities)
+        return self
 
     @_generative
-    def add_columns(self, *column):
+    def add_columns(self: SelfQuery, *column) -> SelfQuery:
         """Add one or more column expressions to the list
         of result columns to be returned."""
 
@@ -1460,6 +1482,7 @@ class Query(
             )
             for c in column
         )
+        return self
 
     @util.deprecated(
         "1.4",
@@ -1475,7 +1498,7 @@ class Query(
         return self.add_columns(column)
 
     @_generative
-    def options(self, *args):
+    def options(self: SelfQuery, *args) -> SelfQuery:
         """Return a new :class:`_query.Query` object,
         applying the given list of
         mapper options.
@@ -1502,6 +1525,7 @@ class Query(
                     opt.process_query(self)
 
         self._with_options += opts
+        return self
 
     def with_transformation(self, fn):
         """Return a new :class:`_query.Query` object transformed by
@@ -1534,7 +1558,7 @@ class Query(
         return self._execution_options
 
     @_generative
-    def execution_options(self, **kwargs):
+    def execution_options(self: SelfQuery, **kwargs) -> SelfQuery:
         """Set non-SQL options which take effect during execution.
 
         Options allowed here include all of those accepted by
@@ -1569,16 +1593,17 @@ class Query(
 
         """
         self._execution_options = self._execution_options.union(kwargs)
+        return self
 
     @_generative
     def with_for_update(
-        self,
+        self: SelfQuery,
         read=False,
         nowait=False,
         of=None,
         skip_locked=False,
         key_share=False,
-    ):
+    ) -> SelfQuery:
         """return a new :class:`_query.Query`
         with the specified options for the
         ``FOR UPDATE`` clause.
@@ -1633,9 +1658,10 @@ class Query(
             skip_locked=skip_locked,
             key_share=key_share,
         )
+        return self
 
     @_generative
-    def params(self, *args, **kwargs):
+    def params(self: SelfQuery, *args, **kwargs) -> SelfQuery:
         r"""Add values for bind parameters which may have been
         specified in filter().
 
@@ -1653,8 +1679,9 @@ class Query(
                 "which is a dictionary."
             )
         self._params = self._params.union(kwargs)
+        return self
 
-    def where(self, *criterion):
+    def where(self: SelfQuery, *criterion) -> SelfQuery:
         """A synonym for :meth:`.Query.filter`.
 
         .. versionadded:: 1.4
@@ -1664,7 +1691,7 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def filter(self, *criterion):
+    def filter(self: SelfQuery, *criterion) -> SelfQuery:
         r"""Apply the given filtering criterion to a copy
         of this :class:`_query.Query`, using SQL expressions.
 
@@ -1702,6 +1729,7 @@ class Query(
             # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
             self._where_criteria += (criterion,)
+        return self
 
     @util.memoized_property
     def _last_joined_entity(self):
@@ -1795,7 +1823,7 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def order_by(self, *clauses):
+    def order_by(self: SelfQuery, *clauses) -> SelfQuery:
         """Apply one or more ORDER BY criteria to the query and return
         the newly resulting :class:`_query.Query`.
 
@@ -1841,10 +1869,11 @@ class Query(
             # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
             self._order_by_clauses += criterion
+        return self
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def group_by(self, *clauses):
+    def group_by(self: SelfQuery, *clauses) -> SelfQuery:
         """Apply one or more GROUP BY criterion to the query and return
         the newly resulting :class:`_query.Query`.
 
@@ -1884,10 +1913,11 @@ class Query(
             # legacy ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
             self._group_by_clauses += criterion
+        return self
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def having(self, criterion):
+    def having(self: SelfQuery, criterion) -> SelfQuery:
         r"""Apply a HAVING criterion to the query and return the
         newly resulting :class:`_query.Query`.
 
@@ -1909,6 +1939,7 @@ class Query(
                 roles.WhereHavingRole, criterion, apply_propagate_attrs=self
             ),
         )
+        return self
 
     def _set_op(self, expr_fn, *q):
         return self._from_selectable(expr_fn(*([self] + list(q))).subquery())
@@ -2005,7 +2036,7 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition, _no_limit_offset)
-    def join(self, target, *props, **kwargs):
+    def join(self: SelfQuery, target, *props, **kwargs) -> SelfQuery:
         r"""Create a SQL JOIN against this :class:`_query.Query`
         object's criterion
         and apply generatively, returning the newly resulting
@@ -2405,6 +2436,7 @@ class Query(
         self._legacy_setup_joins += joins_to_add
 
         self.__dict__.pop("_last_joined_entity", None)
+        return self
 
     def outerjoin(self, target, *props, **kwargs):
         """Create a left outer join against this ``Query`` object's criterion
@@ -2418,7 +2450,7 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def reset_joinpoint(self):
+    def reset_joinpoint(self: SelfQuery) -> SelfQuery:
         """Return a new :class:`.Query`, where the "join point" has
         been reset back to the base FROM entities of the query.
 
@@ -2430,10 +2462,11 @@ class Query(
         """
         self._last_joined_entity = None
         self._aliased_generation = None
+        return self
 
     @_generative
     @_assertions(_no_clauseelement_condition)
-    def select_from(self, *from_obj):
+    def select_from(self: SelfQuery, *from_obj) -> SelfQuery:
         r"""Set the FROM clause of this :class:`.Query` explicitly.
 
         :meth:`.Query.select_from` is often used in conjunction with
@@ -2479,6 +2512,7 @@ class Query(
         """
 
         self._set_select_from(from_obj, False)
+        return self
 
     @util.deprecated_20(
         ":meth:`_orm.Query.select_entity_from`",
@@ -2486,7 +2520,7 @@ class Query(
     )
     @_generative
     @_assertions(_no_clauseelement_condition)
-    def select_entity_from(self, from_obj):
+    def select_entity_from(self: SelfQuery, from_obj) -> SelfQuery:
         r"""Set the FROM clause of this :class:`_query.Query` to a
         core selectable, applying it as a replacement FROM clause
         for corresponding mapped entities.
@@ -2600,6 +2634,7 @@ class Query(
 
         self._set_select_from([from_obj], True)
         self._compile_options += {"_enable_single_crit": False}
+        return self
 
     def __getitem__(self, item):
         return orm_util._getitem(
@@ -2610,7 +2645,7 @@ class Query(
 
     @_generative
     @_assertions(_no_statement_condition)
-    def slice(self, start, stop):
+    def slice(self: SelfQuery, start, stop) -> SelfQuery:
         """Computes the "slice" of the :class:`_query.Query` represented by
         the given indices and returns the resulting :class:`_query.Query`.
 
@@ -2644,28 +2679,31 @@ class Query(
         self._limit_clause, self._offset_clause = sql_util._make_slice(
             self._limit_clause, self._offset_clause, start, stop
         )
+        return self
 
     @_generative
     @_assertions(_no_statement_condition)
-    def limit(self, limit):
+    def limit(self: SelfQuery, limit) -> SelfQuery:
         """Apply a ``LIMIT`` to the query and return the newly resulting
         ``Query``.
 
         """
         self._limit_clause = sql_util._offset_or_limit_clause(limit)
+        return self
 
     @_generative
     @_assertions(_no_statement_condition)
-    def offset(self, offset):
+    def offset(self: SelfQuery, offset) -> SelfQuery:
         """Apply an ``OFFSET`` to the query and return the newly resulting
         ``Query``.
 
         """
         self._offset_clause = sql_util._offset_or_limit_clause(offset)
+        return self
 
     @_generative
     @_assertions(_no_statement_condition)
-    def distinct(self, *expr):
+    def distinct(self: SelfQuery, *expr) -> SelfQuery:
         r"""Apply a ``DISTINCT`` to the query and return the newly resulting
         ``Query``.
 
@@ -2701,6 +2739,7 @@ class Query(
             )
         else:
             self._distinct = True
+        return self
 
     def all(self):
         """Return the results represented by this :class:`_query.Query`
@@ -2722,7 +2761,7 @@ class Query(
 
     @_generative
     @_assertions(_no_clauseelement_condition)
-    def from_statement(self, statement):
+    def from_statement(self: SelfQuery, statement) -> SelfQuery:
         """Execute the given SELECT statement and return results.
 
         This method bypasses all internal statement compilation, and the
@@ -2739,6 +2778,7 @@ class Query(
             roles.SelectStatementRole, statement, apply_propagate_attrs=self
         )
         self._statement = statement
+        return self
 
     def first(self):
         """Return the first result of this ``Query`` or
index 27c12658c0d75001ede3f6ff6ed81c38ea212918..5dc2d393a5af3bc0332a6b33cbc4766696baa3bc 100644 (file)
@@ -12,6 +12,7 @@ import typing
 from typing import Any
 from typing import cast
 from typing import Mapping
+from typing import NoReturn
 from typing import Tuple
 from typing import Union
 
@@ -41,6 +42,8 @@ _COLUMN_TOKEN = "column"
 if typing.TYPE_CHECKING:
     from .mapper import Mapper
 
+Self_AbstractLoad = typing.TypeVar("Self_AbstractLoad", bound="_AbstractLoad")
+
 
 class _AbstractLoad(Generative, LoaderOption):
     _is_strategy_option = True
@@ -658,13 +661,13 @@ class _AbstractLoad(Generative, LoaderOption):
 
     @_generative
     def _set_relationship_strategy(
-        self,
+        self: Self_AbstractLoad,
         attr,
         strategy,
         propagate_to_loaders=True,
         opts=None,
         _reconcile_to_other=None,
-    ) -> "_AbstractLoad":
+    ) -> Self_AbstractLoad:
         strategy = self._coerce_strat(strategy)
 
         self._clone_for_bind_strategy(
@@ -679,8 +682,8 @@ class _AbstractLoad(Generative, LoaderOption):
 
     @_generative
     def _set_column_strategy(
-        self, attrs, strategy, opts=None
-    ) -> "_AbstractLoad":
+        self: Self_AbstractLoad, attrs, strategy, opts=None
+    ) -> Self_AbstractLoad:
         strategy = self._coerce_strat(strategy)
 
         self._clone_for_bind_strategy(
@@ -694,8 +697,8 @@ class _AbstractLoad(Generative, LoaderOption):
 
     @_generative
     def _set_generic_strategy(
-        self, attrs, strategy, _reconcile_to_other=None
-    ) -> "_AbstractLoad":
+        self: Self_AbstractLoad, attrs, strategy, _reconcile_to_other=None
+    ) -> Self_AbstractLoad:
         strategy = self._coerce_strat(strategy)
         self._clone_for_bind_strategy(
             attrs,
@@ -707,7 +710,9 @@ class _AbstractLoad(Generative, LoaderOption):
         return self
 
     @_generative
-    def _set_class_strategy(self, strategy, opts) -> "_AbstractLoad":
+    def _set_class_strategy(
+        self: Self_AbstractLoad, strategy, opts
+    ) -> Self_AbstractLoad:
         strategy = self._coerce_strat(strategy)
 
         self._clone_for_bind_strategy(None, strategy, None, opts=opts)
@@ -722,7 +727,7 @@ class _AbstractLoad(Generative, LoaderOption):
         """
         raise NotImplementedError()
 
-    def options(self, *opts) -> "_AbstractLoad":
+    def options(self: Self_AbstractLoad, *opts) -> NoReturn:
         r"""Apply a series of options as sub-options to this
         :class:`_orm._AbstractLoad` object.
 
@@ -831,6 +836,9 @@ class _AbstractLoad(Generative, LoaderOption):
         return to_chop[i + 1 :]
 
 
+SelfLoad = typing.TypeVar("SelfLoad", bound="Load")
+
+
 class Load(_AbstractLoad):
     """Represents loader options which modify the state of a
     ORM-enabled :class:`_sql.Select` or a legacy :class:`_query.Query` in
@@ -1003,7 +1011,7 @@ class Load(_AbstractLoad):
             parent.context += cloned.context
 
     @_generative
-    def options(self, *opts) -> "_AbstractLoad":
+    def options(self: SelfLoad, *opts) -> SelfLoad:
         r"""Apply a series of options as sub-options to this
         :class:`_orm.Load`
         object.
@@ -1129,6 +1137,9 @@ class Load(_AbstractLoad):
         self.path = PathRegistry.deserialize(self.path)
 
 
+SelfWildcardLoad = typing.TypeVar("SelfWildcardLoad", bound="_WildcardLoad")
+
+
 class _WildcardLoad(_AbstractLoad):
     """represent a standalone '*' load operation"""
 
@@ -1177,7 +1188,7 @@ class _WildcardLoad(_AbstractLoad):
         if opts:
             self.local_opts = util.immutabledict(opts)
 
-    def options(self, *opts) -> "_AbstractLoad":
+    def options(self: SelfWildcardLoad, *opts) -> SelfWildcardLoad:
         raise NotImplementedError("Star option does not support sub-options")
 
     def _apply_to_parent(self, parent):
@@ -1986,7 +1997,7 @@ class _ClassStrategyLoad(_LoadElement):
         return [("loader", cast(PathRegistry, effective_path).natural_path)]
 
 
-def _generate_from_keys(meth, keys, chained, kw):
+def _generate_from_keys(meth, keys, chained, kw) -> _AbstractLoad:
 
     lead_element = None
 
@@ -2041,6 +2052,7 @@ def _generate_from_keys(meth, keys, chained, kw):
                 else:
                     lead_element = meth(lead_element, attr, **kw)
 
+    assert lead_element
     return lead_element
 
 
@@ -2097,12 +2109,12 @@ See :func:`_orm.{fn.__name__}` for usage examples.
 
 
 @loader_unbound_fn
-def contains_eager(*keys, **kw):
+def contains_eager(*keys, **kw) -> _AbstractLoad:
     return _generate_from_keys(Load.contains_eager, keys, True, kw)
 
 
 @loader_unbound_fn
-def load_only(*attrs):
+def load_only(*attrs) -> _AbstractLoad:
     # TODO: attrs against different classes.  we likely have to
     # add some extra state to Load of some kind
     _, lead_element, _ = _parse_attr_argument(attrs[0])
@@ -2110,47 +2122,47 @@ def load_only(*attrs):
 
 
 @loader_unbound_fn
-def joinedload(*keys, **kw):
+def joinedload(*keys, **kw) -> _AbstractLoad:
     return _generate_from_keys(Load.joinedload, keys, False, kw)
 
 
 @loader_unbound_fn
-def subqueryload(*keys):
+def subqueryload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.subqueryload, keys, False, {})
 
 
 @loader_unbound_fn
-def selectinload(*keys):
+def selectinload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.selectinload, keys, False, {})
 
 
 @loader_unbound_fn
-def lazyload(*keys):
+def lazyload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.lazyload, keys, False, {})
 
 
 @loader_unbound_fn
-def immediateload(*keys):
+def immediateload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.immediateload, keys, False, {})
 
 
 @loader_unbound_fn
-def noload(*keys):
+def noload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.noload, keys, False, {})
 
 
 @loader_unbound_fn
-def raiseload(*keys, **kw):
+def raiseload(*keys, **kw) -> _AbstractLoad:
     return _generate_from_keys(Load.raiseload, keys, False, kw)
 
 
 @loader_unbound_fn
-def defaultload(*keys):
+def defaultload(*keys) -> _AbstractLoad:
     return _generate_from_keys(Load.defaultload, keys, False, {})
 
 
 @loader_unbound_fn
-def defer(key, *addl_attrs, **kw):
+def defer(key, *addl_attrs, **kw) -> _AbstractLoad:
     if addl_attrs:
         util.warn_deprecated(
             "The *addl_attrs on orm.defer is deprecated.  Please use "
@@ -2162,7 +2174,7 @@ def defer(key, *addl_attrs, **kw):
 
 
 @loader_unbound_fn
-def undefer(key, *addl_attrs):
+def undefer(key, *addl_attrs) -> _AbstractLoad:
     if addl_attrs:
         util.warn_deprecated(
             "The *addl_attrs on orm.undefer is deprecated.  Please use "
@@ -2174,19 +2186,19 @@ def undefer(key, *addl_attrs):
 
 
 @loader_unbound_fn
-def undefer_group(name):
+def undefer_group(name) -> _AbstractLoad:
     element = _WildcardLoad()
     return element.undefer_group(name)
 
 
 @loader_unbound_fn
-def with_expression(key, expression):
+def with_expression(key, expression) -> _AbstractLoad:
     return _generate_from_keys(
         Load.with_expression, (key,), False, {"expression": expression}
     )
 
 
 @loader_unbound_fn
-def selectin_polymorphic(base_cls, classes):
+def selectin_polymorphic(base_cls, classes) -> _AbstractLoad:
     ul = Load(base_cls)
     return ul.selectin_polymorphic(classes)
diff --git a/lib/sqlalchemy/py.typed b/lib/sqlalchemy/py.typed
new file mode 100644 (file)
index 0000000..e69de29
index 18765143551d81204dfcaa365e6e9b2e5af4580a..30d5892585dda119aae343479dc4ada758d9d0ca 100644 (file)
@@ -16,6 +16,7 @@ import itertools
 from itertools import zip_longest
 import operator
 import re
+import typing
 
 from . import roles
 from . import visitors
@@ -29,6 +30,7 @@ from .. import exc
 from .. import util
 from ..util import HasMemoized
 from ..util import hybridmethod
+from ..util import typing as compat_typing
 
 try:
     from sqlalchemy.cyextension.util import prefix_anon_map  # noqa
@@ -42,6 +44,10 @@ type_api = None
 
 NO_ARG = util.symbol("NO_ARG")
 
+# if I use sqlalchemy.util.typing, which has the exact same
+# symbols, mypy reports: "error: _Fn? not callable"
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+
 
 class Immutable:
     """mark a ClauseElement as 'immutable' when expressions are cloned."""
@@ -101,7 +107,16 @@ def _select_iterables(elements):
     )
 
 
-def _generative(fn):
+_Self = typing.TypeVar("_Self", bound="_GenerativeType")
+_Args = compat_typing.ParamSpec("_Args")
+
+
+class _GenerativeType(compat_typing.Protocol):
+    def _generate(self: "_Self") -> "_Self":
+        ...
+
+
+def _generative(fn: _Fn) -> _Fn:
     """non-caching _generative() decorator.
 
     This is basically the legacy decorator that copies the object and
@@ -110,14 +125,14 @@ def _generative(fn):
     """
 
     @util.decorator
-    def _generative(fn, self, *args, **kw):
+    def _generative(
+        fn: _Fn, self: _Self, *args: _Args.args, **kw: _Args.kwargs
+    ) -> _Self:
         """Mark a method as generative."""
 
         self = self._generate()
         x = fn(self, *args, **kw)
-        assert (
-            x is None or x is self
-        ), "generative methods must return None or self"
+        assert x is self, "generative methods must return self"
         return self
 
     decorated = _generative(fn)
@@ -788,6 +803,9 @@ class ExecutableOption(HasCopyInternals):
         return c
 
 
+SelfExecutable = typing.TypeVar("SelfExecutable", bound="Executable")
+
+
 class Executable(roles.StatementRole, Generative):
     """Mark a :class:`_expression.ClauseElement` as supporting execution.
 
@@ -824,7 +842,7 @@ class Executable(roles.StatementRole, Generative):
         return self.__visit_name__
 
     @_generative
-    def options(self, *options):
+    def options(self: SelfExecutable, *options) -> SelfExecutable:
         """Apply options to this statement.
 
         In the general sense, options are any kind of Python object
@@ -857,9 +875,12 @@ class Executable(roles.StatementRole, Generative):
             coercions.expect(roles.ExecutableOptionRole, opt)
             for opt in options
         )
+        return self
 
     @_generative
-    def _set_compile_options(self, compile_options):
+    def _set_compile_options(
+        self: SelfExecutable, compile_options
+    ) -> SelfExecutable:
         """Assign the compile options to a new value.
 
         :param compile_options: appropriate CacheableOptions structure
@@ -867,15 +888,21 @@ class Executable(roles.StatementRole, Generative):
         """
 
         self._compile_options = compile_options
+        return self
 
     @_generative
-    def _update_compile_options(self, options):
+    def _update_compile_options(
+        self: SelfExecutable, options
+    ) -> SelfExecutable:
         """update the _compile_options with new keys."""
 
         self._compile_options += options
+        return self
 
     @_generative
-    def _add_context_option(self, callable_, cache_args):
+    def _add_context_option(
+        self: SelfExecutable, callable_, cache_args
+    ) -> SelfExecutable:
         """Add a context option to this statement.
 
         These are callable functions that will
@@ -887,9 +914,10 @@ class Executable(roles.StatementRole, Generative):
 
         """
         self._with_context_options += ((callable_, cache_args),)
+        return self
 
     @_generative
-    def execution_options(self, **kw):
+    def execution_options(self: SelfExecutable, **kw) -> SelfExecutable:
         """Set non-SQL options for the statement which take effect during
         execution.
 
@@ -1004,6 +1032,7 @@ class Executable(roles.StatementRole, Generative):
                 "on Connection.execution_options(), not per statement."
             )
         self._execution_options = self._execution_options.union(kw)
+        return self
 
     def get_execution_options(self):
         """Get the non-SQL options which will take effect during execution.
index f415aeaff0344df92e6f9c2d5e2fa2cfcba676d4..ad22fa6da13988bfcaebbac9193e069b69333e53 100644 (file)
@@ -9,6 +9,7 @@ Provides the hierarchy of DDL-defining schema items as well as routines
 to invoke them for a create/drop call.
 
 """
+import typing
 
 from . import roles
 from .base import _generative
@@ -34,6 +35,9 @@ class _DDLCompiles(ClauseElement):
         raise NotImplementedError()
 
 
+SelfDDLElement = typing.TypeVar("SelfDDLElement", bound="DDLElement")
+
+
 class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
     """Base class for DDL expression constructs.
 
@@ -77,7 +81,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
         )
 
     @_generative
-    def against(self, target):
+    def against(self: SelfDDLElement, target) -> SelfDDLElement:
         """Return a copy of this :class:`_schema.DDLElement` which will include
         the given target.
 
@@ -111,9 +115,12 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
         """
 
         self.target = target
+        return self
 
     @_generative
-    def execute_if(self, dialect=None, callable_=None, state=None):
+    def execute_if(
+        self: SelfDDLElement, dialect=None, callable_=None, state=None
+    ) -> SelfDDLElement:
         r"""Return a callable that will execute this
         :class:`_ddl.DDLElement` conditionally within an event handler.
 
@@ -181,6 +188,7 @@ class DDLElement(roles.DDLRole, Executable, _DDLCompiles):
         self.dialect = dialect
         self.callable_ = callable_
         self.state = state
+        return self
 
     def _should_execute(self, target, bind, **kw):
         if isinstance(self.dialect, str):
index 7b3716a68961464aabc2a19880c09cc03a6ea108..ab0a05651b0fbc61c82c9d56919ae856e1ce8f7e 100644 (file)
@@ -10,8 +10,8 @@ Provide :class:`_expression.Insert`, :class:`_expression.Update` and
 
 """
 import collections.abc as collections_abc
+import typing
 
-from sqlalchemy.types import NullType
 from . import coercions
 from . import roles
 from . import util as sql_util
@@ -30,6 +30,7 @@ from .elements import Null
 from .selectable import HasCTE
 from .selectable import HasPrefixes
 from .selectable import ReturnsRows
+from .sqltypes import NullType
 from .visitors import InternalTraversal
 from .. import exc
 from .. import util
@@ -210,6 +211,9 @@ class DeleteDMLState(DMLState):
         self._extra_froms = self._make_extra_froms(statement)
 
 
+SelfUpdateBase = typing.TypeVar("SelfUpdateBase", bound="UpdateBase")
+
+
 class UpdateBase(
     roles.DMLRole,
     HasCTE,
@@ -313,7 +317,7 @@ class UpdateBase(
         )
 
     @_generative
-    def with_dialect_options(self, **opt):
+    def with_dialect_options(self: SelfUpdateBase, **opt) -> SelfUpdateBase:
         """Add dialect options to this INSERT/UPDATE/DELETE object.
 
         e.g.::
@@ -326,6 +330,7 @@ class UpdateBase(
 
         """
         self._validate_dialect_kwargs(opt)
+        return self
 
     def _validate_dialect_kwargs_deprecated(self, dialect_kw):
         util.warn_deprecated_20(
@@ -337,7 +342,7 @@ class UpdateBase(
         self._validate_dialect_kwargs(dialect_kw)
 
     @_generative
-    def returning(self, *cols):
+    def returning(self: SelfUpdateBase, *cols) -> SelfUpdateBase:
         r"""Add a :term:`RETURNING` or equivalent clause to this statement.
 
         e.g.:
@@ -414,6 +419,7 @@ class UpdateBase(
         self._returning += tuple(
             coercions.expect(roles.ColumnsClauseRole, c) for c in cols
         )
+        return self
 
     @property
     def _all_selected_columns(self):
@@ -433,7 +439,9 @@ class UpdateBase(
         ).as_immutable()
 
     @_generative
-    def with_hint(self, text, selectable=None, dialect_name="*"):
+    def with_hint(
+        self: SelfUpdateBase, text, selectable=None, dialect_name="*"
+    ) -> SelfUpdateBase:
         """Add a table hint for a single table to this
         INSERT/UPDATE/DELETE statement.
 
@@ -467,6 +475,10 @@ class UpdateBase(
             selectable = self.table
 
         self._hints = self._hints.union({(selectable, dialect_name): text})
+        return self
+
+
+SelfValuesBase = typing.TypeVar("SelfValuesBase", bound="ValuesBase")
 
 
 class ValuesBase(UpdateBase):
@@ -506,7 +518,7 @@ class ValuesBase(UpdateBase):
             "values present",
         },
     )
-    def values(self, *args, **kwargs):
+    def values(self: SelfValuesBase, *args, **kwargs) -> SelfValuesBase:
         r"""Specify a fixed VALUES clause for an INSERT statement, or the SET
         clause for an UPDATE.
 
@@ -643,7 +655,7 @@ class ValuesBase(UpdateBase):
 
                 if arg and isinstance(arg[0], (list, dict, tuple)):
                     self._multi_values += (arg,)
-                    return
+                    return self
 
                 # tuple values
                 arg = {c.key: value for c, value in zip(self.table.c, arg)}
@@ -681,6 +693,7 @@ class ValuesBase(UpdateBase):
                 self._values = self._values.union(arg)
             else:
                 self._values = util.immutabledict(arg)
+        return self
 
     @_generative
     @_exclusive_against(
@@ -690,7 +703,7 @@ class ValuesBase(UpdateBase):
         },
         defaults={"_returning": _returning},
     )
-    def return_defaults(self, *cols):
+    def return_defaults(self: SelfValuesBase, *cols) -> SelfValuesBase:
         """Make use of a :term:`RETURNING` clause for the purpose
         of fetching server-side expressions and defaults.
 
@@ -776,6 +789,10 @@ class ValuesBase(UpdateBase):
         """
         self._return_defaults = True
         self._return_defaults_columns = cols
+        return self
+
+
+SelfInsert = typing.TypeVar("SelfInsert", bound="Insert")
 
 
 class Insert(ValuesBase):
@@ -918,7 +935,7 @@ class Insert(ValuesBase):
                 self._return_defaults_columns = return_defaults
 
     @_generative
-    def inline(self):
+    def inline(self: SelfInsert) -> SelfInsert:
         """Make this :class:`_expression.Insert` construct "inline" .
 
         When set, no attempt will be made to retrieve the
@@ -936,9 +953,12 @@ class Insert(ValuesBase):
 
         """
         self._inline = True
+        return self
 
     @_generative
-    def from_select(self, names, select, include_defaults=True):
+    def from_select(
+        self: SelfInsert, names, select, include_defaults=True
+    ) -> SelfInsert:
         """Return a new :class:`_expression.Insert` construct which represents
         an ``INSERT...FROM SELECT`` statement.
 
@@ -997,13 +1017,17 @@ class Insert(ValuesBase):
         self._inline = True
         self.include_insert_from_select_defaults = include_defaults
         self.select = coercions.expect(roles.DMLSelectRole, select)
+        return self
+
+
+SelfDMLWhereBase = typing.TypeVar("SelfDMLWhereBase", bound="DMLWhereBase")
 
 
 class DMLWhereBase:
     _where_criteria = ()
 
     @_generative
-    def where(self, *whereclause):
+    def where(self: SelfDMLWhereBase, *whereclause) -> SelfDMLWhereBase:
         """Return a new construct with the given expression(s) added to
         its WHERE clause, joined to the existing clause via AND, if any.
 
@@ -1037,8 +1061,9 @@ class DMLWhereBase:
         for criterion in whereclause:
             where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
             self._where_criteria += (where_criteria,)
+        return self
 
-    def filter(self, *criteria):
+    def filter(self: SelfDMLWhereBase, *criteria) -> SelfDMLWhereBase:
         """A synonym for the :meth:`_dml.DMLWhereBase.where` method.
 
         .. versionadded:: 1.4
@@ -1050,7 +1075,7 @@ class DMLWhereBase:
     def _filter_by_zero(self):
         return self.table
 
-    def filter_by(self, **kwargs):
+    def filter_by(self: SelfDMLWhereBase, **kwargs) -> SelfDMLWhereBase:
         r"""apply the given filtering criterion as a WHERE clause
         to this select.
 
@@ -1081,6 +1106,9 @@ class DMLWhereBase:
         )
 
 
+SelfUpdate = typing.TypeVar("SelfUpdate", bound="Update")
+
+
 class Update(DMLWhereBase, ValuesBase):
     """Represent an Update construct.
 
@@ -1261,7 +1289,7 @@ class Update(DMLWhereBase, ValuesBase):
         self._return_defaults = return_defaults
 
     @_generative
-    def ordered_values(self, *args):
+    def ordered_values(self: SelfUpdate, *args) -> SelfUpdate:
         """Specify the VALUES clause of this UPDATE statement with an explicit
         parameter ordering that will be maintained in the SET clause of the
         resulting UPDATE statement.
@@ -1295,9 +1323,10 @@ class Update(DMLWhereBase, ValuesBase):
 
         kv_generator = DMLState.get_plugin_class(self)._get_crud_kv_pairs
         self._ordered_values = kv_generator(self, args)
+        return self
 
     @_generative
-    def inline(self):
+    def inline(self: SelfUpdate) -> SelfUpdate:
         """Make this :class:`_expression.Update` construct "inline" .
 
         When set, SQL defaults present on :class:`_schema.Column`
@@ -1313,6 +1342,10 @@ class Update(DMLWhereBase, ValuesBase):
 
         """
         self._inline = True
+        return self
+
+
+SelfDelete = typing.TypeVar("SelfDelete", bound="Delete")
 
 
 class Delete(DMLWhereBase, UpdateBase):
index 37425345bd95d10219ebd7a1b8402933593bdf33..f6606e01d5ce0528a2a613267a481934e1ccadf8 100644 (file)
@@ -13,6 +13,7 @@
 import itertools
 import operator
 import re
+import typing
 
 from . import coercions
 from . import operators
@@ -1719,6 +1720,9 @@ class TypeClause(ClauseElement):
         self.type = type_
 
 
+SelfTextClause = typing.TypeVar("SelfTextClause", bound="TextClause")
+
+
 class TextClause(
     roles.DDLConstraintColumnRole,
     roles.DDLExpressionRole,
@@ -1875,7 +1879,9 @@ class TextClause(
         return TextClause(text)
 
     @_generative
-    def bindparams(self, *binds, **names_to_values):
+    def bindparams(
+        self: SelfTextClause, *binds, **names_to_values
+    ) -> SelfTextClause:
         """Establish the values and/or types of bound parameters within
         this :class:`_expression.TextClause` construct.
 
@@ -2000,6 +2006,7 @@ class TextClause(
                 ) from err
             else:
                 new_params[key] = existing._with_value(value, required=False)
+        return self
 
     @util.preload_module("sqlalchemy.sql.selectable")
     def columns(self, *cols, **types):
index 407f1dd33f086cd7296d332d0f63fa68bdebc6e2..89b7c659672bb7f11f6ae950af2c934856850f87 100644 (file)
@@ -88,6 +88,8 @@ __all__ = [
 ]
 
 
+from typing import Callable
+
 from .base import _from_objects
 from .base import _select_iterables
 from .base import ColumnCollection
@@ -175,10 +177,8 @@ from .traversals import CacheKey
 from .visitors import Visitable
 from ..util.langhelpers import public_factory
 
-# factory functions - these pull class-bound constructors and classmethods
-# from SQL elements and selectables into public functions.  This allows
-# the functions to be available in the sqlalchemy.sql.* namespace and
-# to be auto-cross-documenting from the function to the class itself.
+# TODO: proposal is to remove public_factory and replace with traditional
+# functions exported here.
 
 all_ = public_factory(CollectionAggregate._create_all, ".sql.expression.all_")
 any_ = public_factory(CollectionAggregate._create_any, ".sql.expression.any_")
index 802576b89959bbea96fc4ab42eebd8e4683b75f8..8b35dc6ace7ad76de8f447533f7716a2e5eed54a 100644 (file)
@@ -14,6 +14,9 @@ SQL tables and derived rowsets.
 import collections
 import itertools
 from operator import attrgetter
+import typing
+from typing import Type
+from typing import Union
 
 from . import coercions
 from . import operators
@@ -209,6 +212,9 @@ class Selectable(ReturnsRows):
         )
 
 
+SelfHasPrefixes = typing.TypeVar("SelfHasPrefixes", bound="HasPrefixes")
+
+
 class HasPrefixes:
     _prefixes = ()
 
@@ -222,7 +228,7 @@ class HasPrefixes:
         ":meth:`_expression.HasPrefixes.prefix_with`",
         ":paramref:`.HasPrefixes.prefix_with.*expr`",
     )
-    def prefix_with(self, *expr, **kw):
+    def prefix_with(self: SelfHasPrefixes, *expr, **kw) -> SelfHasPrefixes:
         r"""Add one or more expressions following the statement keyword, i.e.
         SELECT, INSERT, UPDATE, or DELETE. Generative.
 
@@ -255,6 +261,7 @@ class HasPrefixes:
                 "Unsupported argument(s): %s" % ",".join(kw)
             )
         self._setup_prefixes(expr, dialect)
+        return self
 
     def _setup_prefixes(self, prefixes, dialect=None):
         self._prefixes = self._prefixes + tuple(
@@ -265,6 +272,9 @@ class HasPrefixes:
         )
 
 
+SelfHasSuffixes = typing.TypeVar("SelfHasSuffixes", bound="HasSuffixes")
+
+
 class HasSuffixes:
     _suffixes = ()
 
@@ -278,7 +288,7 @@ class HasSuffixes:
         ":meth:`_expression.HasSuffixes.suffix_with`",
         ":paramref:`.HasSuffixes.suffix_with.*expr`",
     )
-    def suffix_with(self, *expr, **kw):
+    def suffix_with(self: SelfHasSuffixes, *expr, **kw) -> SelfHasSuffixes:
         r"""Add one or more expressions following the statement as a whole.
 
         This is used to support backend-specific suffix keywords on
@@ -306,6 +316,7 @@ class HasSuffixes:
                 "Unsupported argument(s): %s" % ",".join(kw)
             )
         self._setup_suffixes(expr, dialect)
+        return self
 
     def _setup_suffixes(self, suffixes, dialect=None):
         self._suffixes = self._suffixes + tuple(
@@ -316,6 +327,9 @@ class HasSuffixes:
         )
 
 
+SelfHasHints = typing.TypeVar("SelfHasHints", bound="HasHints")
+
+
 class HasHints:
     _hints = util.immutabledict()
     _statement_hints = ()
@@ -352,7 +366,9 @@ class HasHints:
         return self.with_hint(None, text, dialect_name)
 
     @_generative
-    def with_hint(self, selectable, text, dialect_name="*"):
+    def with_hint(
+        self: SelfHasHints, selectable, text, dialect_name="*"
+    ) -> SelfHasHints:
         r"""Add an indexing or other executional context hint for the given
         selectable to this :class:`_expression.Select` or other selectable
         object.
@@ -398,6 +414,7 @@ class HasHints:
                     ): text
                 }
             )
+        return self
 
 
 class FromClause(roles.AnonymizedFromClauseRole, Selectable):
@@ -2082,6 +2099,9 @@ class CTE(
         return self._restates if self._restates is not None else self
 
 
+SelfHasCTE = typing.TypeVar("SelfHasCTE", bound="HasCTE")
+
+
 class HasCTE(roles.HasCTERole):
     """Mixin that declares a class to include CTE support.
 
@@ -2096,7 +2116,7 @@ class HasCTE(roles.HasCTERole):
     _independent_ctes = ()
 
     @_generative
-    def add_cte(self, cte):
+    def add_cte(self: SelfHasCTE, cte) -> SelfHasCTE:
         """Add a :class:`_sql.CTE` to this statement object that will be
         independently rendered even if not referenced in the statement
         otherwise.
@@ -2161,6 +2181,7 @@ class HasCTE(roles.HasCTERole):
         """
         cte = coercions.expect(roles.IsCTERole, cte)
         self._independent_ctes += (cte,)
+        return self
 
     def cte(self, name=None, recursive=False, nesting=False):
         r"""Return a new :class:`_expression.CTE`,
@@ -2759,6 +2780,9 @@ class ForUpdateArg(ClauseElement):
             self.of = None
 
 
+SelfValues = typing.TypeVar("SelfValues", bound="Values")
+
+
 class Values(Generative, FromClause):
     """Represent a ``VALUES`` construct that can be used as a FROM element
     in a statement.
@@ -2829,7 +2853,7 @@ class Values(Generative, FromClause):
         return [col.type for col in self._column_args]
 
     @_generative
-    def alias(self, name, **kw):
+    def alias(self: SelfValues, name, **kw) -> SelfValues:
         """Return a new :class:`_expression.Values`
         construct that is a copy of this
         one with the given name.
@@ -2846,9 +2870,10 @@ class Values(Generative, FromClause):
         """
         self.name = name
         self.named_with_column = self.name is not None
+        return self
 
     @_generative
-    def lateral(self, name=None):
+    def lateral(self: SelfValues, name=None) -> SelfValues:
         """Return a new :class:`_expression.Values` with the lateral flag set,
         so that
         it renders as LATERAL.
@@ -2861,9 +2886,10 @@ class Values(Generative, FromClause):
         self._is_lateral = True
         if name is not None:
             self.name = name
+        return self
 
     @_generative
-    def data(self, values):
+    def data(self: SelfValues, values) -> SelfValues:
         """Return a new :class:`_expression.Values` construct,
         adding the given data
         to the data list.
@@ -2879,6 +2905,7 @@ class Values(Generative, FromClause):
         """
 
         self._data += (values,)
+        return self
 
     def _populate_column_collection(self):
         for c in self._column_args:
@@ -3312,6 +3339,11 @@ class DeprecatedSelectBaseGenerations:
         self.group_by.non_generative(self, *clauses)
 
 
+SelfGenerativeSelect = typing.TypeVar(
+    "SelfGenerativeSelect", bound="GenerativeSelect"
+)
+
+
 class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
     """Base class for SELECT statements where additional elements can be
     added.
@@ -3371,13 +3403,13 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
 
     @_generative
     def with_for_update(
-        self,
+        self: SelfGenerativeSelect,
         nowait=False,
         read=False,
         of=None,
         skip_locked=False,
         key_share=False,
-    ):
+    ) -> SelfGenerativeSelect:
         """Specify a ``FOR UPDATE`` clause for this
         :class:`_expression.GenerativeSelect`.
 
@@ -3430,6 +3462,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
             skip_locked=skip_locked,
             key_share=key_share,
         )
+        return self
 
     def get_label_style(self):
         """
@@ -3573,7 +3606,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
         )
 
     @_generative
-    def limit(self, limit):
+    def limit(self: SelfGenerativeSelect, limit) -> SelfGenerativeSelect:
         """Return a new selectable with the given LIMIT criterion
         applied.
 
@@ -3603,9 +3636,12 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
 
         self._fetch_clause = self._fetch_clause_options = None
         self._limit_clause = self._offset_or_limit_clause(limit)
+        return self
 
     @_generative
-    def fetch(self, count, with_ties=False, percent=False):
+    def fetch(
+        self: SelfGenerativeSelect, count, with_ties=False, percent=False
+    ) -> SelfGenerativeSelect:
         """Return a new selectable with the given FETCH FIRST criterion
         applied.
 
@@ -3653,9 +3689,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
                 "with_ties": with_ties,
                 "percent": percent,
             }
+        return self
 
     @_generative
-    def offset(self, offset):
+    def offset(self: SelfGenerativeSelect, offset) -> SelfGenerativeSelect:
         """Return a new selectable with the given OFFSET criterion
         applied.
 
@@ -3681,10 +3718,11 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
         """
 
         self._offset_clause = self._offset_or_limit_clause(offset)
+        return self
 
     @_generative
     @util.preload_module("sqlalchemy.sql.util")
-    def slice(self, start, stop):
+    def slice(self: SelfGenerativeSelect, start, stop) -> SelfGenerativeSelect:
         """Apply LIMIT / OFFSET to this statement based on a slice.
 
         The start and stop indices behave like the argument to Python's
@@ -3728,9 +3766,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
         self._limit_clause, self._offset_clause = sql_util._make_slice(
             self._limit_clause, self._offset_clause, start, stop
         )
+        return self
 
     @_generative
-    def order_by(self, *clauses):
+    def order_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of ORDER BY
         criteria applied.
 
@@ -3764,9 +3803,10 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
                 coercions.expect(roles.OrderByRole, clause)
                 for clause in clauses
             )
+        return self
 
     @_generative
-    def group_by(self, *clauses):
+    def group_by(self: SelfGenerativeSelect, *clauses) -> SelfGenerativeSelect:
         r"""Return a new selectable with the given list of GROUP BY
         criterion applied.
 
@@ -3797,6 +3837,7 @@ class GenerativeSelect(DeprecatedSelectBaseGenerations, SelectBase):
                 coercions.expect(roles.GroupByRole, clause)
                 for clause in clauses
             )
+        return self
 
 
 @CompileState.plugin_for("default", "compound_select")
@@ -4658,6 +4699,10 @@ class _MemoizedSelectEntities(
             ) = select_stmt._with_options = ()
 
 
+# TODO: use pep-673 when feasible
+SelfSelect = typing.TypeVar("SelfSelect", bound="Select")
+
+
 class Select(
     HasPrefixes,
     HasSuffixes,
@@ -4737,7 +4782,9 @@ class Select(
     ]
 
     @classmethod
-    def _create(cls, *entities) -> "Select":
+    def _create(
+        cls, *entities: Union[roles.ColumnsClauseRole, Type]
+    ) -> "Select":
         r"""Construct a new :class:`_expression.Select`.
 
 
@@ -4788,7 +4835,7 @@ class Select(
         return self
 
     @classmethod
-    def _create_raw_select(cls, **kw):
+    def _create_raw_select(cls, **kw) -> "Select":
         """Create a :class:`.Select` using raw ``__new__`` with no coercions.
 
         Used internally to build up :class:`.Select` constructs with
@@ -4873,7 +4920,9 @@ class Select(
         return meth(self, statement)
 
     @_generative
-    def join(self, target, onclause=None, isouter=False, full=False):
+    def join(
+        self: SelfSelect, target, onclause=None, isouter=False, full=False
+    ) -> SelfSelect:
         r"""Create a SQL JOIN against this :class:`_expression.Select`
         object's criterion
         and apply generatively, returning the newly resulting
@@ -4939,6 +4988,7 @@ class Select(
         self._setup_joins += (
             (target, onclause, None, {"isouter": isouter, "full": full}),
         )
+        return self
 
     def outerjoin_from(self, from_, target, onclause=None, full=False):
         r"""Create a SQL LEFT OUTER JOIN against this :class:`_expression.Select`
@@ -4955,8 +5005,13 @@ class Select(
 
     @_generative
     def join_from(
-        self, from_, target, onclause=None, isouter=False, full=False
-    ):
+        self: SelfSelect,
+        from_,
+        target,
+        onclause=None,
+        isouter=False,
+        full=False,
+    ) -> SelfSelect:
         r"""Create a SQL JOIN against this :class:`_expression.Select`
         object's criterion
         and apply generatively, returning the newly resulting
@@ -5014,6 +5069,7 @@ class Select(
         self._setup_joins += (
             (target, onclause, from_, {"isouter": isouter, "full": full}),
         )
+        return self
 
     def outerjoin(self, target, onclause=None, full=False):
         """Create a left outer join.
@@ -5211,7 +5267,7 @@ class Select(
         )
 
     @_generative
-    def add_columns(self, *columns):
+    def add_columns(self: SelfSelect, *columns) -> SelfSelect:
         """Return a new :func:`_expression.select` construct with
         the given column expressions added to its columns clause.
 
@@ -5233,6 +5289,7 @@ class Select(
             )
             for column in columns
         ]
+        return self
 
     def _set_entities(self, entities):
         self._raw_columns = [
@@ -5297,7 +5354,7 @@ class Select(
         )
 
     @_generative
-    def with_only_columns(self, *columns, **kw):
+    def with_only_columns(self: SelfSelect, *columns, **kw) -> SelfSelect:
         r"""Return a new :func:`_expression.select` construct with its columns
         clause replaced with the given columns.
 
@@ -5372,6 +5429,7 @@ class Select(
                 "columns", "Select.with_only_columns", columns
             )
         ]
+        return self
 
     @property
     def whereclause(self):
@@ -5393,7 +5451,7 @@ class Select(
     _whereclause = whereclause
 
     @_generative
-    def where(self, *whereclause):
+    def where(self: SelfSelect, *whereclause) -> SelfSelect:
         """Return a new :func:`_expression.select` construct with
         the given expression added to
         its WHERE clause, joined to the existing clause via AND, if any.
@@ -5405,9 +5463,10 @@ class Select(
         for criterion in whereclause:
             where_criteria = coercions.expect(roles.WhereHavingRole, criterion)
             self._where_criteria += (where_criteria,)
+        return self
 
     @_generative
-    def having(self, having):
+    def having(self: SelfSelect, having) -> SelfSelect:
         """Return a new :func:`_expression.select` construct with
         the given expression added to
         its HAVING clause, joined to the existing clause via AND, if any.
@@ -5416,9 +5475,10 @@ class Select(
         self._having_criteria += (
             coercions.expect(roles.WhereHavingRole, having),
         )
+        return self
 
     @_generative
-    def distinct(self, *expr):
+    def distinct(self: SelfSelect, *expr) -> SelfSelect:
         r"""Return a new :func:`_expression.select` construct which
         will apply DISTINCT to its columns clause.
 
@@ -5437,9 +5497,10 @@ class Select(
             )
         else:
             self._distinct = True
+        return self
 
     @_generative
-    def select_from(self, *froms):
+    def select_from(self: SelfSelect, *froms) -> SelfSelect:
         r"""Return a new :func:`_expression.select` construct with the
         given FROM expression(s)
         merged into its list of FROM objects.
@@ -5480,9 +5541,10 @@ class Select(
             )
             for fromclause in froms
         )
+        return self
 
     @_generative
-    def correlate(self, *fromclauses):
+    def correlate(self: SelfSelect, *fromclauses) -> SelfSelect:
         r"""Return a new :class:`_expression.Select`
         which will correlate the given FROM
         clauses to that of an enclosing :class:`_expression.Select`.
@@ -5541,9 +5603,10 @@ class Select(
             self._correlate = self._correlate + tuple(
                 coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
+        return self
 
     @_generative
-    def correlate_except(self, *fromclauses):
+    def correlate_except(self: SelfSelect, *fromclauses) -> SelfSelect:
         r"""Return a new :class:`_expression.Select`
         which will omit the given FROM
         clauses from the auto-correlation process.
@@ -5579,6 +5642,7 @@ class Select(
             self._correlate_except = (self._correlate_except or ()) + tuple(
                 coercions.expect(roles.FromClauseRole, f) for f in fromclauses
             )
+        return self
 
     @HasMemoized.memoized_attribute
     def selected_columns(self):
@@ -5959,6 +6023,9 @@ class Select(
         return CompoundSelect._create_intersect_all(self, *other, **kwargs)
 
 
+SelfScalarSelect = typing.TypeVar("SelfScalarSelect", bound="ScalarSelect")
+
+
 class ScalarSelect(roles.InElementRole, Generative, Grouping):
     """Represent a scalar subquery.
 
@@ -5998,18 +6065,19 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
     c = columns
 
     @_generative
-    def where(self, crit):
+    def where(self: SelfScalarSelect, crit) -> SelfScalarSelect:
         """Apply a WHERE clause to the SELECT statement referred to
         by this :class:`_expression.ScalarSelect`.
 
         """
         self.element = self.element.where(crit)
+        return self
 
     def self_group(self, **kwargs):
         return self
 
     @_generative
-    def correlate(self, *fromclauses):
+    def correlate(self: SelfScalarSelect, *fromclauses) -> SelfScalarSelect:
         r"""Return a new :class:`_expression.ScalarSelect`
         which will correlate the given FROM
         clauses to that of an enclosing :class:`_expression.Select`.
@@ -6039,9 +6107,12 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
 
         """
         self.element = self.element.correlate(*fromclauses)
+        return self
 
     @_generative
-    def correlate_except(self, *fromclauses):
+    def correlate_except(
+        self: SelfScalarSelect, *fromclauses
+    ) -> SelfScalarSelect:
         r"""Return a new :class:`_expression.ScalarSelect`
         which will omit the given FROM
         clauses from the auto-correlation process.
@@ -6073,6 +6144,7 @@ class ScalarSelect(roles.InElementRole, Generative, Grouping):
         """
 
         self.element = self.element.correlate_except(*fromclauses)
+        return self
 
 
 class Exists(UnaryExpression):
@@ -6228,6 +6300,9 @@ class Exists(UnaryExpression):
         return e
 
 
+SelfTextualSelect = typing.TypeVar("SelfTextualSelect", bound="TextualSelect")
+
+
 class TextualSelect(SelectBase):
     """Wrap a :class:`_expression.TextClause` construct within a
     :class:`_expression.SelectBase`
@@ -6315,8 +6390,11 @@ class TextualSelect(SelectBase):
         return self
 
     @_generative
-    def bindparams(self, *binds, **bind_as_values):
+    def bindparams(
+        self: SelfTextualSelect, *binds, **bind_as_values
+    ) -> SelfTextualSelect:
         self.element = self.element.bindparams(*binds, **bind_as_values)
+        return self
 
     def _generate_fromclause_column_proxies(self, fromclause):
         fromclause._columns._populate_separate_keys(
index 8e8a20ff57ab78d4f772d46e2a430598eaaf8d74..ef78b181a7cc9c41c245abecda21ab074ee61c61 100644 (file)
@@ -13,8 +13,11 @@ import inspect
 import operator
 import platform
 import sys
+import typing
+
 
 py311 = sys.version_info >= (3, 11)
+py310 = sys.version_info >= (3, 10)
 py39 = sys.version_info >= (3, 9)
 py38 = sys.version_info >= (3, 8)
 pypy = platform.python_implementation() == "PyPy"
@@ -137,11 +140,13 @@ def _formatannotation(annotation, base_module=None):
     """vendored from python 3.7"""
 
     if getattr(annotation, "__module__", None) == "typing":
-        return repr(annotation).replace("typing.", "")
+        return f'"{repr(annotation).replace("typing.", "")}"'
     if isinstance(annotation, type):
         if annotation.__module__ in ("builtins", base_module):
             return repr(annotation.__qualname__)
         return annotation.__module__ + "." + annotation.__qualname__
+    elif isinstance(annotation, typing.TypeVar):
+        return f'"{annotation}"'
     return repr(annotation)
 
 
index ca64296c1a8260918d7e5041e2dadf17a25a24c6..93caa0ee580a5dafb5161c7c858f0a4ba078c3dd 100644 (file)
@@ -20,6 +20,7 @@ import re
 import sys
 import textwrap
 import types
+import typing
 from typing import Any
 from typing import Callable
 from typing import Generic
@@ -31,10 +32,10 @@ import warnings
 
 from . import _collections
 from . import compat
+from . import typing as compat_typing
 from .. import exc
 
 _T = TypeVar("_T")
-_MP = TypeVar("_MP", bound="memoized_property[Any]")
 
 
 def md5_hex(x):
@@ -166,7 +167,13 @@ def map_bits(fn, n):
         n ^= b
 
 
-def decorator(target):
+_Fn = typing.TypeVar("_Fn", bound=typing.Callable)
+_Args = compat_typing.ParamSpec("_Args")
+
+
+def decorator(
+    target: typing.Callable[compat_typing.Concatenate[_Fn, _Args], typing.Any]
+) -> _Fn:
     """A signature-matching decorator factory."""
 
     def decorate(fn):
@@ -198,7 +205,7 @@ def %(name)s%(grouped_args)s:
         decorated.__wrapped__ = fn
         return update_wrapper(decorated, fn)
 
-    return update_wrapper(decorate, target)
+    return typing.cast(_Fn, update_wrapper(decorate, target))
 
 
 def _update_argspec_defaults_into_env(spec, env):
@@ -227,7 +234,16 @@ def _exec_code_in_env(code, env, fn_name):
     return env[fn_name]
 
 
-def public_factory(target, location, class_location=None):
+_TE = TypeVar("_TE")
+
+_P = compat_typing.ParamSpec("_P")
+
+
+def public_factory(
+    target: typing.Callable[_P, _TE],
+    location: str,
+    class_location: Optional[str] = None,
+) -> typing.Callable[_P, _TE]:
     """Produce a wrapping function for the given cls or classmethod.
 
     Rationale here is so that the __init__ method of the
@@ -273,6 +289,7 @@ def %(name)s%(grouped_args)s:
         "__name__": callable_.__module__,
     }
     exec(code, env)
+
     decorated = env[location_name]
 
     if hasattr(fn, "_linked_to"):
@@ -1077,6 +1094,11 @@ def as_interface(obj, cls=None, methods=None, required=None):
     )
 
 
+Selfmemoized_property = TypeVar(
+    "Selfmemoized_property", bound="memoized_property[Any]"
+)
+
+
 class memoized_property(Generic[_T]):
     """A read-only @property that is only evaluated once."""
 
@@ -1090,14 +1112,18 @@ class memoized_property(Generic[_T]):
         self.__name__ = fget.__name__
 
     @overload
-    def __get__(self: _MP, obj: None, cls: Any) -> _MP:
+    def __get__(
+        self: Selfmemoized_property, obj: None, cls: Any
+    ) -> Selfmemoized_property:
         ...
 
     @overload
     def __get__(self, obj: Any, cls: Any) -> _T:
         ...
 
-    def __get__(self: _MP, obj: Any, cls: Any) -> Union[_MP, _T]:
+    def __get__(
+        self: Selfmemoized_property, obj: Any, cls: Any
+    ) -> Union[Selfmemoized_property, _T]:
         if obj is None:
             return self
         obj.__dict__[self.__name__] = result = self.fget(obj)
index 801c4a110403ae871b4a443a0d0c405bac55da7d..e735ce531d4b7f3be20e3a57de5bb4dacca30cf1 100644 (file)
@@ -1,4 +1,5 @@
 from typing import Any
+from typing import Callable  # noqa
 from typing import Generic
 from typing import overload
 from typing import Type
@@ -15,6 +16,12 @@ else:
     from typing_extensions import Protocol  # noqa
     from typing_extensions import TypedDict  # noqa
 
+if compat.py310:
+    from typing import Concatenate
+    from typing import ParamSpec
+else:
+    from typing_extensions import Concatenate  # noqa
+    from typing_extensions import ParamSpec  # noqa
 
 if compat.py311:
     from typing import NotRequired  # noqa
index 80582f604e4b27fa4a27af38834ba61ce0a7aa3a..128f5285d5fcee0649a8d11cfe46b70869df9556 100644 (file)
--- a/setup.cfg
+++ b/setup.cfg
@@ -29,6 +29,7 @@ project_urls =
 
 [options]
 packages = find:
+include_package_data = True
 python_requires = >=3.7
 package_dir =
     =lib