]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Adding typing to Postgres dialect file.
authorJohn Clow <jclow@canopyservicing.com>
Wed, 12 Apr 2023 06:34:58 +0000 (02:34 -0400)
committerFederico Caselli <cfederico87@gmail.com>
Wed, 12 Apr 2023 20:54:01 +0000 (22:54 +0200)
Adding typing information for various parameters for Postgres types (in accordance to the docs).

This pull request is:

- [x] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [ ] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9594
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9594
Pull-request-sha: c7e39a219108f9e81ad22c008a664b62f09f9d5f

Change-Id: I91b377c246c728885a99df297de7a8933835c540

lib/sqlalchemy/dialects/postgresql/types.py

index a6b82044b3055b50ed82302b6d2191818dc0c000..86d67fd56dd499baee039954060dfed4a6f4ebf4 100644 (file)
@@ -3,24 +3,46 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
+from __future__ import annotations
 
 import datetime as dt
+from typing import Any
+from typing import Optional
+from typing import overload
+from typing import Type
+from typing import TYPE_CHECKING
+from uuid import UUID as _python_UUID
 
 from ...sql import sqltypes
-
+from ...sql import type_api
+from ...util.typing import Literal
 
 _DECIMAL_TYPES = (1231, 1700)
 _FLOAT_TYPES = (700, 701, 1021, 1022)
 _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
 
 
-class PGUuid(sqltypes.UUID):
+class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
     render_bind_cast = True
     render_literal_cast = True
 
+    if TYPE_CHECKING:
+
+        @overload
+        def __init__(
+            self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
+        ) -> None:
+            ...
+
+        @overload
+        def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None:
+            ...
+
+        def __init__(self, as_uuid: bool = True) -> None:
+            ...
+
 
-class BYTEA(sqltypes.LargeBinary[bytes]):
+class BYTEA(sqltypes.LargeBinary):
     __visit_name__ = "BYTEA"
 
 
@@ -53,7 +75,6 @@ PGMacAddr8 = MACADDR8
 
 
 class MONEY(sqltypes.TypeEngine[str]):
-
     r"""Provide the PostgreSQL MONEY type.
 
     Depending on driver, result rows using this type may return a
@@ -150,7 +171,9 @@ class TIMESTAMP(sqltypes.TIMESTAMP):
 
     __visit_name__ = "TIMESTAMP"
 
-    def __init__(self, timezone=False, precision=None):
+    def __init__(
+        self, timezone: bool = False, precision: Optional[int] = None
+    ) -> None:
         """Construct a TIMESTAMP.
 
         :param timezone: boolean value if timezone present, default False
@@ -169,7 +192,9 @@ class TIME(sqltypes.TIME):
 
     __visit_name__ = "TIME"
 
-    def __init__(self, timezone=False, precision=None):
+    def __init__(
+        self, timezone: bool = False, precision: Optional[int] = None
+    ) -> None:
         """Construct a TIME.
 
         :param timezone: boolean value if timezone present, default False
@@ -182,14 +207,16 @@ class TIME(sqltypes.TIME):
         self.precision = precision
 
 
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
 
     """PostgreSQL INTERVAL type."""
 
     __visit_name__ = "INTERVAL"
     native = True
 
-    def __init__(self, precision=None, fields=None):
+    def __init__(
+        self, precision: Optional[int] = None, fields: Optional[str] = None
+    ) -> None:
         """Construct an INTERVAL.
 
         :param precision: optional integer precision value
@@ -204,18 +231,20 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
         self.fields = fields
 
     @classmethod
-    def adapt_emulated_to_native(cls, interval, **kw):
+    def adapt_emulated_to_native(
+        cls, interval: sqltypes.Interval, **kw: Any  # type: ignore[override]
+    ) -> INTERVAL:
         return INTERVAL(precision=interval.second_precision)
 
     @property
-    def _type_affinity(self):
+    def _type_affinity(self) -> Type[sqltypes.Interval]:
         return sqltypes.Interval
 
-    def as_generic(self, allow_nulltype=False):
+    def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval:
         return sqltypes.Interval(native=True, second_precision=self.precision)
 
     @property
-    def python_type(self):
+    def python_type(self) -> Type[dt.timedelta]:
         return dt.timedelta
 
 
@@ -225,13 +254,15 @@ PGInterval = INTERVAL
 class BIT(sqltypes.TypeEngine[int]):
     __visit_name__ = "BIT"
 
-    def __init__(self, length=None, varying=False):
-        if not varying:
+    def __init__(
+        self, length: Optional[int] = None, varying: bool = False
+    ) -> None:
+        if varying:
+            # BIT VARYING can be unlimited-length, so no default
+            self.length = length
+        else:
             # BIT without VARYING defaults to length 1
             self.length = length or 1
-        else:
-            # but BIT VARYING can be unlimited-length, so no default
-            self.length = length
         self.varying = varying