]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use text oid for text/unknown types on PG 9.6
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 24 Dec 2020 03:42:41 +0000 (04:42 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 24 Dec 2020 03:52:07 +0000 (04:52 +0100)
This creates a difference between PG 96 and >= 10 as in the latter less
casts are required. However unknown oid cannot be used to prepare queries
in PG 9.6.

18 files changed:
psycopg3/psycopg3/_queries.py
psycopg3/psycopg3/_transform.py
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/oids.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/date.py
psycopg3/psycopg3/types/json.py
psycopg3/psycopg3/types/network.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/range.py
psycopg3/psycopg3/types/singletons.py
psycopg3/psycopg3/types/text.py
psycopg3/psycopg3/types/uuid.py
psycopg3_c/psycopg3_c/adapt.pyx
tests/test_adapt.py
tests/test_prepared.py
tests/test_prepared_async.py

index b7067d4723304965303460c11308ef64f67cda42..cdf4c1883d912036f83fb89ac980d4382e0315dd 100644 (file)
@@ -12,6 +12,7 @@ from functools import lru_cache
 from . import errors as e
 from .pq import Format
 from .sql import Composable
+from .oids import TEXT_OID, INVALID_OID
 from .proto import Query, Params
 
 if TYPE_CHECKING:
@@ -31,6 +32,7 @@ class PostgresQuery:
 
     _parts: List[QueryPart]
     _query = b""
+    _unknown_oid = INVALID_OID
 
     def __init__(self, transformer: "Transformer"):
         self._tx = transformer
@@ -40,6 +42,11 @@ class PostgresQuery:
         self.formats: Optional[List[Format]] = None
 
         self._order: Optional[List[str]] = None
+        if (
+            self._tx.connection
+            and self._tx.connection.pgconn.server_version < 100000
+        ):
+            self._unknown_oid = TEXT_OID
 
     def convert(self, query: Query, vars: Optional[Params]) -> None:
         """
@@ -84,7 +91,7 @@ class PostgresQuery:
                     ts.append(dumper.oid)
                 else:
                     ps.append(None)
-                    ts.append(0)
+                    ts.append(self._unknown_oid)
             self.types = tuple(ts)
         else:
             self.params = None
index 4fb23043d06afccedd7fa78a93d57070dc4a3333..36fe978b221b42426e3a885c0be0678bd3bfa8cc 100644 (file)
@@ -9,7 +9,7 @@ from typing import TYPE_CHECKING
 
 from . import errors as e
 from .pq import Format
-from .oids import builtins, INVALID_OID
+from .oids import INVALID_OID
 from .proto import AdaptContext, DumpersMap
 from .proto import LoadFunc, LoadersMap
 from .cursor import BaseCursor
@@ -19,8 +19,6 @@ if TYPE_CHECKING:
     from .pq.proto import PGresult
     from .adapt import Dumper, Loader
 
-TEXT_OID = builtins["text"].oid
-
 
 class Transformer:
     """
index 55dfbded4ff622705cf751415a53bf15ca3356f8..f9a03765bb95409ff00da40dce1522a67ce16296 100644 (file)
@@ -10,13 +10,11 @@ from typing import Any, cast, Callable, Optional, Type, Union
 from . import pq
 from . import proto
 from .pq import Format as Format
-from .oids import builtins
+from .oids import TEXT_OID
 from .proto import AdaptContext, DumpersMap, DumperType, LoadersMap, LoaderType
 from .cursor import BaseCursor
 from .connection import BaseConnection
 
-TEXT_OID = builtins["text"].oid
-
 
 class Dumper(ABC):
     """
@@ -26,10 +24,24 @@ class Dumper(ABC):
     globals: DumpersMap = {}
     connection: Optional[BaseConnection]
 
+    # A class-wide oid, which will be used by default by instances unless
+    # the subclass overrides it in init.
+    _oid: int = 0
+
     def __init__(self, src: type, context: AdaptContext = None):
         self.src = src
         self.context = context
         self.connection = connection_from_context(context)
+        self.oid = self._oid
+        """The oid to pass to the server, if known."""
+
+        # Postgres 9.6 doesn't deal well with unknown oids
+        if (
+            not self.oid
+            and self.connection
+            and self.connection.pgconn.server_version < 100000
+        ):
+            self.oid = TEXT_OID
 
     @abstractmethod
     def dump(self, obj: Any) -> bytes:
@@ -49,11 +61,6 @@ class Dumper(ABC):
             esc = pq.Escaping()
             return b"'%s'" % esc.escape_string(value)
 
-    @property
-    def oid(self) -> int:
-        """The oid to pass to the server, if known."""
-        return 0
-
     @classmethod
     def register(
         cls,
index ccaef1f886360cd12c85d978bede3289713e5481..95e67a999b15cc6bdb26a6ecb489c3d5b71bc81a 100644 (file)
@@ -6,8 +6,6 @@ Maps of builtin types and names
 
 from typing import Dict, Iterator, Optional, Union
 
-INVALID_OID = 0
-
 
 class TypeInfo:
     def __init__(self, name: str, oid: int, array_oid: int):
@@ -177,3 +175,9 @@ for r in [
     # fmt: on
 ]:
     builtins.add(BuiltinTypeInfo(*r))
+
+
+# A few oids used a bit everywhere
+INVALID_OID = 0
+TEXT_OID = builtins["text"].oid
+TEXT_ARRAY_OID = builtins["text"].array_oid
index 7cfe8cc188f1347f1106a2297f9246f0e5eeeff2..b552675fd9021b960c784d4540a3072ba8970d8f 100644 (file)
@@ -9,23 +9,18 @@ import struct
 from typing import Any, Iterator, List, Optional, Type
 
 from .. import errors as e
-from ..oids import builtins
+from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID
 from ..adapt import Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
 
-TEXT_OID = builtins["text"].oid
-TEXT_ARRAY_OID = builtins["text"].array_oid
-
 
 class BaseListDumper(Dumper):
+
+    _oid = TEXT_ARRAY_OID
+
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
         self._tx = Transformer(context)
-        self._array_oid = 0
-
-    @property
-    def oid(self) -> int:
-        return self._array_oid or TEXT_ARRAY_OID
 
     def _get_array_oid(self, base_oid: int) -> int:
         """
@@ -99,7 +94,7 @@ class ListDumper(BaseListDumper):
         dump_list(obj)
 
         if oid:
-            self._array_oid = self._get_array_oid(oid)
+            self.oid = self._get_array_oid(oid)
 
         return b"".join(tokens)
 
@@ -154,7 +149,7 @@ class ListBinaryDumper(BaseListDumper):
         if not oid:
             oid = TEXT_OID
 
-        self._array_oid = self._get_array_oid(oid)
+        self.oid = self._get_array_oid(oid)
 
         data[0] = _struct_head.pack(len(dims), hasnull, oid)
         data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
index f48a398ec575af2ec5b64864fe6b3af5e9fd066b..3f26525d0ad57d4d65e4bb5ec0b9458f127bed86 100644 (file)
@@ -12,7 +12,7 @@ from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
 
 from .. import sql
 from .. import errors as e
-from ..oids import builtins, TypeInfo
+from ..oids import builtins, TypeInfo, TEXT_OID
 from ..adapt import Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
 from . import array
@@ -21,9 +21,6 @@ if TYPE_CHECKING:
     from ..connection import Connection, AsyncConnection
 
 
-TEXT_OID = builtins["text"].oid
-
-
 class CompositeInfo(TypeInfo):
     """Manage information about a composite type.
 
@@ -184,6 +181,10 @@ class SequenceDumper(Dumper):
 
 @Dumper.text(tuple)
 class TupleDumper(SequenceDumper):
+
+    # Should be this, but it doesn't work
+    # _oid = builtins["record"].oid
+
     def dump(self, obj: Tuple[Any, ...]) -> bytes:
         return self._dump_sequence(obj, b"(", b")", b",")
 
index d7aa36a7820b8fecd6d038f1d5f75ac79800a7af..b47f782310b8083939b06bd40577865c3e1f3b80 100644 (file)
@@ -18,7 +18,7 @@ from ..errors import InterfaceError, DataError
 @Dumper.text(date)
 class DateDumper(Dumper):
 
-    oid = builtins["date"].oid
+    _oid = builtins["date"].oid
 
     def dump(self, obj: date) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
@@ -29,7 +29,7 @@ class DateDumper(Dumper):
 @Dumper.text(time)
 class TimeDumper(Dumper):
 
-    oid = builtins["timetz"].oid
+    _oid = builtins["timetz"].oid
 
     def dump(self, obj: time) -> bytes:
         return str(obj).encode("utf8")
@@ -38,7 +38,7 @@ class TimeDumper(Dumper):
 @Dumper.text(datetime)
 class DateTimeDumper(Dumper):
 
-    oid = builtins["timestamptz"].oid
+    _oid = builtins["timestamptz"].oid
 
     def dump(self, obj: date) -> bytes:
         # NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
@@ -49,7 +49,7 @@ class DateTimeDumper(Dumper):
 @Dumper.text(timedelta)
 class TimeDeltaDumper(Dumper):
 
-    oid = builtins["interval"].oid
+    _oid = builtins["interval"].oid
 
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
index 0e046626c04145010accedddfe5a498648fc33f2..3fe79d0c53218116d5d1d0beaffd916ebf179d52 100644 (file)
@@ -39,12 +39,12 @@ class _JsonDumper(Dumper):
 @Dumper.text(Json)
 @Dumper.binary(Json)
 class JsonDumper(_JsonDumper):
-    oid = builtins["json"].oid
+    _oid = builtins["json"].oid
 
 
 @Dumper.text(Jsonb)
 class JsonbDumper(_JsonDumper):
-    oid = builtins["jsonb"].oid
+    _oid = builtins["jsonb"].oid
 
 
 @Dumper.binary(Jsonb)
index 6fdd4f8b97fd107b60ac63c90f041a7f268433aa..a04e6117a9d28fc79d589d51d2eaee82794e6654 100644 (file)
@@ -29,7 +29,7 @@ ip_network: Callable[[str], Network]
 @Dumper.text("ipaddress.IPv6Interface")
 class InterfaceDumper(Dumper):
 
-    oid = builtins["inet"].oid
+    _oid = builtins["inet"].oid
 
     def dump(self, obj: Interface) -> bytes:
         return str(obj).encode("utf8")
@@ -39,7 +39,7 @@ class InterfaceDumper(Dumper):
 @Dumper.text("ipaddress.IPv6Network")
 class NetworkDumper(Dumper):
 
-    oid = builtins["cidr"].oid
+    _oid = builtins["cidr"].oid
 
     def dump(self, obj: Network) -> bytes:
         return str(obj).encode("utf8")
index 2a70a8e6f2f1914eaa0b52d0e3952f9bd76806d0..573df51975519a80777bad840111ea4c64239937 100644 (file)
@@ -79,7 +79,7 @@ class SpecialValuesDumper(NumberDumper):
 
 @Dumper.text(int)
 class IntDumper(NumberDumper):
-    oid = builtins["int8"].oid
+    _oid = builtins["int8"].oid
 
 
 @Dumper.binary(int)
@@ -90,7 +90,7 @@ class IntBinaryDumper(IntDumper):
 
 @Dumper.text(float)
 class FloatDumper(SpecialValuesDumper):
-    oid = builtins["float8"].oid
+    _oid = builtins["float8"].oid
 
     _special = {
         b"inf": b"'Infinity'::float8",
@@ -101,13 +101,15 @@ class FloatDumper(SpecialValuesDumper):
 
 @Dumper.binary(float)
 class FloatBinaryDumper(NumberDumper):
+    _oid = builtins["float8"].oid
+
     def dump(self, obj: float) -> bytes:
         return _pack_float8(obj)
 
 
 @Dumper.text(Decimal)
 class DecimalDumper(SpecialValuesDumper):
-    oid = builtins["numeric"].oid
+    _oid = builtins["numeric"].oid
 
     _special = {
         b"Infinity": b"'Infinity'::numeric",
@@ -118,22 +120,22 @@ class DecimalDumper(SpecialValuesDumper):
 
 @Dumper.text(Int2)
 class Int2Dumper(NumberDumper):
-    oid = builtins["int2"].oid
+    _oid = builtins["int2"].oid
 
 
 @Dumper.text(Int4)
 class Int4Dumper(NumberDumper):
-    oid = builtins["int4"].oid
+    _oid = builtins["int4"].oid
 
 
 @Dumper.text(Int8)
 class Int8Dumper(NumberDumper):
-    oid = builtins["int8"].oid
+    _oid = builtins["int8"].oid
 
 
 @Dumper.text(Oid)
 class OidDumper(NumberDumper):
-    oid = builtins["oid"].oid
+    _oid = builtins["oid"].oid
 
 
 @Dumper.binary(Int2)
index 3432a8d82e3748c2830a18598d959a3fe01f4a04..7b2d396ecf62a7d5d773aa5c44a6dc8de5f62253 100644 (file)
@@ -289,32 +289,32 @@ class DateTimeTZRange(Range[datetime]):
 
 @Dumper.text(Int4Range)
 class Int4RangeDumper(RangeDumper):
-    oid = builtins["int4range"].oid
+    _oid = builtins["int4range"].oid
 
 
 @Dumper.text(Int8Range)
 class Int8RangeDumper(RangeDumper):
-    oid = builtins["int8range"].oid
+    _oid = builtins["int8range"].oid
 
 
 @Dumper.text(DecimalRange)
 class NumRangeDumper(RangeDumper):
-    oid = builtins["numrange"].oid
+    _oid = builtins["numrange"].oid
 
 
 @Dumper.text(DateRange)
 class DateRangeDumper(RangeDumper):
-    oid = builtins["daterange"].oid
+    _oid = builtins["daterange"].oid
 
 
 @Dumper.text(DateTimeRange)
 class TimestampRangeDumper(RangeDumper):
-    oid = builtins["tsrange"].oid
+    _oid = builtins["tsrange"].oid
 
 
 @Dumper.text(DateTimeTZRange)
 class TimestampTZRangeDumper(RangeDumper):
-    oid = builtins["tstzrange"].oid
+    _oid = builtins["tstzrange"].oid
 
 
 # Loaders for builtin range types
@@ -407,7 +407,7 @@ class RangeInfo(TypeInfo):
 
         # generate and register a customized text dumper
         dumper: Type[Dumper] = type(
-            f"{self.name.title()}Dumper", (RangeDumper,), {"oid": self.oid}
+            f"{self.name.title()}Dumper", (RangeDumper,), {"_oid": self.oid}
         )
         dumper.register(range_class, context=context, format=Format.TEXT)
 
index dd374b47484b903c775ecea700dc852b1ad98603..4d695a92326b36832ee23f8f65b7b50be336e3ae 100644 (file)
@@ -11,7 +11,7 @@ from ..adapt import Dumper, Loader
 @Dumper.text(bool)
 class BoolDumper(Dumper):
 
-    oid = builtins["bool"].oid
+    _oid = builtins["bool"].oid
 
     def dump(self, obj: bool) -> bytes:
         return b"t" if obj else b"f"
@@ -23,7 +23,7 @@ class BoolDumper(Dumper):
 @Dumper.binary(bool)
 class BoolBinaryDumper(Dumper):
 
-    oid = builtins["bool"].oid
+    _oid = builtins["bool"].oid
 
     def dump(self, obj: bool) -> bytes:
         return b"\x01" if obj else b"\x00"
index 8315b367ab7f8016aac80e827b750e2a98bf8d92..dfc3667e5bd40d841d530c2624cea9995ae222de 100644 (file)
@@ -17,21 +17,24 @@ if TYPE_CHECKING:
 
 
 class _StringDumper(Dumper):
+
+    _encoding = "utf-8"
+
     def __init__(self, src: type, context: AdaptContext):
         super().__init__(src, context)
 
-        self.encoding = "utf-8"
-        if self.connection:
-            enc = self.connection.client_encoding
+        conn = self.connection
+        if conn:
+            enc = conn.client_encoding
             if enc != "ascii":
-                self.encoding = enc
+                self._encoding = enc
 
 
 @Dumper.binary(str)
 class StringBinaryDumper(_StringDumper):
     def dump(self, obj: str) -> bytes:
         # the server will raise DataError subclass if the string contains 0x00
-        return obj.encode(self.encoding)
+        return obj.encode(self._encoding)
 
 
 @Dumper.text(str)
@@ -42,7 +45,7 @@ class StringDumper(_StringDumper):
                 "PostgreSQL text fields cannot contain NUL (0x00) bytes"
             )
         else:
-            return obj.encode(self.encoding)
+            return obj.encode(self._encoding)
 
 
 @Loader.text(builtins["text"].oid)
@@ -51,21 +54,19 @@ class StringDumper(_StringDumper):
 @Loader.binary(builtins["varchar"].oid)
 @Loader.text(INVALID_OID)
 class TextLoader(Loader):
+
+    _encoding = "utf-8"
+
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-
-        if self.connection:
-            enc = self.connection.client_encoding
-            if enc != "ascii":
-                self.encoding = enc
-            else:
-                self.encoding = ""
-        else:
-            self.encoding = "utf-8"
+        conn = self.connection
+        if conn:
+            enc = conn.client_encoding
+            self._encoding = enc if enc != "ascii" else ""
 
     def load(self, data: bytes) -> Union[bytes, str]:
-        if self.encoding:
-            return data.decode(self.encoding)
+        if self._encoding:
+            return data.decode(self._encoding)
         else:
             # return bytes for SQL_ASCII db
             return data
@@ -76,32 +77,36 @@ class TextLoader(Loader):
 @Loader.text(builtins["bpchar"].oid)
 @Loader.binary(builtins["bpchar"].oid)
 class UnknownLoader(Loader):
+
+    _encoding = "utf-8"
+
     def __init__(self, oid: int, context: AdaptContext):
         super().__init__(oid, context)
-        self.encoding = (
-            self.connection.client_encoding if self.connection else "utf-8"
-        )
+        conn = self.connection
+        if conn:
+            self._encoding = conn.client_encoding
 
     def load(self, data: bytes) -> str:
-        return data.decode(self.encoding)
+        return data.decode(self._encoding)
 
 
 @Dumper.text(bytes)
 @Dumper.text(bytearray)
 @Dumper.text(memoryview)
 class BytesDumper(Dumper):
-    oid = builtins["bytea"].oid
+
+    _oid = builtins["bytea"].oid
 
     def __init__(self, src: type, context: AdaptContext = None):
         super().__init__(src, context)
-        self.esc = Escaping(
+        self._esc = Escaping(
             self.connection.pgconn if self.connection else None
         )
 
     def dump(self, obj: bytes) -> memoryview:
         # TODO: mypy doesn't complain, but this function has the wrong signature
         # probably dump return value should be extended to Buffer
-        return self.esc.escape_bytea(obj)
+        return self._esc.escape_bytea(obj)
 
 
 @Dumper.binary(bytes)
@@ -109,7 +114,7 @@ class BytesDumper(Dumper):
 @Dumper.binary(memoryview)
 class BytesBinaryDumper(Dumper):
 
-    oid = builtins["bytea"].oid
+    _oid = builtins["bytea"].oid
 
     def dump(
         self, obj: Union[bytes, bytearray, memoryview]
index eb50a1eb84a67932cb67c3806fc315f6ba7e6781..040e51b927c5aff80441266f54425a2fac767b81 100644 (file)
@@ -20,7 +20,7 @@ UUID: Callable[..., "uuid.UUID"]
 @Dumper.text("uuid.UUID")
 class UUIDDumper(Dumper):
 
-    oid = builtins["uuid"].oid
+    _oid = builtins["uuid"].oid
 
     def dump(self, obj: "uuid.UUID") -> bytes:
         return obj.hex.encode("utf8")
index 3de5919cb2991d8435a7c8051776a1feaf8e7284..e04a70aad347607656606af6c4755fd0cf1bb863 100644 (file)
@@ -19,6 +19,7 @@ from cpython.bytes cimport PyBytes_AsStringAndSize
 from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
 from cpython.bytearray cimport PyByteArray_AS_STRING
 
+from psycopg3_c cimport oids
 from psycopg3_c cimport libpq as impl
 from psycopg3_c.adapt cimport cloader_func, get_context_func
 from psycopg3_c.pq_cython cimport Escaping, _buffer_as_string_and_size
@@ -45,7 +46,16 @@ cdef class CDumper:
         self._pgconn = (
             self.connection.pgconn if self.connection is not None else None
         )
-        # oid is implicitly set to 0, subclasses may override it
+
+        # default oid is implicitly set to 0, subclasses may override it
+        # PG 9.6 goes a bit bonker sending unknown oids, so use text instead
+        # (this does cause side effect, and requres casts more often than >= 10)
+        if (
+            self.oid == 0
+            and self._pgconn is not None
+            and self._pgconn.server_version < 100000
+        ):
+            self.oid = oids.TEXT_OID
 
     def dump(self, obj: Any) -> bytes:
         raise NotImplementedError()
index 745421445fd1cc752a98f506564ada78abd7d505..33dc4bdc61f3e3341c74b17fe9032a7729fddb37 100644 (file)
@@ -1,10 +1,7 @@
 import pytest
 
-import psycopg3
 from psycopg3.adapt import Transformer, Format, Dumper, Loader
-from psycopg3.oids import builtins
-
-TEXT_OID = builtins["text"].oid
+from psycopg3.oids import builtins, TEXT_OID
 
 
 @pytest.mark.parametrize(
@@ -137,15 +134,10 @@ def test_load_cursor_ctx_nested(conn, sql, obj, fmt_out):
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
 def test_none_type_argument(conn, fmt_in):
     cur = conn.cursor()
+    cur.execute("create table none_args (id serial primary key, num integer)")
+    cast = "" if conn.pgconn.server_version >= 100000 else "::int"
     cur.execute(
-        """
-        create table test_none_type_argument (
-            id serial primary key, num integer
-        )
-        """
-    )
-    cur.execute(
-        "insert into test_none_type_argument (num) values (%s) returning id",
+        f"insert into none_args (num) values (%s{cast}) returning id",
         (None,),
     )
     assert cur.fetchone()[0]
@@ -158,20 +150,26 @@ def test_return_untyped(conn, fmt_in):
     # Currently string are passed as unknown oid to libpq. This is because
     # unknown is more easily cast by postgres to different types (see jsonb
     # later). However Postgres < 10 refuses to emit unknown types.
-    if conn.pgconn.server_version > 100000:
+    if conn.pgconn.server_version >= 100000:
         cur.execute("select %s, %s", ["hello", 10])
         assert cur.fetchone() == ("hello", 10)
     else:
-        with pytest.raises(psycopg3.errors.IndeterminateDatatype):
-            cur.execute("select %s, %s", ["hello", 10])
-        conn.rollback()
-        cur.execute("select %s::text, %s", ["hello", 10])
+        # We used to tolerate an error on roundtrip for unknown on pg < 10
+        # however after introducing prepared statements the error happens
+        # in every context, so now we cannot just use unknown oid on PG < 10
+        # with pytest.raises(psycopg3.errors.IndeterminateDatatype):
+        #     cur.execute("select %s, %s", ["hello", 10])
+        # conn.rollback()
+        # cur.execute("select %s::text, %s", ["hello", 10])
+        cur.execute("select %s, %s", ["hello", 10])
         assert cur.fetchone() == ("hello", 10)
 
     # It would be nice if above all postgres version behaved consistently.
     # However this below shouldn't break either.
+    # (unfortunately it does: a cast is required for pre 10 versions)
+    cast = "" if conn.pgconn.server_version >= 100000 else "::jsonb"
     cur.execute("create table testjson(data jsonb)")
-    cur.execute("insert into testjson (data) values (%s)", ["{}"])
+    cur.execute(f"insert into testjson (data) values (%s{cast})", ["{}"])
     assert cur.execute("select data from testjson").fetchone() == ({},)
 
 
index 96f851491debb808ee49424e4646e321d1bce4e8..0ed57b41980435c6f4f27dc8a72b41546798c559 100644 (file)
@@ -181,3 +181,18 @@ def test_different_types(conn):
         prepare=False,
     )
     assert cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+
+
+def test_untyped_json(conn):
+    conn.prepare_threshold = 1
+    conn.execute("create table testjson(data jsonb)")
+    if conn.pgconn.server_version >= 100000:
+        cast, t = "", "jsonb"
+    else:
+        cast, t = "::jsonb", "text"
+
+    for i in range(2):
+        conn.execute(f"insert into testjson (data) values (%s{cast})", ["{}"])
+
+    cur = conn.execute("select parameter_types from pg_prepared_statements")
+    assert cur.fetchall() == [([t],)]
index 3a4c573a6cf655b4aceca3ad0938a1a22fb48bc1..a6663d63c6ce502f92cec3392b79ff53d2e74230 100644 (file)
@@ -193,3 +193,22 @@ async def test_different_types(aconn):
         prepare=False,
     )
     assert await cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+
+
+async def test_untyped_json(aconn):
+    aconn.prepare_threshold = 1
+    await aconn.execute("create table testjson(data jsonb)")
+    if aconn.pgconn.server_version >= 100000:
+        cast, t = "", "jsonb"
+    else:
+        cast, t = "::jsonb", "text"
+
+    for i in range(2):
+        await aconn.execute(
+            f"insert into testjson (data) values (%s{cast})", ["{}"]
+        )
+
+    cur = await aconn.execute(
+        "select parameter_types from pg_prepared_statements"
+    )
+    assert await cur.fetchall() == [([t],)]