from . import errors as e
from .pq import Format
from .sql import Composable
+from .oids import TEXT_OID, INVALID_OID
from .proto import Query, Params
if TYPE_CHECKING:
_parts: List[QueryPart]
_query = b""
+ _unknown_oid = INVALID_OID
def __init__(self, transformer: "Transformer"):
self._tx = transformer
self.formats: Optional[List[Format]] = None
self._order: Optional[List[str]] = None
+ if (
+ self._tx.connection
+ and self._tx.connection.pgconn.server_version < 100000
+ ):
+ self._unknown_oid = TEXT_OID
def convert(self, query: Query, vars: Optional[Params]) -> None:
"""
ts.append(dumper.oid)
else:
ps.append(None)
- ts.append(0)
+ ts.append(self._unknown_oid)
self.types = tuple(ts)
else:
self.params = None
from . import errors as e
from .pq import Format
-from .oids import builtins, INVALID_OID
+from .oids import INVALID_OID
from .proto import AdaptContext, DumpersMap
from .proto import LoadFunc, LoadersMap
from .cursor import BaseCursor
from .pq.proto import PGresult
from .adapt import Dumper, Loader
-TEXT_OID = builtins["text"].oid
-
class Transformer:
"""
from . import pq
from . import proto
from .pq import Format as Format
-from .oids import builtins
+from .oids import TEXT_OID
from .proto import AdaptContext, DumpersMap, DumperType, LoadersMap, LoaderType
from .cursor import BaseCursor
from .connection import BaseConnection
-TEXT_OID = builtins["text"].oid
-
class Dumper(ABC):
"""
globals: DumpersMap = {}
connection: Optional[BaseConnection]
+ # A class-wide oid, which will be used by default by instances unless
+ # the subclass overrides it in init.
+ _oid: int = 0
+
def __init__(self, src: type, context: AdaptContext = None):
self.src = src
self.context = context
self.connection = connection_from_context(context)
+ self.oid = self._oid
+ """The oid to pass to the server, if known."""
+
+ # Postgres 9.6 doesn't deal well with unknown oids
+ if (
+ not self.oid
+ and self.connection
+ and self.connection.pgconn.server_version < 100000
+ ):
+ self.oid = TEXT_OID
@abstractmethod
def dump(self, obj: Any) -> bytes:
esc = pq.Escaping()
return b"'%s'" % esc.escape_string(value)
- @property
- def oid(self) -> int:
- """The oid to pass to the server, if known."""
- return 0
-
@classmethod
def register(
cls,
from typing import Dict, Iterator, Optional, Union
-INVALID_OID = 0
-
class TypeInfo:
def __init__(self, name: str, oid: int, array_oid: int):
# fmt: on
]:
builtins.add(BuiltinTypeInfo(*r))
+
+
+# A few oids used a bit everywhere
+INVALID_OID = 0
+TEXT_OID = builtins["text"].oid
+TEXT_ARRAY_OID = builtins["text"].array_oid
from typing import Any, Iterator, List, Optional, Type
from .. import errors as e
-from ..oids import builtins
+from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID
from ..adapt import Format, Dumper, Loader, Transformer
from ..proto import AdaptContext
-TEXT_OID = builtins["text"].oid
-TEXT_ARRAY_OID = builtins["text"].array_oid
-
class BaseListDumper(Dumper):
+
+ _oid = TEXT_ARRAY_OID
+
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
self._tx = Transformer(context)
- self._array_oid = 0
-
- @property
- def oid(self) -> int:
- return self._array_oid or TEXT_ARRAY_OID
def _get_array_oid(self, base_oid: int) -> int:
"""
dump_list(obj)
if oid:
- self._array_oid = self._get_array_oid(oid)
+ self.oid = self._get_array_oid(oid)
return b"".join(tokens)
if not oid:
oid = TEXT_OID
- self._array_oid = self._get_array_oid(oid)
+ self.oid = self._get_array_oid(oid)
data[0] = _struct_head.pack(len(dims), hasnull, oid)
data[1] = b"".join(_struct_dim.pack(dim, 1) for dim in dims)
from .. import sql
from .. import errors as e
-from ..oids import builtins, TypeInfo
+from ..oids import builtins, TypeInfo, TEXT_OID
from ..adapt import Format, Dumper, Loader, Transformer
from ..proto import AdaptContext
from . import array
from ..connection import Connection, AsyncConnection
-TEXT_OID = builtins["text"].oid
-
-
class CompositeInfo(TypeInfo):
"""Manage information about a composite type.
@Dumper.text(tuple)
class TupleDumper(SequenceDumper):
+
+ # Should be this, but it doesn't work
+ # _oid = builtins["record"].oid
+
def dump(self, obj: Tuple[Any, ...]) -> bytes:
return self._dump_sequence(obj, b"(", b")", b",")
@Dumper.text(date)
class DateDumper(Dumper):
- oid = builtins["date"].oid
+ _oid = builtins["date"].oid
def dump(self, obj: date) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
@Dumper.text(time)
class TimeDumper(Dumper):
- oid = builtins["timetz"].oid
+ _oid = builtins["timetz"].oid
def dump(self, obj: time) -> bytes:
return str(obj).encode("utf8")
@Dumper.text(datetime)
class DateTimeDumper(Dumper):
- oid = builtins["timestamptz"].oid
+ _oid = builtins["timestamptz"].oid
def dump(self, obj: date) -> bytes:
# NOTE: whatever the PostgreSQL DateStyle input format (DMY, MDY, YMD)
@Dumper.text(timedelta)
class TimeDeltaDumper(Dumper):
- oid = builtins["interval"].oid
+ _oid = builtins["interval"].oid
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
@Dumper.text(Json)
@Dumper.binary(Json)
class JsonDumper(_JsonDumper):
- oid = builtins["json"].oid
+ _oid = builtins["json"].oid
@Dumper.text(Jsonb)
class JsonbDumper(_JsonDumper):
- oid = builtins["jsonb"].oid
+ _oid = builtins["jsonb"].oid
@Dumper.binary(Jsonb)
@Dumper.text("ipaddress.IPv6Interface")
class InterfaceDumper(Dumper):
- oid = builtins["inet"].oid
+ _oid = builtins["inet"].oid
def dump(self, obj: Interface) -> bytes:
return str(obj).encode("utf8")
@Dumper.text("ipaddress.IPv6Network")
class NetworkDumper(Dumper):
- oid = builtins["cidr"].oid
+ _oid = builtins["cidr"].oid
def dump(self, obj: Network) -> bytes:
return str(obj).encode("utf8")
@Dumper.text(int)
class IntDumper(NumberDumper):
- oid = builtins["int8"].oid
+ _oid = builtins["int8"].oid
@Dumper.binary(int)
@Dumper.text(float)
class FloatDumper(SpecialValuesDumper):
- oid = builtins["float8"].oid
+ _oid = builtins["float8"].oid
_special = {
b"inf": b"'Infinity'::float8",
@Dumper.binary(float)
class FloatBinaryDumper(NumberDumper):
+ _oid = builtins["float8"].oid
+
def dump(self, obj: float) -> bytes:
return _pack_float8(obj)
@Dumper.text(Decimal)
class DecimalDumper(SpecialValuesDumper):
- oid = builtins["numeric"].oid
+ _oid = builtins["numeric"].oid
_special = {
b"Infinity": b"'Infinity'::numeric",
@Dumper.text(Int2)
class Int2Dumper(NumberDumper):
- oid = builtins["int2"].oid
+ _oid = builtins["int2"].oid
@Dumper.text(Int4)
class Int4Dumper(NumberDumper):
- oid = builtins["int4"].oid
+ _oid = builtins["int4"].oid
@Dumper.text(Int8)
class Int8Dumper(NumberDumper):
- oid = builtins["int8"].oid
+ _oid = builtins["int8"].oid
@Dumper.text(Oid)
class OidDumper(NumberDumper):
- oid = builtins["oid"].oid
+ _oid = builtins["oid"].oid
@Dumper.binary(Int2)
@Dumper.text(Int4Range)
class Int4RangeDumper(RangeDumper):
- oid = builtins["int4range"].oid
+ _oid = builtins["int4range"].oid
@Dumper.text(Int8Range)
class Int8RangeDumper(RangeDumper):
- oid = builtins["int8range"].oid
+ _oid = builtins["int8range"].oid
@Dumper.text(DecimalRange)
class NumRangeDumper(RangeDumper):
- oid = builtins["numrange"].oid
+ _oid = builtins["numrange"].oid
@Dumper.text(DateRange)
class DateRangeDumper(RangeDumper):
- oid = builtins["daterange"].oid
+ _oid = builtins["daterange"].oid
@Dumper.text(DateTimeRange)
class TimestampRangeDumper(RangeDumper):
- oid = builtins["tsrange"].oid
+ _oid = builtins["tsrange"].oid
@Dumper.text(DateTimeTZRange)
class TimestampTZRangeDumper(RangeDumper):
- oid = builtins["tstzrange"].oid
+ _oid = builtins["tstzrange"].oid
# Loaders for builtin range types
# generate and register a customized text dumper
dumper: Type[Dumper] = type(
- f"{self.name.title()}Dumper", (RangeDumper,), {"oid": self.oid}
+ f"{self.name.title()}Dumper", (RangeDumper,), {"_oid": self.oid}
)
dumper.register(range_class, context=context, format=Format.TEXT)
@Dumper.text(bool)
class BoolDumper(Dumper):
- oid = builtins["bool"].oid
+ _oid = builtins["bool"].oid
def dump(self, obj: bool) -> bytes:
return b"t" if obj else b"f"
@Dumper.binary(bool)
class BoolBinaryDumper(Dumper):
- oid = builtins["bool"].oid
+ _oid = builtins["bool"].oid
def dump(self, obj: bool) -> bytes:
return b"\x01" if obj else b"\x00"
class _StringDumper(Dumper):
+
+ _encoding = "utf-8"
+
def __init__(self, src: type, context: AdaptContext):
super().__init__(src, context)
- self.encoding = "utf-8"
- if self.connection:
- enc = self.connection.client_encoding
+ conn = self.connection
+ if conn:
+ enc = conn.client_encoding
if enc != "ascii":
- self.encoding = enc
+ self._encoding = enc
@Dumper.binary(str)
class StringBinaryDumper(_StringDumper):
def dump(self, obj: str) -> bytes:
# the server will raise DataError subclass if the string contains 0x00
- return obj.encode(self.encoding)
+ return obj.encode(self._encoding)
@Dumper.text(str)
"PostgreSQL text fields cannot contain NUL (0x00) bytes"
)
else:
- return obj.encode(self.encoding)
+ return obj.encode(self._encoding)
@Loader.text(builtins["text"].oid)
@Loader.binary(builtins["varchar"].oid)
@Loader.text(INVALID_OID)
class TextLoader(Loader):
+
+ _encoding = "utf-8"
+
def __init__(self, oid: int, context: AdaptContext):
super().__init__(oid, context)
-
- if self.connection:
- enc = self.connection.client_encoding
- if enc != "ascii":
- self.encoding = enc
- else:
- self.encoding = ""
- else:
- self.encoding = "utf-8"
+ conn = self.connection
+ if conn:
+ enc = conn.client_encoding
+ self._encoding = enc if enc != "ascii" else ""
def load(self, data: bytes) -> Union[bytes, str]:
- if self.encoding:
- return data.decode(self.encoding)
+ if self._encoding:
+ return data.decode(self._encoding)
else:
# return bytes for SQL_ASCII db
return data
@Loader.text(builtins["bpchar"].oid)
@Loader.binary(builtins["bpchar"].oid)
class UnknownLoader(Loader):
+
+ _encoding = "utf-8"
+
def __init__(self, oid: int, context: AdaptContext):
super().__init__(oid, context)
- self.encoding = (
- self.connection.client_encoding if self.connection else "utf-8"
- )
+ conn = self.connection
+ if conn:
+ self._encoding = conn.client_encoding
def load(self, data: bytes) -> str:
- return data.decode(self.encoding)
+ return data.decode(self._encoding)
@Dumper.text(bytes)
@Dumper.text(bytearray)
@Dumper.text(memoryview)
class BytesDumper(Dumper):
- oid = builtins["bytea"].oid
+
+ _oid = builtins["bytea"].oid
def __init__(self, src: type, context: AdaptContext = None):
super().__init__(src, context)
- self.esc = Escaping(
+ self._esc = Escaping(
self.connection.pgconn if self.connection else None
)
def dump(self, obj: bytes) -> memoryview:
# TODO: mypy doesn't complain, but this function has the wrong signature
# probably dump return value should be extended to Buffer
- return self.esc.escape_bytea(obj)
+ return self._esc.escape_bytea(obj)
@Dumper.binary(bytes)
@Dumper.binary(memoryview)
class BytesBinaryDumper(Dumper):
- oid = builtins["bytea"].oid
+ _oid = builtins["bytea"].oid
def dump(
self, obj: Union[bytes, bytearray, memoryview]
@Dumper.text("uuid.UUID")
class UUIDDumper(Dumper):
- oid = builtins["uuid"].oid
+ _oid = builtins["uuid"].oid
def dump(self, obj: "uuid.UUID") -> bytes:
return obj.hex.encode("utf8")
from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
from cpython.bytearray cimport PyByteArray_AS_STRING
+from psycopg3_c cimport oids
from psycopg3_c cimport libpq as impl
from psycopg3_c.adapt cimport cloader_func, get_context_func
from psycopg3_c.pq_cython cimport Escaping, _buffer_as_string_and_size
self._pgconn = (
self.connection.pgconn if self.connection is not None else None
)
- # oid is implicitly set to 0, subclasses may override it
+
+ # default oid is implicitly set to 0, subclasses may override it
+ # PG 9.6 goes a bit bonker sending unknown oids, so use text instead
+ # (this does cause side effect, and requres casts more often than >= 10)
+ if (
+ self.oid == 0
+ and self._pgconn is not None
+ and self._pgconn.server_version < 100000
+ ):
+ self.oid = oids.TEXT_OID
def dump(self, obj: Any) -> bytes:
raise NotImplementedError()
import pytest
-import psycopg3
from psycopg3.adapt import Transformer, Format, Dumper, Loader
-from psycopg3.oids import builtins
-
-TEXT_OID = builtins["text"].oid
+from psycopg3.oids import builtins, TEXT_OID
@pytest.mark.parametrize(
@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
def test_none_type_argument(conn, fmt_in):
cur = conn.cursor()
+ cur.execute("create table none_args (id serial primary key, num integer)")
+ cast = "" if conn.pgconn.server_version >= 100000 else "::int"
cur.execute(
- """
- create table test_none_type_argument (
- id serial primary key, num integer
- )
- """
- )
- cur.execute(
- "insert into test_none_type_argument (num) values (%s) returning id",
+ f"insert into none_args (num) values (%s{cast}) returning id",
(None,),
)
assert cur.fetchone()[0]
# Currently string are passed as unknown oid to libpq. This is because
# unknown is more easily cast by postgres to different types (see jsonb
# later). However Postgres < 10 refuses to emit unknown types.
- if conn.pgconn.server_version > 100000:
+ if conn.pgconn.server_version >= 100000:
cur.execute("select %s, %s", ["hello", 10])
assert cur.fetchone() == ("hello", 10)
else:
- with pytest.raises(psycopg3.errors.IndeterminateDatatype):
- cur.execute("select %s, %s", ["hello", 10])
- conn.rollback()
- cur.execute("select %s::text, %s", ["hello", 10])
+ # We used to tolerate an error on roundtrip for unknown on pg < 10
+ # however after introducing prepared statements the error happens
+ # in every context, so now we cannot just use unknown oid on PG < 10
+ # with pytest.raises(psycopg3.errors.IndeterminateDatatype):
+ # cur.execute("select %s, %s", ["hello", 10])
+ # conn.rollback()
+ # cur.execute("select %s::text, %s", ["hello", 10])
+ cur.execute("select %s, %s", ["hello", 10])
assert cur.fetchone() == ("hello", 10)
# It would be nice if above all postgres version behaved consistently.
# However this below shouldn't break either.
+ # (unfortunately it does: a cast is required for pre 10 versions)
+ cast = "" if conn.pgconn.server_version >= 100000 else "::jsonb"
cur.execute("create table testjson(data jsonb)")
- cur.execute("insert into testjson (data) values (%s)", ["{}"])
+ cur.execute(f"insert into testjson (data) values (%s{cast})", ["{}"])
assert cur.execute("select data from testjson").fetchone() == ({},)
prepare=False,
)
assert cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+
+
+def test_untyped_json(conn):
+ conn.prepare_threshold = 1
+ conn.execute("create table testjson(data jsonb)")
+ if conn.pgconn.server_version >= 100000:
+ cast, t = "", "jsonb"
+ else:
+ cast, t = "::jsonb", "text"
+
+ for i in range(2):
+ conn.execute(f"insert into testjson (data) values (%s{cast})", ["{}"])
+
+ cur = conn.execute("select parameter_types from pg_prepared_statements")
+ assert cur.fetchall() == [([t],)]
prepare=False,
)
assert await cur.fetchall() == [(["text"],), (["date"],), (["bigint"],)]
+
+
+async def test_untyped_json(aconn):
+ aconn.prepare_threshold = 1
+ await aconn.execute("create table testjson(data jsonb)")
+ if aconn.pgconn.server_version >= 100000:
+ cast, t = "", "jsonb"
+ else:
+ cast, t = "::jsonb", "text"
+
+ for i in range(2):
+ await aconn.execute(
+ f"insert into testjson (data) values (%s{cast})", ["{}"]
+ )
+
+ cur = await aconn.execute(
+ "select parameter_types from pg_prepared_statements"
+ )
+ assert await cur.fetchall() == [([t],)]