docs_dir = Path(__file__).parent
sys.path.append(str(docs_dir / "lib"))
-
# -- Project information -----------------------------------------------------
project = "psycopg"
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"]
-
# -- Options for HTML output -------------------------------------------------
# The announcement may be in the website but not shipped with the docs
-ann_file = docs_dir / "../../templates/docs3-announcement.html"
-if ann_file.exists():
+if (ann_file := (docs_dir / "../../templates/docs3-announcement.html")).exists():
with ann_file.open() as f:
announcement = f.read()
else:
"sidebar_hide_name": False,
"light_logo": "psycopg.svg",
"dark_logo": "psycopg.svg",
- "light_css_variables": {
- "admonition-font-size": "1rem",
- },
+ "light_css_variables": {"admonition-font-size": "1rem"},
}
# Add any paths that contain custom static files (such as style sheets) here,
self.add_function(data)
def handle_sect1(self, tag, attrs):
- attrs = dict(attrs)
- if "id" in attrs:
+ if "id" in (attrs := dict(attrs)):
self.section_id = attrs["id"]
def handle_varlistentry(self, tag, attrs):
- attrs = dict(attrs)
- if "id" in attrs:
+ if "id" in (attrs := dict(attrs)):
self.varlist_id = attrs["id"]
def add_function(self, func_name):
# must be set before using the rest of the class.
app = None
- _url_pattern = (
- "https://raw.githubusercontent.com/postgres/postgres/{branch}"
- "/doc/src/sgml/libpq.sgml"
- )
+ _url_pattern = "https://raw.githubusercontent.com/postgres/postgres/"
+ _url_pattern += "{branch}/doc/src/sgml/libpq.sgml"
data = None
parser.feed(f.read())
def download(self):
- filename = os.environ.get("LIBPQ_DOCS_FILE")
- if filename:
+ if filename := os.environ.get("LIBPQ_DOCS_FILE"):
logger.info("reading postgres libpq docs from %s", filename)
with open(filename, "rb") as f:
data = f.read()
if "(" in text:
func, noise = text.split("(", 1)
noise = "(" + noise
-
else:
func = text
noise = ""
def before_process_signature(app, obj, bound_method):
- ann = getattr(obj, "__annotations__", {})
- if "return" in ann:
+ if "return" in (ann := getattr(obj, "__annotations__", {})):
# Drop "return: None" from the function signatures
if ann["return"] is None:
del ann["return"]
for fn in walk_modules(mdir):
assert fn.startswith(mdir)
modname = os.path.splitext(fn[len(mdir) + 1 :])[0].replace("/", ".")
- modname = f"{m.__name__}.{modname}"
- if modname in skip_modules:
+ if (modname := f"{m.__name__}.{modname}") in skip_modules:
continue
with open(fn) as f:
classnames = re.findall(r"^class\s+([^(:]+)", f.read(), re.M)
for cls in classnames:
- cls = deep_import(f"{modname}.{cls}")
- if cls.__module__ != modname:
+ if (cls := deep_import(f"{modname}.{cls}")).__module__ != modname:
recovered_classes[cls] = modname
def ticket_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
- cfg = inliner.document.settings.env.app.config
- if cfg.ticket_url is None:
+ if (cfg := inliner.document.settings.env.app.config).ticket_url is None:
msg = inliner.reporter.warning(
"ticket not configured: please configure ticket_url in conf.py"
)
from .connection_async import AsyncConnection
# Set the logger to a quiet default, can be enabled if needed
-logger = logging.getLogger("psycopg")
-if logger.level == logging.NOTSET:
+if (logger := logging.getLogger("psycopg")).level == logging.NOTSET:
logger.setLevel(logging.WARNING)
# DBAPI compliance
_optimised: dict[type, type] = {}
def __init__(
- self,
- template: AdaptersMap | None = None,
- types: TypesRegistry | None = None,
+ self, template: AdaptersMap | None = None, types: TypesRegistry | None = None
):
if template:
self._dumpers = template._dumpers.copy()
if _psycopg:
loader = self._get_optimised(loader)
- fmt = loader.format
- if not self._own_loaders[fmt]:
+ if not self._own_loaders[(fmt := loader.format)]:
self._loaders[fmt] = self._loaders[fmt].copy()
self._own_loaders[fmt] = True
from psycopg import types
if cls.__module__.startswith(types.__name__):
- new = cast("type[RV]", getattr(_psycopg, cls.__name__, None))
- if new:
+ if new := cast("type[RV]", getattr(_psycopg, cls.__name__, None)):
self._optimised[cls] = new
return new
res = cursor.pgresult
assert res
- fname = res.fname(index)
- if fname:
+ if fname := res.fname(index):
self._name = fname.decode(cursor._encoding)
else:
# COPY_OUT results have columns but no name
def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
# Raise an exception if we are in a transaction
- status = self.pgconn.transaction_status
- if status == IDLE and self._pipeline:
+ if (status := self.pgconn.transaction_status) == IDLE and self._pipeline:
yield from self._pipeline._sync_gen()
status = self.pgconn.transaction_status
if status != IDLE:
self._check_connection_ok()
if self._pipeline:
- cmd = partial(
- self.pgconn.send_close_prepared,
- name,
- )
+ cmd = partial(self.pgconn.send_close_prepared, name)
self._pipeline.command_queue.append(cmd)
self._pipeline.result_queue.append(None)
return
# Local path, or no host to resolve
return [params]
- hostaddr = get_param(params, "hostaddr")
- if hostaddr:
+ if get_param(params, "hostaddr"):
# Already resolved
return [params]
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
- port = get_param(params, "port")
- if not port:
+ if not (port := get_param(params, "port")):
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
# Local path, or no host to resolve
return [params]
- hostaddr = get_param(params, "hostaddr")
- if hostaddr:
+ if get_param(params, "hostaddr"):
# Already resolved
return [params]
# If the host is already an ip address don't try to resolve it
return [{**params, "hostaddr": host}]
- port = get_param(params, "port")
- if not port:
+ if not (port := get_param(params, "port")):
port_def = get_param_def("port")
port = port_def and port_def.compiled or "5432"
# TODO: check if in service
- paramdef = get_param_def(name)
- if not paramdef:
+ if not (paramdef := get_param_def(name)):
return None
- env = os.environ.get(paramdef.envvar)
- if env is not None:
+ if (env := os.environ.get(paramdef.envvar)) is not None:
return env
return None
def __iter__(self) -> Iterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
- data = self.read()
- if not data:
+ if not (data := self.read()):
break
yield data
bytes, unless data types are specified using `set_types()`.
"""
while True:
- record = self.read_row()
- if record is None:
+ if (record := self.read_row()) is None:
break
yield record
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
- data = self.formatter.write(buffer)
- if data:
+ if data := self.formatter.write(buffer):
self._write(data)
def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
- data = self.formatter.write_row(row)
- if data:
+ if data := self.formatter.write_row(row):
self._write(data)
def finish(self, exc: BaseException | None) -> None:
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
- data = self.formatter.end()
- if data:
+ if data := self.formatter.end():
self._write(data)
self.writer.finish(exc)
self._finished = True
"""
try:
while True:
- data = self._queue.get()
- if not data:
+ if not (data := self._queue.get()):
break
self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
except BaseException as ex:
async def __aiter__(self) -> AsyncIterator[Buffer]:
"""Implement block-by-block iteration on :sql:`COPY TO`."""
while True:
- data = await self.read()
- if not data:
+ if not (data := (await self.read())):
break
yield data
bytes, unless data types are specified using `set_types()`.
"""
while True:
- record = await self.read_row()
- if record is None:
+ if (record := (await self.read_row())) is None:
break
yield record
If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
text mode it can be either `!bytes` or `!str`.
"""
- data = self.formatter.write(buffer)
- if data:
+ if data := self.formatter.write(buffer):
await self._write(data)
async def write_row(self, row: Sequence[Any]) -> None:
"""Write a record to a table after a :sql:`COPY FROM` operation."""
- data = self.formatter.write_row(row)
- if data:
+ if data := self.formatter.write_row(row):
await self._write(data)
async def finish(self, exc: BaseException | None) -> None:
using the `Copy` object outside a block.
"""
if self._direction == COPY_IN:
- data = self.formatter.end()
- if data:
+ if data := self.formatter.end():
await self._write(data)
await self.writer.finish(exc)
self._finished = True
"""
try:
while True:
- data = await self._queue.get()
- if not data:
+ if not (data := (await self._queue.get())):
break
await self.connection.wait(
copy_to(self._pgconn, data, flush=PREFER_FLUSH)
formatter: Formatter
def __init__(
- self,
- cursor: BaseCursor[ConnectionType, Any],
- *,
- binary: bool | None = None,
+ self, cursor: BaseCursor[ConnectionType, Any], *, binary: bool | None = None
):
self.cursor = cursor
self.connection = cursor.connection
self._pgconn = self.connection.pgconn
- result = cursor.pgresult
- if result:
+ if result := cursor.pgresult:
self._direction = result.status
if self._direction != COPY_IN and self._direction != COPY_OUT:
raise e.ProgrammingError(
return memoryview(b"")
def _read_row_gen(self) -> PQGen[tuple[Any, ...] | None]:
- data = yield from self._read_gen()
- if not data:
+ if not (data := (yield from self._read_gen())):
return None
- row = self.formatter.parse_row(data)
- if row is None:
+ if (row := self.formatter.parse_row(data)) is None:
# Get the final result to finish the copy operation
yield from self._read_gen()
self._finished = True
query = self._convert_query(statement)
self._execute_send(query, binary=False)
- results = yield from execute(self._pgconn)
- if len(results) != 1:
+ if len(results := (yield from execute(self._pgconn))) != 1:
raise e.ProgrammingError("COPY cannot be mixed with other operations")
self._check_copy_result(results[0])
"""
Raise an appropriate error message for an unexpected database result
"""
- status = result.status
- if status == FATAL_ERROR:
+ if (status := result.status) == FATAL_ERROR:
raise e.error_from_result(result, encoding=self._encoding)
elif status == PIPELINE_ABORTED:
raise e.PipelineAborted("pipeline aborted")
# Received from execute()
self._results[:] = results
self._select_current_result(0)
-
+ # Received from executemany()
+ elif self._execmany_returning:
+ first_batch = not self._results
+ self._results.extend(results)
+ if first_batch:
+ self._select_current_result(0)
else:
- # Received from executemany()
- if self._execmany_returning:
- first_batch = not self._results
- self._results.extend(results)
- if first_batch:
- self._select_current_result(0)
- else:
- # In non-returning case, set rowcount to the cumulated number of
- # rows of executed queries.
- for res in results:
- self._rowcount += res.command_tuples or 0
+ # In non-returning case, set rowcount to the cumulated number of
+ # rows of executed queries.
+ for res in results:
+ self._rowcount += res.command_tuples or 0
def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
if self._conn._pipeline:
def _check_result_for_fetch(self) -> None:
if self.closed:
raise e.InterfaceError("the cursor is closed")
- res = self.pgresult
- if not res:
+
+ if not (res := self.pgresult):
raise e.ProgrammingError("no result available")
- status = res.status
- if status == TUPLES_OK:
+ if (status := res.status) == TUPLES_OK:
return
elif status == FATAL_ERROR:
raise e.error_from_result(res, encoding=self._encoding)
ports.append(str(attempt["port"]))
out = params.copy()
- shosts = ",".join(hosts)
- if shosts:
+ if shosts := ",".join(hosts):
out["host"] = shosts
- shostaddrs = ",".join(hostaddrs)
- if shostaddrs:
+ if shostaddrs := ",".join(hostaddrs):
out["hostaddr"] = shostaddrs
sports = ",".join(ports)
if ports:
def resolve(self, params: dict[str, Any]) -> dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
- attempts = self._get_attempts(params)
- if not attempts:
+ if not (attempts := self._get_attempts(params)):
return params
hps = []
async def resolve_async(self, params: dict[str, Any]) -> dict[str, Any]:
"""Update the parameters host and port after SRV lookup."""
- attempts = self._get_attempts(params)
- if not attempts:
+ if not (attempts := self._get_attempts(params)):
return params
hps = []
host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
hosts_in = host_arg.split(",")
port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
- ports_in = port_arg.split(",")
-
- if len(ports_in) == 1:
+ if len((ports_in := port_arg.split(","))) == 1:
# If only one port is specified, it applies to all the hosts.
ports_in *= len(hosts_in)
if len(ports_in) != len(hosts_in):
out = []
srv_found = False
for host, port in zip(hosts_in, ports_in):
- m = self.re_srv_rr.match(host)
- if m or port.lower() == "srv":
+ if (m := self.re_srv_rr.match(host)) or port.lower() == "srv":
srv_found = True
target = m.group("target") if m else None
hp = HostPort(host=host, port=port, totry=True, target=target)
from .conninfo import conninfo_to_dict
params = conninfo_to_dict(conninfo)
- pgenc = params.get("client_encoding")
- if pgenc:
+ if pgenc := params.get("client_encoding"):
try:
return pg2pyenc(str(pgenc).encode())
except NotSupportedError:
exception = None
while self.result_queue:
- results = yield from fetch_many(self.pgconn)
- if not results:
+ if not (results := (yield from fetch_many(self.pgconn))):
# No more results to fetch, but there may still be pending
# commands.
break
# The user doesn't want this query to be prepared
return Prepare.NO, b""
- key = self.key(query)
- name = self._names.get(key)
- if name:
+ if name := self._names.get(key := self.key(query)):
# The query was already prepared in this session
return Prepare.YES, name
# We cannot prepare a multiple statement
return False
- status = results[0].status
- if COMMAND_OK != status != TUPLES_OK:
+ if COMMAND_OK != results[0].status != TUPLES_OK:
# We don't prepare failed queries or other weird results
return False
if self.prepare_threshold is None:
return None
- key = self.key(query)
- if key in self._counts:
+ if (key := self.key(query)) in self._counts:
if prep is Prepare.SHOULD:
del self._counts[key]
self._names[key] = name
return key
def validate(
- self,
- key: Key,
- prep: Prepare,
- name: bytes,
- results: Sequence[PGresult],
+ self, key: Key, prep: Prepare, name: bytes, results: Sequence[PGresult]
) -> None:
"""Validate cached entry with 'key' by checking query 'results'.
# right size.
if self._row_dumpers:
for i in range(nparams):
- param = params[i]
- if param is not None:
+ if (param := params[i]) is not None:
out[i] = self._row_dumpers[i].dump(param)
return out
pqformats = [TEXT] * nparams
for i in range(nparams):
- param = params[i]
- if param is None:
+ if (param := params[i]) is None:
continue
dumper = self.get_dumper(param, formats[i])
out[i] = dumper.dump(param)
try:
type_sql = self._oid_types[oid]
except KeyError:
- ti = self.adapters.types.get(oid)
- if ti:
+ if ti := self.adapters.types.get(oid):
if oid < 8192:
# builtin: prefer "timestamptz" to "timestamp with time zone"
type_sql = ti.name.encode(self.encoding)
cache[key] = dumper = dcls(key, self)
# Check if the dumper requires an upgrade to handle this specific value
- key1 = dumper.get_key(obj, format)
- if key1 is key:
+ if (key1 := dumper.get_key(obj, format)) is key:
return dumper
# If it does, ask the dumper to create its own upgraded version
return dumper
def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]:
- res = self._pgresult
- if not res:
+ if not (res := self._pgresult):
raise e.InterfaceError("result not set")
if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
for row in range(row0, row1):
record: list[Any] = [None] * self._nfields
for col in range(self._nfields):
- val = res.get_value(row, col)
- if val is not None:
+ if (val := res.get_value(row, col)) is not None:
record[col] = self._row_loaders[col](val)
records.append(make_row(record))
return records
def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
- res = self._pgresult
- if not res:
+ if not (res := self._pgresult):
return None
if not 0 <= row < self._ntuples:
record: list[Any] = [None] * self._nfields
for col in range(self._nfields):
- val = res.get_value(row, col)
- if val is not None:
+ if (val := res.get_value(row, col)) is not None:
record[col] = self._row_loaders[col](val)
return make_row(record)
except KeyError:
pass
- loader_cls = self._adapters.get_loader(oid, format)
- if not loader_cls:
- loader_cls = self._adapters.get_loader(INVALID_OID, format)
- if not loader_cls:
+ if not (loader_cls := self._adapters.get_loader(oid, format)):
+ if not (loader_cls := self._adapters.get_loader(INVALID_OID, format)):
raise e.InterfaceError("unknown oid loader not found")
loader = self._loaders[format][oid] = loader_cls(oid, self)
return loader
rv.append(QueryPart(pre, 0, PyFormat.AUTO))
break
- ph = m.group(0)
- if ph == b"%%":
+ if (ph := m.group(0)) == b"%%":
# unescape '%%' to '%' if necessary, then merge the parts
if collapse_double_percent:
ph = b"%"
@classmethod
def _parse_string(cls, s: str) -> Xid:
- m = _re_xid.match(s)
- if not m:
+ if not (m := _re_xid.match(s)):
raise ValueError("bad Xid format")
format_id = int(m.group(1))
@classmethod
def _has_to_regtype_function(cls, conn: BaseConnection[Any]) -> bool:
# to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
- info = conn.info
- if info.vendor == "PostgreSQL":
+ if (info := conn.info).vendor == "PostgreSQL":
return info.server_version >= 90400
elif info.vendor == "CockroachDB":
return info.server_version >= 220200
pass
def get_type_display(self, oid: int | None = None, fmod: int | None = None) -> str:
- parts = []
- parts.append(self.name)
- mod = self.typemod.get_modifier(fmod) if fmod is not None else ()
- if mod:
+ parts = [self.name]
+ if mod := (self.typemod.get_modifier(fmod) if fmod is not None else ()):
parts.append(f"({','.join(map(str, mod))})")
if oid == self.array_oid:
for most types and you won't likely have to implement this method in a
subclass.
"""
- value = self.dump(obj)
- if value is None:
+ if (value := self.dump(obj)) is None:
return b"NULL"
if self.connection:
with self.lock:
self._check_connection_ok()
- pipeline = self._pipeline
- if pipeline is None:
+ if (pipeline := self._pipeline) is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = Pipeline(self)
async with self.lock:
self._check_connection_ok()
- pipeline = self._pipeline
- if pipeline is None:
+ if (pipeline := self._pipeline) is None:
# WARNING: reference loop, broken ahead.
pipeline = self._pipeline = AsyncPipeline(self)
if not s:
return "''"
- s = re_escape.sub(r"\\\1", s)
- if re_space.search(s):
+ if re_space.search(s := re_escape.sub(r"\\\1", s)):
s = "'" + s + "'"
return s
Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
"""
- sver = self.parameter_status("crdb_version")
- if not sver:
+ if not (sver := self.parameter_status("crdb_version")):
raise e.InternalError("'crdb_version' parameter status not set")
- ver = self.parse_crdb_version(sver)
- if ver is None:
+ if (ver := self.parse_crdb_version(sver)) is None:
raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
return ver
@classmethod
def parse_crdb_version(self, sver: str) -> int | None:
- m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
- if not m:
+ if not (m := re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)):
return None
return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
with self._conn.lock:
- p = self._conn._pipeline
- if p:
+ if p := self._conn._pipeline:
self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
first = True
while self._conn.wait(self._stream_fetchone_gen(first)):
for pos in range(size):
- rec = self._tx.load_row(pos, self._make_row)
- if rec is None:
+ if (rec := self._tx.load_row(pos, self._make_row)) is None:
break
yield rec
first = False
"""
self._fetch_pipeline()
self._check_result_for_fetch()
- record = self._tx.load_row(self._pos, self._make_row)
- if record is not None:
+ if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
self._pos += 1
return record
return self._tx.load_row(pos, self._make_row)
while True:
- row = load(self._pos)
- if row is None:
+ if (row := load(self._pos)) is None:
break
self._pos += 1
yield row
return self
async def executemany(
- self,
- query: Query,
- params_seq: Iterable[Params],
- *,
- returning: bool = False,
+ self, query: Query, params_seq: Iterable[Params], *, returning: bool = False
) -> None:
"""
Execute the same command with a sequence of input data.
# If there is already a pipeline, ride it, in order to avoid
# sending unnecessary Sync.
async with self._conn.lock:
- p = self._conn._pipeline
- if p:
+ if p := self._conn._pipeline:
await self._conn.wait(
self._executemany_gen_pipeline(query, params_seq, returning)
)
first = True
while await self._conn.wait(self._stream_fetchone_gen(first)):
for pos in range(size):
- rec = self._tx.load_row(pos, self._make_row)
- if rec is None:
+ if (rec := self._tx.load_row(pos, self._make_row)) is None:
break
yield rec
first = False
"""
await self._fetch_pipeline()
self._check_result_for_fetch()
- record = self._tx.load_row(self._pos, self._make_row)
- if record is not None:
+ if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
self._pos += 1
return record
if not size:
size = self.arraysize
records = self._tx.load_rows(
- self._pos,
- min(self._pos + size, self.pgresult.ntuples),
- self._make_row,
+ self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
)
self._pos += len(records)
return records
return self._tx.load_row(pos, self._make_row)
while True:
- row = load(self._pos)
- if row is None:
+ if (row := load(self._pos)) is None:
break
self._pos += 1
yield row
self.obj = obj
def __repr__(self) -> str:
- sobj = repr(self.obj)
- if len(sobj) > 40:
+ if len((sobj := repr(self.obj))) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
return f"{self.__class__.__name__}({sobj})"
if conn.status == BAD:
encoding = conninfo_encoding(conninfo)
raise e.OperationalError(
- f"connection is bad: {conn.get_error_message(encoding)}",
- pgconn=conn,
+ f"connection is bad: {conn.get_error_message(encoding)}", pgconn=conn
)
status = conn.connect_poll()
while True:
if deadline and monotonic() > deadline:
raise e.CancellationTimeout("cancellation timeout expired")
- status = cancel_conn.poll()
- if status == POLL_OK:
+
+ if (status := cancel_conn.poll()) == POLL_OK:
break
elif status == POLL_READING:
yield cancel_conn.socket, WAIT_R
to retrieve the results available.
"""
while True:
- f = pgconn.flush()
- if f == 0:
+ if pgconn.flush() == 0:
break
while True:
- ready = yield WAIT_RW
- if ready:
+ if ready := (yield WAIT_RW):
break
if ready & READY_R:
"""
if pgconn.is_busy():
while True:
- ready = yield WAIT_R
- if ready:
+ if (yield WAIT_R):
break
while True:
if not pgconn.is_busy():
break
while True:
- ready = yield WAIT_R
- if ready:
+ if (yield WAIT_R):
break
_consume_notifies(pgconn)
while True:
while True:
- ready = yield WAIT_RW
- if ready:
+ if ready := (yield WAIT_RW):
break
if ready & READY_R:
res: list[PGresult] = []
while not pgconn.is_busy():
- r = pgconn.get_result()
- if r is None:
+ if (r := pgconn.get_result()) is None:
if not res:
break
results.append(res)
res = []
+ elif (status := r.status) == PIPELINE_SYNC:
+ assert not res
+ results.append([r])
+ elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+ # This shouldn't happen, but insisting hard enough, it will.
+ # For instance, in test_executemany_badquery(), with the COPY
+ # statement and the AsyncClientCursor, which disables
+ # prepared statements).
+ # Bail out from the resulting infinite loop.
+ raise e.NotSupportedError("COPY cannot be used in pipeline mode")
else:
- status = r.status
- if status == PIPELINE_SYNC:
- assert not res
- results.append([r])
- elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
- # This shouldn't happen, but insisting hard enough, it will.
- # For instance, in test_executemany_badquery(), with the COPY
- # statement and the AsyncClientCursor, which disables
- # prepared statements).
- # Bail out from the resulting infinite loop.
- raise e.NotSupportedError(
- "COPY cannot be used in pipeline mode"
- )
- else:
- res.append(r)
+ res.append(r)
if ready & READY_W:
pgconn.flush()
def _consume_notifies(pgconn: PGconn) -> None:
# Consume notifies
while True:
- n = pgconn.notifies()
- if not n:
+ if not (n := pgconn.notifies()):
break
if pgconn.notify_handler:
pgconn.notify_handler(n)
ns = []
while True:
- n = pgconn.notifies()
- if n:
+ if n := pgconn.notifies():
ns.append(n)
if pgconn.notify_handler:
pgconn.notify_handler(n)
# would block
while True:
- ready = yield WAIT_R
- if ready:
+ if (yield WAIT_R):
break
pgconn.consume_input()
return data
# Retrieve the final result of copy
- results = yield from _fetch_many(pgconn)
- if len(results) > 1:
+
+ if len(results := (yield from _fetch_many(pgconn))) > 1:
# TODO: too brutal? Copy worked.
raise e.ProgrammingError("you cannot mix COPY with other operations")
- result = results[0]
- if result.status != COMMAND_OK:
+
+ if (result := results[0]).status != COMMAND_OK:
raise e.error_from_result(result, encoding=pgconn._encoding)
return result
# do it upstream the queue decoupling the writer task from the producer one.
while pgconn.put_copy_data(buffer) == 0:
while True:
- ready = yield WAIT_W
- if ready:
+ if (yield WAIT_W):
break
# Flushing often has a good effect on macOS because memcpy operations
# Repeat until it the message is flushed to the server
while True:
while True:
- ready = yield WAIT_W
- if ready:
+ if (yield WAIT_W):
break
- f = pgconn.flush()
- if f == 0:
+
+ if pgconn.flush() == 0:
break
# Retry enqueuing end copy message until successful
while pgconn.put_copy_end(error) == 0:
while True:
- ready = yield WAIT_W
- if ready:
+ if (yield WAIT_W):
break
# Repeat until it the message is flushed to the server
while True:
while True:
- ready = yield WAIT_W
- if ready:
+ if (yield WAIT_W):
break
- f = pgconn.flush()
- if f == 0:
+
+ if pgconn.flush() == 0:
break
# Retrieve the final result of copy
return f"<{cls} {info} at 0x{id(self):x}>"
def __getattr__(self, attr: str) -> Any:
- value = getattr(self._pgconn, attr)
- if callable(value):
+ if callable((value := getattr(self._pgconn, attr))):
return debugging(value)
else:
logger.info("PGconn.%s -> %s", attr, value)
from .misc import find_libpq_full_path, version_pretty
from ..errors import NotSupportedError
-libname = find_libpq_full_path()
-if not libname:
+if not (libname := find_libpq_full_path()):
raise ImportError("libpq library not found")
pq = ctypes.cdll.LoadLibrary(libname)
FILE_ptr = POINTER(FILE)
if sys.platform == "linux":
- libcname = ctypes.util.find_library("c")
- if not libcname:
+ if not (libcname := ctypes.util.find_library("c")):
# Likely this is a system using musl libc, see the following bug:
# https://github.com/python/cpython/issues/65821
libcname = "libc.so"
fdopen.argtypes = (c_int, c_char_p)
fdopen.restype = FILE_ptr
-
# Get the libpq version to define what functions are available.
PQlibVersion = pq.PQlibVersion
if libpq_version >= 100000:
PQencryptPasswordConn = pq.PQencryptPasswordConn
- PQencryptPasswordConn.argtypes = [
- PGconn_ptr,
- c_char_p,
- c_char_p,
- c_char_p,
- ]
+ PQencryptPasswordConn.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_char_p]
PQencryptPasswordConn.restype = POINTER(c_char)
else:
PQencryptPasswordConn = not_supported_before("PQencryptPasswordConn", 100000)
if libpq_version >= 170000:
PQchangePassword = pq.PQchangePassword
- PQchangePassword.argtypes = [
- PGconn_ptr,
- c_char_p,
- c_char_p,
- ]
+ PQchangePassword.argtypes = [PGconn_ptr, c_char_p, c_char_p]
PQchangePassword.restype = PGresult_ptr
else:
PQchangePassword = not_supported_before("PQchangePassword", 170000)
@cache
def find_libpq_full_path() -> str | None:
if sys.platform == "win32":
- libname = ctypes.util.find_library("libpq.dll")
- if libname is None:
+ if (libname := ctypes.util.find_library("libpq.dll")) is None:
return None
libname = str(Path(libname).resolve())
-
elif sys.platform == "darwin":
libname = ctypes.util.find_library("libpq.dylib")
# (hopefully) temporary hack: libpq not in a standard place
import subprocess as sp
libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
- libname = os.path.join(libdir, "libpq.dylib")
- if not os.path.exists(libname):
+ if not os.path.exists((libname := os.path.join(libdir, "libpq.dylib"))):
libname = None
except Exception as ex:
logger.debug("couldn't use pg_config to find libpq: %s", ex)
def strip_severity(msg: str) -> str:
"""Strip severity and whitespaces from error message."""
- m = PREFIXES.match(msg)
- if m:
+ if m := PREFIXES.match(msg):
msg = msg[m.span()[1] :]
return msg.strip()
def _clean_error_message(msg: bytes, encoding: str) -> str:
- smsg = msg.decode(encoding, "replace")
- if smsg:
+ if smsg := msg.decode(encoding, "replace"):
return strip_severity(smsg)
else:
return "no error details available"
else:
status = ConnStatus(pgconn.status).name
- sparts = " ".join("%s=%s" % part for part in parts)
- if sparts:
+ if sparts := " ".join(("%s=%s" % part for part in parts)):
sparts = f" ({sparts})"
return f"[{status}]{sparts}"
def connect(cls, conninfo: bytes) -> PGconn:
if not isinstance(conninfo, bytes):
raise TypeError(f"bytes expected, got {type(conninfo)} instead")
-
- pgconn_ptr = impl.PQconnectdb(conninfo)
- if not pgconn_ptr:
+ if not (pgconn_ptr := impl.PQconnectdb(conninfo)):
raise MemoryError("couldn't allocate PGconn")
return cls(pgconn_ptr)
def connect_start(cls, conninfo: bytes) -> PGconn:
if not isinstance(conninfo, bytes):
raise TypeError(f"bytes expected, got {type(conninfo)} instead")
-
- pgconn_ptr = impl.PQconnectStart(conninfo)
- if not pgconn_ptr:
+ if not (pgconn_ptr := impl.PQconnectStart(conninfo)):
raise MemoryError("couldn't allocate PGconn")
return cls(pgconn_ptr)
@property
def info(self) -> list[ConninfoOption]:
self._ensure_pgconn()
- opts = impl.PQconninfo(self._pgconn_ptr)
- if not opts:
+ if not (opts := impl.PQconninfo(self._pgconn_ptr)):
raise MemoryError("couldn't allocate connection info")
try:
return Conninfo._options_from_array(opts)
@property
def socket(self) -> int:
- rv = self._call_int(impl.PQsocket)
- if rv == -1:
+ if (rv := self._call_int(impl.PQsocket)) == -1:
raise e.OperationalError("the connection is lost")
return rv
if not isinstance(command, bytes):
raise TypeError(f"bytes expected, got {type(command)} instead")
self._ensure_pgconn()
- rv = impl.PQexec(self._pgconn_ptr, command)
- if not rv:
+ if not (rv := impl.PQexec(self._pgconn_ptr, command)):
raise e.OperationalError(
f"executing query failed: {self.get_error_message()}"
)
command, param_values, param_types, param_formats, result_format
)
self._ensure_pgconn()
- rv = impl.PQexecParams(*args)
- if not rv:
+ if not (rv := impl.PQexecParams(*args)):
raise e.OperationalError(
f"executing query failed: {self.get_error_message()}"
)
)
def send_prepare(
- self,
- name: bytes,
- command: bytes,
- param_types: Sequence[int] | None = None,
+ self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
) -> None:
atypes: Array[impl.Oid] | None
if not param_types:
)
def prepare(
- self,
- name: bytes,
- command: bytes,
- param_types: Sequence[int] | None = None,
+ self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
) -> PGresult:
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
atypes = (impl.Oid * nparams)(*param_types)
self._ensure_pgconn()
- rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)
- if not rv:
+ if not (rv := impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)):
raise e.OperationalError(
f"preparing query failed: {self.get_error_message()}"
)
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
self._ensure_pgconn()
- rv = impl.PQdescribePrepared(self._pgconn_ptr, name)
- if not rv:
+ if not (rv := impl.PQdescribePrepared(self._pgconn_ptr, name)):
raise e.OperationalError(
f"describe prepared failed: {self.get_error_message()}"
)
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
self._ensure_pgconn()
- rv = impl.PQdescribePortal(self._pgconn_ptr, name)
- if not rv:
+ if not (rv := impl.PQdescribePortal(self._pgconn_ptr, name)):
raise e.OperationalError(
f"describe portal failed: {self.get_error_message()}"
)
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
self._ensure_pgconn()
- rv = impl.PQclosePrepared(self._pgconn_ptr, name)
- if not rv:
+ if not (rv := impl.PQclosePrepared(self._pgconn_ptr, name)):
raise e.OperationalError(
f"close prepared failed: {self.get_error_message()}"
)
if not isinstance(name, bytes):
raise TypeError(f"'name' must be bytes, got {type(name)} instead")
self._ensure_pgconn()
- rv = impl.PQclosePortal(self._pgconn_ptr, name)
- if not rv:
+ if not (rv := impl.PQclosePortal(self._pgconn_ptr, name)):
raise e.OperationalError(f"close portal failed: {self.get_error_message()}")
return PGresult(rv)
See :pq:`PQcancelCreate` for details.
"""
- rv = impl.PQcancelCreate(self._pgconn_ptr)
- if not rv:
+ if not (rv := impl.PQcancelCreate(self._pgconn_ptr)):
raise e.OperationalError("couldn't create cancelConn object")
return PGcancelConn(rv)
See :pq:`PQgetCancel` for details.
"""
- rv = impl.PQgetCancel(self._pgconn_ptr)
- if not rv:
+ if not (rv := impl.PQgetCancel(self._pgconn_ptr)):
raise e.OperationalError("couldn't create cancel object")
return PGcancel(rv)
def notifies(self) -> PGnotify | None:
- ptr = impl.PQnotifies(self._pgconn_ptr)
- if ptr:
+ if ptr := impl.PQnotifies(self._pgconn_ptr):
c = ptr.contents
rv = PGnotify(c.relname, c.be_pid, c.extra)
impl.PQfreemem(ptr)
def put_copy_data(self, buffer: abc.Buffer) -> int:
if not isinstance(buffer, bytes):
buffer = bytes(buffer)
- rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
- if rv < 0:
+ if (rv := impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))) < 0:
raise e.OperationalError(
f"sending copy data failed: {self.get_error_message()}"
)
return rv
def put_copy_end(self, error: bytes | None = None) -> int:
- rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
- if rv < 0:
+ if (rv := impl.PQputCopyEnd(self._pgconn_ptr, error)) < 0:
raise e.OperationalError(
f"sending copy end failed: {self.get_error_message()}"
)
)
def make_empty_result(self, exec_status: int) -> PGresult:
- rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)
- if not rv:
+ if not (rv := impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)):
raise MemoryError("couldn't allocate empty PGresult")
return PGresult(rv)
:raises ~e.OperationalError: if the connection is not in pipeline mode
or if sync failed.
"""
- rv = impl.PQpipelineSync(self._pgconn_ptr)
- if rv == 0:
+
+ if (rv := impl.PQpipelineSync(self._pgconn_ptr)) == 0:
raise e.OperationalError("connection not in pipeline mode")
if rv != 1:
raise e.OperationalError("failed to sync pipeline")
if length:
v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
return string_at(v, length)
+ elif impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
+ return None
else:
- if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
- return None
- else:
- return b""
+ return b""
@property
def nparams(self) -> int:
impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore
]
array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore
- rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
- if rv == 0:
+
+ if impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) == 0:
raise e.OperationalError("PQsetResultAttrs failed")
@property
def socket(self) -> int:
- rv = impl.PQcancelSocket(self.pgcancelconn_ptr)
- if rv == -1:
+ if (rv := impl.PQcancelSocket(self.pgcancelconn_ptr)) == -1:
raise e.OperationalError("cancel connection not opened")
return rv
@classmethod
def get_defaults(cls) -> list[ConninfoOption]:
- opts = impl.PQconndefaults()
- if not opts:
+ if not (opts := impl.PQconndefaults()):
raise MemoryError("couldn't allocate connection defaults")
try:
return cls._options_from_array(opts)
# TODO: might be done without copy (however C does that)
if not isinstance(data, bytes):
data = bytes(data)
- out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))
- if not out:
+
+ if not (out := impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))):
raise e.OperationalError(
f"escape_literal failed: {self.conn.get_error_message()} bytes"
)
if not isinstance(data, bytes):
data = bytes(data)
- out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))
- if not out:
+
+ if not (out := impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))):
raise e.OperationalError(
f"escape_identifier failed: {self.conn.get_error_message()} bytes"
)
The dictionary keys are taken from the column names of the returned columns.
"""
- names = _get_names(cursor)
- if names is None:
- return no_result
+ if (names := _get_names(cursor)) is not None:
+
+ def dict_row_(values: Sequence[Any]) -> dict[str, Any]:
+ return dict(zip(names, values))
- def dict_row_(values: Sequence[Any]) -> dict[str, Any]:
- return dict(zip(names, values))
+ return dict_row_
- return dict_row_
+ else:
+ return no_result
def namedtuple_row(cursor: BaseCursor[Any, Any]) -> RowMaker[NamedTuple]:
The field names are taken from the column names of the returned columns,
with some mangling to deal with invalid names.
"""
- res = cursor.pgresult
- if not res:
- return no_result
-
- nfields = _get_nfields(res)
- if nfields is None:
+ if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+ nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
+ return nt._make
+ else:
return no_result
- nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
- return nt._make
-
@functools.lru_cache(512)
def _make_nt(enc: str, *names: bytes) -> type[NamedTuple]:
"""
def class_row_(cursor: BaseCursor[Any, Any]) -> RowMaker[T]:
- names = _get_names(cursor)
- if names is None:
- return no_result
+ if (names := _get_names(cursor)) is not None:
- def class_row__(values: Sequence[Any]) -> T:
- return cls(**dict(zip(names, values)))
+ def class_row__(values: Sequence[Any]) -> T:
+ return cls(**dict(zip(names, values)))
- return class_row__
+ return class_row__
+
+ else:
+ return no_result
return class_row_
"""
def kwargs_row_(cursor: BaseCursor[Any, T]) -> RowMaker[T]:
- names = _get_names(cursor)
- if names is None:
- return no_result
+ if (names := _get_names(cursor)) is not None:
- def kwargs_row__(values: Sequence[Any]) -> T:
- return func(**dict(zip(names, values)))
+ def kwargs_row__(values: Sequence[Any]) -> T:
+ return func(**dict(zip(names, values)))
- return kwargs_row__
+ return kwargs_row__
+
+ else:
+ return no_result
return kwargs_row_
Generate a row factory returning the first column
as a scalar value.
"""
- res = cursor.pgresult
- if not res:
- return no_result
+ if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+ if nfields < 1:
+ raise e.ProgrammingError("at least one column expected")
- nfields = _get_nfields(res)
- if nfields is None:
- return no_result
-
- if nfields < 1:
- raise e.ProgrammingError("at least one column expected")
+ def scalar_row_(values: Sequence[Any]) -> Any:
+ return values[0]
- def scalar_row_(values: Sequence[Any]) -> Any:
- return values[0]
+ return scalar_row_
- return scalar_row_
+ else:
+ return no_result
def no_result(values: Sequence[Any]) -> NoReturn:
def _get_names(cursor: BaseCursor[Any, Any]) -> list[str] | None:
- res = cursor.pgresult
- if not res:
- return None
-
- nfields = _get_nfields(res)
- if nfields is None:
+ if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+ enc = cursor._encoding
+ return [
+ res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
+ ]
+ else:
return None
- enc = cursor._encoding
- return [
- res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr]
- ]
-
def _get_nfields(res: PGresult) -> int | None:
"""
__slots__ = "_name _scrollable _withhold _described itersize _format".split()
- def __init__(
- self,
- name: str,
- scrollable: bool | None,
- withhold: bool,
- ):
+ def __init__(self, name: str, scrollable: bool | None, withhold: bool):
self._name = name
self._scrollable = scrollable
self._withhold = withhold
return self._pos if tuples else None
def _declare_gen(
- self,
- query: Query,
- params: Params | None = None,
- binary: bool | None = None,
+ self, query: Query, params: Params | None = None, binary: bool | None = None
) -> PQGen[None]:
"""Generator implementing `ServerCursor.execute()`."""
if not isinstance(query, sql.Composable):
query = sql.SQL(query)
- parts = [
- sql.SQL("DECLARE"),
- sql.Identifier(self._name),
- ]
+ parts = [sql.SQL("DECLARE"), sql.Identifier(self._name)]
if self._scrollable is not None:
parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
parts.append(sql.SQL("CURSOR"))
return self
def executemany(
- self,
- query: Query,
- params_seq: Iterable[Params],
- *,
- returning: bool = True,
+ self, query: Query, params_seq: Iterable[Params], *, returning: bool = True
) -> None:
"""Method not implemented for server-side cursors."""
raise e.NotSupportedError("executemany not supported on server-side cursors")
return self
async def executemany(
- self,
- query: Query,
- params_seq: Iterable[Params],
- *,
- returning: bool = True,
+ self, query: Query, params_seq: Iterable[Params], *, returning: bool = True
) -> None:
raise e.NotSupportedError("executemany not supported on server-side cursors")
:type context: `connection` or `cursor`
"""
- conn = context.connection if context else None
- enc = conn_encoding(conn)
- b = self.as_bytes(context)
- if isinstance(b, bytes):
+ enc = conn_encoding(context.connection if context else None)
+ if isinstance((b := self.as_bytes(context)), bytes):
return b.decode(enc)
else:
# buffer object
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
def as_bytes(self, context: AdaptContext | None = None) -> bytes:
- conn = context.connection if context else None
- if conn:
+ if conn := (context.connection if context else None):
esc = Escaping(conn.pgconn)
enc = conn_encoding(conn)
escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
Find the first non-null element of an eventually nested list
"""
items = list(self._flatiter(L, set()))
- types = {type(item): item for item in items}
- if not types:
+ if not (types := {type(item): item for item in items}):
return None
if len(types) == 1:
# More than one type in the list. It might be still good, as long
# as they dump with the same oid (e.g. IPv4Network, IPv6Network).
dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
- oids = {d.oid for d in dumpers}
- if len(oids) == 1:
+ if len({d.oid for d in dumpers}) == 1:
t, v = types.popitem()
else:
raise e.DataError(
Return text info as fallback.
"""
if base_oid:
- info = self._tx.adapters.types.get(base_oid)
- if info:
+ if info := self._tx.adapters.types.get(base_oid):
return info
return self._tx.adapters.types["text"]
if self.oid:
return self.cls
- item = self._find_list_element(obj, format)
- if item is None:
+ if (item := self._find_list_element(obj, format)) is None:
return self.cls
sd = self._tx.get_dumper(item, format)
if self.oid:
return self
- item = self._find_list_element(obj, format)
- if item is None:
+ if (item := self._find_list_element(obj, format)) is None:
# Empty lists can only be dumped as text if the type is unknown.
return self
if isinstance(item, list):
dump_list(item)
elif item is not None:
- ad = self._dump_item(item)
- if ad is None:
+ if (ad := self._dump_item(item)) is None:
tokens.append(b"NULL")
else:
if needs_quotes(ad):
if self.oid:
return self.cls
- item = self._find_list_element(obj, format)
- if item is None:
+ if (item := self._find_list_element(obj, format)) is None:
return (self.cls,)
sd = self._tx.get_dumper(item, format)
if self.oid:
return self
- item = self._find_list_element(obj, format)
- if item is None:
+ if (item := self._find_list_element(obj, format)) is None:
return ListDumper(self.cls, self._tx)
sd = self._tx.get_dumper(item, format.from_pq(self.format))
if data and data[0] == b"["[0]:
if isinstance(data, memoryview):
data = bytes(data)
- idx = data.find(b"=")
- if idx == -1:
+
+ if (idx := data.find(b"=")) == -1:
raise e.DataError("malformed array: no '=' after dimension information")
data = data[idx + 1 :]
re_parse = _get_array_parse_regexp(delimiter)
for m in re_parse.finditer(data):
- t = m.group(1)
- if t == b"{":
+ if (t := m.group(1)) == b"{":
if stack:
stack[-1].append(a)
stack.append(a)
continue
dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
- ad = dumper.dump(item)
- if ad is None:
+ if (ad := dumper.dump(item)) is None:
ad = b""
elif not ad:
ad = b'""'
raise NotImplementedError
def _get_offset(self, obj: time) -> timedelta:
- offset = obj.utcoffset()
- if offset is None:
+ if (offset := obj.utcoffset()) is None:
raise DataError(
f"cannot calculate the offset of tzinfo '{obj.tzinfo}' without a date"
)
def __init__(self, oid: int, context: AdaptContext | None = None):
super().__init__(oid, context)
- ds = _get_datestyle(self.connection)
- if ds.startswith(b"I"): # ISO
+ if (ds := _get_datestyle(self.connection)).startswith(b"I"): # ISO
self._order = self._ORDER_YMD
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
_re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
def load(self, data: Buffer) -> time:
- m = self._re_format.match(data)
- if not m:
+ if not (m := self._re_format.match(data)):
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse time {s!r}")
)
def load(self, data: Buffer) -> time:
- m = self._re_format.match(data)
- if not m:
+ if not (m := self._re_format.match(data)):
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse timetz {s!r}")
def __init__(self, oid: int, context: AdaptContext | None = None):
super().__init__(oid, context)
-
- ds = _get_datestyle(self.connection)
- if ds.startswith(b"I"): # ISO
+ if (ds := _get_datestyle(self.connection)).startswith(b"I"): # ISO
self._order = self._ORDER_YMD
elif ds.startswith(b"G"): # German
self._order = self._ORDER_DMY
raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
def load(self, data: Buffer) -> datetime:
- m = self._re_format.match(data)
- if not m:
+ if not (m := self._re_format.match(data)):
raise _get_timestamp_load_error(self.connection, data) from None
if self._order == self._ORDER_YMD:
super().__init__(oid, context)
self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
- ds = _get_datestyle(self.connection)
- if ds.startswith(b"I"): # ISO
+ if _get_datestyle(self.connection).startswith(b"I"): # ISO
self._load_method = self._load_iso
else:
self._load_method = self._load_notimpl
@staticmethod
def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime:
- m = self._re_format.match(data)
- if not m:
+ if not (m := self._re_format.match(data)):
raise _get_timestamp_load_error(self.connection, data) from None
ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
# timezone) we can still save the day by shifting the value by the
# timezone offset and then replacing the timezone.
if self._timezone:
- utcoff = self._timezone.utcoffset(
+ if utcoff := self._timezone.utcoffset(
datetime.min if micros < 0 else datetime.max
- )
- if utcoff:
+ ):
usoff = 1_000_000 * int(utcoff.total_seconds())
try:
ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
@staticmethod
def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta:
- m = self._re_interval.match(data)
- if not m:
+ if not (m := self._re_interval.match(data)):
s = bytes(data).decode("utf8", "replace")
raise DataError(f"can't parse interval {s!r}")
def _get_datestyle(conn: BaseConnection[Any] | None) -> bytes:
- if conn:
- ds = conn.pgconn.parameter_status(b"DateStyle")
- if ds:
- return ds
+ if conn and (ds := conn.pgconn.parameter_status(b"DateStyle")):
+ return ds
return b"ISO, DMY"
def _get_intervalstyle(conn: BaseConnection[Any] | None) -> bytes:
- if conn:
- ints = conn.pgconn.parameter_status(b"IntervalStyle")
- if ints:
- return ints
+ if conn and (ints := conn.pgconn.parameter_status(b"IntervalStyle")):
+ return ints
return b"unknown"
if not s:
return False
- ds = _get_datestyle(conn)
- if not ds.startswith(b"P"): # Postgres
+ if not _get_datestyle(conn).startswith(b"P"): # Postgres
return len(s.split()[0]) > 10 # date is first token
else:
return len(s.split()[-1]) > 4 # year is last token
if m is None or m.start() != start:
raise e.DataError(f"error parsing hstore pair at char {start}")
k = _re_unescape.sub(r"\1", m.group(1))
- v = m.group(2)
- if v is not None:
+ if (v := m.group(2)) is not None:
v = _re_unescape.sub(r"\1", v)
rv[k] = v
@cache
def _make_dumper(base: type[abc.Dumper], dumps: JsonDumpsFunction) -> type[abc.Dumper]:
- name = base.__name__
- if not name.startswith("Custom"):
+ if not (name := base.__name__).startswith("Custom"):
name = f"Custom{name}"
return type(name, (base,), {"_dumps": dumps})
@cache
def _make_loader(base: type[Loader], loads: JsonLoadsFunction) -> type[Loader]:
- name = base.__name__
- if not name.startswith("Custom"):
+ if not (name := base.__name__).startswith("Custom"):
name = f"Custom{name}"
return type(name, (base,), {"_loads": loads})
self.dumps = dumps
def __repr__(self) -> str:
- sobj = repr(self.obj)
- if len(sobj) > 40:
+ if len((sobj := repr(self.obj))) > 40:
sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
return f"{self.__class__.__name__}({sobj})"
obj = obj.obj
else:
dumps = self.dumps
- data = dumps(obj)
- if isinstance(data, str):
+ if isinstance((data := dumps(obj)), str):
return data.encode()
return data
oid = _oids.JSONB_OID
def dump(self, obj: Any) -> Buffer | None:
- obj_bytes = super().dump(obj)
- if obj_bytes is not None:
+ if (obj_bytes := super().dump(obj)) is not None:
return b"\x01" + obj_bytes
else:
return None
def load(self, data: Buffer) -> Any:
if data and data[0] != 1:
raise DataError("unknown jsonb binary format: {data[0]}")
- data = data[1:]
- if not isinstance(data, bytes):
+ if not isinstance((data := data[1:]), bytes):
data = bytes(data)
return self.loads(data)
if self.cls is not Multirange:
return self.cls
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.get_key(item, format))
else:
if self.cls is not Multirange:
return self
- item = self._get_item(obj)
- if item is None:
+ if (item := self._get_item(obj)) is None:
return self
dumper: BaseMultirangeDumper
if not obj:
return b"{}"
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
format = Format.BINARY
def dump(self, obj: Multirange[Any]) -> Buffer | None:
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
return str(obj).encode()
def quote(self, obj: Any) -> Buffer:
- value = self.dump(obj)
- if value is None:
+ if (value := self.dump(obj)) is None:
return b"NULL"
return value if obj >= 0 else b" " + value
return str(obj).encode()
def quote(self, obj: Any) -> Buffer:
- value = self.dump(obj)
-
- if value is None:
+ if (value := self.dump(obj)) is None:
return b"NULL"
if not isinstance(value, bytes):
value = bytes(value)
return self._int2_dumper
else:
return self._int4_dumper
+ elif -(2**63) <= obj < 2**63:
+ return self._int8_dumper
else:
- if -(2**63) <= obj < 2**63:
- return self._int8_dumper
- else:
- return self._int_numeric_dumper
+ return self._int_numeric_dumper
class Int2BinaryDumper(Int2Dumper):
# Equivalent of 0-padding left to align the py digits to the pg digits
# but without changing the digits tuple.
- mod = (ndigits - dscale) % DEC_DIGITS
- if mod:
+ if mod := ((ndigits - dscale) % DEC_DIGITS):
wi = DEC_DIGITS - mod
ndigits += wi
# It doesn't seem that Python has an ABC for ordered types.
if x < self._lower: # type: ignore[operator]
return False
- else:
- if x <= self._lower: # type: ignore[operator]
- return False
+ elif x <= self._lower: # type: ignore[operator]
+ return False
if self._upper is not None:
if self._bounds[1] == "]":
if x > self._upper: # type: ignore[operator]
return False
- else:
- if x >= self._upper: # type: ignore[operator]
- return False
+ elif x >= self._upper: # type: ignore[operator]
+ return False
return True
return NotImplemented
for attr in ("_lower", "_upper", "_bounds"):
self_value = getattr(self, attr)
- other_value = getattr(other, attr)
- if self_value == other_value:
+ if self_value == (other_value := getattr(other, attr)):
pass
elif self_value is None:
return True
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Range:
return self.cls
-
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls, sd.get_key(item, format))
else:
# If we are a subclass whose oid is specified we don't need upgrade
if self.cls is not Range:
return self
-
- item = self._get_item(obj)
- if item is None:
+ if (item := self._get_item(obj)) is None:
return self
dumper: BaseRangeDumper
"""
def dump(self, obj: Range[Any]) -> Buffer | None:
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
parts: list[Buffer] = [b"[" if obj.lower_inc else b"("]
def dump_item(item: Any) -> Buffer:
- ad = dump(item)
- if ad is None:
+ if (ad := dump(item)) is None:
return b""
elif not ad:
return b'""'
format = Format.BINARY
def dump(self, obj: Range[Any]) -> Buffer | None:
- item = self._get_item(obj)
- if item is not None:
+ if (item := self._get_item(obj)) is not None:
dump = self._tx.get_dumper(item, self._adapt_format).dump
else:
dump = fail_dump
head |= RANGE_UB_INC
if obj.lower is not None:
- data = dump(obj.lower)
- if data is not None:
+ if (data := dump(obj.lower)) is not None:
out += pack_len(len(data))
out += data
else:
head |= RANGE_LB_INF
if obj.upper is not None:
- data = dump(obj.upper)
- if data is not None:
+ if (data := dump(obj.upper)) is not None:
out += pack_len(len(data))
out += data
else:
def load_range_text(data: Buffer, load: LoadFunc) -> tuple[Range[Any], int]:
if data == b"empty":
return Range(empty=True), 5
-
- m = _re_range.match(data)
- if m is None:
+ if (m := _re_range.match(data)) is None:
raise e.DataError(
f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
)
lower = None
- item = m.group(3)
- if item is None:
- item = m.group(2)
- if item is not None:
+ if (item := m.group(3)) is None:
+ if (item := m.group(2)) is not None:
lower = load(_re_undouble.sub(rb"\1", item))
else:
lower = load(item)
upper = None
- item = m.group(5)
- if item is None:
- item = m.group(4)
- if item is not None:
+ if (item := m.group(5)) is None:
+ if (item := m.group(4)) is not None:
upper = load(_re_undouble.sub(rb"\1", item))
else:
upper = load(item)
def load_range_binary(data: Buffer, load: LoadFunc) -> Range[Any]:
- head = data[0]
- if head & RANGE_EMPTY:
+ if (head := data[0]) & RANGE_EMPTY:
return Range(empty=True)
lb = "[" if head & RANGE_LB_INC else "("
return self._esc.escape_bytea(obj)
def quote(self, obj: Buffer) -> Buffer:
- escaped = self.dump(obj)
- if escaped is None:
+ if (escaped := self.dump(obj)) is None:
return b"NULL"
# We cannot use the base quoting because escape_bytea already returns
with DefaultSelector() as sel:
sel.register(fileno, s)
while True:
- rlist = sel.select(timeout=interval)
- if not rlist:
+ if not (rlist := sel.select(timeout=interval)):
gen.send(READY_NONE)
continue
with DefaultSelector() as sel:
sel.register(fileno, s)
while True:
- rlist = sel.select(timeout=interval)
- if not rlist:
+ if not (rlist := sel.select(timeout=interval)):
gen.send(READY_NONE)
continue
while True:
reader = s & WAIT_R
writer = s & WAIT_W
- if not reader and not writer:
+ if not (reader or writer):
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
ready = 0
while True:
reader = s & WAIT_R
writer = s & WAIT_W
- if not reader and not writer:
+ if not (reader or writer):
raise e.InternalError(f"bad poll status: {s}")
ev.clear()
ready = 0 # type: ignore[assignment]
evmask = _epoll_evmasks[s]
epoll.register(fileno, evmask)
while True:
- fileevs = epoll.poll(interval)
- if not fileevs:
+ if not (fileevs := epoll.poll(interval)):
gen.send(READY_NONE)
continue
ev = fileevs[0][1]
evmask = _poll_evmasks[s]
poll.register(fileno, evmask)
while True:
- fileevs = poll.poll(interval)
- if not fileevs:
+ if not (fileevs := poll.poll(interval)):
gen.send(READY_NONE)
continue
Currently supported: gevent.
"""
# If not imported, don't import it.
- m = sys.modules.get("gevent.monkey")
- if m:
+ if m := sys.modules.get("gevent.monkey"):
try:
if m.is_module_patched("select"):
return True
raise PoolClosed(f"the pool {self.name!r} is not open yet")
def _check_pool_putconn(self, conn: BaseConnection[Any]) -> None:
- pool = getattr(conn, "_pool", None)
- if pool is self:
+ if (pool := getattr(conn, "_pool", None)) is self:
return
if pool:
while self._waiting:
# If there is a client waiting (which is still waiting and
# hasn't timed out), give it the connection and notify it.
- pos = self._waiting.popleft()
- if pos.set(conn):
+ if self._waiting.popleft().set(conn):
break
else:
# No client waiting for a connection: close the connection
while self._waiting:
# If there is a client waiting (which is still waiting and
# hasn't timed out), give it the connection and notify it.
- pos = self._waiting.popleft()
- if await pos.set(conn):
+ if await self._waiting.popleft().set(conn):
break
else:
# No client waiting for a connection: close the connection
await conn.close()
-
# If we have been asked to wait for pool init, notify the
# waiter if the pool is ready.
if self._pool_full_event:
# Critical section: decide here if there's a connection ready
# or if the client needs to wait.
with self._lock:
- conn = self._get_ready_connection(timeout)
- if not conn:
+ if not (conn := self._get_ready_connection(timeout)):
# No connection available: put the client in the waiting queue
t0 = monotonic()
pos: WaitingClient[CT] = WaitingClient()
StopWorker is received.
"""
while True:
- task = q.get()
-
- if isinstance(task, StopWorker):
+ if isinstance((task := q.get()), StopWorker):
logger.debug("terminating working task %s", current_thread_name())
return
if self._configure:
self._configure(conn)
- status = conn.pgconn.transaction_status
- if status != TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
sname = TransactionStatus(status).name
raise e.ProgrammingError(
f"connection left in status {sname} by configure function {self._configure}: discarded"
while self._waiting:
# If there is a client waiting (which is still waiting and
# hasn't timed out), give it the connection and notify it.
- pos = self._waiting.popleft()
- if pos.set(conn):
+
+ if self._waiting.popleft().set(conn):
break
else:
# No client waiting for a connection: put it back into the pool
"""
Bring a connection to IDLE state or close it.
"""
- status = conn.pgconn.transaction_status
- if status == TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE:
pass
elif status == TransactionStatus.UNKNOWN:
# Connection closed
if self._reset:
try:
self._reset(conn)
- status = conn.pgconn.transaction_status
- if status != TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
sname = TransactionStatus(status).name
raise e.ProgrammingError(
f"connection left in status {sname} by reset function {self._reset}: discarded"
# Critical section: decide here if there's a connection ready
# or if the client needs to wait.
async with self._lock:
- conn = await self._get_ready_connection(timeout)
- if not conn:
+ if not (conn := (await self._get_ready_connection(timeout))):
# No connection available: put the client in the waiting queue
t0 = monotonic()
pos: WaitingClient[ACT] = WaitingClient()
StopWorker is received.
"""
while True:
- task = await q.get()
-
- if isinstance(task, StopWorker):
+ if isinstance((task := (await q.get())), StopWorker):
logger.debug("terminating working task %s", current_task_name())
return
if self._configure:
await self._configure(conn)
- status = conn.pgconn.transaction_status
- if status != TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
sname = TransactionStatus(status).name
raise e.ProgrammingError(
f"connection left in status {sname} by configure function"
while self._waiting:
# If there is a client waiting (which is still waiting and
# hasn't timed out), give it the connection and notify it.
- pos = self._waiting.popleft()
- if await pos.set(conn):
+
+ if await self._waiting.popleft().set(conn):
break
else:
# No client waiting for a connection: put it back into the pool
self._pool.append(conn)
-
# If we have been asked to wait for pool init, notify the
# waiter if the pool is full.
if self._pool_full_event and len(self._pool) >= self._min_size:
"""
Bring a connection to IDLE state or close it.
"""
- status = conn.pgconn.transaction_status
- if status == TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE:
pass
-
elif status == TransactionStatus.UNKNOWN:
# Connection closed
return
if self._reset:
try:
await self._reset(conn)
- status = conn.pgconn.transaction_status
- if status != TransactionStatus.IDLE:
+ if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
sname = TransactionStatus(status).name
raise e.ProgrammingError(
f"connection left in status {sname} by reset function"
while True:
with self._lock:
now = monotonic()
- task = q[0] if q else None
- if task:
+ if task := (q[0] if q else None):
if task.time <= now:
heappop(q)
else:
while True:
async with self._lock:
now = monotonic()
- task = q[0] if q else None
- if task:
+ if task := (q[0] if q else None):
if task.time <= now:
heappop(q)
else:
def pytest_report_header(config):
- rv = []
-
- rv.append(f"default selector: {selectors.DefaultSelector.__name__}")
- loop = config.getoption("--loop")
- if loop != "default":
+ rv = [f"default selector: {selectors.DefaultSelector.__name__}"]
+ if (loop := config.getoption("--loop")) != "default":
rv.append(f"asyncio loop: {loop}")
return rv
# In case of segfault, pytest doesn't get a chance to write failed tests
# in the cache. As a consequence, retries would find no test failed and
# assume that all tests passed in the previous run, making the whole test pass.
- cache = session.config.cache
- if cache.get("segfault", False):
+ if (cache := session.config.cache).get("segfault", False):
session.warn(Warning("Previous run resulted in segfault! Not running any test"))
session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)"))
raise session.Failed
# We often find the record with {"after": null} at least another time
# in the queue. Let's tolerate an extra one.
for i in range(2):
- row = q.get()
- if row is None:
+ if (row := q.get()) is None:
break
assert json.loads(row.value)["after"] is None, json
else:
# We often find the record with {"after": null} at least another time
# in the queue. Let's tolerate an extra one.
for i in range(2):
- row = await q.get()
- if row is None:
+ if (row := (await q.get())) is None:
break
assert json.loads(row.value)["after"] is None, json
else:
pred = VersionCheck.parse(spec)
pred.whose = "CockroachDB"
- msg = pred.get_skip_message(got)
- if not msg:
+ if not (msg := pred.get_skip_message(got)):
return None
- reason = crdb_skip_message(reason)
- if reason:
+ if reason := crdb_skip_message(reason):
msg = f"{msg}: {reason}"
return msg
def pytest_report_header(config):
- dsn = config.getoption("--test-dsn")
- if dsn is None:
+ if (dsn := config.getoption("--test-dsn")) is None:
return []
try:
except Exception as ex:
server_version = f"unknown ({ex})"
- return [
- f"Server version: {server_version}",
- ]
+ return [f"Server version: {server_version}"]
def pytest_collection_modifyitems(items):
"""
Return the dsn used to connect to the `--test-dsn` database (session-wide).
"""
- dsn = request.config.getoption("--test-dsn")
- if dsn is None:
+ if (dsn := request.config.getoption("--test-dsn")) is None:
pytest.skip("skipping test as no --test-dsn")
warm_up_database(dsn)
"""Open and yield a file for libpq client/server communication traces if
--pq-tracefile option is set.
"""
- tracefile = request.config.getoption("--pq-trace")
- if not tracefile:
+ if not (tracefile := request.config.getoption("--pq-trace")):
yield None
return
def pgconn(dsn, request, tracefile):
"""Return a PGconn connection open to `--test-dsn`."""
check_connection_version(request.node)
-
- conn = pq.PGconn.connect(dsn.encode())
- if conn.status != pq.ConnStatus.OK:
+ if (conn := pq.PGconn.connect(dsn.encode())).status != pq.ConnStatus.OK:
pytest.fail(f"bad connection: {conn.get_error_message()}")
with maybe_trace(conn, tracefile, request.function):
L = ListPopAll()
def _exec_command(command, *args, **kwargs):
- cmdcopy = command
- if isinstance(cmdcopy, bytes):
+ if isinstance((cmdcopy := command), bytes):
cmdcopy = cmdcopy.decode(conn.info.encoding)
elif isinstance(cmdcopy, sql.Composable):
cmdcopy = cmdcopy.as_string(conn)
for mark in node.iter_markers():
if mark.name == "pg":
assert len(mark.args) == 1
- msg = check_postgres_version(pg_version, mark.args[0])
- if msg:
+ if msg := check_postgres_version(pg_version, mark.args[0]):
pytest.skip(msg)
-
elif mark.name in ("crdb", "crdb_skip"):
from .fix_crdb import check_crdb_version
- msg = check_crdb_version(crdb_version, mark)
- if msg:
+ if msg := check_crdb_version(crdb_version, mark):
pytest.skip(msg)
try:
with psycopg.connect(dsn, connect_timeout=10) as conn:
conn.execute("select 1")
-
pg_version = conn.info.server_version
-
crdb_version = None
- param = conn.info.parameter_status("crdb_version")
- if param:
+ if param := conn.info.parameter_status("crdb_version"):
from psycopg.crdb import CrdbConnectionInfo
crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
try:
cur.execute(self._insert_field_stmt(j), (val,))
except psycopg.DatabaseError as e:
- r = repr(val)
- if len(r) > 200:
+ if len((r := repr(val))) > 200:
r = f"{r[:200]}... ({len(r)} chars)"
raise Exception(
f"value {r!r} at record {i} column0 {j} failed insert: {e}"
try:
await acur.execute(self._insert_field_stmt(j), (val,))
except psycopg.DatabaseError as e:
- r = repr(val)
- if len(r) > 200:
+ if len((r := repr(val))) > 200:
r = f"{r[:200]}... ({len(r)} chars)"
raise Exception(
f"value {r!r} at record {i} column0 {j} failed insert: {e}"
def choose_schema(self, ncols=20):
schema: list[tuple[type, ...] | type] = []
while len(schema) < ncols:
- s = self.make_schema(choice(self.types))
- if s is not None:
+ if (s := self.make_schema(choice(self.types))) is not None:
schema.append(s)
self.schema = schema
return schema
except KeyError:
pass
- meth = self._get_method("make", cls)
- if meth:
+ if meth := self._get_method("make", cls):
self._makers[cls] = meth
return meth
else:
parts = name.split(".")
for i in range(len(parts) - 1, -1, -1):
- mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
- meth = getattr(self, mname, None)
- if meth:
+ if meth := getattr(self, f"{prefix}_{'_'.join(parts[-(i + 1):])}", None):
return meth
return None
def example(self, spec):
# A good representative of the object - no degenerate case
cls = spec if isinstance(spec, type) else spec[0]
- meth = self._get_method("example", cls)
- if meth:
+ if meth := self._get_method("example", cls):
return meth(spec)
else:
return self.make(spec)
def match_float(self, spec, got, want, rel=None):
if got is not None and isnan(got):
assert isnan(want)
+ elif rel or self._server_rounds():
+ assert got == pytest.approx(want, rel=rel)
else:
- if rel or self._server_rounds():
- assert got == pytest.approx(want, rel=rel)
- else:
- assert got == want
+ assert got == want
def _server_rounds(self):
"""Return True if the connected server perform float rounding"""
def schema_list(self, cls):
while True:
- scls = choice(self.types)
- if scls is cls:
- continue
- if scls is float:
- # TODO: float lists are currently adapted as decimal.
- # There may be rounding errors or problems with inf.
+ # TODO: float lists are currently adapted as decimal.
+ # There may be rounding errors or problems with inf.
+ if (scls := choice(self.types)) is cls or scls is float:
continue
# CRDB doesn't support arrays of json
# https://github.com/cockroachdb/cockroach/issues/23468
if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb):
continue
-
- schema = self.make_schema(scls)
- if schema is not None:
+ if (schema := self.make_schema(scls)) is not None:
break
return (cls, schema)
out: list[Range[Any]] = []
for i in range(length):
- r = self.make_Range((Range, spec[1]), **kwargs)
- if r.isempty:
+ if (r := self.make_Range((Range, spec[1]), **kwargs)).isempty:
continue
for r2 in out:
if overlap(r, r2):
rec_types = [list, dict]
scal_types = [type(None), int, JsonFloat, bool, str]
if random() < container_chance:
- cls = choice(rec_types)
- if cls is list:
+ if (cls := choice(rec_types)) is list:
return [
self._make_json(container_chance=container_chance / 2.0)
for i in range(randrange(self.json_max_length))
def pytest_runtest_setup(item):
for m in item.iter_markers(name="libpq"):
assert len(m.args) == 1
- msg = check_libpq_version(pq.version(), m.args[0])
- if msg:
+
+ if msg := check_libpq_version(pq.version(), m.args[0]):
pytest.skip(msg)
@pytest.fixture
def trace(libpq):
- pqver = pq.__build_version__
- if pqver < 140000:
+ if (pqver := pq.__build_version__) < 140000:
pytest.skip(f"trace not available on libpq {pqver}")
if sys.platform != "linux":
pytest.skip(f"trace not available on {sys.platform}")
return
logging.info("starting proxy")
- pproxy = which("pproxy")
- if not pproxy:
+
+ if not (pproxy := which("pproxy")):
raise ValueError("pproxy program not found")
cmdline = [pproxy, "--reuse"]
cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
# send loop
waited_on_send = 0
while True:
- f = pgconn.flush()
- if f == 0:
+ if pgconn.flush() == 0:
break
waited_on_send += 1
if pgconn.is_busy():
select([pgconn.socket], [], [])
continue
- res = pgconn.get_result()
- if res is None:
+
+ if (res := pgconn.get_result()) is None:
break
assert res.status == pq.ExecStatus.TUPLES_OK
results.append(res)
poll = getattr(conn, poll_method)
while True:
assert conn.status != pq.ConnStatus.BAD, conn.error_message
- rv = poll()
- if rv == return_on:
+
+ if (rv := poll()) == return_on:
return
elif rv == pq.PollingStatus.READING:
select([conn.socket], [], [], timeout)
# so it may be that has_password is false but still a password was
# requested by the server and passed by libpq.
info = pq.Conninfo.parse(dsn.encode())
- has_password = (
- "PGPASSWORD" in os.environ
- or [i for i in info if i.keyword == b"password"][0].val is not None
- )
- if has_password:
+
+ if "PGPASSWORD" in os.environ:
+ assert pgconn.used_password
+ if [i for i in info if i.keyword == b"password"][0].val is not None:
assert pgconn.used_password
pgconn.finish()
default=logging.INFO,
)
- args = parser.parse_args()
-
- if args.writer:
+ if (args := parser.parse_args()).writer:
try:
getattr(psycopg.copy, args.writer)
except AttributeError:
self._logger.info("sent prepared '%s' with %s", name.decode(), param_values)
def send_prepare(
- self,
- name: bytes,
- command: bytes,
- param_types: Sequence[int] | None = None,
+ self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
) -> None:
self._pgconn.send_prepare(name, command, param_types)
self._logger.info("prepare %s as '%s'", command.decode(), name.decode())
def get_result(self) -> pq.abc.PGresult | None:
- r = self._pgconn.get_result()
- if r is not None:
+ if (r := self._pgconn.get_result()) is not None:
self._logger.info("got %s result", pq.ExecStatus(r.status).name)
return r
):
while results_queue:
fetched = waiting.wait(
- pipeline_communicate(pgconn, commands),
- pgconn.socket,
+ pipeline_communicate(pgconn, commands), pgconn.socket
)
assert not commands, commands
for results in fetched:
):
while results_queue:
fetched = await waiting.wait_async(
- pipeline_communicate(pgconn, commands),
- pgconn.socket,
+ pipeline_communicate(pgconn, commands), pgconn.socket
)
assert not commands, commands
for results in fetched:
def main() -> None:
- opt = parse_cmdline()
- if opt.loglevel:
+ if (opt := parse_cmdline()).loglevel:
loglevel = getattr(logging, opt.loglevel.upper())
logging.basicConfig(
level=loglevel, format="%(asctime)s %(levelname)s %(message)s"
t0 = time.time()
conn = super().connect(conninfo, **kwargs)
t1 = time.time()
- wait = max(0.0, conn_delay - (t1 - t0))
- if wait:
+
+ if wait := max(0.0, conn_delay - (t1 - t0)):
time.sleep(wait)
return conn
assert dumper.quote(data) == result
-@pytest.mark.parametrize(
- "data, result",
- [
- ("hello", b"'hello'"),
- ("", b"NULL"),
- ],
-)
+@pytest.mark.parametrize("data, result", [("hello", b"'hello'"), ("", b"NULL")])
def test_quote_none(data, result, global_adapters):
psycopg.adapters.register_dumper(str, StrNoneDumper)
t = Transformer()
for n in dir(_psycopg):
if n.startswith("_") or n in ("CDumper", "CLoader"):
continue
- obj = getattr(_psycopg, n)
- if not isinstance(obj, type):
+
+ if not isinstance((obj := getattr(_psycopg, n)), type):
continue
if not issubclass(obj, (_psycopg.CDumper, _psycopg.CLoader)):
continue
# Check that every optimised adapter is the optimised version of a Py one
for n in dir(psycopg.types):
- mod = getattr(psycopg.types, n)
- if not isinstance(mod, ModuleType):
+ if not isinstance((mod := getattr(psycopg.types, n)), ModuleType):
continue
for n1 in dir(mod):
- obj = getattr(mod, n1)
- if not isinstance(obj, type):
+ if not isinstance((obj := getattr(mod, n1)), type):
continue
if not issubclass(obj, (Dumper, Loader)):
continue
def run_process():
nonlocal proc
proc = sp.Popen(
- [sys.executable, "-s", "-c", script],
- creationflags=creationflags,
+ [sys.executable, "-s", "-c", script], creationflags=creationflags
)
proc.communicate()
cur = conn.execute(
"select pid from pg_stat_activity where application_name = %s", (APPNAME,)
)
- rec = cur.fetchone()
- if rec:
+
+ if rec := cur.fetchone():
pid = rec[0]
break
time.sleep(0.1)
cur = conn.execute(
"select pid from pg_stat_activity where application_name = %s", (APPNAME,)
)
- rec = cur.fetchone()
- if rec:
+
+ if rec := cur.fetchone():
pid = rec[0]
break
time.sleep(0.1)
with cur.copy(query) as copy:
copy.set_types(["text"])
while True:
- row = copy.read_row()
- if not row:
+ if not (row := copy.read_row()):
break
assert len(row) == 1
rows.append(row[0])
with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
rows = []
while True:
- row = copy.read_row()
- if not row:
+ if not (row := copy.read_row()):
break
rows.append(row)
if method == "read":
while True:
- tmp = copy.read()
- if not tmp:
+ if not copy.read():
break
elif method == "iter":
list(copy)
elif method == "row":
while True:
- tmp = copy.read_row()
- if tmp is None:
+ if copy.read_row() is None:
break
elif method == "rows":
list(copy.rows())
def blocks(self):
f = self.file()
while True:
- block = f.read(self.block_size)
- if not block:
+ if not (block := f.read(self.block_size)):
break
yield block
def sha(self, f):
m = hashlib.sha256()
while True:
- block = f.read()
- if not block:
+ if not (block := f.read()):
break
if isinstance(block, str):
block = block.encode()
async with cur.copy(query) as copy:
copy.set_types(["text"])
while True:
- row = await copy.read_row()
- if not row:
+ if not (row := (await copy.read_row())):
break
assert len(row) == 1
rows.append(row[0])
) as copy:
rows = []
while True:
- row = await copy.read_row()
- if not row:
+ if not (row := (await copy.read_row())):
break
rows.append(row)
if method == "read":
while True:
- tmp = await copy.read()
- if not tmp:
+ if not (await copy.read()):
break
elif method == "iter":
await alist(copy)
elif method == "row":
while True:
- tmp = await copy.read_row()
- if tmp is None:
+ if (await copy.read_row()) is None:
break
elif method == "rows":
await alist(copy.rows())
def blocks(self):
f = self.file()
while True:
- block = f.read(self.block_size)
- if not block:
+ if not (block := f.read(self.block_size)):
break
yield block
def sha(self, f):
m = hashlib.sha256()
while True:
- block = f.read()
- if not block:
+ if not (block := f.read()):
break
if isinstance(block, str):
block = block.encode()
if fetch == "one":
while True:
- tmp = cur.fetchone()
- if tmp is None:
+ if cur.fetchone() is None:
break
elif fetch == "many":
while True:
- tmp = cur.fetchmany(3)
- if not tmp:
+ if not cur.fetchmany(3):
break
elif fetch == "all":
cur.fetchall()
if fetch == "one":
while True:
- tmp = await cur.fetchone()
- if tmp is None:
+ if (await cur.fetchone()) is None:
break
elif fetch == "many":
while True:
- tmp = await cur.fetchmany(3)
- if not tmp:
+ if not (await cur.fetchmany(3)):
break
elif fetch == "all":
await cur.fetchall()
if fetch == "one":
while True:
- tmp = cur.fetchone()
- if tmp is None:
+ if cur.fetchone() is None:
break
elif fetch == "many":
while True:
- tmp = cur.fetchmany(3)
- if not tmp:
+ if not cur.fetchmany(3):
break
elif fetch == "all":
cur.fetchall()
if fetch == "one":
while True:
- tmp = await cur.fetchone()
- if tmp is None:
+ if (await cur.fetchone()) is None:
break
elif fetch == "many":
while True:
- tmp = await cur.fetchmany(3)
- if not tmp:
+ if not (await cur.fetchmany(3)):
break
elif fetch == "all":
await cur.fetchall()
if fetch == "one":
while True:
- tmp = cur.fetchone()
- if tmp is None:
+ if cur.fetchone() is None:
break
elif fetch == "many":
while True:
- tmp = cur.fetchmany(3)
- if not tmp:
+ if not cur.fetchmany(3):
break
elif fetch == "all":
cur.fetchall()
if fetch == "one":
while True:
- tmp = await cur.fetchone()
- if tmp is None:
+ if (await cur.fetchone()) is None:
break
elif fetch == "many":
while True:
- tmp = await cur.fetchmany(3)
- if not tmp:
+ if not (await cur.fetchmany(3)):
break
elif fetch == "all":
await cur.fetchall()
recs = cur.fetchall()
cur.scroll(0, "absolute")
while True:
- rec = cur.fetchone()
- if not rec:
+ if not (rec := cur.fetchone()):
break
recs.append(rec)
assert recs == [[1, -1], [1, -2], [1, -3]] * 2
recs = await cur.fetchall()
await cur.scroll(0, "absolute")
while True:
- rec = await cur.fetchone()
- if not rec:
+ if not (rec := (await cur.fetchone())):
break
recs.append(rec)
assert recs == [[1, -1], [1, -2], [1, -3]] * 2
OperationalError has a pgconn attribute set with needs_password.
"""
gen = generators.connect(dsn)
- pgconn = waiting.wait_conn(gen)
- if not pgconn.used_password:
+
+ if not (pgconn := waiting.wait_conn(gen)).used_password:
pytest.skip("test connection needs no password")
with monkeypatch.context() as m:
@pytest.mark.parametrize(
- "factory",
- "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
+ "factory", "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split()
)
def test_no_result(factory, conn):
cur = conn.cursor(row_factory=factory_from_name(factory))
def factory_from_name(name):
- factory = getattr(rows, name)
- if factory is rows.class_row:
+ if (factory := getattr(rows, name)) is rows.class_row:
factory = factory(Person)
if factory is rows.args_row:
factory = factory(argf)
def test_tpc_disabled(conn, pipeline):
cur = conn.execute("show max_prepared_transactions")
- val = int(cur.fetchone()[0])
- if val:
+
+ if int(cur.fetchone()[0]):
pytest.skip("prepared transactions enabled")
conn.rollback()
async def test_tpc_disabled(aconn, apipeline):
cur = await aconn.execute("show max_prepared_transactions")
- val = int((await cur.fetchone())[0])
- if val:
+
+ if int((await cur.fetchone())[0]):
pytest.skip("prepared transactions enabled")
await aconn.rollback()
conn.execute("select set_config('client_encoding', %s, false)", [encoding])
if status:
- status = getattr(TransactionStatus, status)
- if status == TransactionStatus.INTRANS:
+ if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
conn.execute("select 1")
else:
conn.autocommit = True
)
if status:
- status = getattr(TransactionStatus, status)
- if status == TransactionStatus.INTRANS:
+ if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
await aconn.execute("select 1")
else:
await aconn.set_autocommit(True)
return exit_orig(self, exc_type, exc_val, exc_tb)
monkeypatch.setattr(psycopg.Transaction, "__exit__", exit)
- status = getattr(TransactionStatus, status)
- if status == TransactionStatus.INTRANS:
+
+ if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
conn.execute("select 1")
assert conn.info.transaction_status == status
return await exit_orig(self, exc_type, exc_val, exc_tb)
monkeypatch.setattr(psycopg.AsyncTransaction, "__aexit__", aexit)
- status = getattr(TransactionStatus, status)
- if status == TransactionStatus.INTRANS:
+
+ if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
await aconn.execute("select 1")
assert aconn.info.transaction_status == status
@pytest.mark.parametrize("fmt_in", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
- wrapper = getattr(psycopg.types.numeric, wrapper)
- if wrapper is Decimal:
+ if (wrapper := getattr(psycopg.types.numeric, wrapper)) is Decimal:
want_cls = Decimal
else:
assert wrapper.__mro__[1] in (int, float)
@pytest.mark.pg(">= 14")
-@pytest.mark.parametrize(
- "val, expr",
- [
- ("inf", "Infinity"),
- ("-inf", "-Infinity"),
- ],
-)
+@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")])
def test_dump_numeric_binary_inf(conn, val, expr):
cur = conn.cursor()
val = Decimal(val)
def test_load_numeric_binary(conn, expr):
cur = conn.cursor(binary=1)
res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
- val = Decimal(expr)
- if val.is_nan():
+
+ if (val := Decimal(expr)).is_nan():
assert res.is_nan()
else:
assert res == val
@pytest.mark.pg(">= 14")
-@pytest.mark.parametrize(
- "val, expr",
- [
- ("inf", "Infinity"),
- ("-inf", "-Infinity"),
- ],
-)
+@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")])
def test_load_numeric_binary_inf(conn, val, expr):
cur = conn.cursor(binary=1)
res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
@pytest.mark.parametrize(
- "val",
- [
- "0",
- "0.0",
- "0.000000000000000000001",
- "-0.000000000000000000001",
- "nan",
- ],
+ "val", ["0", "0.0", "0.000000000000000000001", "-0.000000000000000000001", "nan"]
)
def test_numeric_as_float(conn, val):
cur = conn.cursor()
def _get_arch_size() -> int:
- psize = struct.calcsize("P") * 8
- if psize not in (32, 64):
+ if (psize := (struct.calcsize("P") * 8)) not in (32, 64):
msg = f"the pointer size {psize} is unusual"
raise ValueError(msg)
return psize
fnames = [f"f{t}" for t in pgtypes]
fields = [f"f{t} {t}" for fname, t in zip(fnames, pgtypes)]
- cur.execute(
- f"create table numpyoid (id serial primary key, {', '.join(fields)})",
- )
+ cur.execute(f"create table numpyoid (id serial primary key, {', '.join(fields)})")
with cur.copy(
f"copy numpyoid ({', '.join(fnames)}) from stdin (format {fmt.name})"
) as copy:
def main() -> int:
- opt = parse_cmdline()
- if opt.container:
+ if (opt := parse_cmdline()).container:
return run_in_container(opt.container)
logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s")
- current_ver = ".".join(map(str, sys.version_info[:2]))
- if current_ver != PYVER:
+ if (current_ver := ".".join(map(str, sys.version_info[:2]))) != PYVER:
logger.warning(
"Expecting output generated by Python %s; you are running %s instead.",
PYVER,
new_body.append(before)
for i in range(1, len(body)):
after = body[i]
- nblanks = after.lineno - before.end_lineno - 1
- if nblanks > 0:
+ if after.lineno - before.end_lineno - 1 > 0:
# Inserting one blank is enough.
blank = ast.Comment(
value="",
help="the files to process (process all files if not specified)",
)
- opt = parser.parse_args()
- if not opt.inputs:
+ if not (opt := parser.parse_args()).inputs:
opt.inputs = [PROJECT_DIR / Path(fn) for fn in ALL_INPUTS]
fp: Path
curdir = Path(__file__).parent
pdir = curdir / "../.."
-target = pdir / "psycopg_binary"
-if target.exists():
+if (target := (pdir / "psycopg_binary")).exists():
raise Exception(f"path {target} already exists")
def sed_i(pattern: str, repl: str, filename: str | Path) -> None:
with open(filename, "rb") as f:
data = f.read()
- newdata = re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)
- if newdata != data:
+
+ if (newdata := re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)) != data:
with open(filename, "wb") as f:
f.write(newdata)
matches = []
with fp.open() as f:
for line in f:
- m = self._ini_regex.match(line)
- if m:
+ if m := self._ini_regex.match(line):
matches.append(m)
if not matches:
with fp.open() as f:
lines = f.readlines()
- lns = self._find_lines(r"^[^\s]+ " + re.escape(str(version)), lines)
- if not lns:
+ if not (lns := self._find_lines("^[^\\s]+ " + re.escape(str(version)), lines)):
logger.warning("no change log line found")
return []
"name": data["name"],
}
if data["blog"]:
- website = data["blog"]
- if not website.startswith("http"):
+ if not (website := data["blog"]).startswith("http"):
website = "http://" + website
out["website"] = website
# entry is an username or an user entry daat
if isinstance(entry, str):
username = entry
- entry = [e for e in filedata if e["username"] == username]
- if not entry:
+
+ if not (entry := [e for e in filedata if e["username"] == username]):
raise Exception(f"{username} not found")
entry = entry[0]
else:
#!/usr/bin/env python
-"""Find the error prefixes in various l10n used for precise prefixstripping."""
+"""Find the error prefixes in various l10n used for precise prefix stripping."""
import re
import logging
help="the file to change [default: %(default)s]",
)
- opt = parser.parse_args()
- if not opt.pgroot.is_dir():
+ if not (opt := parser.parse_args()).pgroot.is_dir():
parser.error("not a valid directory: {opt.pgroot}")
return opt
page = urlopen(url)
for line in page.read().decode("ascii").splitlines():
# Strip comments and skip blanks
- line = line.split("#")[0].strip()
- if not line:
+
+ if not (line := line.split("#")[0].strip()):
continue
# Parse a section
- m = re.match(r"Section: (Class (..) - .+)", line)
- if m:
+ if m := re.match("Section: (Class (..) - .+)", line):
label, class_ = m.groups()
classes[class_] = label
continue
# Parse an error
- m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line)
- if m:
+ if m := re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line):
sqlstate, macro, spec = m.groups()
# skip sqlstates without specs as they are not publicly visible
if not spec:
def main() -> None:
opt = parse_cmdline()
- conn = psycopg.connect(opt.dsn, autocommit=True)
- if CrdbConnection.is_crdb(conn):
+ if CrdbConnection.is_crdb((conn := psycopg.connect(opt.dsn, autocommit=True))):
conn = CrdbConnection.connect(opt.dsn, autocommit=True)
update_crdb_python_oids(conn)
else: