]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Type annotate postgresql/sqlite/mysql insert
authorMehdi Gmira <mgmira@wiremind.io>
Wed, 28 Jun 2023 13:52:39 +0000 (09:52 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Thu, 29 Jun 2023 22:10:08 +0000 (00:10 +0200)
### Description
The goal is to annotate postgresql specific apis that are under postgresql/dml.py file.
I've looked around to see what types are used for similar apis, hope I got it right :)

### Checklist

This pull request is:

- [x] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #10021
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10021
Pull-request-sha: 0562f093974520b162de31e8473a4d4d6656d529

Change-Id: I142f8929505c0263fcf45072d888df7ae81e6e85

lib/sqlalchemy/dialects/_typing.py [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/dml.py
lib/sqlalchemy/dialects/mysql/mariadb.py
lib/sqlalchemy/dialects/postgresql/dml.py
lib/sqlalchemy/dialects/sqlite/dml.py
test/typing/plain_files/dialects/mysql/mysql_stuff.py [new file with mode: 0644]
test/typing/plain_files/dialects/postgresql/pg_stuff.py
test/typing/plain_files/dialects/sqlite/sqlite_stuff.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/dialects/_typing.py b/lib/sqlalchemy/dialects/_typing.py
new file mode 100644 (file)
index 0000000..932742b
--- /dev/null
@@ -0,0 +1,19 @@
+from __future__ import annotations
+
+from typing import Any
+from typing import Iterable
+from typing import Mapping
+from typing import Optional
+from typing import Union
+
+from ..sql._typing import _DDLColumnArgument
+from ..sql.elements import DQLDMLClauseElement
+from ..sql.schema import ColumnCollectionConstraint
+from ..sql.schema import Index
+
+
+_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
+_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
+_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
+_OnConflictSetT = Optional[Mapping[Any, Any]]
+_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]
index 7c724c6f12fae6e8d0f854d670d5e93dd3e78445..dfa39f6e086cbff05a14e983b3477701f7a39862 100644 (file)
@@ -1,26 +1,37 @@
+# mysql/dml.py
 # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
 
+from typing import Any
+from typing import List
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
+from typing import Union
 
 from ... import exc
 from ... import util
+from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.base import ColumnCollection
+from ...sql.base import ReadOnlyColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
+from ...sql.elements import KeyedColumnElement
 from ...sql.expression import alias
+from ...sql.selectable import NamedFromClause
 from ...util.typing import Self
 
 
 __all__ = ("Insert", "insert")
 
 
-def insert(table):
+def insert(table: _DMLTableArgument) -> Insert:
     """Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert`
     construct.
 
@@ -55,7 +66,9 @@ class Insert(StandardInsert):
     inherit_cache = False
 
     @property
-    def inserted(self):
+    def inserted(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
         statement
 
@@ -87,7 +100,7 @@ class Insert(StandardInsert):
         return self.inserted_alias.columns
 
     @util.memoized_property
-    def inserted_alias(self):
+    def inserted_alias(self) -> NamedFromClause:
         return alias(self.table, name="inserted")
 
     @_generative
@@ -98,7 +111,7 @@ class Insert(StandardInsert):
             "has an ON DUPLICATE KEY clause present"
         },
     )
-    def on_duplicate_key_update(self, *args, **kw) -> Self:
+    def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self:
         r"""
         Specifies the ON DUPLICATE KEY UPDATE clause.
 
@@ -157,19 +170,22 @@ class Insert(StandardInsert):
         else:
             values = kw
 
-        inserted_alias = getattr(self, "inserted_alias", None)
-        self._post_values_clause = OnDuplicateClause(inserted_alias, values)
+        self._post_values_clause = OnDuplicateClause(
+            self.inserted_alias, values
+        )
         return self
 
 
 class OnDuplicateClause(ClauseElement):
     __visit_name__ = "on_duplicate_key_update"
 
-    _parameter_ordering = None
+    _parameter_ordering: Optional[List[str]] = None
 
     stringify_dialect = "mysql"
 
-    def __init__(self, inserted_alias, update):
+    def __init__(
+        self, inserted_alias: NamedFromClause, update: _UpdateArg
+    ) -> None:
         self.inserted_alias = inserted_alias
 
         # auto-detect that parameters should be ordered.   This is copied from
@@ -196,3 +212,8 @@ class OnDuplicateClause(ClauseElement):
                 "of a Table object"
             )
         self.update = update
+
+
+_UpdateArg = Union[
+    Mapping[Any, Any], List[Tuple[str, Any]], ColumnCollection[Any, Any]
+]
index 05190dff4192cdb9fce0f80609ef3967108e1aad..a6ee5dfac93f7d0868f51137d4bd42e610d3b773 100644 (file)
@@ -1,5 +1,10 @@
+# mysql/mariadb.py
+# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
 # mypy: ignore-errors
-
 from .base import MariaDBIdentifierPreparer
 from .base import MySQLDialect
 
index 829237bfe42ee4ef62666b7bd717e80050bb868b..dee7af3311e24e18d347da1633a48872bbffe17c 100644 (file)
@@ -1,21 +1,32 @@
-# postgresql/on_conflict.py
+# postgresql/dml.py
 # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import Optional
 
 from . import ext
+from .._typing import _OnConflictConstraintT
+from .._typing import _OnConflictIndexElementsT
+from .._typing import _OnConflictIndexWhereT
+from .._typing import _OnConflictSetT
+from .._typing import _OnConflictWhereT
 from ... import util
 from ...sql import coercions
 from ...sql import roles
 from ...sql import schema
+from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.base import ColumnCollection
+from ...sql.base import ReadOnlyColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
+from ...sql.elements import KeyedColumnElement
 from ...sql.expression import alias
 from ...util.typing import Self
 
@@ -23,7 +34,7 @@ from ...util.typing import Self
 __all__ = ("Insert", "insert")
 
 
-def insert(table):
+def insert(table: _DMLTableArgument) -> Insert:
     """Construct a PostgreSQL-specific variant :class:`_postgresql.Insert`
     construct.
 
@@ -57,7 +68,9 @@ class Insert(StandardInsert):
     inherit_cache = False
 
     @util.memoized_property
-    def excluded(self):
+    def excluded(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """Provide the ``excluded`` namespace for an ON CONFLICT statement
 
         PG's ON CONFLICT clause allows reference to the row that would
@@ -95,11 +108,11 @@ class Insert(StandardInsert):
     @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
-        constraint=None,
-        index_elements=None,
-        index_where=None,
-        set_=None,
-        where=None,
+        constraint: _OnConflictConstraintT = None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+        set_: _OnConflictSetT = None,
+        where: _OnConflictWhereT = None,
     ) -> Self:
         r"""
         Specifies a DO UPDATE SET action for ON CONFLICT clause.
@@ -161,9 +174,9 @@ class Insert(StandardInsert):
     @_on_conflict_exclusive
     def on_conflict_do_nothing(
         self,
-        constraint=None,
-        index_elements=None,
-        index_where=None,
+        constraint: _OnConflictConstraintT = None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
     ) -> Self:
         """
         Specifies a DO NOTHING action for ON CONFLICT clause.
@@ -198,7 +211,16 @@ class Insert(StandardInsert):
 class OnConflictClause(ClauseElement):
     stringify_dialect = "postgresql"
 
