]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use the _psycopg module from either psycopg_c or psycopg_binary
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 28 Jun 2021 02:14:20 +0000 (03:14 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 28 Jun 2021 11:23:44 +0000 (12:23 +0100)
psycopg/psycopg/_cmodule.py [new file with mode: 0644]
psycopg/psycopg/_wrappers.py [new file with mode: 0644]
psycopg/psycopg/adapt.py
psycopg/psycopg/connection.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/types/numeric.py
psycopg_c/psycopg_c/types/numeric.pyx
tests/test_adapt.py

diff --git a/psycopg/psycopg/_cmodule.py b/psycopg/psycopg/_cmodule.py
new file mode 100644 (file)
index 0000000..0ab4813
--- /dev/null
@@ -0,0 +1,20 @@
+"""
+Simplify access to the _psycopg module
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+from . import pq
+
+# Note: "c" must the first attempt so that mypy associates the variable the
+# right module interface. It will not result Optional, but hey.
+if pq.__impl__ == "c":
+    from psycopg_c import _psycopg
+elif pq.__impl__ == "binary":
+    from psycopg_binary import _psycopg  # type: ignore
+elif pq.__impl__ == "python":
+    _psycopg = None  # type: ignore
+else:
+    raise ImportError(
+        f"can't find _psycopg optimised module in {pq.__impl__!r}"
+    )
diff --git a/psycopg/psycopg/_wrappers.py b/psycopg/psycopg/_wrappers.py
new file mode 100644 (file)
index 0000000..ecbb34c
--- /dev/null
@@ -0,0 +1,51 @@
+"""
+Wrappers for numeric types.
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+# Wrappers to force numbers to be cast as specific PostgreSQL types
+
+# These types are implemented here but exposed by `psycopg.types.numeric`.
+# They are defined here to avoid a circular import.
+_MODULE = "psycopg.types.numeric"
+
+
+class Int2(int):
+
+    __module__ = _MODULE
+
+    def __new__(cls, arg: int) -> "Int2":
+        return super().__new__(cls, arg)
+
+
+class Int4(int):
+
+    __module__ = _MODULE
+
+    def __new__(cls, arg: int) -> "Int4":
+        return super().__new__(cls, arg)
+
+
+class Int8(int):
+
+    __module__ = _MODULE
+
+    def __new__(cls, arg: int) -> "Int8":
+        return super().__new__(cls, arg)
+
+
+class IntNumeric(int):
+
+    __module__ = _MODULE
+
+    def __new__(cls, arg: int) -> "IntNumeric":
+        return super().__new__(cls, arg)
+
+
+class Oid(int):
+
+    __module__ = _MODULE
+
+    def __new__(cls, arg: int) -> "Oid":
+        return super().__new__(cls, arg)
index 06b404ba80cf52c6ca465be281b41c69359b953a..6acc9bb1f87ebea97ed739311a568aba936123bd 100644 (file)
@@ -14,6 +14,7 @@ from . import errors as e
 from ._enums import Format as Format
 from .oids import postgres_types
 from .proto import AdaptContext, Buffer as Buffer
+from ._cmodule import _psycopg
 from ._typeinfo import TypesRegistry
 
 if TYPE_CHECKING:
@@ -186,7 +187,7 @@ class AdaptersMap(AdaptContext):
                 f"dumpers should be registered on classes, got {cls} instead"
             )
 
-        if pq.__impl__ != "python":
+        if _psycopg:
             dumper = self._get_optimised(dumper)
 
         # Register the dumper both as its format and as default
@@ -210,7 +211,7 @@ class AdaptersMap(AdaptContext):
                 f"loaders should be registered on oid, got {oid} instead"
             )
 
-        if pq.__impl__ != "python":
+        if _psycopg:
             loader = self._get_optimised(loader)
 
         fmt = loader.format
@@ -272,7 +273,6 @@ class AdaptersMap(AdaptContext):
         # Check if the class comes from psycopg.types and there is a class
         # with the same name in psycopg_c._psycopg.
         from psycopg import types
-        from psycopg_c import _psycopg
 
         if cls.__module__.startswith(types.__name__):
             new = cast(Type[RV], getattr(_psycopg, cls.__name__, None))
@@ -292,9 +292,7 @@ global_adapters = AdaptersMap(types=postgres_types)
 Transformer: Type[proto.Transformer]
 
 # Override it with fast object if available
-if pq.__impl__ == "c":
-    from psycopg_c import _psycopg
-
+if _psycopg:
     Transformer = _psycopg.Transformer
 else:
     from . import _transform
index 906b5fdeb4ffd0ca2ac491e7dabf5d8b7ba113c4..5faedf2616ab8460d48a4db6a27b5c9a28392d04 100644 (file)
@@ -28,6 +28,7 @@ from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
 from .proto import Query, RV
 from .compat import asynccontextmanager
 from .cursor import Cursor, AsyncCursor
+from ._cmodule import _psycopg
 from .conninfo import _conninfo_connect_timeout, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
