func_id=self.varlist_id.upper(),
)
- _url_pattern = (
- "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}"
- )
+ _url_pattern = "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}"
class LibpqReader:
del ann["return"]
-def process_signature(
- app, what, name, obj, options, signature, return_annotation
-):
+def process_signature(app, what, name, obj, options, signature, return_annotation):
pass
self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
- def register_loader(
- self, oid: Union[int, str], loader: Type["Loader"]
- ) -> None:
+ def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
"""
Configure the context to use *loader* to convert data of oid *oid*.
if isinstance(oid, str):
oid = self.types[oid].oid
if not isinstance(oid, int):
- raise TypeError(
- f"loaders should be registered on oid, got {oid} instead"
- )
+ raise TypeError(f"loaders should be registered on oid, got {oid} instead")
if _psycopg:
loader = self._get_optimised(loader)
)
raise e.ProgrammingError(msg)
- def get_loader(
- self, oid: int, format: pq.Format
- ) -> Optional[Type["Loader"]]:
+ def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
"""
Return the loader class for the given oid and format.
elif pq.__impl__ == "python":
_psycopg = None # type: ignore
else:
- raise ImportError(
- f"can't find _psycopg optimised module in {pq.__impl__!r}"
- )
+ raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
the async paths.
"""
- re_srv_rr = re.compile(
- r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)"
- )
+ re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
# The query is not to be prepared yet
return Prepare.NO, b""
- def _should_discard(
- self, prep: Prepare, results: Sequence["PGresult"]
- ) -> bool:
+ def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
"""Check if we need to discard our entire state: it should happen on
rollback or on dropping objects, because the same object may get
recreated and postgres would fail internal lookups.
if result.status != ExecStatus.COMMAND_OK:
continue
cmdstat = result.command_status
- if cmdstat and (
- cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"
- ):
+ if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
return self.clear()
return False
This method updates `params` and `types`.
"""
if vars is not None:
- params = _validate_and_reorder_params(
- self._parts, vars, self._order
- )
+ params = _validate_and_reorder_params(self._parts, vars, self._order)
assert self._want_formats is not None
self.params = self._tx.dump_sequence(params, self._want_formats)
self.types = self._tx.types or ()
else:
if seen[part.item][1] != part.format:
raise e.ProgrammingError(
- f"placeholder '{part.item}' cannot have"
- f" different formats"
+ f"placeholder '{part.item}' cannot have different formats"
)
chunks.append(seen[part.item][0])
f" {len(vars)} parameters were passed"
)
if vars and not isinstance(parts[0].item, int):
- raise TypeError(
- "named placeholders require a mapping of parameters"
- )
+ raise TypeError("named placeholders require a mapping of parameters")
return vars # type: ignore[return-value]
else:
"positional placeholders (%s) require a sequence of parameters"
)
try:
- return [
- vars[item] for item in order or () # type: ignore[call-overload]
- ]
+ return [vars[item] for item in order or ()] # type: ignore[call-overload]
except KeyError:
raise e.ProgrammingError(
f"query parameter missing:"
if bqual is None:
raise TypeError("if format_id is specified, bqual must be too")
if not 0 <= format_id < 0x80000000:
- raise ValueError(
- "format_id must be a non-negative 32-bit integer"
- )
+ raise ValueError("format_id must be a non-negative 32-bit integer")
if len(bqual) > 64:
raise ValueError("bqual must be not longer than 64 chars")
if len(gtrid) > 64:
self.get_loader(result.ftype(i), fmt).load for i in range(nf)
]
- def set_dumper_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None:
- self._row_dumpers = [
- self.get_dumper_by_oid(oid, format) for oid in types
- ]
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
self.types = tuple(types)
self.formats = [format] * len(types)
- def set_loader_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None:
- self._row_loaders = [
- self.get_loader(oid, format).load for oid in types
- ]
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+ self._row_loaders = [self.get_loader(oid, format).load for oid in types]
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
return dumper
- def load_rows(
- self, row0: int, row1: int, make_row: RowMaker[Row]
- ) -> List[Row]:
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
res = self._pgresult
if not res:
raise e.InterfaceError("result not set")
return make_row(record)
- def load_sequence(
- self, record: Sequence[Optional[bytes]]
- ) -> Tuple[Any, ...]:
+ def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
if len(self._row_loaders) != len(record):
raise e.ProgrammingError(
f"cannot load sequence of {len(record)} items:"
try:
async with conn.transaction():
- async with conn.cursor(
- binary=True, row_factory=dict_row
- ) as cur:
- await cur.execute(
- cls._get_info_query(conn), {"name": name}
- )
+ async with conn.cursor(binary=True, row_factory=dict_row) as cur:
+ await cur.execute(cls._get_info_query(conn), {"name": name})
recs = await cur.fetchall()
except e.UndefinedObject:
return None
elif not recs:
return None
else:
- raise e.ProgrammingError(
- f"found {len(recs)} different types named {name}"
- )
+ raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
def register(self, context: Optional[AdaptContext] = None) -> None:
"""
if key.endswith("[]"):
key = key[:-2]
elif not isinstance(key, (int, tuple)):
- raise TypeError(
- f"the key must be an oid or a name, got {type(key)}"
- )
+ raise TypeError(f"the key must be an oid or a name, got {type(key)}")
try:
return self._registry[key]
except KeyError:
- raise KeyError(
- f"couldn't find the type {key!r} in the types registry"
- )
+ raise KeyError(f"couldn't find the type {key!r} in the types registry")
@overload
def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
else:
return t.oid
- def get_by_subtype(
- self, cls: Type[T], subtype: Union[int, str]
- ) -> Optional[T]:
+ def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
"""
Return info about a `TypeInfo` subclass by its element name or oid.
try:
zi: tzinfo = ZoneInfo(sname)
except KeyError:
- logger.warning(
- "unknown PostgreSQL timezone: %r; will use UTC", sname
- )
+ logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
zi = timezone.utc
_timezones[tzname] = zi
) -> None:
...
- def set_dumper_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None:
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
...
- def set_loader_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None:
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
...
def dump_sequence(
def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
...
- def load_rows(
- self, row0: int, row1: int, make_row: "RowMaker[Row]"
- ) -> List["Row"]:
+ def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
...
def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
...
- def load_sequence(
- self, record: Sequence[Optional[bytes]]
- ) -> Tuple[Any, ...]:
+ def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
...
def get_loader(self, oid: int, format: pq.Format) -> Loader:
rv = rv.replace(b"\\", b"\\\\")
return rv
- def get_key(
- self, obj: Any, format: PyFormat
- ) -> Union[type, Tuple[type, ...]]:
+ def get_key(self, obj: Any, format: PyFormat) -> Union[type, Tuple[type, ...]]:
"""
Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
`~psycopg.abc.Dumper` protocol. Look at its definition for details.
# Base implementation, not thread safe.
# Subclasses must call it holding a lock
self._check_intrans("isolation_level")
- self._isolation_level = (
- IsolationLevel(value) if value is not None else None
- )
+ self._isolation_level = IsolationLevel(value) if value is not None else None
self._begin_statement = b""
@property
try:
cb(diag)
except Exception as ex:
- logger.exception(
- "error processing notice callback '%s': %s", cb, ex
- )
+ logger.exception("error processing notice callback '%s': %s", cb, ex)
def add_notify_handler(self, callback: NotifyHandler) -> None:
"""
if result_format == Format.TEXT:
self.pgconn.send_query(command)
else:
- self.pgconn.send_query_params(
- command, None, result_format=result_format
- )
+ self.pgconn.send_query_params(command, None, result_format=result_format)
result = (yield from execute(self.pgconn))[-1]
if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK):
if result.status == ExecStatus.FATAL_ERROR:
- raise e.error_from_result(
- result, encoding=pgconn_encoding(self.pgconn)
- )
+ raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
else:
raise e.InterfaceError(
f"unexpected result {ExecStatus(result.status).name}"
parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
if self.deferrable is not None:
- parts.append(
- b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE"
- )
+ parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
self._begin_statement = b" ".join(parts)
return self._begin_statement
)
xid = self._tpc[0]
self._tpc = (xid, True)
- yield from self._exec_command(
- SQL("PREPARE TRANSACTION {}").format(str(xid))
- )
+ yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
- def _tpc_finish_gen(
- self, action: str, xid: Union[Xid, str, None]
- ) -> PQGen[None]:
+ def _tpc_finish_gen(self, action: str, xid: Union[Xid, str, None]) -> PQGen[None]:
fname = f"tpc_{action}()"
if xid is None:
if not self._tpc:
server_cursor_factory: Type[ServerCursor[Row]]
row_factory: RowFactory[Row]
- def __init__(
- self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None
- ):
+ def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None):
super().__init__(pgconn)
self.row_factory = row_factory or cast(RowFactory[Row], tuple_row)
self.lock = threading.Lock()
self.close()
@classmethod
- def _get_connection_params(
- cls, conninfo: str, **kwargs: Any
- ) -> Dict[str, Any]:
+ def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
"""Manipulate connection parameters before connecting.
:param conninfo: Connection string as received by `~Connection.connect()`.
ns = self.wait(notifies(self.pgconn))
enc = pgconn_encoding(self.pgconn)
for pgn in ns:
- n = Notify(
- pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
- )
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
ns = await self.wait(notifies(self.pgconn))
enc = pgconn_encoding(self.pgconn)
for pgn in ns:
- n = Notify(
- pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
- )
+ n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
yield n
async def wait(self, gen: PQGen[RV]) -> RV:
return await waiting.wait_async(gen, self.pgconn.socket)
@classmethod
- async def _wait_conn(
- cls, gen: PQGenConn[RV], timeout: Optional[int]
- ) -> RV:
+ async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
return await waiting.wait_conn_async(gen, timeout)
def _set_autocommit(self, value: bool) -> None:
def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
self._no_set_async("isolation_level")
- async def set_isolation_level(
- self, value: Optional[IsolationLevel]
- ) -> None:
+ async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
"""Async version of the `~Connection.isolation_level` setter."""
async with self.lock:
super()._set_isolation_level(value)
tmp.update(kwargs)
kwargs = tmp
- conninfo = " ".join(
- f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()
- )
+ conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
# Verify the result is valid
_parse_conninfo(conninfo)
#LIBPQ-CONNSTRING
"""
opts = _parse_conninfo(conninfo)
- rv = {
- opt.keyword.decode(): opt.val.decode()
- for opt in opts
- if opt.val is not None
- }
+ rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
for k, v in kwargs.items():
if v is not None:
rv[k] = v
self._pgresult: "PGresult" = tx.pgresult
if self._pgresult.binary_tuples == pq.Format.TEXT:
- self.formatter = TextFormatter(
- tx, encoding=pgconn_encoding(self._pgconn)
- )
+ self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
else:
self.formatter = BinaryFormatter(tx)
"""
registry = self.cursor.adapters.types
- oids = [
- t if isinstance(t, int) else registry.get_oid(t) for t in types
- ]
+ oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
if self._pgresult.status == ExecStatus.COPY_IN:
- self.formatter.transformer.set_dumper_types(
- oids, self.formatter.format
- )
+ self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
else:
- self.formatter.transformer.set_loader_types(
- oids, self.formatter.format
- )
+ self.formatter.transformer.set_loader_types(oids, self.formatter.format)
# High level copy protocol generators (state change of the Copy object)
if not exc:
return
- if (
- self.connection.pgconn.transaction_status
- != pq.TransactionStatus.ACTIVE
- ):
+ if self.connection.pgconn.transaction_status != pq.TransactionStatus.ACTIVE:
# The server has already finished to send copy data. The connection
# is already in a good state.
return
def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)
- self._queue: asyncio.Queue[bytes] = asyncio.Queue(
- maxsize=self.QUEUE_SIZE
- )
+ self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE)
self._worker: Optional[asyncio.Future[None]] = None
async def __aenter__(self) -> "AsyncCopy":
return data
elif isinstance(data, str):
- raise TypeError(
- "cannot copy str data in binary mode: use bytes instead"
- )
+ raise TypeError("cannot copy str data in binary mode: use bytes instead")
else:
raise TypeError(f"can't write {type(data).__name__}")
}
-def _dump_sub(
- m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl
-) -> bytes:
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
return __map[m.group(0)]
_load_repl = {v: k for k, v in _dump_repl.items()}
-def _load_sub(
- m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl
-) -> bytes:
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
return __map[m.group(0)]
self._execute_send(query, binary=False)
results = yield from execute(self._pgconn)
if len(results) != 1:
- raise e.ProgrammingError(
- "COPY cannot be mixed with other operations"
- )
+ raise e.ProgrammingError("COPY cannot be mixed with other operations")
result = results[0]
self._check_copy_result(result)
f" {ExecStatus(result.status).name}"
)
- def _set_current_result(
- self, i: int, format: Optional[Format] = None
- ) -> None:
+ def _set_current_result(self, i: int, format: Optional[Format] = None) -> None:
"""
Select one of the results in the cursor as the active one.
"""
if not res:
raise e.ProgrammingError("no result available")
elif res.status != ExecStatus.TUPLES_OK:
- raise e.ProgrammingError(
- "the last operation didn't produce a result"
- )
+ raise e.ProgrammingError("the last operation didn't produce a result")
def _check_copy_result(self, result: "PGresult") -> None:
"""
elif mode == "absolute":
newpos = value
else:
- raise ValueError(
- f"bad mode: {mode}. It should be 'relative' or 'absolute'"
- )
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
if not 0 <= newpos < self.pgresult.ntuples:
raise IndexError("position out of bound")
self._pos = newpos
__module__ = "psycopg"
__slots__ = ()
- def __init__(
- self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]
- ):
+ def __init__(self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]):
super().__init__(connection)
self._row_factory = row_factory
try:
with self._conn.lock:
self._conn.wait(
- self._execute_gen(
- query, params, prepare=prepare, binary=binary
- )
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e.Error as ex:
raise ex.with_traceback(None)
"""
try:
with self._conn.lock:
- self._conn.wait(
- self._executemany_gen(query, params_seq, returning)
- )
+ self._conn.wait(self._executemany_gen(query, params_seq, returning))
except e.Error as ex:
raise ex.with_traceback(None)
Iterate row-by-row on a result from the database.
"""
with self._conn.lock:
- self._conn.wait(
- self._stream_send_gen(query, params, binary=binary)
- )
+ self._conn.wait(self._stream_send_gen(query, params, binary=binary))
first = True
while self._conn.wait(self._stream_fetchone_gen(first)):
rec = self._tx.load_row(0, self._make_row)
"""
self._check_result_for_fetch()
assert self.pgresult
- records = self._tx.load_rows(
- self._pos, self.pgresult.ntuples, self._make_row
- )
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
try:
async with self._conn.lock:
await self._conn.wait(
- self._execute_gen(
- query, params, prepare=prepare, binary=binary
- )
+ self._execute_gen(query, params, prepare=prepare, binary=binary)
)
except e.Error as ex:
raise ex.with_traceback(None)
binary: Optional[bool] = None,
) -> AsyncIterator[Row]:
async with self._conn.lock:
- await self._conn.wait(
- self._stream_send_gen(query, params, binary=binary)
- )
+ await self._conn.wait(self._stream_send_gen(query, params, binary=binary))
first = True
while await self._conn.wait(self._stream_fetchone_gen(first)):
rec = self._tx.load_row(0, self._make_row)
async def fetchall(self) -> List[Row]:
self._check_result_for_fetch()
assert self.pgresult
- records = self._tx.load_rows(
- self._pos, self.pgresult.ntuples, self._make_row
- )
+ records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
self._pos = self.pgresult.ntuples
return records
DATETIME = DBAPITypeObject(
"DATETIME", "timestamp timestamptz date time timetz interval".split()
)
-NUMBER = DBAPITypeObject(
- "NUMBER", "int2 int4 int8 float4 float8 numeric".split()
-)
+NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
ROWID = DBAPITypeObject("ROWID", ("oid",))
STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
sqlstate: Optional[str] = None
def __init__(
- self,
- *args: Sequence[Any],
- info: ErrorInfo = None,
- encoding: str = "utf-8"
+ self, *args: Sequence[Any], info: ErrorInfo = None, encoding: str = "utf-8"
):
super().__init__(*args)
self._info = info
}
-def sqlcode(
- const_name: str, code: str
-) -> Callable[[Type[Error]], Type[Error]]:
+def sqlcode(const_name: str, code: str) -> Callable[[Type[Error]], Type[Error]]:
"""
Decorator to associate an exception class to a sqlstate.
"""
RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
- MultirangeInfo(
- "datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082
- ),
- MultirangeInfo(
- "int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23
- ),
- MultirangeInfo(
- "int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20
- ),
- MultirangeInfo(
- "nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700
- ),
- MultirangeInfo(
- "tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114
- ),
- MultirangeInfo(
- "tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184
- ),
+ MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
+ MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
+ MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
+ MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
+ MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
+ MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
# autogenerated: end
]:
types.add(t)
)
known = {
- line[4:].split("(", 1)[0]
- for line in lines[:istart]
- if line.startswith("def ")
+ line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
}
signatures = []
arg6: Optional[Array[c_int]],
arg7: int,
) -> int: ...
-def PQcancel(
- arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int
-) -> int: ...
+def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
def PQsetNoticeReceiver(
arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
) -> Callable[[Any], PGresult_struct]: ...
def PQnotifies(
arg1: Optional[PGconn_struct],
) -> Optional[pointer[PGnotify_struct]]: ... # type: ignore
-def PQputCopyEnd(
- arg1: Optional[PGconn_struct], arg2: Optional[bytes]
-) -> int: ...
+def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
# Arg 2 is a pointer, reported as _CArgObject by mypy
-def PQgetCopyData(
- arg1: Optional[PGconn_struct], arg2: Any, arg3: int
-) -> int: ...
+def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
def PQsetResultAttrs(
arg1: Optional[PGresult_struct],
arg2: int,
def binary_tuples(self) -> int:
...
- def get_value(
- self, row_number: int, column_number: int
- ) -> Optional[bytes]:
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
...
@property
...
@classmethod
- def _options_from_array(
- cls, opts: Sequence[Any]
- ) -> List["ConninfoOption"]:
+ def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
...
bmsg = bmsg.split(b":", 1)[-1].strip()
else:
- raise TypeError(
- f"PGconn or PGresult expected, got {type(obj).__name__}"
- )
+ raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
if bmsg:
msg = bmsg.decode(encoding, "replace")
raise TypeError(f"bytes expected, got {type(command)} instead")
self._ensure_pgconn()
if not impl.PQsendQuery(self._pgconn_ptr, command):
- raise e.OperationalError(
- f"sending query failed: {error_message(self)}"
- )
+ raise e.OperationalError(f"sending query failed: {error_message(self)}")
def exec_params(
self,
atypes = (impl.Oid * nparams)(*param_types)
self._ensure_pgconn()
- if not impl.PQsendPrepare(
- self._pgconn_ptr, name, command, nparams, atypes
- ):
+ if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes):
raise e.OperationalError(
f"sending query and params failed: {error_message(self)}"
)
for b in param_values
)
)
- alenghts = (c_int * nparams)(
- *(len(p) if p else 0 for p in param_values)
- )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
else:
nparams = 0
aparams = alenghts = None
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
if not isinstance(command, bytes):
- raise TypeError(
- f"'command' must be bytes, got {type(command)} instead"
- )
+ raise TypeError(f"'command' must be bytes, got {type(command)} instead")
if not param_types:
nparams = 0
if param_values:
nparams = len(param_values)
aparams = (c_char_p * nparams)(*param_values)
- alenghts = (c_int * nparams)(
- *(len(p) if p else 0 for p in param_values)
- )
+ alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
else:
nparams = 0
aparams = alenghts = None
def consume_input(self) -> None:
if 1 != impl.PQconsumeInput(self._pgconn_ptr):
- raise e.OperationalError(
- f"consuming input failed: {error_message(self)}"
- )
+ raise e.OperationalError(f"consuming input failed: {error_message(self)}")
def is_busy(self) -> int:
return impl.PQisBusy(self._pgconn_ptr)
def flush(self) -> int:
if not self._pgconn_ptr:
- raise e.OperationalError(
- "flushing failed: the connection is closed"
- )
+ raise e.OperationalError("flushing failed: the connection is closed")
rv: int = impl.PQflush(self._pgconn_ptr)
if rv < 0:
raise e.OperationalError(f"flushing failed: {error_message(self)}")
buffer = bytes(buffer)
rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
if rv < 0:
- raise e.OperationalError(
- f"sending copy data failed: {error_message(self)}"
- )
+ raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
return rv
def put_copy_end(self, error: Optional[bytes] = None) -> int:
rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
if rv < 0:
- raise e.OperationalError(
- f"sending copy end failed: {error_message(self)}"
- )
+ raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
return rv
def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
buffer_ptr = c_char_p()
- nbytes = impl.PQgetCopyData(
- self._pgconn_ptr, byref(buffer_ptr), async_
- )
+ nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_)
if nbytes == -2:
raise e.OperationalError(
f"receiving copy data failed: {error_message(self)}"
def encrypt_password(
self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
) -> bytes:
- out = impl.PQencryptPasswordConn(
- self._pgconn_ptr, passwd, user, algorithm
- )
+ out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm)
if not out:
raise e.OperationalError(
f"password encryption failed: {error_message(self)}"
:raises ~e.OperationalError: if the flush request failed.
"""
if impl.PQsendFlushRequest(self._pgconn_ptr) == 0:
- raise e.OperationalError(
- f"flush request failed: {error_message(self)}"
- )
+ raise e.OperationalError(f"flush request failed: {error_message(self)}")
def _call_bytes(
self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
def binary_tuples(self) -> int:
return impl.PQbinaryTuples(self._pgresult_ptr)
- def get_value(
- self, row_number: int, column_number: int
- ) -> Optional[bytes]:
- length: int = impl.PQgetlength(
- self._pgresult_ptr, row_number, column_number
- )
+ def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+ length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number)
if length:
v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
return string_at(v, length)
def set_attributes(self, descriptions: List[PGresAttDesc]) -> None:
structs = [
- impl.PGresAttDesc_struct(*desc) # type: ignore
- for desc in descriptions
+ impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore
]
array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore
rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
See :pq:`PQcancel()` for details.
"""
buf = create_string_buffer(256)
- res = impl.PQcancel(
- self.pgcancel_ptr, pointer(buf), len(buf) # type: ignore
- )
+ res = impl.PQcancel(self.pgcancel_ptr, pointer(buf), len(buf)) # type: ignore
if not res:
raise e.OperationalError(
f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
def escape_literal(self, data: "abc.Buffer") -> memoryview:
if not self.conn:
- raise e.OperationalError(
- "escape_literal failed: no connection provided"
- )
+ raise e.OperationalError("escape_literal failed: no connection provided")
self.conn._ensure_pgconn()
# TODO: might be done without copy (however C does that)
def escape_identifier(self, data: "abc.Buffer") -> memoryview:
if not self.conn:
- raise e.OperationalError(
- "escape_identifier failed: no connection provided"
- )
+ raise e.OperationalError("escape_identifier failed: no connection provided")
self.conn._ensure_pgconn()
pointer(t_cast(c_ulong, len_out)),
)
else:
- out = impl.PQescapeBytea(
- data, len(data), pointer(t_cast(c_ulong, len_out))
- )
+ out = impl.PQescapeBytea(data, len(data), pointer(t_cast(c_ulong, len_out)))
if not out:
raise MemoryError(
f"couldn't allocate for escape_bytea of {len(data)} bytes"
# ascii except alnum and underscore
-_re_clean = re.compile(
- "[" + re.escape(" !\"#$%&'()*+,-./:;<=>?@[\\]^`{|}~") + "]"
-)
+_re_clean = re.compile("[" + re.escape(" !\"#$%&'()*+,-./:;<=>?@[\\]^`{|}~") + "]")
@functools.lru_cache(512)
# The above result only returned COMMAND_OK. Get the cursor shape
yield from self._describe_gen(cur)
- def _describe_gen(
- self, cur: BaseCursor[ConnectionType, Row]
- ) -> PQGen[None]:
+ def _describe_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
conn = cur._conn
conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
results = yield from execute(conn.pgconn)
sql.SQL("ALL") if num is None else sql.Literal(num),
sql.Identifier(self.name),
)
- res = yield from cur._conn._exec_command(
- query, result_format=self._format
- )
+ res = yield from cur._conn._exec_command(query, result_format=self._format)
cur.pgresult = res
cur._tx.set_pgresult(res, set_loaders=False)
self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str
) -> PQGen[None]:
if mode not in ("relative", "absolute"):
- raise ValueError(
- f"bad mode: {mode}. It should be 'relative' or 'absolute'"
- )
+ raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
query = sql.SQL("MOVE{} {} FROM {}").format(
sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
sql.Literal(value),
try:
with self._conn.lock:
- self._conn.wait(
- self._helper._declare_gen(self, query, params, binary)
- )
+ self._conn.wait(self._helper._declare_gen(self, query, params, binary))
except e.Error as ex:
raise ex.with_traceback(None)
returning: bool = True,
) -> None:
"""Method not implemented for server-side cursors."""
- raise e.NotSupportedError(
- "executemany not supported on server-side cursors"
- )
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
def fetchone(self) -> Optional[Row]:
with self._conn.lock:
def __iter__(self) -> Iterator[Row]:
while True:
with self._conn.lock:
- recs = self._conn.wait(
- self._helper._fetch_gen(self, self.itersize)
- )
+ recs = self._conn.wait(self._helper._fetch_gen(self, self.itersize))
for rec in recs:
self._pos += 1
yield rec
*,
returning: bool = True,
) -> None:
- raise e.NotSupportedError(
- "executemany not supported on server-side cursors"
- )
+ raise e.NotSupportedError("executemany not supported on server-side cursors")
async def fetchone(self) -> Optional[Row]:
async with self._conn.lock:
_obj: List[Composable]
def __init__(self, seq: Sequence[Any]):
- seq = [
- obj if isinstance(obj, Composable) else Literal(obj) for obj in seq
- ]
+ seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
super().__init__(seq)
def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
# operational error that might arise in the block.
raise
except Exception as exc2:
- logger.warning(
- "error ignored in rollback of %s: %s", self, exc2
- )
+ logger.warning("error ignored in rollback of %s: %s", self, exc2)
return False
def _commit_gen(self) -> PQGen[PGresult]:
def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
if isinstance(exc_val, Rollback):
- logger.debug(
- f"{self._conn}: Explicit rollback from: ", exc_info=True
- )
+ logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
ex = self._pop_savepoint("rollback")
self._exited = True
else:
# inner transaction: it always has a name
if not self._savepoint_name:
- self._savepoint_name = (
- f"_pg3_{self._conn._num_transactions + 1}"
- )
+ self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
self._stack_index = self._conn._num_transactions
self._conn._num_transactions += 1
) -> bool:
if self._conn.pgconn.status == ConnStatus.OK:
with self._conn.lock:
- return self._conn.wait(
- self._exit_gen(exc_type, exc_val, exc_tb)
- )
+ return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False
) -> bool:
if self._conn.pgconn.status == ConnStatus.OK:
async with self._conn.lock:
- return await self._conn.wait(
- self._exit_gen(exc_type, exc_val, exc_tb)
- )
+ return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
else:
return False
_struct_head = struct.Struct("!III") # ndims, hasnull, elem oid
_pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
-_unpack_head = cast(
- Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from
-)
+_unpack_head = cast(Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from)
_struct_dim = struct.Struct("!II") # dim, lower bound
_pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
-_unpack_dim = cast(
- Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from
-)
+_unpack_dim = cast(Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from)
TEXT_ARRAY_OID = postgres.types["text"].array_oid
NoneType: type = type(None)
super().__init__(cls, context)
self.sub_dumper: Optional[Dumper] = None
if self.element_oid and context:
- sdclass = context.adapters.get_dumper_by_oid(
- self.element_oid, self.format
- )
+ sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
self.sub_dumper = sdclass(NoneType, context)
def _find_list_element(self, L: List[Any]) -> Any:
else:
for item in L:
if not isinstance(item, self.cls):
- raise e.DataError(
- "nested lists have inconsistent depths"
- )
+ raise e.DataError("nested lists have inconsistent depths")
dump_list(item, dim + 1) # type: ignore
dump_list(obj, 0)
else:
if not stack:
wat = (
- t[:10].decode("utf8", "replace") + "..."
- if len(t) > 10
- else ""
+ t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
)
raise e.DataError(f"malformed array, unexpected '{wat}'")
if t == b"NULL":
return agg(dims)
-def register_array(
- info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
if not info.array_oid:
raise ValueError(f"the type info {info} doesn't describe an array")
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
- raise TypeError(
- "no info passed. Is the requested composite available?"
- )
+ raise TypeError("no info passed. Is the requested composite available?")
# Register arrays and type info
info.register(context)
adapters.register_dumper(factory, dumper)
# Default to the text dumper because it is more flexible
- dumper = type(
- f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid}
- )
+ dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid})
adapters.register_dumper(factory, dumper)
info.python_type = factory
_struct_timetz = struct.Struct("!qi") # microseconds, sec tz offset
_pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
-_unpack_timetz = cast(
- Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack
-)
+_unpack_timetz = cast(Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack)
_struct_interval = struct.Struct("!qii") # microseconds, days, months
_pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
def dump(self, obj: datetime) -> bytes:
delta = obj - _pg_datetimetz_epoch
- micros = delta.microseconds + 1_000_000 * (
- 86_400 * delta.days + delta.seconds
- )
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
return pack_int8(micros)
def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
def dump(self, obj: datetime) -> bytes:
delta = obj - _pg_datetime_epoch
- micros = delta.microseconds + 1_000_000 * (
- 86_400 * delta.days + delta.seconds
- )
+ micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
return pack_int8(micros)
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
elif ds.startswith(b"S") or ds.startswith(b"P"): # SQL or Postgres
- self._order = (
- self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
- )
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
else:
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
try:
return time(h, m, s, us)
except ValueError:
- raise DataError(
- f"time not supported by Python: hour={h}"
- ) from None
+ raise DataError(f"time not supported by Python: hour={h}") from None
class TimetzLoader(Loader):
try:
return time(h, m, s, us, timezone(timedelta(seconds=-off)))
except ValueError:
- raise DataError(
- f"time not supported by Python: hour={h}"
- ) from None
+ raise DataError(f"time not supported by Python: hour={h}") from None
class TimestampLoader(Loader):
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
elif ds.startswith(b"S"): # SQL
- self._order = (
- self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
- )
+ self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
elif ds.startswith(b"P"): # Postgres
- self._order = (
- self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
- )
+ self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
self._re_format = self._re_format_pg
else:
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
us = 0
try:
- return datetime(
- int(ye), imo, int(da), int(ho), int(mi), int(se), us
- )
+ return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
except ValueError as e:
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse timestamp {s!r}: {e}") from None
return _pg_datetime_epoch + timedelta(microseconds=micros)
except OverflowError:
if micros <= 0:
- raise DataError(
- "timestamp too small (before year 1)"
- ) from None
+ raise DataError("timestamp too small (before year 1)") from None
else:
- raise DataError(
- "timestamp too large (after year 10K)"
- ) from None
+ raise DataError("timestamp too large (after year 10K)") from None
class TimestamptzLoader(Loader):
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
- self._timezone = get_tzinfo(
- self.connection.pgconn if self.connection else None
- )
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
ds = _get_datestyle(self.connection)
if not ds.startswith(b"I"): # not ISO
dt = None
ex: Exception
try:
- dt = datetime(
- int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc
- )
+ dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
return (dt - tzoff).astimezone(self._timezone)
except OverflowError as e:
# If we have created the temporary 'dt' it means that we have a
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
- self._timezone = get_tzinfo(
- self.connection.pgconn if self.connection else None
- )
+ self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
def load(self, data: Buffer) -> datetime:
micros = unpack_int8(data)[0]
if utcoff:
usoff = 1_000_000 * int(utcoff.total_seconds())
try:
- ts = _pg_datetime_epoch + timedelta(
- microseconds=micros + usoff
- )
+ ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
except OverflowError:
pass # will raise downstream
else:
return ts.replace(tzinfo=self._timezone)
if micros <= 0:
- raise DataError(
- "timestamp too small (before year 1)"
- ) from None
+ raise DataError("timestamp too small (before year 1)") from None
else:
- raise DataError(
- "timestamp too large (after year 10K)"
- ) from None
+ raise DataError("timestamp too large (after year 10K)") from None
class IntervalLoader(Loader):
_month_abbr = {
n: i
- for i, n in enumerate(
- b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1
- )
+ for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
}
# Pad to get microseconds from a fraction of seconds
start = m.end()
if start < len(s):
- raise e.DataError(
- f"error parsing hstore: unparsed data after char {start}"
- )
+ raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
return rv
-def register_hstore(
- info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register the adapters to load and dump hstore.
:param info: The object with the information about the hstore type.
def __getitem__(self, index: slice) -> "Multirange[T]":
...
- def __getitem__(
- self, index: Union[int, slice]
- ) -> "Union[Range[T],Multirange[T]]":
+ def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
if isinstance(index, int):
return self._ranges[index]
else:
else:
return (self.cls,)
- def upgrade(
- self, obj: Multirange[Any], format: PyFormat
- ) -> "BaseMultirangeDumper":
+ def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Multirange:
return self
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
- self._load = self._tx.get_loader(
- self.subtype_oid, format=self.format
- ).load
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
class MultirangeLoader(BaseMultirangeLoader[T]):
# A friendly error warning instead of an AttributeError in case fetch()
# failed and it wasn't noticed.
if not info:
- raise TypeError(
- "no info passed. Is the requested multirange available?"
- )
+ raise TypeError("no info passed. Is the requested multirange available?")
# Register arrays and type info
info.register(context)
adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
adapters.register_dumper(DateMultirange, DateMultirangeDumper)
adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
- adapters.register_dumper(
- TimestamptzMultirange, TimestamptzMultirangeDumper
- )
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
- adapters.register_dumper(
- TimestampMultirange, TimestampMultirangeBinaryDumper
- )
- adapters.register_dumper(
- TimestamptzMultirange, TimestamptzMultirangeBinaryDumper
- )
+ adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
+ adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
adapters.register_loader("int4multirange", Int4MultirangeLoader)
adapters.register_loader("int8multirange", Int8MultirangeLoader)
adapters.register_loader("nummultirange", NumericMultirangeLoader)
adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
- adapters.register_loader(
- "tstzmultirange", TimestampTZMultirangeBinaryLoader
- )
+ adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
try:
return _decimal_special[sign]
except KeyError:
- raise e.DataError(
- f"bad value for numeric sign: 0x{sign:X}"
- ) from None
+ raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
def __getstate__(self) -> Dict[str, Any]:
return {
- slot: getattr(self, slot)
- for slot in self.__slots__
- if hasattr(self, slot)
+ slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
}
def __setstate__(self, state: Dict[str, Any]) -> None:
return dump_range_binary(obj, dump)
-def dump_range_binary(
- obj: Range[Any], dump: Callable[[Any], Buffer]
-) -> Buffer:
+def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
if not obj:
return _EMPTY_HEAD
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
- self._load = self._tx.get_loader(
- self.subtype_oid, format=self.format
- ).load
+ self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
class RangeLoader(BaseRangeLoader[T]):
return load_range_binary(data, self._load)
-def load_range_binary(
- data: Buffer, load: Callable[[Buffer], Any]
-) -> Range[Any]:
+def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
head = data[0]
if head & RANGE_EMPTY:
return Range(empty=True)
return Range(min, max, lb + ub)
-def register_range(
- info: RangeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register the adapters to load and dump a range type.
:param info: The object with the information about the range to register.
return dumps(obj, hex=True).encode() # type: ignore
-def register_shapely(
- info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
"""Register Shapely dumper and loaders.
After invoking this function on an adapter, the queries retrieving
class _StrDumper(_BaseStrDumper):
def dump(self, obj: str) -> bytes:
if "\x00" in obj:
- raise DataError(
- "PostgreSQL text fields cannot contain NUL (0x00) bytes"
- )
+ raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
else:
return obj.encode(self._encoding)
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
- self._esc = Escaping(
- self.connection.pgconn if self.connection else None
- )
+ self._esc = Escaping(self.connection.pgconn if self.connection else None)
def dump(self, obj: bytes) -> Buffer:
return self._esc.escape_bytea(obj)
RW = EVENT_READ | EVENT_WRITE
-def wait_selector(
- gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
-) -> RV:
+def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using the best strategy available.
return rv
-async def wait_conn_async(
- gen: PQGenConn[RV], timeout: Optional[float] = None
-) -> RV:
+async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
"""
Coroutine waiting for a connection generator to complete.
return rv
-def wait_epoll(
- gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
-) -> RV:
+def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
"""
Wait for a generator using epoll where supported.
[build-system]
requires = ["setuptools>=49.2.0", "wheel>=0.37"]
build-backend = "setuptools.build_meta"
-
-[tool.black]
-line-length = 79
shapely
[flake8]
-max-line-length = 85
+max-line-length = 88
ignore = W503, E203
# This package shouldn't be imported before psycopg itself, or weird things
# will happen
if "psycopg" not in sys.modules:
- raise ImportError(
- "the psycopg package should be imported before psycopg_c"
- )
+ raise ImportError("the psycopg package should be imported before psycopg_c")
from .version import __version__ as __version__ # noqa
set_loaders: bool = True,
format: Optional[pq.Format] = None,
) -> None: ...
- def set_dumper_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None: ...
- def set_loader_types(
- self, types: Sequence[int], format: pq.Format
- ) -> None: ...
+ def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+ def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
def dump_sequence(
self, params: Sequence[Any], formats: Sequence[PyFormat]
) -> Sequence[Optional[abc.Buffer]]: ...
def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
- def load_rows(
- self, row0: int, row1: int, make_row: RowMaker[Row]
- ) -> List[Row]: ...
+ def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
- def load_sequence(
- self, record: Sequence[Optional[bytes]]
- ) -> Tuple[Any, ...]: ...
+ def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]: ...
def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
# Generators
[build-system]
requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0a5"]
build-backend = "setuptools.build_meta"
-
-[tool.black]
-line-length = 79
-e {toxinidir}/../psycopg_pool
[flake8]
-max-line-length = 85
+max-line-length = 88
ignore = W503, E203
max_lifetime: float = 60 * 60.0,
max_idle: float = 10 * 60.0,
reconnect_timeout: float = 5 * 60.0,
- reconnect_failed: Optional[
- Callable[["BasePool[ConnectionType]"], None]
- ] = None,
+ reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None,
num_workers: int = 3,
):
min_size, max_size = self._check_size(min_size, max_size)
"""`!True` if the pool is closed."""
return self._closed
- def _check_size(
- self, min_size: int, max_size: Optional[int]
- ) -> Tuple[int, int]:
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
if max_size is None:
max_size = min_size
if max_size < min_size:
raise ValueError("max_size must be greater or equal than min_size")
if min_size == max_size == 0:
- raise ValueError(
- "if min_size is 0 max_size must be greater or than 0"
- )
+ raise ValueError("if min_size is 0 max_size must be greater or than 0")
return min_size, max_size
Add some randomness to avoid mass reconnection.
"""
- conn._expire_at = monotonic() + self._jitter(
- self.max_lifetime, -0.05, 0.0
- )
+ conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0)
class ConnectionAttempt:
conninfo, *args, min_size=min_size, **kwargs
)
- def _check_size(
- self, min_size: int, max_size: Optional[int]
- ) -> Tuple[int, int]:
+ def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
if max_size is None:
max_size = min_size
self.run_task(AddConnection(self))
if not self._pool_full_event.wait(timeout):
self.close() # stop all the threads
- raise PoolTimeout(
- f"pool initialization incomplete after {timeout} sec"
- )
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
with self._lock:
assert self._pool_full_event
)
return conn
- async def _maybe_close_connection(
- self, conn: AsyncConnection[Any]
- ) -> bool:
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
# Close the connection if no client is waiting for it, or if the pool
# is closed. For extra refcare remove the pool reference from it.
# Maintain the stats.
self._nconns -= 1
return True
- async def resize(
- self, min_size: int, max_size: Optional[int] = None
- ) -> None:
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
min_size, max_size = self._check_size(min_size, max_size)
logger.info(
logger.info("waiting for pool %r initialization", self.name)
if not self._pool_full_event.wait(timeout):
self.close() # stop all the threads
- raise PoolTimeout(
- f"pool initialization incomplete after {timeout} sec"
- )
+ raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
with self._lock:
assert self._pool_full_event
logger.info("pool %r is ready to use", self.name)
@contextmanager
- def connection(
- self, timeout: Optional[float] = None
- ) -> Iterator[Connection[Any]]:
+ def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
"""Context manager to obtain a connection from the pool.
Return the connection immediately if available, otherwise wait up to
"""
now = monotonic()
if not attempt:
- attempt = ConnectionAttempt(
- reconnect_timeout=self.reconnect_timeout
- )
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
try:
conn = self._connect()
with self._lock:
if self._nconns < self._max_size and self._waiting:
self._nconns += 1
- logger.info(
- "growing pool %r to %s", self.name, self._nconns
- )
+ logger.info("growing pool %r to %s", self.name, self._nconns)
self.run_task(AddConnection(self, growing=True))
else:
self._growing = False
def __init__(self, pool: "ConnectionPool"):
self.pool = ref(pool)
- logger.debug(
- "task created in %s: %s", threading.current_thread().name, self
- )
+ logger.debug("task created in %s: %s", threading.current_thread().name, self)
def __repr__(self) -> str:
pool = self.pool()
# Pool is no more working. Quietly discard the operation.
return
- logger.debug(
- "task running in %s: %s", threading.current_thread().name, self
- )
+ logger.debug("task running in %s: %s", threading.current_thread().name, self)
self._run(pool)
def tick(self) -> None:
*,
open: bool = True,
connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
- configure: Optional[
- Callable[[AsyncConnection[Any]], Awaitable[None]]
- ] = None,
- reset: Optional[
- Callable[[AsyncConnection[Any]], Awaitable[None]]
- ] = None,
+ configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+ reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
**kwargs: Any,
):
self.connection_class = connection_class
self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
await self.putconn(conn)
- async def getconn(
- self, timeout: Optional[float] = None
- ) -> AsyncConnection[Any]:
+ async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
logger.info("connection requested from %r", self.name)
self._stats[self._REQUESTS_NUM] += 1
else:
await self._return_connection(conn)
- async def _maybe_close_connection(
- self, conn: AsyncConnection[Any]
- ) -> bool:
+ async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
# If the pool is closed just close the connection instead of returning
# it to the pool. For extra refcare remove the pool reference from it.
if not self._closed:
) -> None:
await self.close()
- async def resize(
- self, min_size: int, max_size: Optional[int] = None
- ) -> None:
+ async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
min_size, max_size = self._check_size(min_size, max_size)
ngrow = max(0, min_size - self._min_size)
"""Run a maintenance task in a worker."""
self._tasks.put_nowait(task)
- async def schedule_task(
- self, task: "MaintenanceTask", delay: float
- ) -> None:
+ async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
"""Run a maintenance task in a worker in the future."""
await self._sched.enter(delay, task.tick)
ex,
)
- async def _connect(
- self, timeout: Optional[float] = None
- ) -> AsyncConnection[Any]:
+ async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
self._stats[self._CONNECTIONS_NUM] += 1
kwargs = self.kwargs
if timeout:
"""
now = monotonic()
if not attempt:
- attempt = ConnectionAttempt(
- reconnect_timeout=self.reconnect_timeout
- )
+ attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
try:
conn = await self._connect()
async with self._lock:
if self._nconns < self._max_size and self._waiting:
self._nconns += 1
- logger.info(
- "growing pool %r to %s", self.name, self._nconns
- )
+ logger.info("growing pool %r to %s", self.name, self._nconns)
self.run_task(AddConnection(self, growing=True))
else:
self._growing = False
class ReturnConnection(MaintenanceTask):
"""Clean up and return a connection to the pool."""
- def __init__(
- self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"
- ):
+ def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
super().__init__(pool)
self.conn = conn
time = monotonic() + delay
return self.enterabs(time, action)
- def enterabs(
- self, time: float, action: Optional[Callable[[], Any]]
- ) -> Task:
+ def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
"""Enter a new task in the queue at an absolute time.
Schedule a `!None` to stop the execution.
EMPTY_QUEUE_TIMEOUT = 600.0
- async def enter(
- self, delay: float, action: Optional[Callable[[], Any]]
- ) -> Task:
+ async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
"""Enter a new task in the queue delayed in the future.
Schedule a `!None` to stop the execution.
time = monotonic() + delay
return await self.enterabs(time, action)
- async def enterabs(
- self, time: float, action: Optional[Callable[[], Any]]
- ) -> Task:
+ async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
"""Enter a new task in the queue at an absolute time.
Schedule a `!None` to stop the execution.
[flake8]
-max-line-length = 85
+max-line-length = 88
ignore = W503, E203
requires = ["setuptools>=49.2.0", "wheel>=0.37"]
build-backend = "setuptools.build_meta"
-[tool.black]
-line-length = 79
-
[tool.pytest.ini_options]
asyncio_mode = "auto"
filterwarnings = [
pgconn.trace(tracefile.fileno())
try:
- pgconn.set_trace_flags(
- pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE
- )
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
except psycopg.NotSupportedError:
pass
try:
conn = pq.PGconn.connect(dsn.encode())
if conn.status != pq.ConnStatus.OK:
- pytest.fail(
- f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
- )
+ pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
msg = check_connection_version(conn.server_version, request.function)
if msg:
conn.finish()
pytest.skip(str(e))
-def warm_up_database(
- dsn: str, __first_connection: List[bool] = [True]
-) -> None:
+def warm_up_database(dsn: str, __first_connection: List[bool] = [True]) -> None:
"""Connect to the database before returning a connection.
In the CI sometimes, the first test fails with a timeout, probably because
@property
def types_names(self):
- types = [
- t.as_string(self.conn).replace('"', "")
- for t in self.types_names_sql
- ]
+ types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql]
return types
def _get_type_name(self, tx, schema, value):
field_values.append(sql.SQL("{} {}").format(name, type))
fields = sql.SQL(", ").join(field_values)
- return sql.SQL(
- "create table {table} (id serial primary key, {fields})"
- ).format(table=self.table_name, fields=fields)
+ return sql.SQL("create table {table} (id serial primary key, {fields})").format(
+ table=self.table_name, fields=fields
+ )
@property
def insert_stmt(self):
- phs = [
- sql.Placeholder(format=self.format)
- for i in range(len(self.schema))
- ]
+ phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))]
return sql.SQL("insert into {} ({}) values ({})").format(
self.table_name,
sql.SQL(", ").join(self.fields_names),
@property
def select_stmt(self):
fields = sql.SQL(", ").join(self.fields_names)
- return sql.SQL("select {} from {} order by id").format(
- fields, self.table_name
- )
+ return sql.SQL("select {} from {} order by id").format(fields, self.table_name)
@contextmanager
def find_insert_problem(self, conn):
return tuple(self.example(spec) for spec in self.schema)
else:
return tuple(
- self.make(spec) if random() > nulls else None
- for spec in self.schema
+ self.make(spec) if random() > nulls else None for spec in self.schema
)
def assert_record(self, got, want):
for cls in dumpers.keys():
if isinstance(cls, str):
cls = deep_import(cls)
- if (
- issubclass(cls, Multirange)
- and self.conn.info.server_version < 140000
- ):
+ if issubclass(cls, Multirange) and self.conn.info.server_version < 140000:
continue
rv.add(cls)
self._makers[cls] = meth
return meth
else:
- raise NotImplementedError(
- f"cannot make fake objects of class {cls}"
- )
+ raise NotImplementedError(f"cannot make fake objects of class {cls}")
def get_matcher(self, spec):
cls = spec if isinstance(spec, type) else spec[0]
else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
)
else:
- return choice(
- (0.0, -0.0, float("-inf"), float("inf"), float("nan"))
- )
+ return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
def match_float(self, spec, got, want, approx=False, rel=None):
if got is not None and isnan(got):
def make_JsonFloat(self, spec):
# A float limited to what json accepts
# this exponent should generate no inf
- return float(
- f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}"
- )
+ return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}")
def schema_list(self, cls):
while True:
return spec[0](sorted(out))
def example_Multirange(self, spec):
- return self.make_Multirange(
- spec, length=1, empty_chance=0, no_bound_chance=0
- )
+ return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0)
def make_Int4Multirange(self, spec):
return self.make_Multirange((spec, Int4))
if unit is not None:
if want.lower is not None and not want.lower_inc:
- want = type(want)(
- want.lower + unit, want.upper, "[" + want.bounds[1]
- )
+ want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1])
if want.upper_inc:
- want = type(want)(
- want.lower, want.upper + unit, want.bounds[0] + ")"
- )
+ want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")")
if spec[1] == (dt.datetime, True) and not want.isempty:
# work around https://bugs.python.org/issue45347
raise ValueError("pproxy program not found")
cmdline = [pproxy, "--reuse"]
cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
- cmdline.extend(
- ["-r", f"tunnel://{self.server_host}:{self.server_port}"]
- )
+ cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"])
self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL)
logging.info("proxy started")
self.conn = conn
def check_tpc(self):
- val = int(
- self.conn.execute("show max_prepared_transactions").fetchone()[0]
- )
+ val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0])
if not val:
pytest.skip("prepared transactions disabled in the database")
def pytest_configure(config):
- config.addinivalue_line(
- "markers", "pool: test related to the psycopg_pool package"
- )
+ config.addinivalue_line("markers", "pool: test related to the psycopg_pool package")
def pytest_collection_modifyitems(items):
assert p.max_size == 2
-@pytest.mark.parametrize(
- "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
def test_bad_size(dsn, min_size, max_size):
with pytest.raises(ValueError):
NullConnectionPool(min_size=min_size, max_size=max_size)
@pytest.mark.slow
@pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
def test_no_queue_timeout(deaf_port):
- with NullConnectionPool(
- kwargs={"host": "localhost", "port": deaf_port}
- ) as p:
+ with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
with pytest.raises(PoolTimeout):
with p.connection(timeout=1):
pass
ensure_waiting(p)
pids.append(conn.info.backend_pid)
- conn.pgconn.exec_(
- b"copy (select * from generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
assert conn.info.transaction_status == TransactionStatus.ACTIVE
p.putconn(conn)
t.join()
p.open()
-@pytest.mark.parametrize(
- "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
def test_bad_resize(dsn, min_size, max_size):
with NullConnectionPool() as p:
with pytest.raises(ValueError):
assert p.max_size == 2
-@pytest.mark.parametrize(
- "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
async def test_bad_size(dsn, min_size, max_size):
with pytest.raises(ValueError):
AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
@pytest.mark.slow
async def test_setup_no_timeout(dsn, proxy):
with pytest.raises(PoolTimeout):
- async with AsyncNullConnectionPool(
- proxy.client_dsn, num_workers=1
- ) as p:
+ async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
await p.wait(0.2)
async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
async def worker(n):
t0 = time()
async with p.connection() as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
t1 = time()
results.append((n, t1 - t0, pid))
t0 = time()
try:
async with p.connection() as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
except PoolTimeout as e:
t1 = time()
timeout = 0.25 if n == 3 else None
try:
async with p.connection(timeout=timeout) as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
except PoolTimeout as e:
t1 = time()
await ensure_waiting(p)
pids.append(conn.info.backend_pid)
- conn.pgconn.exec_(
- b"copy (select * from generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
assert conn.info.transaction_status == TransactionStatus.ACTIVE
await p.putconn(conn)
await asyncio.gather(t)
await p.open()
-@pytest.mark.parametrize(
- "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
async def test_bad_resize(dsn, min_size, max_size):
async with AsyncNullConnectionPool() as p:
with pytest.raises(ValueError):
assert p.max_size == max_size if max_size is not None else min_size
-@pytest.mark.parametrize(
- "min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
def test_bad_size(dsn, min_size, max_size):
with pytest.raises(ValueError):
pool.ConnectionPool(min_size=min_size, max_size=max_size)
def test_kwargs(dsn):
- with pool.ConnectionPool(
- dsn, kwargs={"autocommit": True}, min_size=1
- ) as p:
+ with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p:
with p.connection() as conn:
assert conn.autocommit
@pytest.mark.slow
def test_setup_no_timeout(dsn, proxy):
with pytest.raises(pool.PoolTimeout):
- with pool.ConnectionPool(
- proxy.client_dsn, min_size=1, num_workers=1
- ) as p:
+ with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
p.wait(0.2)
with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
with pool.ConnectionPool(dsn, min_size=1) as p:
conn = p.getconn()
pid = conn.info.backend_pid
- conn.pgconn.exec_(
- b"copy (select * from generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
assert conn.info.transaction_status == TransactionStatus.ACTIVE
p.putconn(conn)
t1 = time()
results.append((n, t1 - t0))
- with pool.ConnectionPool(
- dsn, min_size=min_size, max_size=4, num_workers=3
- ) as p:
+ with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p:
p.wait(1.0)
results: List[Tuple[int, float]] = []
@pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
async def test_min_size_max_size(dsn, min_size, max_size):
- async with pool.AsyncConnectionPool(
- dsn, min_size=min_size, max_size=max_size
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
assert p.min_size == min_size
assert p.max_size == max_size if max_size is not None else min_size
-@pytest.mark.parametrize(
- "min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
async def test_bad_size(dsn, min_size, max_size):
with pytest.raises(ValueError):
pool.AsyncConnectionPool(min_size=min_size, max_size=max_size)
class MyConn(psycopg.AsyncConnection[Any]):
pass
- async with pool.AsyncConnectionPool(
- dsn, connection_class=MyConn, min_size=1
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
async with p.connection() as conn:
assert isinstance(conn, MyConn)
async def test_wait_ready(dsn, monkeypatch):
delay_connection(monkeypatch, 0.1)
with pytest.raises(pool.PoolTimeout):
- async with pool.AsyncConnectionPool(
- dsn, min_size=4, num_workers=1
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
await p.wait(0.3)
async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
async with conn.transaction():
await conn.execute("set default_transaction_read_only to on")
- async with pool.AsyncConnectionPool(
- dsn, min_size=1, configure=configure
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
await p.wait(timeout=1.0)
async with p.connection() as conn:
assert inits == 1
async def configure(conn):
await conn.execute("select 1")
- async with pool.AsyncConnectionPool(
- dsn, min_size=1, configure=configure
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
with pytest.raises(pool.PoolTimeout):
await p.wait(timeout=0.5)
async with conn.transaction():
await conn.execute("WAT")
- async with pool.AsyncConnectionPool(
- dsn, min_size=1, configure=configure
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
with pytest.raises(pool.PoolTimeout):
await p.wait(timeout=0.5)
async def worker(n):
t0 = time()
async with p.connection() as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
t1 = time()
results.append((n, t1 - t0, pid))
t0 = time()
try:
async with p.connection() as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
except pool.PoolTimeout as e:
t1 = time()
timeout = 0.25 if n == 3 else None
try:
async with p.connection(timeout=timeout) as conn:
- cur = await conn.execute(
- "select pg_backend_pid() from pg_sleep(0.2)"
- )
+ cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
(pid,) = await cur.fetchone() # type: ignore[misc]
except pool.PoolTimeout as e:
t1 = time()
async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
conn = await p.getconn()
pid = conn.info.backend_pid
- conn.pgconn.exec_(
- b"copy (select * from generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
assert conn.info.transaction_status == TransactionStatus.ACTIVE
await p.putconn(conn)
async def test_open_wait(dsn, monkeypatch):
delay_connection(monkeypatch, 0.1)
with pytest.raises(pool.PoolTimeout):
- p = pool.AsyncConnectionPool(
- dsn, min_size=4, num_workers=1, open=False
- )
+ p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
try:
await p.open(wait=True, timeout=0.3)
finally:
async def test_open_as_wait(dsn, monkeypatch):
delay_connection(monkeypatch, 0.1)
with pytest.raises(pool.PoolTimeout):
- async with pool.AsyncConnectionPool(
- dsn, min_size=4, num_workers=1
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
await p.open(wait=True, timeout=0.3)
async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
async with p.connection() as conn:
await conn.execute("select pg_sleep(0.1)")
- async with pool.AsyncConnectionPool(
- dsn, min_size=2, max_size=4, max_idle=0.2
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
await p.wait(5.0)
assert p.max_idle == 0.2
def test_jitter():
- rnds = [
- pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)
- ]
+ rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
assert 27 <= min(rnds) <= 28
assert 35 < max(rnds) < 36
@pytest.mark.slow
@pytest.mark.timing
async def test_max_lifetime(dsn):
- async with pool.AsyncConnectionPool(
- dsn, min_size=1, max_lifetime=0.2
- ) as p:
+ async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
await asyncio.sleep(0.1)
pids = []
for i in range(5):
# Long query to make sure we have to wait on send
pgconn.send_query(
- b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
- % (b"x" * 1_000_000)
+ b"/* %s */ select pg_sleep(0.01); select 1 as foo;" % (b"x" * 1_000_000)
)
# send loop
def test_send_query_compact_test(pgconn):
# Like the above test but use psycopg facilities for compactness
pgconn.send_query(
- b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
- % (b"x" * 1_000_000)
+ b"/* %s */ select pg_sleep(0.01); select 1 as foo;" % (b"x" * 1_000_000)
)
results = execute_wait(pgconn)
pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
-@pytest.mark.parametrize(
- "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
def test_send_prepared_binary_out(pgconn, fmt, out):
val = b"foo\00bar"
pgconn.send_prepare(b"", b"select $1::bytea")
(res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
- pgconn.send_query_prepared(
- b"", [val], param_formats=[1], result_format=fmt
- )
+ pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt)
(res,) = execute_wait(pgconn)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == out
"""
sample_binary_rows = [
- bytes.fromhex("".join(row.split()))
- for row in sample_binary_value.split("\n\n")
+ bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n")
]
sample_binary = b"".join(sample_binary_rows)
@pytest.mark.parametrize("scs", ["on", "off"])
def test_escape_literal_1char(pgconn, scs):
- res = pgconn.exec_(
- f"set standard_conforming_strings to {scs}".encode("ascii")
- )
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
assert res.status == pq.ExecStatus.COMMAND_OK
esc = pq.Escaping(pgconn)
special = {b"'": b"''''", b"\\": b" E'\\\\'"}
@pytest.mark.parametrize("scs", ["on", "off"])
def test_escape_identifier_1char(pgconn, scs):
- res = pgconn.exec_(
- f"set standard_conforming_strings to {scs}".encode("ascii")
- )
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
assert res.status == pq.ExecStatus.COMMAND_OK
esc = pq.Escaping(pgconn)
special = {b'"': b'""""', b"\\": b'"\\"'}
@pytest.mark.parametrize("scs", ["on", "off"])
def test_escape_string_1char(pgconn, scs):
esc = pq.Escaping(pgconn)
- res = pgconn.exec_(
- f"set standard_conforming_strings to {scs}".encode("ascii")
- )
+ res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
assert res.status == pq.ExecStatus.COMMAND_OK
special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"}
for c in range(1, 128):
data = bytes(range(256))
esc = pq.Escaping()
escdata = esc.escape_bytea(data)
- res = pgconn.exec_params(
- b"select '%s'::bytea" % escdata, [], result_format=1
- )
+ res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == data
def test_exec_params_nulls(pgconn):
- res = pgconn.exec_params(
- b"select $1::text, $2::text, $3::text", [b"hi", b"", None]
- )
+ res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None])
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == b"hi"
assert res.get_value(0, 1) == b""
pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
-@pytest.mark.parametrize(
- "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
def test_exec_params_binary_out(pgconn, fmt, out):
val = b"foo\00bar"
res = pgconn.exec_params(
pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
-@pytest.mark.parametrize(
- "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
def test_exec_prepared_binary_out(pgconn, fmt, out):
val = b"foo\00bar"
res = pgconn.prepare(b"", b"select $1::bytea")
assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
- res = pgconn.exec_prepared(
- b"", [val], param_formats=[1], result_format=fmt
- )
+ res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt)
assert res.status == pq.ExecStatus.TUPLES_OK
assert res.get_value(0, 0) == out
def test_connect_async_bad(dsn):
- parsed_dsn = {
- e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val
- }
+ parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val}
parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real"
dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items())
conn = pq.PGconn.connect_start(dsn)
tracef = tmp_path / "trace"
with tracef.open("w") as f:
pgconn.trace(f.fileno())
- pgconn.set_trace_flags(
- pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE
- )
+ pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
pgconn.exec_(b"select 1")
pgconn.untrace()
pgconn.exec_(b"select 2")
def test_error_field(pgconn):
res = pgconn.exec_(b"select wat")
- assert (
- res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) == b"ERROR"
- )
+ assert res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) == b"ERROR"
assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703"
assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
res.clear()
@pytest.mark.parametrize("n", range(4))
def test_ntuples(pgconn, n):
- res = pgconn.exec_params(
- b"select generate_series(1, $1)", [str(n).encode("ascii")]
- )
+ res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")])
assert res.ntuples == n
res.clear()
assert res.ntuples == 0
assert pgconn.pipeline_status == pq.PipelineStatus.OFF
pgconn.enter_pipeline_mode()
pgconn.send_query_params(b"select $1", [b"1"])
- with pytest.raises(
- psycopg.OperationalError, match="cannot exit pipeline mode"
- ):
+ with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"):
pgconn.exit_pipeline_mode()
cur = cnn.cursor()
if test == "copy":
- with cur.copy(
- f"copy testdec from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy testdec from stdin (format {format.name})") as copy:
for j in range(nrows):
copy.write_row(
- [
- Decimal(randrange(10000000000)) / 100
- for i in range(ncols)
- ]
+ [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
)
elif test == "insert":
# Create and start all the thread: they will get stuck on the event
ev = threading.Event()
threads = [
- threading.Thread(
- target=worker, args=(pool, 0.002, ev), daemon=True
- )
+ threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True)
for i in range(opt.num_clients)
]
for t in threads:
self.measures = []
def start(self, interval):
- self.worker = threading.Thread(
- target=self._run, args=(interval,), daemon=True
- )
+ self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True)
self.worker.start()
def stop(self):
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
- parser.add_argument(
- "--dsn", default="", help="connection string to the database"
- )
+ parser.add_argument("--dsn", default="", help="connection string to the database")
parser.add_argument(
"--min_size",
default=5,
def test_register_dumper_by_class_name(conn):
dumper = make_dumper("x")
assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
- conn.adapters.register_dumper(
- f"{MyStr.__module__}.{MyStr.__qualname__}", dumper
- )
+ conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper)
assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
pass
cur = conn.cursor()
- cur.execute(
- "select %s::text, %b::text", [MyString("hello"), MyString("world")]
- )
+ cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")])
assert cur.fetchone() == ("hello", "world")
def test_none_type_argument(conn, fmt_in):
cur = conn.cursor()
cur.execute("create table none_args (id serial primary key, num integer)")
- cur.execute(
- "insert into none_args (num) values (%s) returning id", (None,)
- )
+ cur.execute("insert into none_args (num) values (%s) returning id", (None,))
assert cur.fetchone()[0]
else:
# Binary types cannot be passed as unknown oids.
with pytest.raises(e.DatatypeMismatch):
- cur.execute(
- f"insert into testjson (data) values (%{fmt_in})", ["{}"]
- )
+ cur.execute(f"insert into testjson (data) values (%{fmt_in})", ["{}"])
@pytest.mark.parametrize("fmt_in", PyFormat.AUTO)
def test_identify_closure(dsn):
def closer():
time.sleep(0.2)
- conn2.execute(
- "select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]
- )
+ conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
conn = psycopg.connect(dsn)
conn2 = psycopg.connect(dsn)
# Because of bad status check, we commit even when a commit is already on
# its way. We can detect this condition by the warnings.
notices = Queue() # type: ignore[var-annotated]
- aconn.add_notice_handler(
- lambda diag: notices.put_nowait(diag.message_primary)
- )
+ aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary))
stop = False
async def committer():
def test_broken(conn):
with pytest.raises(psycopg.OperationalError):
- conn.execute(
- "select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]
- )
+ conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
assert conn.closed
assert conn.broken
conn.close()
with pytest.raises(ZeroDivisionError):
with psycopg.connect(dsn) as conn:
- conn.pgconn.exec_(
- b"copy (select generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
status = conn.info.transaction_status
assert status == conn.TransactionStatus.ACTIVE
1 / 0
conn.add_notice_handler(cb1)
conn.add_notice_handler(cb2)
conn.add_notice_handler("the wrong thing")
- conn.add_notice_handler(
- lambda diag: severities.append(diag.severity_nonlocalized)
- )
+ conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
conn.pgconn.exec_(b"set client_min_messages to notice")
cur = conn.cursor()
- cur.execute(
- "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
- )
+ cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
assert messages == ["hello notice"]
assert severities == ["NOTICE"]
conn.remove_notice_handler(cb1)
conn.remove_notice_handler("the wrong thing")
- cur.execute(
- "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
- )
+ cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql")
assert len(caplog.records) == 3
assert messages == ["hello notice"]
assert severities == ["NOTICE", "WARNING"]
with pytest.raises(ZeroDivisionError):
async with await psycopg.AsyncConnection.connect(dsn) as conn:
- conn.pgconn.exec_(
- b"copy (select generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
status = conn.info.transaction_status
assert status == conn.TransactionStatus.ACTIVE
1 / 0
aconn.add_notice_handler(cb1)
aconn.add_notice_handler(cb2)
aconn.add_notice_handler("the wrong thing")
- aconn.add_notice_handler(
- lambda diag: severities.append(diag.severity_nonlocalized)
- )
+ aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
aconn.pgconn.exec_(b"set client_min_messages to notice")
cur = aconn.cursor()
- await cur.execute(
- "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
- )
+ await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
assert messages == ["hello notice"]
assert severities == ["NOTICE"]
for attr in tx_params:
guc = tx_params[attr]["guc"]
- cur = await aconn.execute(
- "select current_setting(%s)", [f"transaction_{guc}"]
- )
+ cur = await aconn.execute("select current_setting(%s)", [f"transaction_{guc}"])
pgval = (await cur.fetchone())[0]
assert tx_values_map[pgval] == value
monkeypatch.setenv("PGAPPNAME", "hello test")
with psycopg.connect(**dsn) as conn:
- assert (
- conn.info.get_parameters()["application_name"] == "hello test"
- )
+ assert conn.info.get_parameters()["application_name"] == "hello test"
def test_dsn_env(self, dsn, monkeypatch):
dsn = conninfo_to_dict(dsn)
with pytest.raises(psycopg.OperationalError):
conn.info.backend_pid
- @pytest.mark.skipif(
- sys.platform == "win32", reason="no IANA db on Windows"
- )
+ @pytest.mark.skipif(sys.platform == "win32", reason="no IANA db on Windows")
def test_timezone(self, conn):
conn.execute("set timezone to 'Europe/Rome'")
tz = conn.info.timezone
"""
sample_binary_rows = [
- bytes.fromhex("".join(row.split()))
- for row in sample_binary_str.split("\n\n")
+ bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
]
sample_binary = b"".join(sample_binary_rows)
want = sample_binary_rows
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
for row in want:
got = copy.read()
assert got == row
- assert (
- conn.info.transaction_status == conn.TransactionStatus.ACTIVE
- )
+ assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE
assert copy.read() == b""
assert copy.read() == b""
want = sample_binary_rows
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
assert list(copy) == want
assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
@pytest.mark.parametrize("format", Format)
def test_rows(conn, format):
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
copy.set_types(["int4", "int4", "text"])
rows = list(copy.rows())
chars = list(map(chr, range(1, 256))) + [eur]
conn.execute("set client_encoding to utf8")
rows = []
- query = sql.SQL(
- "copy (select unnest({}::text[])) to stdout (format {})"
- ).format(chars, sql.SQL(format.name))
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
with cur.copy(query) as copy:
copy.set_types(["text"])
while 1:
@pytest.mark.parametrize("format", Format)
def test_read_row_notypes(conn, format):
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
rows = []
while 1:
row = copy.read_row()
break
rows.append(row)
- ref = [
- tuple(py_to_raw(i, format) for i in record)
- for record in sample_records
- ]
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
assert rows == ref
@pytest.mark.parametrize("format", Format)
def test_rows_notypes(conn, format):
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
rows = list(copy.rows())
- ref = [
- tuple(py_to_raw(i, format) for i in record)
- for record in sample_records
- ]
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
assert rows == ref
@pytest.mark.parametrize("format", Format)
def test_copy_out_badntypes(conn, format, err):
cur = conn.cursor()
- with cur.copy(
- f"copy ({sample_values}) to stdout (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
copy.set_types([0] * (len(sample_records[0]) + err))
with pytest.raises(e.ProgrammingError):
copy.read_row()
cur = conn.cursor()
ensure_table(cur, sample_tabledef)
- with cur.copy(
- f"copy copy_in (data) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
copy.write_row(("hello",))
rec = cur.execute("select data from copy_in").fetchone()
def test_copy_out_error_with_copy_not_finished(conn):
cur = conn.cursor()
with pytest.raises(ZeroDivisionError):
- with cur.copy(
- "copy (select generate_series(1, 1000000)) to stdout"
- ) as copy:
+ with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy:
copy.read_row()
1 / 0
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
@pytest.mark.slow
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
def py_to_raw(item, fmt):
for row in want:
got = await copy.read()
assert got == row
- assert (
- aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
- )
+ assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
assert await copy.read() == b""
assert await copy.read() == b""
chars = list(map(chr, range(1, 256))) + [eur]
await aconn.execute("set client_encoding to utf8")
rows = []
- query = sql.SQL(
- "copy (select unnest({}::text[])) to stdout (format {})"
- ).format(chars, sql.SQL(format.name))
+ query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+ chars, sql.SQL(format.name)
+ )
async with cur.copy(query) as copy:
copy.set_types(["text"])
while 1:
break
rows.append(row)
- ref = [
- tuple(py_to_raw(i, format) for i in record)
- for record in sample_records
- ]
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
assert rows == ref
f"copy ({sample_values}) to stdout (format {format.name})"
) as copy:
rows = await alist(copy.rows())
- ref = [
- tuple(py_to_raw(i, format) for i in record)
- for record in sample_records
- ]
+ ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
assert rows == ref
async def test_copy_in_buffers(aconn, format, buffer):
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with cur.copy(
- f"copy copy_in from stdin (format {format.name})"
- ) as copy:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
await copy.write(globals()[buffer])
await cur.execute("select * from copy_in order by 1")
async def test_copy_out_error_with_copy_finished(aconn):
cur = aconn.cursor()
with pytest.raises(ZeroDivisionError):
- async with cur.copy(
- "copy (select generate_series(1, 2)) to stdout"
- ) as copy:
+ async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
await copy.read_row()
1 / 0
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with cur.copy(
- f"copy copy_in from stdin (format {format.name})"
- ) as copy:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
for row in sample_records:
if format == Format.BINARY:
row = tuple(
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with cur.copy(
- f"copy copy_in from stdin (format {format.name})"
- ) as copy:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
copy.set_types(["int4", "int4", "text"])
for row in sample_records:
await copy.write_row(row)
async def test_worker_life(aconn, format, buffer):
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
- async with cur.copy(
- f"copy copy_in from stdin (format {format.name})"
- ) as copy:
+ async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
assert not copy._worker
await copy.write(globals()[buffer])
assert copy._worker
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
@pytest.mark.slow
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
async def ensure_table(cur, tabledef, name="copy_in"):
assert cur.rowcount == 0
cur.executemany("delete from execmany where id = %s", [])
assert cur.rowcount == 0
- cur.executemany(
- "delete from execmany where id = %s returning num", [(-1,), (-2,)]
- )
+ cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
assert cur.rowcount == 0
cur.execute("create table test_rowcount_notuples (id int primary key)")
assert cur.rowcount == -1
- cur.execute(
- "insert into test_rowcount_notuples select generate_series(1, 42)"
- )
+ cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
assert cur.rowcount == 42
@pytest.mark.parametrize("fmt", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
-@pytest.mark.parametrize(
- "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
-)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
faker.format = fmt
faker.choose_schema(ncols=5)
work()
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
def my_row_factory(
titles = [c.name for c in cursor.description]
def mkrow(values):
- return [
- f"{value.upper()}{title}"
- for title, value in zip(titles, values)
- ]
+ return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
return mkrow
else:
async def test_execute_sequence(aconn):
cur = aconn.cursor()
- rv = await cur.execute(
- "select %s::int, %s::text, %s::text", [1, "foo", None]
- )
+ rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
assert rv is cur
assert len(cur._results) == 1
assert cur.pgresult.get_value(0, 0) == b"1"
async def test_executemany_no_data(aconn, execmany):
cur = aconn.cursor()
- await cur.executemany(
- "insert into execmany(num, data) values (%s, %s)", []
- )
+ await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
assert cur.rowcount == 0
await cur.execute("select 1 from generate_series(1, 42)")
assert cur.rowcount == 42
- await cur.execute(
- "create table test_rowcount_notuples (id int primary key)"
- )
+ await cur.execute("create table test_rowcount_notuples (id int primary key)")
assert cur.rowcount == -1
await cur.execute(
@pytest.mark.parametrize("fmt", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
@pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
-@pytest.mark.parametrize(
- "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
-)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
faker.format = fmt
faker.choose_schema(ncols=5)
async def work():
async with await psycopg.AsyncConnection.connect(dsn) as conn:
- async with conn.cursor(
- binary=fmt_out, row_factory=row_factory
- ) as cur:
+ async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
await cur.execute(faker.drop_stmt)
await cur.execute(faker.create_stmt)
async with faker.find_insert_problem_async(conn):
gc_collect()
n.append(len(gc.get_objects()))
- assert (
- n[0] == n[1] == n[2]
- ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+ assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
],
)
@pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(
- monkeypatch, conninfo, env, fake_resolve
-):
+async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
if env:
for k, v in env.items():
monkeypatch.setenv(k, v)
params = conninfo_to_dict(conninfo)
with pytest.raises(psycopg.Error):
- await psycopg._dns.resolve_hostaddr_async( # type: ignore[attr-defined]
- params
- )
+ await psycopg._dns.resolve_hostaddr_async(params) # type: ignore[attr-defined]
@pytest.mark.asyncio
got.append(conninfo)
1 / 0
- monkeypatch.setattr(
- psycopg.AsyncConnection, "_connect_gen", fake_connect_gen
- )
+ monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
# TODO: not enabled by default, but should be usable to make a subclass
class AsyncDnsConnection(psycopg.AsyncConnection[Row]):
else:
for entry in ans:
pri, w, port, target = entry.split()
- rv.append(
- SRV("IN", "SRV", int(pri), int(w), int(port), target)
- )
+ rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target))
return rv
conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary))
conn.execute(f"set client_encoding to {enc}")
cur = conn.cursor()
- cur.execute(
- "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql"
- )
+ cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql")
assert msgs == [f"hello {eur}"]
def test_no_prepare_multi(conn):
res = []
for i in range(10):
- cur = conn.execute(
- "select count(*) from pg_prepared_statements; select 1"
- )
+ cur = conn.execute("select count(*) from pg_prepared_statements; select 1")
res.append(cur.fetchone()[0])
assert res == [0] * 10
conn.execute("create table prepared_test (num int)", prepare=False)
conn.prepare_threshold = 0
conn.execute(query)
- cur = conn.execute(
- "select count(*) from pg_prepared_statements", prepare=False
- )
+ cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
assert cur.fetchone() == (1,)
{"enum_col": ["foo"]},
)
- cur = conn.execute(
- "select count(*) from pg_prepared_statements", prepare=False
- )
+ cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
assert cur.fetchone()[0] == 3
for i in range(3):
with pytest.raises(ZeroDivisionError):
with conn.transaction():
- conn.execute(
- "CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')"
- )
- conn.execute(
- "CREATE TABLE preptable(id integer, bar prepenum[])"
- )
+ conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+ conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
conn.cursor().execute(
"INSERT INTO preptable (bar) "
"VALUES (%(enum_col)s::prepenum[])",
async def test_auto_prepare_conn(aconn):
res = []
for i in range(10):
- cur = await aconn.execute(
- "select count(*) from pg_prepared_statements"
- )
+ cur = await aconn.execute("select count(*) from pg_prepared_statements")
res.append((await cur.fetchone())[0])
assert res == [0] * 5 + [1] * 5
aconn.prepare_threshold = None
res = []
for i in range(10):
- cur = await aconn.execute(
- "select count(*) from pg_prepared_statements"
- )
+ cur = await aconn.execute("select count(*) from pg_prepared_statements")
res.append((await cur.fetchone())[0])
assert res == [0] * 10
[dt.date(2020, 12, 10), 42, Decimal(42)],
prepare=True,
)
- cur = await aconn.execute(
- "select parameter_types from pg_prepared_statements"
- )
+ cur = await aconn.execute("select parameter_types from pg_prepared_statements")
(rec,) = await cur.fetchall()
assert rec[0] == ["date", "smallint", "numeric"]
"select statement from pg_prepared_statements order by prepare_time",
prepare=False,
)
- assert await cur.fetchall() == [
- (f"select {i}",) for i in ["'a'", 6, 7, 8, 9]
- ]
+ assert await cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]]
async def test_different_types(aconn):
for i in range(2):
await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
- cur = await aconn.execute(
- "select parameter_types from pg_prepared_statements"
- )
+ cur = await aconn.execute("select parameter_types from pg_prepared_statements")
assert await cur.fetchall() == [(["jsonb"],)]
# Shut up warnings
PsycopgTests.failUnless = PsycopgTests.assertTrue # type: ignore[assignment]
-PsycopgTPCTests.assertEquals = ( # type: ignore[assignment]
- PsycopgTPCTests.assertEqual
-)
+PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual # type: ignore[assignment]
@pytest.mark.parametrize(
cur.close()
assert cur.closed
- assert not conn.execute(
- "select * from pg_cursors where name = 'foo'"
- ).fetchone()
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
del cur
assert not recwarn, [str(w.message) for w in recwarn.list]
cur.execute("select generate_series(1, 10) as bar")
assert cur.closed
- assert not conn.execute(
- "select * from pg_cursors where name = 'foo'"
- ).fetchone()
+ assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
del cur
assert not recwarn, [str(w.message) for w in recwarn.list]
cur.execute("select generate_series(1, %s) as foo", (3,))
assert cur.fetchone() == (1,)
- cur.execute(
- "select %s::text as bar, %s::text as baz", ("hello", "world")
- )
+ cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
assert cur.fetchone() == ("hello", "world")
assert cur.description[0].name == "bar"
assert cur.description[0].type_code == cur.adapters.types["text"].oid
await cur.execute("select generate_series(1, %s) as foo", (3,))
assert await cur.fetchone() == (1,)
- await cur.execute(
- "select %s::text as bar, %s::text as baz", ("hello", "world")
- )
+ await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
assert await cur.fetchone() == ("hello", "world")
assert cur.description[0].name == "bar"
assert cur.description[0].type_code == cur.adapters.types["text"].oid
async def test_no_result(aconn):
async with aconn.cursor("foo") as cur:
- await cur.execute(
- "select generate_series(1, %s) as bar where false", (3,)
- )
+ await cur.execute("select generate_series(1, %s) as bar where false", (3,))
assert len(cur.description) == 1
assert (await cur.fetchall()) == []
s.as_string(conn)
def test_auto_literal(self, conn):
- s = sql.SQL("select {}, {}, {}").format(
- "he'lo", 10, dt.date(2020, 1, 1)
- )
+ s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1))
assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'"
def test_execute(self, conn):
cur.execute(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier("test_compose"),
- sql.SQL(", ").join(
- map(sql.Identifier, ["foo", "bar", "ba'z"])
- ),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
(sql.Placeholder() * 3).join(", "),
),
(10, "a", "b", "c"),
cur.executemany(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier("test_compose"),
- sql.SQL(", ").join(
- map(sql.Identifier, ["foo", "bar", "ba'z"])
- ),
+ sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
(sql.Placeholder() * 3).join(", "),
),
[(10, "a", "b", "c"), (20, "d", "e", "f")],
assert sql.Literal(None).as_string(conn) == "NULL"
assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
assert sql.Literal(42).as_string(conn) == "42"
- assert (
- sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
- )
+ assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
def test_as_bytes(self, conn):
assert sql.Literal(None).as_bytes(conn) == b"NULL"
assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
assert sql.Literal(42).as_bytes(conn) == b"42"
- assert (
- sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
- )
+ assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
conn.execute("set client_encoding to utf8")
assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode()
assert obj.as_string(conn) == '"foo", bar, 42'
obj = sql.SQL(", ").join(
- sql.Composed(
- [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
- )
+ sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)])
)
assert isinstance(obj, sql.Composed)
assert obj.as_string(conn) == '"foo", bar, 42'
def test_repr(self):
obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
- assert (
- repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
- )
+ assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
assert str(obj) == repr(obj)
def test_eq(self):
conn.close()
with psycopg.connect(dsn) as conn:
- xids = [
- x for x in conn.tpc_recover() if x.database == conn.info.dbname
- ]
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
assert len(xids) == 1
xid = xids[0]
conn.close()
with psycopg.connect(dsn) as conn:
- xids = [
- x for x in conn.tpc_recover() if x.database == conn.info.dbname
- ]
+ xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
assert len(xids) == 1
xid = xids[0]
conn.close()
with psycopg.connect(dsn) as conn:
- xid = [
- x for x in conn.tpc_recover() if x.database == conn.info.dbname
- ][0]
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
assert 10 == xid.format_id
assert "uni" == xid.gtrid
assert "code" == xid.bqual
conn.close()
with psycopg.connect(dsn) as conn:
- xid = [
- x for x in conn.tpc_recover() if x.database == conn.info.dbname
- ][0]
+ xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
assert xid.format_id is None
assert xid.gtrid == "transaction-id"
assert aconn.info.transaction_status == TransactionStatus.INTRANS
cur = aconn.cursor()
- await cur.execute(
- "insert into test_tpc values ('test_tpc_commit_rec')"
- )
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
assert tpc.count_xacts() == 0
assert tpc.count_test_records() == 0
assert aconn.info.transaction_status == TransactionStatus.INTRANS
cur = aconn.cursor()
- await cur.execute(
- "insert into test_tpc values ('test_tpc_rollback_1p')"
- )
+ await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
assert tpc.count_xacts() == 0
assert tpc.count_test_records() == 0
assert aconn.info.transaction_status == TransactionStatus.INTRANS
cur = aconn.cursor()
- await cur.execute(
- "insert into test_tpc values ('test_tpc_commit_rec')"
- )
+ await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
assert tpc.count_xacts() == 0
assert tpc.count_test_records() == 0
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
xids = [
- x
- for x in await aconn.tpc_recover()
- if x.database == aconn.info.dbname
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
]
assert len(xids) == 1
xid = xids[0]
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
xids = [
- x
- for x in await aconn.tpc_recover()
- if x.database == aconn.info.dbname
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
]
assert len(xids) == 1
xid = xids[0]
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
xid = [
- x
- for x in await aconn.tpc_recover()
- if x.database == aconn.info.dbname
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
][0]
assert 10 == xid.format_id
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
xid = [
- x
- for x in await aconn.tpc_recover()
- if x.database == aconn.info.dbname
+ x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
][0]
assert xid.format_id is None
try:
with pytest.raises(ZeroDivisionError):
with conn.transaction():
- conn.pgconn.exec_(
- b"copy (select generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
status = conn.info.transaction_status
assert status == conn.TransactionStatus.ACTIVE
1 / 0
try:
with pytest.raises(ZeroDivisionError):
async with conn.transaction():
- conn.pgconn.exec_(
- b"copy (select generate_series(1, 10)) to stdout"
- )
+ conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
status = conn.info.transaction_status
assert status == conn.TransactionStatus.ACTIVE
1 / 0
assert aconn.autocommit is autocommit
-async def test_autocommit_off_but_no_tx_started_successful_exit(
- aconn, svcconn
-):
+async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn):
"""
Scenario:
* Connection has autocommit off but no transaction has been initiated
assert not inserted(svcconn)
-async def test_autocommit_off_and_tx_in_progress_successful_exit(
- aconn, svcconn
-):
+async def test_autocommit_off_and_tx_in_progress_successful_exit(aconn, svcconn):
"""
Scenario:
* Connection has autocommit off but and a transaction is already in
assert not inserted(svcconn)
-async def test_autocommit_off_and_tx_in_progress_exception_exit(
- aconn, svcconn
-):
+async def test_autocommit_off_and_tx_in_progress_exception_exit(aconn, svcconn):
"""
Scenario:
* Connection has autocommit off but and a transaction is already in
assert not inserted(svcconn)
-async def test_nested_inner_scope_exception_handled_in_outer_scope(
- aconn, svcconn
-):
+async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn):
"""
An exception escaping the inner transaction context causes changes made
within that inner context to be discarded, but the error can then be
assert not await inserted(aconn)
-async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
- aconn, svcconn
-):
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn):
"""
Rolling-back an enclosing transaction does not impact an outer transaction.
"""
def _test_reveal(stmts, type, mypy):
- ignore = (
- "" if type.startswith("Optional") else "# type: ignore[assignment]"
- )
+ ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]"
stmts = "\n".join(f" {line}" for line in stmts.splitlines())
src = f"""\
assert res[1] == [val]
-@pytest.mark.parametrize(
- "array, type", [([1, 32767], "int2"), ([1, 32768], "int4")]
-)
+@pytest.mark.parametrize("array, type", [([1, 32767], "int2"), ([1, 32768], "int4")])
def test_array_mixed_numbers(array, type):
tx = Transformer()
dumper = tx.get_dumper(array, PyFormat.BINARY)
assert dumper.oid == builtins[type].array_oid
-@pytest.mark.parametrize(
- "wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()
-)
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
@pytest.mark.parametrize("fmt_in", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
def test_empty_list_after_choice(conn, fmt_in):
cur = conn.cursor()
cur.execute("create table test (id serial primary key, data float[])")
- cur.executemany(
- f"insert into test (data) values (%{fmt_in})", [([1.0],), ([],)]
- )
+ cur.executemany(f"insert into test (data) values (%{fmt_in})", [([1.0],), ([],)])
cur.execute("select data from test order by id")
assert cur.fetchall() == [([1.0],), ([],)]
def test_quote_bool(conn, val):
tx = Transformer()
- assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(
- val
- ).lower().encode("ascii")
+ assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode(
+ "ascii"
+ )
cur = conn.cursor()
cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
assert res == (chr(i),)
- cur.execute(
- "select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256))
- )
+ cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256)))
res = cur.fetchone()[0]
assert res == tuple(map(chr, range(1, 256)))
(b"foo'", b"'foo", b'"bar', b'bar"'),
),
(
- "10::int, null::text, 20::float,"
- " null::text, 'foo'::text, 'bar'::bytea ",
+ "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ",
(10, None, 20.0, None, "foo", b"bar"),
),
],
assert res.baz == 20.0
assert isinstance(res.baz, float)
- res = cur.execute(
- "select array[row('hello', 10, 30)::testcomp]"
- ).fetchone()[0]
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
assert len(res) == 1
assert res[0].baz == 30.0
assert isinstance(res[0].baz, float)
assert res.baz == 20.0
assert isinstance(res.baz, float)
- res = cur.execute(
- "select array[row('hello', 10, 30)::testcomp]"
- ).fetchone()[0]
+ res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
assert len(res) == 1
assert res[0].baz == 30.0
assert isinstance(res[0].baz, float)
cur.execute(f"select '{expr}'::date")
assert cur.fetchone()[0] == as_date(val)
- @pytest.mark.parametrize(
- "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
- )
+ @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
def test_load_date_datestyle(self, conn, datestyle_out):
cur = conn.cursor(binary=False)
cur.execute(f"set datestyle = {datestyle_out}, YMD")
assert cur.fetchone()[0] == dt.date(2000, 1, 2)
@pytest.mark.parametrize("val", ["min", "max"])
- @pytest.mark.parametrize(
- "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
- )
+ @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
def test_load_date_overflow(self, conn, val, datestyle_out):
cur = conn.cursor(binary=False)
cur.execute(f"set datestyle = {datestyle_out}, YMD")
- cur.execute(
- "select %t + %s::int", (as_date(val), -1 if val == "min" else 1)
- )
+ cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1))
with pytest.raises(DataError):
cur.fetchone()[0]
@pytest.mark.parametrize("val", ["min", "max"])
def test_load_date_overflow_binary(self, conn, val):
cur = conn.cursor(binary=True)
- cur.execute(
- "select %s + %s::int", (as_date(val), -1 if val == "min" else 1)
- )
+ cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1))
with pytest.raises(DataError):
cur.fetchone()[0]
]
@pytest.mark.parametrize("val, expr", load_datetime_samples)
- @pytest.mark.parametrize(
- "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
- )
+ @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
@pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out):
cur = conn.cursor(binary=False)
assert cur.fetchone()[0] == as_dt(val)
@pytest.mark.parametrize("val", ["min", "max"])
- @pytest.mark.parametrize(
- "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
- )
+ @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
def test_load_datetime_overflow(self, conn, val, datestyle_out):
cur = conn.cursor(binary=False)
cur.execute(f"set datestyle = {datestyle_out}, YMD")
@pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
@pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
@pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
- def test_load_datetimetz_tzname(
- self, conn, val, expr, datestyle_in, datestyle_out
- ):
+ def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out):
cur = conn.cursor(binary=False)
cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
cur.execute("set timezone to '-02:00'")
"SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]
).fetchone()
assert rec == ("2020-12-31", "infinity")
- rec = cur.execute(
- "select '2020-12-31'::date, 'infinity'::date"
- ).fetchone()
+ rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone()
assert rec == (date(2020, 12, 31), date(9999, 12, 31))
def test_load_copy(self, conn):
def as_date(s):
- return (
- dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
- )
+ return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
def as_time(s):
wrapper = getattr(multirange, wrapper)
mr1 = Multirange() # type: ignore[var-annotated]
mr2 = Multirange([Range(bounds="()")]) # type: ignore[var-annotated]
- cur = conn.execute(
- f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],)
- )
+ cur = conn.execute(f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],))
assert cur.fetchone()[0] is True
@pytest.mark.parametrize("format", pq.Format)
def test_copy_in(conn, min, max, bounds, format):
cur = conn.cursor()
- cur.execute(
- "create table copymr (id serial primary key, mr datemultirange)"
- )
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
if bounds != "empty":
min = dt.date(*map(int, min.split(","))) if min else None
mr = Multirange([r])
try:
- with cur.copy(
- f"copy copymr (mr) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
copy.write_row([mr])
except e.InternalError_:
if not min and not max and format == pq.Format.BINARY:
- pytest.xfail(
- "TODO: add annotation to dump multirange with no type info"
- )
+ pytest.xfail("TODO: add annotation to dump multirange with no type info")
else:
raise
@pytest.mark.parametrize("format", pq.Format)
def test_copy_in_empty_wrappers(conn, wrapper, format):
cur = conn.cursor()
- cur.execute(
- "create table copymr (id serial primary key, mr datemultirange)"
- )
+ cur.execute("create table copymr (id serial primary key, mr datemultirange)")
mr = getattr(multirange, wrapper)()
- with cur.copy(
- f"copy copymr (mr) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
copy.write_row([mr])
rec = cur.execute("select mr from copymr order by id").fetchone()
mr = Multirange() # type: ignore[var-annotated]
- with cur.copy(
- f"copy copymr (mr) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
copy.set_types([pgtype])
copy.write_row([mr])
@pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"])
def test_address_dump(conn, fmt_in, val):
cur = conn.cursor()
- cur.execute(
- f"select %{fmt_in} = %s::inet", (ipaddress.ip_address(val), val)
- )
+ cur.execute(f"select %{fmt_in} = %s::inet", (ipaddress.ip_address(val), val))
assert cur.fetchone()[0] is True
cur.execute(
f"select %{fmt_in} = array[null, %s]::inet[]",
@pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
def test_network_dump(conn, fmt_in, val):
cur = conn.cursor()
- cur.execute(
- f"select %{fmt_in} = %s::cidr", (ipaddress.ip_network(val), val)
- )
+ cur.execute(f"select %{fmt_in} = %s::cidr", (ipaddress.ip_network(val), val))
assert cur.fetchone()[0] is True
cur.execute(
f"select %{fmt_in} = array[NULL, %s]::cidr[]",
def test_dump_float_approx(conn, val, expr):
assert isinstance(val, float)
cur = conn.cursor()
- cur.execute(
- f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)
- )
+ cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,))
assert cur.fetchone()[0] is True
- cur.execute(
- f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)
- )
+ cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,))
assert cur.fetchone()[0] is True
for f in funcs:
expr = f(i)
val = Decimal(expr)
- cur.execute(
- f"select %{fmt_in}::text, %s::decimal::text", [val, expr]
- )
+ cur.execute(f"select %{fmt_in}::text, %s::decimal::text", [val, expr])
want, got = cur.fetchone()
assert got == want
wrapper = getattr(range_module, wrapper)
r1 = wrapper(empty=True)
r2 = wrapper(bounds="()")
- cur = conn.execute(
- f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],)
- )
+ cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],))
assert cur.fetchone()[0] is True
r1 = Range(empty=True) # type: ignore[var-annotated]
r2 = Range(bounds="()") # type: ignore[var-annotated]
cur = conn.cursor(binary=fmt_out)
- (got,) = cur.execute(
- f"select array['empty'::{pgtype}, '(,)'::{pgtype}]"
- ).fetchone()
+ (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone()
assert got == [r1, r2]
r = Range(min, max, bounds) # type: ignore[var-annotated]
sub = type2sub[pgtype]
cur = conn.cursor(binary=fmt_out)
- cur.execute(
- f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
- )
+ cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds))
# normalise discrete ranges
if r.upper_inc and isinstance(r.upper, int):
bounds = "[)" if r.lower_inc else "()"
r = Range(empty=True)
try:
- with cur.copy(
- f"copy copyrange (r) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
copy.write_row([r])
except e.ProtocolViolation:
if not min and not max and format == pq.Format.BINARY:
- pytest.xfail(
- "TODO: add annotation to dump ranges with no type info"
- )
+ pytest.xfail("TODO: add annotation to dump ranges with no type info")
else:
raise
cls = getattr(range_module, wrapper)
r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
- with cur.copy(
- f"copy copyrange (r) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
copy.write_row([r])
rec = cur.execute("select r from copyrange order by id").fetchone()
else:
r = Range(None, None, bounds)
- with cur.copy(
- f"copy copyrange (r) from stdin (format {format.name})"
- ) as copy:
+ with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
copy.set_types([pgtype])
copy.write_row([r])
def test_no_adapter(conn):
point = Point(1.2, 3.4)
- with pytest.raises(
- psycopg.ProgrammingError, match="cannot adapt type 'Point'"
- ):
+ with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"):
conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0]
SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
assert (
- shapely_conn.execute(
- "SELECT pg_typeof(%s)", [SAMPLE_POINT]
- ).fetchone()[0]
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0]
== "geometry"
)
assert (
- shapely_conn.execute(
- "SELECT pg_typeof(%s)", [SAMPLE_POLYGON]
- ).fetchone()[0]
+ shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0]
== "geometry"
)
assert cur.fetchone()[0] is True
-@pytest.mark.parametrize(
- "typename", ["text", "varchar", "name", "bpchar", '"char"']
-)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar", '"char"'])
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_load_1char(conn, typename, fmt_out):
cur = conn.cursor(binary=fmt_out)
got = (got_maj, got_min, got_fix)
# Parse a spec like "> 9.6"
- m = re.match(
- r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
- )
+ m = re.match(r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want)
if m is None:
pytest.fail(f"bad wanted version spec: {want}")
else:
want = (want_maj, want_min, want_fix)
- op = getattr(
- operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
- )
+ op = getattr(operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)])
if not op(got, want):
revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
from ruamel.yaml import YAML # pip install ruamel.yaml
logger = logging.getLogger()
-logging.basicConfig(
- level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
-)
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def fetch_user(username):
from psycopg.errors import get_base_exception
logger = logging.getLogger()
-logging.basicConfig(
- level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
-)
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def main():
continue
# Parse an error
- m = re.match(
- r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line
- )
+ m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line)
if m:
sqlstate, macro, spec = m.groups()
# skip sqlstates without specs as they are not publically visible
new = []
for query in queries:
- out = sp.run(
- ["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True
- )
+ out = sp.run(["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True)
new.extend(out.stdout.splitlines())
new = [b" " * 4 + line if line else b"" for line in new] # indent
codespell
[flake8]
-max-line-length = 85
+max-line-length = 88
ignore = W503, E203
extend-exclude = .venv