-    def __init__(self, constraint=None, index_elements=None, index_where=None):
+    constraint_target: Optional[str]
+    inferred_target_elements: _OnConflictIndexElementsT
+    inferred_target_whereclause: _OnConflictIndexWhereT
+
+    def __init__(
+        self,
+        constraint: _OnConflictConstraintT = None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+    ):
         if constraint is not None:
             if not isinstance(constraint, str) and isinstance(
                 constraint,
@@ -249,11 +271,11 @@ class OnConflictDoUpdate(OnConflictClause):
 
     def __init__(
         self,
-        constraint=None,
-        index_elements=None,
-        index_where=None,
-        set_=None,
-        where=None,
+        constraint: _OnConflictConstraintT = None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+        set_: _OnConflictSetT = None,
+        where: _OnConflictWhereT = None,
     ):
         super().__init__(
             constraint=constraint,
index 23066c7beecb564d1361a1c29686d3cf851b63ce..ec428f5b1722fd4c7886770162c1a1266c6930de 100644 (file)
@@ -1,27 +1,35 @@
+# sqlite/dml.py
 # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
 
+from typing import Any
 
+from .._typing import _OnConflictIndexElementsT
+from .._typing import _OnConflictIndexWhereT
+from .._typing import _OnConflictSetT
+from .._typing import _OnConflictWhereT
 from ... import util
 from ...sql import coercions
 from ...sql import roles
+from ...sql._typing import _DMLTableArgument
 from ...sql.base import _exclusive_against
 from ...sql.base import _generative
 from ...sql.base import ColumnCollection
+from ...sql.base import ReadOnlyColumnCollection
 from ...sql.dml import Insert as StandardInsert
 from ...sql.elements import ClauseElement
+from ...sql.elements import KeyedColumnElement
 from ...sql.expression import alias
 from ...util.typing import Self
 
-
 __all__ = ("Insert", "insert")
 
 
-def insert(table):
+def insert(table: _DMLTableArgument) -> Insert:
     """Construct a sqlite-specific variant :class:`_sqlite.Insert`
     construct.
 
@@ -61,7 +69,9 @@ class Insert(StandardInsert):
     inherit_cache = False
 
     @util.memoized_property
-    def excluded(self):
+    def excluded(
+        self,
+    ) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
         """Provide the ``excluded`` namespace for an ON CONFLICT statement
 
         SQLite's ON CONFLICT clause allows reference to the row that would
@@ -94,10 +104,10 @@ class Insert(StandardInsert):
     @_on_conflict_exclusive
     def on_conflict_do_update(
         self,
-        index_elements=None,
-        index_where=None,
-        set_=None,
-        where=None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+        set_: _OnConflictSetT = None,
+        where: _OnConflictWhereT = None,
     ) -> Self:
         r"""
         Specifies a DO UPDATE SET action for ON CONFLICT clause.
@@ -147,7 +157,9 @@ class Insert(StandardInsert):
     @_generative
     @_on_conflict_exclusive
     def on_conflict_do_nothing(
-        self, index_elements=None, index_where=None
+        self,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
     ) -> Self:
         """
         Specifies a DO NOTHING action for ON CONFLICT clause.
@@ -172,7 +184,15 @@ class Insert(StandardInsert):
 class OnConflictClause(ClauseElement):
     stringify_dialect = "sqlite"
 
-    def __init__(self, index_elements=None, index_where=None):
+    constraint_target: None
+    inferred_target_elements: _OnConflictIndexElementsT
+    inferred_target_whereclause: _OnConflictIndexWhereT
+
+    def __init__(
+        self,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+    ):
         if index_elements is not None:
             self.constraint_target = None
             self.inferred_target_elements = index_elements
@@ -192,10 +212,10 @@ class OnConflictDoUpdate(OnConflictClause):
 
     def __init__(
         self,
-        index_elements=None,
-        index_where=None,
-        set_=None,
-        where=None,
+        index_elements: _OnConflictIndexElementsT = None,
+        index_where: _OnConflictIndexWhereT = None,
+        set_: _OnConflictSetT = None,
+        where: _OnConflictWhereT = None,
     ):
         super().__init__(
             index_elements=index_elements,
diff --git a/test/typing/plain_files/dialects/mysql/mysql_stuff.py b/test/typing/plain_files/dialects/mysql/mysql_stuff.py
new file mode 100644 (file)
index 0000000..3fcdc75
--- /dev/null
@@ -0,0 +1,21 @@
+from sqlalchemy import Integer
+from sqlalchemy.dialects.mysql import insert
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Test(Base):
+    __tablename__ = "test_table_json"
+
+    id = mapped_column(Integer, primary_key=True)
+    data: Mapped[str] = mapped_column()
+
+
+insert(Test).on_duplicate_key_update(
+    {"id": 42, Test.data: 99}, [("foo", 44)], data=99, id="foo"
+).inserted.foo.desc()
index c90bb67f0e89f388dcc253259ba128a515141e0d..4567daa38665ddafe36597d0cd283ac5ae067875 100644 (file)
@@ -9,8 +9,10 @@ from sqlalchemy import Integer
 from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import Text
+from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import array
+from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import JSONB
 from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.orm import DeclarativeBase
@@ -68,3 +70,10 @@ reveal_type(t1.data)
 
 # EXPECTED_TYPE: UUID
 reveal_type(t1.ident)
+
+unique = UniqueConstraint(name="my_constraint")
+insert(Test).on_conflict_do_nothing(
+    "foo", [Test.id], Test.id > 0
+).on_conflict_do_update(
+    unique, ["foo"], Test.id > 0, {"id": 42, Test.ident: 99}, Test.id == 22
+).excluded.foo.desc()
diff --git a/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py b/test/typing/plain_files/dialects/sqlite/sqlite_stuff.py
new file mode 100644 (file)
index 0000000..00debda
--- /dev/null
@@ -0,0 +1,23 @@
+from sqlalchemy import Integer
+from sqlalchemy import UniqueConstraint
+from sqlalchemy.dialects.sqlite import insert
+from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import mapped_column
+
+
+class Base(DeclarativeBase):
+    pass
+
+
+class Test(Base):
+    __tablename__ = "test_table_json"
+
+    id = mapped_column(Integer, primary_key=True)
+    data: Mapped[str] = mapped_column()
+
+
+unique = UniqueConstraint(name="my_constraint")
+insert(Test).on_conflict_do_nothing("foo", Test.id > 0).on_conflict_do_update(
+    unique, Test.id > 0, {"id": 42, Test.data: 99}, Test.id == 22
+).excluded.foo.desc()