]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added tests about adapters/casters selection, some fixes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 1 Apr 2020 13:25:58 +0000 (02:25 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 1 Apr 2020 13:37:09 +0000 (02:37 +1300)
- Added Transformer.cast() function;
- Can return binary results even with no input parameters;
- libpq functions accept sequences rather than lists.

psycopg3/adaptation.py
psycopg3/cursor.py
psycopg3/pq/pq_ctypes.py
psycopg3/types/numeric.py
tests/test_adapt.py [new file with mode: 0644]

index 3db374ce587a2787995651a972c6b50df581315e..082f32ad7a8419cd86f03dcac6128dc155c69b4a 100644 (file)
@@ -175,7 +175,7 @@ class Transformer:
     connection: Optional[BaseConnection]
     cursor: Optional[BaseCursor]
 
-    def __init__(self, context: AdaptContext):
+    def __init__(self, context: AdaptContext = None):
         if context is None:
             self.connection = None
             self.cursor = None
@@ -239,7 +239,7 @@ class Transformer:
 
         return out, types
 
-    def adapt(self, obj: None, fmt: Format) -> MaybeOid:
+    def adapt(self, obj: None, fmt: Format = Format.TEXT) -> MaybeOid:
         if obj is None:
             return None, type_oid["text"]
 
@@ -286,6 +286,15 @@ class Transformer:
                 v = func(v)
             yield v
 
+    def cast(
+        self, data: Optional[bytes], oid: Oid, fmt: Format = Format.TEXT
+    ) -> Any:
+        if data is not None:
+            f = self.get_cast_function(oid, fmt)
+            return f(data)
+        else:
+            return None
+
     def get_cast_function(self, oid: Oid, fmt: Format) -> TypecasterFunc:
         try:
             return self._cast_funcs[oid, fmt]
index ec6416820ca3d9056b1461253b59df01fe6d44e0..091a4a9caaa89e7673c6f6e4db9a4a8baae432a4 100644 (file)
@@ -66,7 +66,14 @@ class BaseCursor:
                 result_format=Format(self.binary),
             )
         else:
-            self.conn.pgconn.send_query(query)
+            # if we don't have to, let's use exec_ as it can run more than
+            # one query in one go
+            if self.binary:
+                self.conn.pgconn.send_query_params(
+                    query, (), result_format=Format(self.binary)
+                )
+            else:
+                self.conn.pgconn.send_query(query)
 
         return self.conn._exec_gen(self.conn.pgconn)
 