@@ -47,9 +48,7 @@ if TYPE_CHECKING:
     from .pq.proto import PGconn, PGresult
     from .pool.base import BasePool
 
-if pq.__impl__ == "c":
-    from psycopg_c import _psycopg
-
+if _psycopg:
     connect = _psycopg.connect
     execute = _psycopg.execute
 
index f34868edec670b544a8cdf411e2c0090a95a2ada..f258be088914d5e68cc72891bc7a2d491849af98 100644 (file)
@@ -20,6 +20,7 @@ from .pq import ExecStatus
 from .adapt import Format
 from .proto import ConnectionType, PQGen, Transformer
 from .compat import create_task
+from ._cmodule import _psycopg
 from .generators import copy_from, copy_to, copy_end
 
 if TYPE_CHECKING:
@@ -639,9 +640,7 @@ def _load_sub(
 
 
 # Override functions with fast versions if available
-if pq.__impl__ == "c":
-    from psycopg_c import _psycopg
-
+if _psycopg:
     format_row_text = _psycopg.format_row_text
     format_row_binary = _psycopg.format_row_binary
     parse_row_text = _psycopg.parse_row_text
index 0fd8a05738da516c0411d95318a97cbf7581fa89..3fa4ac93c589cd21f283adb7b2e123beedfa9a5b 100644 (file)
@@ -21,6 +21,7 @@ from .rows import Row, RowFactory
 from .proto import ConnectionType, Query, Params, PQGen
 from .compat import asynccontextmanager
 from ._column import Column
+from ._cmodule import _psycopg
 from ._queries import PostgresQuery
 from ._preparing import Prepare
 
@@ -32,9 +33,7 @@ if TYPE_CHECKING:
 
 execute: Callable[["PGconn"], PQGen[List["PGresult"]]]
 
-if pq.__impl__ == "c":
-    from psycopg_c import _psycopg
-
+if _psycopg:
     execute = _psycopg.execute
 
 else:
index decfa232080ad8d6cebbd828ecb4377cbb4bfc04..2d12af3fe767d6097e440ad193a7ee3519adde38 100644 (file)
@@ -20,33 +20,14 @@ from .._struct import pack_int4, pack_uint4, unpack_int4, unpack_uint4
 from .._struct import pack_int8, unpack_int8
 from .._struct import pack_float8, unpack_float4, unpack_float8
 
-
-# Wrappers to force numbers to be cast as specific PostgreSQL types
-
-
-class Int2(int):
-    def __new__(cls, arg: int) -> "Int2":
-        return super().__new__(cls, arg)
-
-
-class Int4(int):
-    def __new__(cls, arg: int) -> "Int4":
-        return super().__new__(cls, arg)
-
-
-class Int8(int):
-    def __new__(cls, arg: int) -> "Int8":
-        return super().__new__(cls, arg)
-
-
-class IntNumeric(int):
-    def __new__(cls, arg: int) -> "IntNumeric":
-        return super().__new__(cls, arg)
-
-
-class Oid(int):
-    def __new__(cls, arg: int) -> "Oid":
-        return super().__new__(cls, arg)
+# Exposed here
+from .._wrappers import (
+    Int2 as Int2,
+    Int4 as Int4,
+    Int8 as Int8,
+    IntNumeric as IntNumeric,
+    Oid as Oid,
+)
 
 
 class _NumberDumper(Dumper):
index 01a67b3a3f33446222dee01f895cd6a135234260..75af361b1e61ebf7ccd7455e60c6434e874bd6ad 100644 (file)
@@ -21,8 +21,7 @@ from decimal import Decimal, Context, DefaultContext
 
 from psycopg_c._psycopg cimport endian
 from psycopg import errors as e
-
-from psycopg.types.numeric import Int2, Int4, Int8, IntNumeric
+from psycopg._wrappers import Int2, Int4, Int8, IntNumeric
 
 cdef extern from "Python.h":
     # work around https://github.com/cython/cython/issues/3909
index 3988d7d37044619eb6a6533fe3be0fb53b29cd4d..6815d27f1eae00fb139a6d9f7b390f949f0bc014 100644 (file)
@@ -7,6 +7,7 @@ import psycopg
 from psycopg import pq
 from psycopg.adapt import Transformer, Format, Dumper, Loader
 from psycopg.oids import postgres_types as builtins, TEXT_OID
+from psycopg._cmodule import _psycopg
 
 
 @pytest.mark.parametrize(
@@ -292,11 +293,9 @@ def test_no_cast_needed(conn, fmt_in):
     assert cur.fetchone()[0] == 20
 
 
-@pytest.mark.skipif(psycopg.pq.__impl__ == "python", reason="C module test")
+@pytest.mark.skipif(_psycopg is None, reason="C module test")
 def test_optimised_adapters():
 
-    from psycopg_c import _psycopg
-
     # All the optimised adapters available
     c_adapters = {}
     for n in dir(_psycopg):