]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(crdb) split crdb module into a package
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Jun 2022 02:22:10 +0000 (04:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
psycopg/psycopg/crdb/__init__.py [new file with mode: 0644]
psycopg/psycopg/crdb/_types.py [moved from psycopg/psycopg/crdb.py with 50% similarity]
psycopg/psycopg/crdb/connection.py [new file with mode: 0644]

diff --git a/psycopg/psycopg/crdb/__init__.py b/psycopg/psycopg/crdb/__init__.py
new file mode 100644 (file)
index 0000000..323903a
--- /dev/null
@@ -0,0 +1,19 @@
+"""
+CockroachDB support package.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+from . import _types
+from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
+
+adapters = _types.adapters  # exposed by the package
+connect = CrdbConnection.connect
+
+_types.register_crdb_adapters(adapters)
+
+__all__ = [
+    "AsyncCrdbConnection",
+    "CrdbConnection",
+    "CrdbConnectionInfo",
+]
similarity index 50%
rename from psycopg/psycopg/crdb.py
rename to psycopg/psycopg/crdb/_types.py
index 9ae7c6f24d561d3d0a0a582f0c27e89f764614d4..5311e05b0c580dd72cd986df2ce13de69a7fc2ec 100644 (file)
@@ -4,26 +4,14 @@ Types configuration specific for CockroachDB.
 
 # Copyright (C) 2022 The Psycopg Team
 
-import re
 from enum import Enum
-from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
-from ._typeinfo import TypeInfo, TypesRegistry
-
-from . import errors as e
-from .abc import AdaptContext, NoneType
-from .rows import Row, RowFactory, AsyncRowFactory, TupleRow
-from .postgres import TEXT_OID
-from .conninfo import ConnectionInfo
-from .connection import Connection
-from ._adapters_map import AdaptersMap
-from .connection_async import AsyncConnection
-from .types.enum import EnumDumper, EnumBinaryDumper
-from .types.none import NoneDumper
-
-if TYPE_CHECKING:
-    from .pq.abc import PGconn
-    from .cursor import Cursor
-    from .cursor_async import AsyncCursor
+from .._typeinfo import TypeInfo, TypesRegistry
+
+from ..abc import AdaptContext, NoneType
+from ..postgres import TEXT_OID
+from .._adapters_map import AdaptersMap
+from ..types.enum import EnumDumper, EnumBinaryDumper
+from ..types.none import NoneDumper
 
 types = TypesRegistry()
 
@@ -31,152 +19,6 @@ types = TypesRegistry()
 adapters = AdaptersMap(types=types)
 
 
-class _CrdbConnectionMixin:
-
-    _adapters: Optional[AdaptersMap]
-    pgconn: "PGconn"
-
-    @classmethod
-    def is_crdb(
-        cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
-    ) -> bool:
-        """
-        Return True if the server connected to ``conn`` is CockroachDB.
-        """
-        if isinstance(conn, (Connection, AsyncConnection)):
-            conn = conn.pgconn
-
-        return bool(conn.parameter_status(b"crdb_version"))
-
-    @property
-    def adapters(self) -> AdaptersMap:
-        if not self._adapters:
-            # By default, use CockroachDB adapters map
-            self._adapters = AdaptersMap(adapters)
-
-        return self._adapters
-
-    @property
-    def info(self) -> "CrdbConnectionInfo":
-        return CrdbConnectionInfo(self.pgconn)
-
-
-class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
-    # TODO: this method shouldn't require re-definition if the base class
-    # implements a generic self.
-    # https://github.com/psycopg/psycopg/issues/308
-    @overload
-    @classmethod
-    def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        row_factory: RowFactory[Row],
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "CrdbConnection[Row]":
-        ...
-
-    @overload
-    @classmethod
-    def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "CrdbConnection[TupleRow]":
-        ...
-
-    @classmethod
-    def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
-        return super().connect(conninfo, **kwargs)  # type: ignore[return-value]
-
-
-class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
-    # TODO: this method shouldn't require re-definition if the base class
-    # implements a generic self.
-    # https://github.com/psycopg/psycopg/issues/308
-    @overload
-    @classmethod
-    async def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        row_factory: AsyncRowFactory[Row],
-        cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "AsyncCrdbConnection[Row]":
-        ...
-
-    @overload
-    @classmethod
-    async def connect(
-        cls,
-        conninfo: str = "",
-        *,
-        autocommit: bool = False,
-        prepare_threshold: Optional[int] = 5,
-        cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
-        context: Optional[AdaptContext] = None,
-        **kwargs: Union[None, int, str],
-    ) -> "AsyncCrdbConnection[TupleRow]":
-        ...
-
-    @classmethod
-    async def connect(
-        cls, conninfo: str = "", **kwargs: Any
-    ) -> "AsyncCrdbConnection[Any]":
-        return await super().connect(conninfo, **kwargs)  # type: ignore [no-any-return]
-
-
-connect = CrdbConnection.connect
-
-
-class CrdbConnectionInfo(ConnectionInfo):
-    @property
-    def vendor(self) -> str:
-        return "CockroachDB"
-
-    @property
-    def crdb_version(self) -> int:
-        """
-        Return the CockroachDB server version connected.
-
-        Return None if the server is not CockroachDB, else return a number in
-        the PostgreSQL format (e.g. 21.2.10 -> 200210)
-
-        Assume all the connections are on the same db: return a cached result on
-        following calls.
-        """
-        sver = self.parameter_status("crdb_version")
-        if not sver:
-            raise e.InternalError("'crdb_version' parameter status not set")
-
-        ver = self.parse_crdb_version(sver)
-        if ver is None:
-            raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
-
-        return ver
-
-    @classmethod
-    def parse_crdb_version(self, sver: str) -> Optional[int]:
-        m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
-        if not m:
-            return None
-
-        return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
-
-
 class CrdbEnumDumper(EnumDumper):
     oid = TEXT_OID
 
@@ -192,8 +34,8 @@ class CrdbNoneDumper(NoneDumper):
 def register_postgres_adapters(context: AdaptContext) -> None:
     # Same adapters used by PostgreSQL, or a good starting point for customization
 
-    from .types import array, bool, composite, datetime
-    from .types import numeric, string, uuid
+    from ..types import array, bool, composite, datetime
+    from ..types import numeric, string, uuid
 
     array.register_default_adapters(context)
     bool.register_default_adapters(context)
@@ -205,8 +47,8 @@ def register_postgres_adapters(context: AdaptContext) -> None:
 
 
 def register_crdb_adapters(context: AdaptContext) -> None:
-    from . import dbapi20
-    from .types import array
+    from .. import dbapi20
+    from ..types import array
 
     register_postgres_adapters(context)
 
@@ -223,7 +65,7 @@ def register_crdb_adapters(context: AdaptContext) -> None:
 
 
 def register_crdb_string_adapters(context: AdaptContext) -> None:
-    from .types import string
+    from ..types import string
 
     # Dump strings with text oid instead of unknown.
     # Unlike PostgreSQL, CRDB seems able to cast text to most types.
@@ -237,7 +79,7 @@ def register_crdb_enum_adapters(context: AdaptContext) -> None:
 
 
 def register_crdb_json_adapters(context: AdaptContext) -> None:
-    from .types import json
+    from ..types import json
 
     adapters = context.adapters
 
@@ -255,7 +97,7 @@ def register_crdb_json_adapters(context: AdaptContext) -> None:
 
 
 def register_crdb_net_adapters(context: AdaptContext) -> None:
-    from psycopg.types import net
+    from ..types import net
 
     adapters = context.adapters
 
@@ -319,6 +161,3 @@ for t in [
     # autogenerated: end
 ]:
     types.add(t)
-
-
-register_crdb_adapters(adapters)
diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py
new file mode 100644 (file)
index 0000000..a52bad0
--- /dev/null
@@ -0,0 +1,186 @@
+"""
+CockroachDB-specific connections.
+"""
+
+# Copyright (C) 2022 The Psycopg Team
+
+import re
+from typing import Any, Optional, Type, Union, overload, TYPE_CHECKING
+
+from .. import errors as e
+from ..abc import AdaptContext
+from ..rows import Row, RowFactory, AsyncRowFactory, TupleRow
+from ..conninfo import ConnectionInfo
+from ..connection import Connection
+from .._adapters_map import AdaptersMap
+from ..connection_async import AsyncConnection
+from ._types import adapters
+
+if TYPE_CHECKING:
+    from ..pq.abc import PGconn
+    from ..cursor import Cursor
+    from ..cursor_async import AsyncCursor
+
+
+class _CrdbConnectionMixin:
+
+    _adapters: Optional[AdaptersMap]
+    pgconn: "PGconn"
+
+    @classmethod
+    def is_crdb(
+        cls, conn: Union[Connection[Any], AsyncConnection[Any], "PGconn"]
+    ) -> bool:
+        """
+        Return True if the server connected to ``conn`` is CockroachDB.
+        """
+        if isinstance(conn, (Connection, AsyncConnection)):
+            conn = conn.pgconn
+
+        return bool(conn.parameter_status(b"crdb_version"))
+
+    @property
+    def adapters(self) -> AdaptersMap:
+        if not self._adapters:
+            # By default, use CockroachDB adapters map
+            self._adapters = AdaptersMap(adapters)
+
+        return self._adapters
+
+    @property
+    def info(self) -> "CrdbConnectionInfo":
+        return CrdbConnectionInfo(self.pgconn)
+
+
+class CrdbConnection(_CrdbConnectionMixin, Connection[Row]):
+    """
+    Wrapper for a connection to a CockroachDB database.
+    """
+
+    __module__ = "psycopg.crdb"
+
+    # TODO: this method shouldn't require re-definition if the base class
+    # implements a generic self.
+    # https://github.com/psycopg/psycopg/issues/308
+    @overload
+    @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        row_factory: RowFactory[Row],
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[Cursor[Row]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "CrdbConnection[Row]":
+        ...
+
+    @overload
+    @classmethod
+    def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[Cursor[Any]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "CrdbConnection[TupleRow]":
+        ...
+
+    @classmethod
+    def connect(cls, conninfo: str = "", **kwargs: Any) -> "CrdbConnection[Any]":
+        """
+        Connect to a database server and return a new `CrdbConnection` instance.
+        """
+        return super().connect(conninfo, **kwargs)  # type: ignore[return-value]
+
+
+class AsyncCrdbConnection(_CrdbConnectionMixin, AsyncConnection[Row]):
+    """
+    Wrapper for an async connection to a CockroachDB database.
+    """
+
+    __module__ = "psycopg.crdb"
+
+    # TODO: this method shouldn't require re-definition if the base class
+    # implements a generic self.
+    # https://github.com/psycopg/psycopg/issues/308
+    @overload
+    @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        row_factory: AsyncRowFactory[Row],
+        cursor_factory: "Optional[Type[AsyncCursor[Row]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncCrdbConnection[Row]":
+        ...
+
+    @overload
+    @classmethod
+    async def connect(
+        cls,
+        conninfo: str = "",
+        *,
+        autocommit: bool = False,
+        prepare_threshold: Optional[int] = 5,
+        cursor_factory: "Optional[Type[AsyncCursor[Any]]]" = None,
+        context: Optional[AdaptContext] = None,
+        **kwargs: Union[None, int, str],
+    ) -> "AsyncCrdbConnection[TupleRow]":
+        ...
+
+    @classmethod
+    async def connect(
+        cls, conninfo: str = "", **kwargs: Any
+    ) -> "AsyncCrdbConnection[Any]":
+        return await super().connect(conninfo, **kwargs)  # type: ignore [no-any-return]
+
+
+class CrdbConnectionInfo(ConnectionInfo):
+    """
+    `~psycopg.ConnectionInfo` subclass to get info about a CockroachDB database.
+    """
+
+    __module__ = "psycopg.crdb"
+
+    @property
+    def vendor(self) -> str:
+        return "CockroachDB"
+
+    @property
+    def crdb_version(self) -> int:
+        """
+        Return the CockroachDB server version connected.
+
+        Return None if the server is not CockroachDB, else return a number in
+        the PostgreSQL format (e.g. 21.2.10 -> 200210)
+
+        Assume all the connections are on the same db: return a cached result on
+        following calls.
+        """
+        sver = self.parameter_status("crdb_version")
+        if not sver:
+            raise e.InternalError("'crdb_version' parameter status not set")
+
+        ver = self.parse_crdb_version(sver)
+        if ver is None:
+            raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
+
+        return ver
+
+    @classmethod
+    def parse_crdb_version(self, sver: str) -> Optional[int]:
+        m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
+        if not m:
+            return None
+
+        return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))