index d11414b837d7ebc5d282112def0882691b32edb8..1b26f71fb66ec3fb45956ee6fc292111a48972a9 100644 (file)
@@ -196,9 +196,9 @@ class PGconn:
     def exec_params(
         self,
         command: bytes,
-        param_values: List[Optional[bytes]],
-        param_types: Optional[List[Oid]] = None,
-        param_formats: Optional[List[Format]] = None,
+        param_values: Sequence[Optional[bytes]],
+        param_types: Optional[Sequence[Oid]] = None,
+        param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
     ) -> "PGresult":
         args = self._query_params_args(
@@ -212,9 +212,9 @@ class PGconn:
     def send_query_params(
         self,
         command: bytes,
-        param_values: List[Optional[bytes]],
-        param_types: Optional[List[Oid]] = None,
-        param_formats: Optional[List[Format]] = None,
+        param_values: Sequence[Optional[bytes]],
+        param_types: Optional[Sequence[Oid]] = None,
+        param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
     ) -> None:
         args = self._query_params_args(
@@ -228,9 +228,9 @@ class PGconn:
     def _query_params_args(
         self,
         command: bytes,
-        param_values: List[Optional[bytes]],
-        param_types: Optional[List[Oid]] = None,
-        param_formats: Optional[List[Format]] = None,
+        param_values: Sequence[Optional[bytes]],
+        param_types: Optional[Sequence[Oid]] = None,
+        param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
     ) -> Any:
         if not isinstance(command, bytes):
@@ -281,7 +281,7 @@ class PGconn:
         self,
         name: bytes,
         command: bytes,
-        param_types: Optional[List[Oid]] = None,
+        param_types: Optional[Sequence[Oid]] = None,
     ) -> "PGresult":
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
@@ -306,8 +306,8 @@ class PGconn:
     def exec_prepared(
         self,
         name: bytes,
-        param_values: List[bytes],
-        param_formats: Optional[List[int]] = None,
+        param_values: Sequence[bytes],
+        param_formats: Optional[Sequence[int]] = None,
         result_format: int = 0,
     ) -> "PGresult":
         if not isinstance(name, bytes):
index dc1f74777be92511a3e848688de523beb3111d65..ffc3ffef207d7415a06caa7faedacd4caabdb7e1 100644 (file)
@@ -19,6 +19,9 @@ def adapt_int(obj: int) -> Tuple[bytes, int]:
     return _encode(str(obj))[0], type_oid["numeric"]
 
 
-@Typecaster.text(type_oid["numeric"])
+@Typecaster.text(type_oid["int4"])
+@Typecaster.text(type_oid["int8"])
+@Typecaster.text(type_oid["oid"])
+@Typecaster.text(type_oid["numeric"])  # TODO: wrong: return Decimal
 def cast_int(data: bytes) -> int:
     return int(_decode(data)[0])
diff --git a/tests/test_adapt.py b/tests/test_adapt.py
new file mode 100644 (file)
index 0000000..d981cd4
--- /dev/null
@@ -0,0 +1,108 @@
+import pytest
+from psycopg3.adaptation import Transformer, Format, Adapter, Typecaster
+from psycopg3.types.oids import type_oid
+
+
+@pytest.mark.parametrize(
+    "data, format, result, type",
+    [
+        (None, Format.TEXT, None, "text"),
+        (None, Format.BINARY, None, "text"),
+        (1, Format.TEXT, b"1", "numeric"),
+        ("hello", Format.TEXT, b"hello", "text"),
+        ("hello", Format.BINARY, b"hello", "text"),
+    ],
+)
+def test_adapt(data, format, result, type):
+    t = Transformer()
+    rv = t.adapt(data, format)
+    if isinstance(rv, tuple):
+        assert rv[0] == result
+        assert rv[1] == type_oid[type]
+    else:
+        assert rv == result
+
+
+def test_adapt_connection_ctx(conn):
+    Adapter.register(str, lambda s: s.encode("ascii") + b"t", conn)
+    Adapter.register_binary(str, lambda s: s.encode("ascii") + b"b", conn)
+
+    cur = conn.cursor()
+    cur.execute("select %s, %b", ["hello", "world"])
+    assert cur.fetchone() == ("hellot", "worldb")
+
+
+def test_adapt_cursor_ctx(conn):
+    Adapter.register(str, lambda s: s.encode("ascii") + b"t", conn)
+    Adapter.register_binary(str, lambda s: s.encode("ascii") + b"b", conn)
+
+    cur = conn.cursor()
+    Adapter.register(str, lambda s: s.encode("ascii") + b"tc", cur)
+    Adapter.register_binary(str, lambda s: s.encode("ascii") + b"bc", cur)
+
+    cur.execute("select %s, %b", ["hello", "world"])
+    assert cur.fetchone() == ("hellotc", "worldbc")
+
+    cur = conn.cursor()
+    cur.execute("select %s, %b", ["hello", "world"])
+    assert cur.fetchone() == ("hellot", "worldb")
+
+
+@pytest.mark.parametrize(
+    "data, format, type, result",
+    [
+        (None, Format.TEXT, "text", None),
+        (None, Format.BINARY, "text", None),
+        (b"1", Format.TEXT, "int4", 1),
+        (b"hello", Format.TEXT, "text", "hello"),
+        (b"hello", Format.BINARY, "text", "hello"),
+    ],
+)
+def test_cast(data, format, type, result):
+    t = Transformer()
+    rv = t.cast(data, type_oid[type], format)
+    assert rv == result
+
+
+def test_cast_connection_ctx(conn):
+    Typecaster.register(
+        type_oid["text"], lambda b: b.decode("ascii") + "t", conn
+    )
+    Typecaster.register_binary(
+        type_oid["text"], lambda b: b.decode("ascii") + "b", conn
+    )
+
+    r = conn.cursor().execute("select 'hello'::text").fetchone()
+    assert r == ("hellot",)
+    r = conn.cursor(binary=True).execute("select 'hello'::text").fetchone()
+    assert r == ("hellob",)
+
+
+def test_cast_cursor_ctx(conn):
+    Typecaster.register(
+        type_oid["text"], lambda b: b.decode("ascii") + "t", conn
+    )
+    Typecaster.register_binary(
+        type_oid["text"], lambda b: b.decode("ascii") + "b", conn
+    )
+
+    cur = conn.cursor()
+    Typecaster.register(
+        type_oid["text"], lambda b: b.decode("ascii") + "tc", cur
+    )
+    Typecaster.register_binary(
+        type_oid["text"], lambda b: b.decode("ascii") + "bc", cur
+    )
+
+    r = cur.execute("select 'hello'::text").fetchone()
+    assert r == ("hellotc",)
+    cur.binary = True
+    r = cur.execute("select 'hello'::text").fetchone()
+    assert r == ("hellobc",)
+
+    cur = conn.cursor()
+    r = cur.execute("select 'hello'::text").fetchone()
+    assert r == ("hellot",)
+    cur.binary = True
+    r = cur.execute("select 'hello'::text").fetchone()
+    assert r == ("hellob",)