The types and formats can be read as object attributes.
from . import pq
from . import errors as e
from .sql import Composable
-from .abc import Query, Params
+from .abc import Buffer, Query, Params
from ._enums import PyFormat
if TYPE_CHECKING:
def __init__(self, transformer: "Transformer"):
self._tx = transformer
- self.params: Optional[List[Optional[bytes]]] = None
+ self.params: Optional[Sequence[Optional[Buffer]]] = None
# these are tuples so they can be used as keys e.g. in prepared stmts
self.types: Tuple[int, ...] = ()
self._parts, vars, self._order
)
assert self._want_formats is not None
- self.params, self.types, self.formats = self._tx.dump_sequence(
- params, self._want_formats
- )
+ self.params = self._tx.dump_sequence(params, self._want_formats)
+ self.types = self._tx.types or ()
+ self.formats = self._tx.formats
else:
self.params = None
self.types = ()
from . import pq
from . import postgres
from . import errors as e
-from .abc import LoadFunc, AdaptContext, PyFormat, DumperKey
+from .abc import Buffer, LoadFunc, AdaptContext, PyFormat, DumperKey
from .rows import Row, RowMaker
from .postgres import INVALID_OID
_adapters: "AdaptersMap"
_pgresult: Optional["PGresult"] = None
+ types: Tuple[int, ...]
+ formats: List[pq.Format]
+
def __init__(self, context: Optional[AdaptContext] = None):
+ self.types = ()
+ self.formats = []
# WARNING: don't store context, or you'll create a loop with the Cursor
if context:
) -> None:
self._pgresult = result
- self._ntuples: int
- self._nfields: int
if not result:
self._nfields = self._ntuples = 0
if set_loaders:
self._row_loaders = []
return
- nf = self._nfields = result.nfields
self._ntuples = result.ntuples
+ nf = self._nfields = result.nfields
+
+ if not set_loaders:
+ return
+
+ if not nf:
+ self._row_loaders = []
+ return
- if set_loaders:
- rc = self._row_loaders = []
- for i in range(nf):
- oid = result.ftype(i)
- fmt = result.fformat(i) if format is None else format
- rc.append(self.get_loader(oid, fmt).load) # type: ignore
+ fmt: pq.Format
+ fmt = result.fformat(0) if format is None else format # type: ignore
+ self._row_loaders = [
+ self.get_loader(result.ftype(i), fmt).load for i in range(nf)
+ ]
def set_dumper_types(
self, types: Sequence[int], format: pq.Format
) -> None:
- dumpers: List[Optional["Dumper"]] = []
- for i in range(len(types)):
- dumpers.append(self.get_dumper_by_oid(types[i], format))
-
- self._row_dumpers = dumpers
+ self._row_dumpers = [
+ self.get_dumper_by_oid(oid, format) for oid in types
+ ]
+ self.types = tuple(types)
+ self.formats = [format] * len(types)
def set_loader_types(
self, types: Sequence[int], format: pq.Format
) -> None:
- loaders: List[LoadFunc] = []
- for i in range(len(types)):
- loaders.append(self.get_loader(types[i], format).load)
-
- self._row_loaders = loaders
+ self._row_loaders = [
+ self.get_loader(oid, format).load for oid in types
+ ]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
- ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
- ps: List[Optional[bytes]] = [None] * len(params)
- ts = [INVALID_OID] * len(params)
- fs: List[pq.Format] = [pq.Format.TEXT] * len(params)
+ ) -> Sequence[Optional[Buffer]]:
+ out: List[Optional[Buffer]] = [None] * len(params)
- dumpers = self._row_dumpers
+ change_state = False
+
+ dumpers: List[Optional[Dumper]] = self._row_dumpers
+ types: Optional[List[int]] = None
+ pqformats: Optional[List[pq.Format]] = None
+
+ # If we have dumpers, it means dump_sequnece or set_dumper_types were
+ # called already, in which case self.types and self.formats are set to
+ # sequences of the right size. We may change their contents if
+ # now we find a dumper we didn't have before, for instance because in
+ # an executemany the first records has a null, the second has a value.
if not dumpers:
- dumpers = self._row_dumpers = [None] * len(params)
+ change_state = True
+ dumpers = [None] * len(params)
+ types = [INVALID_OID] * len(params)
+ pqformats = [pq.Format.TEXT] * len(params)
for i in range(len(params)):
param = params[i]
dumper = dumpers[i]
if not dumper:
dumper = dumpers[i] = self.get_dumper(param, formats[i])
- ps[i] = dumper.dump(param)
- ts[i] = dumper.oid
- fs[i] = dumper.format
-
- return ps, tuple(ts), fs
+ change_state = True
+ if not types:
+ types = list(self.types)
+ types[i] = dumper.oid
+ if not pqformats:
+ pqformats = list(self.formats)
+ pqformats[i] = dumper.format
+
+ out[i] = dumper.dump(param)
+
+ if change_state:
+ self._row_dumpers = dumpers
+ assert types is not None
+ self.types = tuple(types)
+ assert pqformats is not None
+ self.formats = pqformats
+
+ return out
def get_dumper(self, obj: Any, format: PyFormat) -> "Dumper":
"""
class Transformer(Protocol):
+
+ types: Tuple[int, ...]
+ formats: Sequence[pq.Format]
+
def __init__(self, context: Optional[AdaptContext] = None):
...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
- ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]:
+ ) -> Sequence[Optional[Buffer]]:
...
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
out = bytearray()
out += _pack_int2(len(row))
- adapted, _, _ = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
+ adapted = tx.dump_sequence(row, [PyFormat.BINARY] * len(row))
for b in adapted:
if b is not None:
out += _pack_int4(len(b))
super().__init__(cls, context)
nfields = len(self.info.field_types)
self._tx.set_dumper_types(self.info.field_types, self.format)
- self._formats = [PyFormat.from_pq(self.format)] * nfields
+ self._formats = (PyFormat.from_pq(self.format),) * nfields
def dump(self, obj: Tuple[Any, ...]) -> bytearray:
out = bytearray(pack_len(len(obj)))
- adapted, _, _ = self._tx.dump_sequence(obj, self._formats)
+ adapted = self._tx.dump_sequence(obj, self._formats)
for i in range(len(obj)):
b = adapted[i]
oid = self.info.field_types[i]
from psycopg.connection import BaseConnection
class Transformer(abc.AdaptContext):
+ types: Tuple[int, ...]
+ formats: Sequence[pq.Format]
def __init__(self, context: Optional[abc.AdaptContext] = None): ...
@property
def connection(self) -> Optional[BaseConnection[Any]]: ...
) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
- ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
+ ) -> Sequence[Optional[abc.Buffer]]: ...
def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper: ...
def load_rows(
cur = conn.cursor()
with cur.copy("copy (select 1) to stdout") as copy:
assert cur._query.query == b"copy (select 1) to stdout"
- assert cur._query.params is None
+ assert not cur._query.params
list(copy)
cur = aconn.cursor()
async with cur.copy("copy (select 1) to stdout") as copy:
assert cur._query.query == b"copy (select 1) to stdout"
- assert cur._query.params is None
+ assert not cur._query.params
async for record in copy:
pass
cur.execute("select 1")
assert cur._query.query == b"select 1"
- assert cur._query.params is None
+ assert not cur._query.params
with pytest.raises(psycopg.DataError):
cur.execute("select %t::int", ["wat"])
await cur.execute("select 1")
assert cur._query.query == b"select 1"
- assert cur._query.params is None
+ assert not cur._query.params
with pytest.raises(psycopg.DataError):
await cur.execute("select %t::int", ["wat"])
@pytest.mark.parametrize(
"query, params, want, wformats, wparams",
[
- (b"", None, b"", None, None),
+ (b"", None, b"", (), ()),
(b"", [], b"", [], []),
(b"%%", [], b"%", [], []),
(b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),