From: Mehdi Gmira Date: Wed, 28 Jun 2023 13:52:39 +0000 (-0400) Subject: Type annotate postgresql/sqlite/mysql insert X-Git-Tag: rel_2_0_18~2^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=79998e531120c2a14bb69f48101ddcc61bc1a3ab;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Type annotate postgresql/sqlite/mysql insert ### Description The goal is to annotate postgresql specific apis that are under postgresql/dml.py file. I've looked around to see what types are used for similar apis, hope I got it right :) ### Checklist This pull request is: - [x] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #10021 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10021 Pull-request-sha: 0562f093974520b162de31e8473a4d4d6656d529 Change-Id: I142f8929505c0263fcf45072d888df7ae81e6e85 --- diff --git a/lib/sqlalchemy/dialects/_typing.py b/lib/sqlalchemy/dialects/_typing.py new file mode 100644 index 0000000000..932742bd04 --- /dev/null +++ b/lib/sqlalchemy/dialects/_typing.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Union + +from ..sql._typing import _DDLColumnArgument +from ..sql.elements import DQLDMLClauseElement +from ..sql.schema import ColumnCollectionConstraint +from ..sql.schema import Index + + +_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None] +_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]] +_OnConflictIndexWhereT = Optional[DQLDMLClauseElement] +_OnConflictSetT = Optional[Mapping[Any, Any]] +_OnConflictWhereT = Union[DQLDMLClauseElement, str, None] diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 7c724c6f12..dfa39f6e08 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -1,26 +1,37 @@ +# mysql/dml.py # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations +from typing import Any +from typing import List +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Union from ... import exc from ... import util +from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement from ...sql.expression import alias +from ...sql.selectable import NamedFromClause from ...util.typing import Self __all__ = ("Insert", "insert") -def insert(table): +def insert(table: _DMLTableArgument) -> Insert: """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert` construct. @@ -55,7 +66,9 @@ class Insert(StandardInsert): inherit_cache = False @property - def inserted(self): + def inserted( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE statement @@ -87,7 +100,7 @@ class Insert(StandardInsert): return self.inserted_alias.columns @util.memoized_property - def inserted_alias(self): + def inserted_alias(self) -> NamedFromClause: return alias(self.table, name="inserted") @_generative @@ -98,7 +111,7 @@ class Insert(StandardInsert): "has an ON DUPLICATE KEY clause present" }, ) - def on_duplicate_key_update(self, *args, **kw) -> Self: + def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self: r""" Specifies the ON DUPLICATE KEY UPDATE clause. @@ -157,19 +170,22 @@ class Insert(StandardInsert): else: values = kw - inserted_alias = getattr(self, "inserted_alias", None) - self._post_values_clause = OnDuplicateClause(inserted_alias, values) + self._post_values_clause = OnDuplicateClause( + self.inserted_alias, values + ) return self class OnDuplicateClause(ClauseElement): __visit_name__ = "on_duplicate_key_update" - _parameter_ordering = None + _parameter_ordering: Optional[List[str]] = None stringify_dialect = "mysql" - def __init__(self, inserted_alias, update): + def __init__( + self, inserted_alias: NamedFromClause, update: _UpdateArg + ) -> None: self.inserted_alias = inserted_alias # auto-detect that parameters should be ordered. This is copied from @@ -196,3 +212,8 @@ class OnDuplicateClause(ClauseElement): "of a Table object" ) self.update = update + + +_UpdateArg = Union[ + Mapping[Any, Any], List[Tuple[str, Any]], ColumnCollection[Any, Any] +] diff --git a/lib/sqlalchemy/dialects/mysql/mariadb.py b/lib/sqlalchemy/dialects/mysql/mariadb.py index 05190dff41..a6ee5dfac9 100644 --- a/lib/sqlalchemy/dialects/mysql/mariadb.py +++ b/lib/sqlalchemy/dialects/mysql/mariadb.py @@ -1,5 +1,10 @@ +# mysql/mariadb.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors - from .base import MariaDBIdentifierPreparer from .base import MySQLDialect diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 829237bfe4..dee7af3311 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -1,21 +1,32 @@ -# postgresql/on_conflict.py +# postgresql/dml.py # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations + +from typing import Any +from typing import Optional from . import ext +from .._typing import _OnConflictConstraintT +from .._typing import _OnConflictIndexElementsT +from .._typing import _OnConflictIndexWhereT +from .._typing import _OnConflictSetT +from .._typing import _OnConflictWhereT from ... import util from ...sql import coercions from ...sql import roles from ...sql import schema +from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement from ...sql.expression import alias from ...util.typing import Self @@ -23,7 +34,7 @@ from ...util.typing import Self __all__ = ("Insert", "insert") -def insert(table): +def insert(table: _DMLTableArgument) -> Insert: """Construct a PostgreSQL-specific variant :class:`_postgresql.Insert` construct. @@ -57,7 +68,9 @@ class Insert(StandardInsert): inherit_cache = False @util.memoized_property - def excluded(self): + def excluded( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """Provide the ``excluded`` namespace for an ON CONFLICT statement PG's ON CONFLICT clause allows reference to the row that would @@ -95,11 +108,11 @@ class Insert(StandardInsert): @_on_conflict_exclusive def on_conflict_do_update( self, - constraint=None, - index_elements=None, - index_where=None, - set_=None, - where=None, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, ) -> Self: r""" Specifies a DO UPDATE SET action for ON CONFLICT clause. @@ -161,9 +174,9 @@ class Insert(StandardInsert): @_on_conflict_exclusive def on_conflict_do_nothing( self, - constraint=None, - index_elements=None, - index_where=None, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, ) -> Self: """ Specifies a DO NOTHING action for ON CONFLICT clause. @@ -198,7 +211,16 @@ class Insert(StandardInsert): class OnConflictClause(ClauseElement): stringify_dialect = "postgresql" - def __init__(self, constraint=None, index_elements=None, index_where=None): + constraint_target: Optional[str] + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT + + def __init__( + self, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ): if constraint is not None: if not isinstance(constraint, str) and isinstance( constraint, @@ -249,11 +271,11 @@ class OnConflictDoUpdate(OnConflictClause): def __init__( self, - constraint=None, - index_elements=None, - index_where=None, - set_=None, - where=None, + constraint: _OnConflictConstraintT = None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, ): super().__init__( constraint=constraint, diff --git a/lib/sqlalchemy/dialects/sqlite/dml.py b/lib/sqlalchemy/dialects/sqlite/dml.py index 23066c7bee..ec428f5b17 100644 --- a/lib/sqlalchemy/dialects/sqlite/dml.py +++ b/lib/sqlalchemy/dialects/sqlite/dml.py @@ -1,27 +1,35 @@ +# sqlite/dml.py # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations +from typing import Any +from .._typing import _OnConflictIndexElementsT +from .._typing import _OnConflictIndexWhereT +from .._typing import _OnConflictSetT +from .._typing import _OnConflictWhereT from ... import util from ...sql import coercions from ...sql import roles +from ...sql._typing import _DMLTableArgument from ...sql.base import _exclusive_against from ...sql.base import _generative from ...sql.base import ColumnCollection +from ...sql.base import ReadOnlyColumnCollection from ...sql.dml import Insert as StandardInsert from ...sql.elements import ClauseElement +from ...sql.elements import KeyedColumnElement from ...sql.expression import alias from ...util.typing import Self - __all__ = ("Insert", "insert") -def insert(table): +def insert(table: _DMLTableArgument) -> Insert: """Construct a sqlite-specific variant :class:`_sqlite.Insert` construct. @@ -61,7 +69,9 @@ class Insert(StandardInsert): inherit_cache = False @util.memoized_property - def excluded(self): + def excluded( + self, + ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]: """Provide the ``excluded`` namespace for an ON CONFLICT statement SQLite's ON CONFLICT clause allows reference to the row that would @@ -94,10 +104,10 @@ class Insert(StandardInsert): @_on_conflict_exclusive def on_conflict_do_update( self, - index_elements=None, - index_where=None, - set_=None, - where=None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, ) -> Self: r""" Specifies a DO UPDATE SET action for ON CONFLICT clause. @@ -147,7 +157,9 @@ class Insert(StandardInsert): @_generative @_on_conflict_exclusive def on_conflict_do_nothing( - self, index_elements=None, index_where=None + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, ) -> Self: """ Specifies a DO NOTHING action for ON CONFLICT clause. @@ -172,7 +184,15 @@ class Insert(StandardInsert): class OnConflictClause(ClauseElement): stringify_dialect = "sqlite" - def __init__(self, index_elements=None, index_where=None): + constraint_target: None + inferred_target_elements: _OnConflictIndexElementsT + inferred_target_whereclause: _OnConflictIndexWhereT + + def __init__( + self, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + ): if index_elements is not None: self.constraint_target = None self.inferred_target_elements = index_elements @@ -192,10 +212,10 @@ class OnConflictDoUpdate(OnConflictClause): def __init__( self, - index_elements=None, - index_where=None, - set_=None, - where=None, + index_elements: _OnConflictIndexElementsT = None, + index_where: _OnConflictIndexWhereT = None, + set_: _OnConflictSetT = None, + where: _OnConflictWhereT = None, ): super().__init__( index_elements=index_elements, diff --git a/test/typing/plain_files/dialects/mysql/mysql_stuff.py b/test/typing/plain_files/dialects/mysql/mysql_stuff.py new file mode 100644 index 0000000000..3fcdc75a97 --- /dev/null +++ b/test/typing/plain_files/dialects/mysql/mysql_stuff.py @@ -0,0 +1,21 @@ +from sqlalchemy import Integer +from sqlalchemy.dialects.mysql import insert +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Test(Base): + __tablename__ = "test_table_json" + + id = mapped_column(Integer, primary_key=True) + data: Mapped[str] = mapped_column() + + +insert(Test).on_duplicate_key_update( + {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo" +).inserted.foo.desc() diff --git a/test/typing/plain_files/dialects/postgresql/pg_stuff.py b/test/typing/plain_files/dialects/postgresql/pg_stuff.py index c90bb67f0e..4567daa386 100644 --- a/test/typing/plain_files/dialects/postgresql/pg_stuff.py +++ b/test/typing/plain_files/dialects/postgresql/pg_stuff.py @@ -9,8 +9,10 @@ from sqlalchemy import Integer from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import Text +from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import array +from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import DeclarativeBase @@ -68,3 +70,10 @@ reveal_type(t1.data) # EXPECTED_TYPE: UUID reveal_type(t1.ident) + +unique = UniqueConstraint(name="my_constraint") +insert(Test).on_conflict_do_nothing( + "foo", [Test.id], Test.id > 0 +).on_conflict_do_update( + unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22 +).excluded.foo.desc() diff --git a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py new file mode 100644 index 0000000000..00debda509 --- /dev/null +++ b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py @@ -0,0 +1,23 @@ +from sqlalchemy import Integer +from sqlalchemy import UniqueConstraint +from sqlalchemy.dialects.sqlite import insert +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column + + +class Base(DeclarativeBase): + pass + + +class Test(Base): + __tablename__ = "test_table_json" + + id = mapped_column(Integer, primary_key=True) + data: Mapped[str] = mapped_column() + + +unique = UniqueConstraint(name="my_constraint") +insert(Test).on_conflict_do_nothing("foo", Test.id > 0).on_conflict_do_update( + unique, Test.id > 0, {"id": 42, Test.data: 99}, Test.id == 22 +).excluded.foo.desc()