From: John Clow Date: Wed, 12 Apr 2023 06:34:58 +0000 (-0400) Subject: Adding typing to Postgres dialect file. X-Git-Tag: rel_2_0_11~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f3bc7e5e2b0f8242661c8d89797bfcb3503d9948;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Adding typing to Postgres dialect file. Adding typing information for various parameters for Postgres types (in accordance to the docs). This pull request is: - [x] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [ ] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #9594 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9594 Pull-request-sha: c7e39a219108f9e81ad22c008a664b62f09f9d5f Change-Id: I91b377c246c728885a99df297de7a8933835c540 --- diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py index a6b82044b3..86d67fd56d 100644 --- a/lib/sqlalchemy/dialects/postgresql/types.py +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -3,24 +3,46 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors +from __future__ import annotations import datetime as dt +from typing import Any +from typing import Optional +from typing import overload +from typing import Type +from typing import TYPE_CHECKING +from uuid import UUID as _python_UUID from ...sql import sqltypes - +from ...sql import type_api +from ...util.typing import Literal _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) -class PGUuid(sqltypes.UUID): +class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]): render_bind_cast = True render_literal_cast = True + if TYPE_CHECKING: + + @overload + def __init__( + self: PGUuid[_python_UUID], as_uuid: Literal[True] = ... + ) -> None: + ... + + @overload + def __init__(self: PGUuid[str], as_uuid: Literal[False] = ...) -> None: + ... + + def __init__(self, as_uuid: bool = True) -> None: + ... + -class BYTEA(sqltypes.LargeBinary[bytes]): +class BYTEA(sqltypes.LargeBinary): __visit_name__ = "BYTEA" @@ -53,7 +75,6 @@ PGMacAddr8 = MACADDR8 class MONEY(sqltypes.TypeEngine[str]): - r"""Provide the PostgreSQL MONEY type. Depending on driver, result rows using this type may return a @@ -150,7 +171,9 @@ class TIMESTAMP(sqltypes.TIMESTAMP): __visit_name__ = "TIMESTAMP" - def __init__(self, timezone=False, precision=None): + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: """Construct a TIMESTAMP. :param timezone: boolean value if timezone present, default False @@ -169,7 +192,9 @@ class TIME(sqltypes.TIME): __visit_name__ = "TIME" - def __init__(self, timezone=False, precision=None): + def __init__( + self, timezone: bool = False, precision: Optional[int] = None + ) -> None: """Construct a TIME. :param timezone: boolean value if timezone present, default False @@ -182,14 +207,16 @@ class TIME(sqltypes.TIME): self.precision = precision -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): +class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval): """PostgreSQL INTERVAL type.""" __visit_name__ = "INTERVAL" native = True - def __init__(self, precision=None, fields=None): + def __init__( + self, precision: Optional[int] = None, fields: Optional[str] = None + ) -> None: """Construct an INTERVAL. :param precision: optional integer precision value @@ -204,18 +231,20 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): self.fields = fields @classmethod - def adapt_emulated_to_native(cls, interval, **kw): + def adapt_emulated_to_native( + cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override] + ) -> INTERVAL: return INTERVAL(precision=interval.second_precision) @property - def _type_affinity(self): + def _type_affinity(self) -> Type[sqltypes.Interval]: return sqltypes.Interval - def as_generic(self, allow_nulltype=False): + def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval: return sqltypes.Interval(native=True, second_precision=self.precision) @property - def python_type(self): + def python_type(self) -> Type[dt.timedelta]: return dt.timedelta @@ -225,13 +254,15 @@ PGInterval = INTERVAL class BIT(sqltypes.TypeEngine[int]): __visit_name__ = "BIT" - def __init__(self, length=None, varying=False): - if not varying: + def __init__( + self, length: Optional[int] = None, varying: bool = False + ) -> None: + if varying: + # BIT VARYING can be unlimited-length, so no default + self.length = length + else: # BIT without VARYING defaults to length 1 self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length self.varying = varying