from . import proto
from ._enums import Format as Format
from .oids import builtins
-from .proto import AdaptContext
+from .proto import AdaptContext, Buffer as Buffer
if TYPE_CHECKING:
from .connection import BaseConnection
"""The oid to pass to the server, if known."""
@abstractmethod
- def dump(self, obj: Any) -> bytes:
+ def dump(self, obj: Any) -> Buffer:
"""Convert the object *obj* to PostgreSQL representation."""
...
- # TODO: the protocol signature should probably return a Buffer like object
- # (the C implementation may return bytearray)
- def quote(self, obj: Any) -> bytes:
+ def quote(self, obj: Any) -> Buffer:
"""Convert the object *obj* to escaped representation."""
value = self.dump(obj)
)
@abstractmethod
- def load(self, data: bytes) -> Any:
+ def load(self, data: Buffer) -> Any:
"""Convert a PostgreSQL value to a Python object."""
...
Return None if not found.
"""
- # TODO: auto selection
if format == Format.AUTO:
dmaps = [
self._dumpers[pq.Format.BINARY],
from .waiting import Wait, Ready
from .sql import Composable
+# An object implementing the buffer protocol
+Buffer = Union[bytes, bytearray, memoryview]
+
Query = Union[str, bytes, "Composable"]
Params = Union[Sequence[Any], Mapping[str, Any]]
ConnectionType = TypeVar("ConnectionType", bound="BaseConnection")
from .._enums import Format
from .. import errors as e
from ..oids import builtins, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
-from ..adapt import Dumper, Loader, Transformer
+from ..adapt import Buffer, Dumper, Loader, Transformer
from ..proto import AdaptContext
"""
)
- def load(self, data: bytes) -> List[Any]:
+ def load(self, data: Buffer) -> List[Any]:
rv = None
stack: List[Any] = []
cast = self._tx.get_loader(self.base_oid, self.format).load
format = pq.Format.BINARY
- def load(self, data: bytes) -> List[Any]:
+ def load(self, data: Buffer) -> List[Any]:
ndims, hasnull, oid = _struct_head.unpack_from(data[:12])
if not ndims:
return []
from .. import sql
from .. import errors as e
from ..oids import TypeInfo, TEXT_OID
-from ..adapt import Format, Dumper, Loader, Transformer
+from ..adapt import Buffer, Format, Dumper, Loader, Transformer
from ..proto import AdaptContext
from . import array
class RecordLoader(BaseCompositeLoader):
- def load(self, data: bytes) -> Tuple[Any, ...]:
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
if data == b"()":
return ()
super().__init__(oid, context)
self._tx = Transformer(context)
- def load(self, data: bytes) -> Tuple[Any, ...]:
+ def load(self, data: Buffer) -> Tuple[Any, ...]:
if not self._types_set:
self._config_types(data)
self._types_set = True
fields_types: List[int]
_types_set = False
- def load(self, data: bytes) -> Any:
+ def load(self, data: Buffer) -> Any:
if not self._types_set:
self._config_types(data)
self._types_set = True
format = pq.Format.BINARY
factory: Callable[..., Any]
- def load(self, data: bytes) -> Any:
+ def load(self, data: Buffer) -> Any:
r = super().load(data)
return type(self).factory(*r)
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..proto import AdaptContext
from ..errors import InterfaceError, DataError
super().__init__(oid, context)
self._format = self._format_from_context()
- def load(self, data: bytes) -> date:
+ def load(self, data: Buffer) -> date:
+ if isinstance(data, memoryview):
+ data = bytes(data)
try:
return datetime.strptime(data.decode("utf8"), self._format).date()
except ValueError as e:
_format = "%H:%M:%S.%f"
_format_no_micro = _format.replace(".%f", "")
- def load(self, data: bytes) -> time:
+ def load(self, data: Buffer) -> time:
# check if the data contains microseconds
+ if isinstance(data, memoryview):
+ data = bytes(data)
fmt = self._format if b"." in data else self._format_no_micro
try:
return datetime.strptime(data.decode("utf8"), fmt).time()
super().__init__(oid, context)
- def load(self, data: bytes) -> time:
+ def load(self, data: Buffer) -> time:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
# Hack to convert +HH in +HHMM
if data[-3] in (43, 45):
data += b"00"
return dt.time().replace(tzinfo=dt.tzinfo)
- def _load_py36(self, data: bytes) -> time:
+ def _load_py36(self, data: Buffer) -> time:
+ if isinstance(data, memoryview):
+ data = bytes(data)
# Drop seconds from timezone for Python 3.6
# Also, Python 3.6 doesn't support HHMM, only HH:MM
if data[-6] in (43, 45): # +-HH:MM -> +-HHMM
super().__init__(oid, context)
self._format_no_micro = self._format.replace(".%f", "")
- def load(self, data: bytes) -> datetime:
+ def load(self, data: Buffer) -> datetime:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
# check if the data contains microseconds
fmt = (
self._format if data.find(b".", 19) >= 0 else self._format_no_micro
setattr(self, "load", self._load_notimpl)
return ""
- def load(self, data: bytes) -> datetime:
+ def load(self, data: Buffer) -> datetime:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
# Hack to convert +HH in +HHMM
if data[-3] in (43, 45):
data += b"00"
return super().load(data)
- def _load_py36(self, data: bytes) -> datetime:
+ def _load_py36(self, data: Buffer) -> datetime:
+ if isinstance(data, memoryview):
+ data = bytes(data)
# Drop seconds from timezone for Python 3.6
# Also, Python 3.6 doesn't support HHMM, only HH:MM
tzsep = (43, 45) # + and - bytes
return super().load(data)
- def _load_notimpl(self, data: bytes) -> datetime:
+ def _load_notimpl(self, data: Buffer) -> datetime:
+ if isinstance(data, memoryview):
+ data = bytes(data)
raise NotImplementedError(
"can't parse datetimetz with DateStyle"
f" {self._get_datestyle().decode('ascii')}: {data.decode('ascii')}"
if ints != b"postgres":
setattr(self, "load", self._load_notimpl)
- def load(self, data: bytes) -> timedelta:
+ def load(self, data: Buffer) -> timedelta:
m = self._re_interval.match(data)
if not m:
raise ValueError("can't parse interval: {data.decode('ascii')}")
except OverflowError as e:
raise DataError(str(e))
- def _load_notimpl(self, data: bytes) -> timedelta:
+ def _load_notimpl(self, data: Buffer) -> timedelta:
+ if isinstance(data, memoryview):
+ data = bytes(data)
ints = (
self.connection
and self.connection.pgconn.parameter_status(b"IntervalStyle")
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..errors import DataError
JsonDumpsFunction = Callable[[Any], str]
format = Format.TEXT
- def load(self, data: bytes) -> Any:
+ def load(self, data: Buffer) -> Any:
+ # Json crashes on memoryview
+ if isinstance(data, memoryview):
+ data = bytes(data)
return json.loads(data)
format = Format.BINARY
- def load(self, data: bytes) -> Any:
+ def load(self, data: Buffer) -> Any:
if data and data[0] != 1:
raise DataError("unknown jsonb binary format: {data[0]}")
- return json.loads(data[1:])
+ data = data[1:]
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return json.loads(data)
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..proto import AdaptContext
if TYPE_CHECKING:
format = Format.TEXT
- def load(self, data: bytes) -> Union[Address, Interface]:
+ def load(self, data: Buffer) -> Union[Address, Interface]:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
if b"/" in data:
return ip_interface(data.decode("utf8"))
else:
format = Format.TEXT
- def load(self, data: bytes) -> Network:
+ def load(self, data: Buffer) -> Network:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+
return ip_network(data.decode("utf8"))
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
_PackInt = Callable[[int], bytes]
_PackFloat = Callable[[float], bytes]
format = Format.TEXT
- def load(self, data: bytes) -> int:
+ def load(self, data: Buffer) -> int:
# it supports bytes directly
return int(data)
format = Format.BINARY
- def load(self, data: bytes) -> int:
+ def load(self, data: Buffer) -> int:
return _unpack_int2(data)[0]
format = Format.BINARY
- def load(self, data: bytes) -> int:
+ def load(self, data: Buffer) -> int:
return _unpack_int4(data)[0]
format = Format.BINARY
- def load(self, data: bytes) -> int:
+ def load(self, data: Buffer) -> int:
return _unpack_int8(data)[0]
format = Format.BINARY
- def load(self, data: bytes) -> int:
+ def load(self, data: Buffer) -> int:
return _unpack_uint4(data)[0]
format = Format.TEXT
- def load(self, data: bytes) -> float:
+ def load(self, data: Buffer) -> float:
# it supports bytes directly
return float(data)
format = Format.BINARY
- def load(self, data: bytes) -> float:
+ def load(self, data: Buffer) -> float:
return _unpack_float4(data)[0]
format = Format.BINARY
- def load(self, data: bytes) -> float:
+ def load(self, data: Buffer) -> float:
return _unpack_float8(data)[0]
format = Format.TEXT
- def load(self, data: bytes) -> Decimal:
+ def load(self, data: Buffer) -> Decimal:
+ if isinstance(data, memoryview):
+ data = bytes(data)
return Decimal(data.decode("utf8"))
from .. import errors as e
from ..pq import Format
from ..oids import builtins, TypeInfo
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..proto import AdaptContext
from . import array
subtype_oid: int
cls: Type[Range[T]]
- def load(self, data: bytes) -> Range[T]:
+ def load(self, data: Buffer) -> Range[T]:
if data == b"empty":
return self.cls(empty=True)
cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
- bounds = (data[:1] + data[-1:]).decode("utf-8")
+ bounds = _int2parens[data[0]] + _int2parens[data[-1]]
min, max = (
cast(token) if token is not None else None
for token in self._parse_record(data[1:-1])
return self.cls(min, max, bounds)
+_int2parens = {ord(c): c for c in "[]()"}
+
+
# Python wrappers for builtin range types
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
class BoolDumper(Dumper):
format = Format.TEXT
- def load(self, data: bytes) -> bool:
+ def load(self, data: Buffer) -> bool:
return data == b"t"
format = Format.BINARY
- def load(self, data: bytes) -> bool:
+ def load(self, data: Buffer) -> bool:
return data != b"\x00"
from ..pq import Format, Escaping
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..proto import AdaptContext
from ..errors import DataError
enc = conn.client_encoding
self._encoding = enc if enc != "ascii" else ""
- def load(self, data: bytes) -> Union[bytes, str]:
+ def load(self, data: Buffer) -> Union[bytes, str]:
if self._encoding:
if isinstance(data, memoryview):
return bytes(data).decode(self._encoding)
if not hasattr(self.__class__, "_escaping"):
self.__class__._escaping = Escaping()
- def load(self, data: bytes) -> bytes:
+ def load(self, data: Buffer) -> bytes:
return self._escaping.unescape_bytea(data)
format = Format.BINARY
- def load(self, data: bytes) -> bytes:
+ def load(self, data: Buffer) -> bytes:
return data
from ..pq import Format
from ..oids import builtins
-from ..adapt import Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader
from ..proto import AdaptContext
if TYPE_CHECKING:
imported = True
- def load(self, data: bytes) -> "uuid.UUID":
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
return UUID(data.decode("utf8"))
format = Format.BINARY
- def load(self, data: bytes) -> "uuid.UUID":
+ def load(self, data: Buffer) -> "uuid.UUID":
+ if isinstance(data, memoryview):
+ data = bytes(data)
return UUID(bytes=data)
if attval.len == -1: # NULL_LEN
pyval = None
else:
- # TODO: no copy
- b = attval.value[:attval.len]
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
pyval = PyObject_CallFunctionObjArgs(
(<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
pyval = (<RowLoader>loader).cloader.cload(
attval.value, attval.len)
else:
- # TODO: no copy
- b = attval.value[:attval.len]
+ b = PyMemoryView_FromObject(
+ ViewBuffer._from_buffer(
+ self._pgresult,
+ <unsigned char *>attval.value, attval.len))
pyval = PyObject_CallFunctionObjArgs(
(<RowLoader>loader).loadfunc, <PyObject *>b, NULL)
class MyTextLoader(TextLoader):
def load(self, data):
- return (data * 2).decode("utf-8")
+ return (bytes(data) * 2).decode("utf-8")
MyTextLoader.register("text", conn)
assert conn.execute("select 'hello'::text").fetchone()[0] == "hellohello"
format = pq.Format.TEXT
def load(self, b):
- return b.decode("ascii") + suffix
+ return bytes(b).decode("ascii") + suffix
return TestLoader
import pytest
import psycopg3
from psycopg3 import pq
+from psycopg3 import sql
from psycopg3.oids import builtins
from psycopg3.adapt import Format, Transformer
from psycopg3.types import array
cur.execute("select %s::int[]", (obj,))
assert cur.fetchone()[0] == want
+ stmt = sql.SQL("copy (select {}::int[]) to stdout (format {})").format(
+ obj, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["int4[]"])
+ (got,) = copy.read_row()
+
+ assert got == want
+
def test_array_register(conn):
cur = conn.cursor()
import psycopg3.types
from psycopg3 import pq
+from psycopg3 import sql
from psycopg3.types import Json, Jsonb
from psycopg3.adapt import Format
assert cur.fetchone()[0] == json.loads(val)
+@pytest.mark.parametrize("val", samples)
+@pytest.mark.parametrize("jtype", ["json", "jsonb"])
+@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
+def test_json_load_copy(conn, val, jtype, fmt_out):
+ cur = conn.cursor()
+ stmt = sql.SQL("copy (select {}::{}) to stdout (format {})").format(
+ val, sql.Identifier(jtype), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([jtype])
+ (got,) = copy.read_row()
+
+ assert got == json.loads(val)
+
+
@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
@pytest.mark.parametrize("wrapper", ["Json", "Jsonb"])
def test_json_dump_customise(conn, wrapper, fmt_in):
import pytest
from psycopg3 import pq
+from psycopg3 import sql
from psycopg3.adapt import Format
@pytest.mark.parametrize("val", ["127.0.0.1/32", "::ffff:102:300/128"])
def test_inet_load_address(conn, fmt_out, val):
binary_check(fmt_out)
+ addr = ipaddress.ip_address(val.split("/", 1)[0])
cur = conn.cursor(binary=fmt_out)
+
cur.execute("select %s::inet", (val,))
- addr = ipaddress.ip_address(val.split("/", 1)[0])
assert cur.fetchone()[0] == addr
+
cur.execute("select array[null, %s::inet]", (val,))
assert cur.fetchone()[0] == [None, addr]
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == addr
+
@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
@pytest.mark.parametrize("val", ["127.0.0.1/24", "::ffff:102:300/127"])
def test_inet_load_network(conn, fmt_out, val):
binary_check(fmt_out)
+ pyval = ipaddress.ip_interface(val)
cur = conn.cursor(binary=fmt_out)
+
cur.execute("select %s::inet", (val,))
- assert cur.fetchone()[0] == ipaddress.ip_interface(val)
+ assert cur.fetchone()[0] == pyval
+
cur.execute("select array[null, %s::inet]", (val,))
- assert cur.fetchone()[0] == [None, ipaddress.ip_interface(val)]
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::inet) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["inet"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
def test_cidr_load(conn, fmt_out, val):
binary_check(fmt_out)
+ pyval = ipaddress.ip_network(val)
cur = conn.cursor(binary=fmt_out)
+
cur.execute("select %s::cidr", (val,))
- assert cur.fetchone()[0] == ipaddress.ip_network(val)
+ assert cur.fetchone()[0] == pyval
+
cur.execute("select array[null, %s::cidr]", (val,))
- assert cur.fetchone()[0] == [None, ipaddress.ip_network(val)]
+ assert cur.fetchone()[0] == [None, pyval]
+
+ stmt = sql.SQL("copy (select {}::cidr) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["cidr"])
+ (got,) = copy.read_row()
+
+ assert got == pyval
def binary_check(fmt):
).fetchone()
assert res == eur
+ stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
+ assert res == eur
+
@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
def test_load_badenc(conn, typename, fmt_out):
+ conn.autocommit = True
cur = conn.cursor(binary=fmt_out)
conn.client_encoding = "latin1"
- with pytest.raises(psycopg3.DatabaseError):
+ with pytest.raises(psycopg3.DataError):
cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+ stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ with pytest.raises(psycopg3.DataError):
+ copy.read_row()
+
@pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
cur = conn.cursor(binary=fmt_out)
conn.client_encoding = "ascii"
- (res,) = cur.execute(
- f"select chr(%s::int)::{typename}", (ord(eur),)
- ).fetchone()
+ cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+ assert cur.fetchone()[0] == eur.encode("utf8")
+
+ stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ ord(eur), sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types([typename])
+ (res,) = copy.read_row()
+
assert res == eur.encode("utf8")
import pytest
from psycopg3 import pq
+from psycopg3 import sql
from psycopg3.adapt import Format
cur.execute("select %s::uuid", (val,))
assert cur.fetchone()[0] == UUID(val)
+ stmt = sql.SQL("copy (select {}::uuid) to stdout (format {})").format(
+ val, sql.SQL(fmt_out.name)
+ )
+ with cur.copy(stmt) as copy:
+ copy.set_types(["uuid"])
+ (res,) = copy.read_row()
+
+ assert res == UUID(val)
+
@pytest.mark.subprocess
def test_lazy_load(dsn):