From: Daniele Varrazzo Date: Thu, 27 Mar 2025 00:25:49 +0000 (+0100) Subject: refactor: use the assignment operator in assignments followed by an if X-Git-Tag: 3.2.7~8^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5209f5e717b070fb665317d1e79cb70626f4550a;p=thirdparty%2Fpsycopg.git refactor: use the assignment operator in assignments followed by an if --- diff --git a/docs/conf.py b/docs/conf.py index cac973651..cf87980d3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ import psycopg docs_dir = Path(__file__).parent sys.path.append(str(docs_dir / "lib")) - # -- Project information ----------------------------------------------------- project = "psycopg" @@ -55,12 +54,10 @@ templates_path = ["_templates"] # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"] - # -- Options for HTML output ------------------------------------------------- # The announcement may be in the website but not shipped with the docs -ann_file = docs_dir / "../../templates/docs3-announcement.html" -if ann_file.exists(): +if (ann_file := (docs_dir / "../../templates/docs3-announcement.html")).exists(): with ann_file.open() as f: announcement = f.read() else: @@ -84,9 +81,7 @@ html_theme_options = { "sidebar_hide_name": False, "light_logo": "psycopg.svg", "dark_logo": "psycopg.svg", - "light_css_variables": { - "admonition-font-size": "1rem", - }, + "light_css_variables": {"admonition-font-size": "1rem"}, } # Add any paths that contain custom static files (such as style sheets) here, diff --git a/docs/lib/libpq_docs.py b/docs/lib/libpq_docs.py index a5488d151..8bd098e82 100644 --- a/docs/lib/libpq_docs.py +++ b/docs/lib/libpq_docs.py @@ -61,13 +61,11 @@ class LibpqParser(HTMLParser): self.add_function(data) def handle_sect1(self, tag, attrs): - attrs = dict(attrs) - if "id" in attrs: + if "id" in (attrs := dict(attrs)): self.section_id = attrs["id"] def handle_varlistentry(self, tag, attrs): - attrs = dict(attrs) - if "id" in attrs: + if "id" in (attrs := dict(attrs)): self.varlist_id = attrs["id"] def add_function(self, func_name): @@ -89,10 +87,8 @@ class LibpqReader: # must be set before using the rest of the class. app = None - _url_pattern = ( - "https://raw.githubusercontent.com/postgres/postgres/{branch}" - "/doc/src/sgml/libpq.sgml" - ) + _url_pattern = "https://raw.githubusercontent.com/postgres/postgres/" + _url_pattern += "{branch}/doc/src/sgml/libpq.sgml" data = None @@ -113,8 +109,7 @@ class LibpqReader: parser.feed(f.read()) def download(self): - filename = os.environ.get("LIBPQ_DOCS_FILE") - if filename: + if filename := os.environ.get("LIBPQ_DOCS_FILE"): logger.info("reading postgres libpq docs from %s", filename) with open(filename, "rb") as f: data = f.read() @@ -156,7 +151,6 @@ def pq_role(name, rawtext, text, lineno, inliner, options={}, content=[]): if "(" in text: func, noise = text.split("(", 1) noise = "(" + noise - else: func = text noise = "" diff --git a/docs/lib/pg3_docs.py b/docs/lib/pg3_docs.py index 4388cc9d4..8bd771f2d 100644 --- a/docs/lib/pg3_docs.py +++ b/docs/lib/pg3_docs.py @@ -19,8 +19,7 @@ def process_docstring(app, what, name, obj, options, lines): def before_process_signature(app, obj, bound_method): - ann = getattr(obj, "__annotations__", {}) - if "return" in ann: + if "return" in (ann := getattr(obj, "__annotations__", {})): # Drop "return: None" from the function signatures if ann["return"] is None: del ann["return"] @@ -68,14 +67,12 @@ def recover_defined_module(m, skip_modules=()): for fn in walk_modules(mdir): assert fn.startswith(mdir) modname = os.path.splitext(fn[len(mdir) + 1 :])[0].replace("/", ".") - modname = f"{m.__name__}.{modname}" - if modname in skip_modules: + if (modname := f"{m.__name__}.{modname}") in skip_modules: continue with open(fn) as f: classnames = re.findall(r"^class\s+([^(:]+)", f.read(), re.M) for cls in classnames: - cls = deep_import(f"{modname}.{cls}") - if cls.__module__ != modname: + if (cls := deep_import(f"{modname}.{cls}")).__module__ != modname: recovered_classes[cls] = modname diff --git a/docs/lib/ticket_role.py b/docs/lib/ticket_role.py index f8f935bf5..107e67d6d 100644 --- a/docs/lib/ticket_role.py +++ b/docs/lib/ticket_role.py @@ -16,8 +16,7 @@ from docutils.parsers.rst import roles def ticket_role(name, rawtext, text, lineno, inliner, options={}, content=[]): - cfg = inliner.document.settings.env.app.config - if cfg.ticket_url is None: + if (cfg := inliner.document.settings.env.app.config).ticket_url is None: msg = inliner.reporter.warning( "ticket not configured: please configure ticket_url in conf.py" ) diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index cd1ad261e..581d9b23b 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -32,8 +32,7 @@ from ._connection_info import ConnectionInfo from .connection_async import AsyncConnection # Set the logger to a quiet default, can be enabled if needed -logger = logging.getLogger("psycopg") -if logger.level == logging.NOTSET: +if (logger := logging.getLogger("psycopg")).level == logging.NOTSET: logger.setLevel(logging.WARNING) # DBAPI compliance diff --git a/psycopg/psycopg/_adapters_map.py b/psycopg/psycopg/_adapters_map.py index d8fedfa12..f3142c02a 100644 --- a/psycopg/psycopg/_adapters_map.py +++ b/psycopg/psycopg/_adapters_map.py @@ -69,9 +69,7 @@ class AdaptersMap: _optimised: dict[type, type] = {} def __init__( - self, - template: AdaptersMap | None = None, - types: TypesRegistry | None = None, + self, template: AdaptersMap | None = None, types: TypesRegistry | None = None ): if template: self._dumpers = template._dumpers.copy() @@ -179,8 +177,7 @@ class AdaptersMap: if _psycopg: loader = self._get_optimised(loader) - fmt = loader.format - if not self._own_loaders[fmt]: + if not self._own_loaders[(fmt := loader.format)]: self._loaders[fmt] = self._loaders[fmt].copy() self._own_loaders[fmt] = True @@ -282,8 +279,7 @@ class AdaptersMap: from psycopg import types if cls.__module__.startswith(types.__name__): - new = cast("type[RV]", getattr(_psycopg, cls.__name__, None)) - if new: + if new := cast("type[RV]", getattr(_psycopg, cls.__name__, None)): self._optimised[cls] = new return new diff --git a/psycopg/psycopg/_column.py b/psycopg/psycopg/_column.py index 372775cf8..8a7c806ca 100644 --- a/psycopg/psycopg/_column.py +++ b/psycopg/psycopg/_column.py @@ -20,8 +20,7 @@ class Column(Sequence[Any]): res = cursor.pgresult assert res - fname = res.fname(index) - if fname: + if fname := res.fname(index): self._name = fname.decode(cursor._encoding) else: # COPY_OUT results have columns but no name diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index 65440ce2f..7ac6af1a5 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -247,8 +247,7 @@ class BaseConnection(Generic[Row]): def _check_intrans_gen(self, attribute: str) -> PQGen[None]: # Raise an exception if we are in a transaction - status = self.pgconn.transaction_status - if status == IDLE and self._pipeline: + if (status := self.pgconn.transaction_status) == IDLE and self._pipeline: yield from self._pipeline._sync_gen() status = self.pgconn.transaction_status if status != IDLE: @@ -503,10 +502,7 @@ class BaseConnection(Generic[Row]): self._check_connection_ok() if self._pipeline: - cmd = partial( - self.pgconn.send_close_prepared, - name, - ) + cmd = partial(self.pgconn.send_close_prepared, name) self._pipeline.command_queue.append(cmd) self._pipeline.result_queue.append(None) return diff --git a/psycopg/psycopg/_conninfo_attempts.py b/psycopg/psycopg/_conninfo_attempts.py index 7bc96dd99..f853e0537 100644 --- a/psycopg/psycopg/_conninfo_attempts.py +++ b/psycopg/psycopg/_conninfo_attempts.py @@ -85,8 +85,7 @@ def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: # Local path, or no host to resolve return [params] - hostaddr = get_param(params, "hostaddr") - if hostaddr: + if get_param(params, "hostaddr"): # Already resolved return [params] @@ -94,8 +93,7 @@ def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: # If the host is already an ip address don't try to resolve it return [{**params, "hostaddr": host}] - port = get_param(params, "port") - if not port: + if not (port := get_param(params, "port")): port_def = get_param_def("port") port = port_def and port_def.compiled or "5432" diff --git a/psycopg/psycopg/_conninfo_attempts_async.py b/psycopg/psycopg/_conninfo_attempts_async.py index e50e4f95e..08f363d1a 100644 --- a/psycopg/psycopg/_conninfo_attempts_async.py +++ b/psycopg/psycopg/_conninfo_attempts_async.py @@ -83,8 +83,7 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: # Local path, or no host to resolve return [params] - hostaddr = get_param(params, "hostaddr") - if hostaddr: + if get_param(params, "hostaddr"): # Already resolved return [params] @@ -92,8 +91,7 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: # If the host is already an ip address don't try to resolve it return [{**params, "hostaddr": host}] - port = get_param(params, "port") - if not port: + if not (port := get_param(params, "port")): port_def = get_param_def("port") port = port_def and port_def.compiled or "5432" diff --git a/psycopg/psycopg/_conninfo_utils.py b/psycopg/psycopg/_conninfo_utils.py index 844e71abd..e959bbf15 100644 --- a/psycopg/psycopg/_conninfo_utils.py +++ b/psycopg/psycopg/_conninfo_utils.py @@ -75,12 +75,10 @@ def get_param(params: ConnMapping, name: str) -> str | None: # TODO: check if in service - paramdef = get_param_def(name) - if not paramdef: + if not (paramdef := get_param_def(name)): return None - env = os.environ.get(paramdef.envvar) - if env is not None: + if (env := os.environ.get(paramdef.envvar)) is not None: return env return None diff --git a/psycopg/psycopg/_copy.py b/psycopg/psycopg/_copy.py index 32f8fc857..b218ecd21 100644 --- a/psycopg/psycopg/_copy.py +++ b/psycopg/psycopg/_copy.py @@ -81,8 +81,7 @@ class Copy(BaseCopy["Connection[Any]"]): def __iter__(self) -> Iterator[Buffer]: """Implement block-by-block iteration on :sql:`COPY TO`.""" while True: - data = self.read() - if not data: + if not (data := self.read()): break yield data @@ -102,8 +101,7 @@ class Copy(BaseCopy["Connection[Any]"]): bytes, unless data types are specified using `set_types()`. """ while True: - record = self.read_row() - if record is None: + if (record := self.read_row()) is None: break yield record @@ -125,14 +123,12 @@ class Copy(BaseCopy["Connection[Any]"]): If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ - data = self.formatter.write(buffer) - if data: + if data := self.formatter.write(buffer): self._write(data) def write_row(self, row: Sequence[Any]) -> None: """Write a record to a table after a :sql:`COPY FROM` operation.""" - data = self.formatter.write_row(row) - if data: + if data := self.formatter.write_row(row): self._write(data) def finish(self, exc: BaseException | None) -> None: @@ -143,8 +139,7 @@ class Copy(BaseCopy["Connection[Any]"]): using the `Copy` object outside a block. """ if self._direction == COPY_IN: - data = self.formatter.end() - if data: + if data := self.formatter.end(): self._write(data) self.writer.finish(exc) self._finished = True @@ -257,8 +252,7 @@ class QueuedLibpqWriter(LibpqWriter): """ try: while True: - data = self._queue.get() - if not data: + if not (data := self._queue.get()): break self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH)) except BaseException as ex: diff --git a/psycopg/psycopg/_copy_async.py b/psycopg/psycopg/_copy_async.py index 22ef3b197..2d36353a4 100644 --- a/psycopg/psycopg/_copy_async.py +++ b/psycopg/psycopg/_copy_async.py @@ -78,8 +78,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): async def __aiter__(self) -> AsyncIterator[Buffer]: """Implement block-by-block iteration on :sql:`COPY TO`.""" while True: - data = await self.read() - if not data: + if not (data := (await self.read())): break yield data @@ -99,8 +98,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): bytes, unless data types are specified using `set_types()`. """ while True: - record = await self.read_row() - if record is None: + if (record := (await self.read_row())) is None: break yield record @@ -122,14 +120,12 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In text mode it can be either `!bytes` or `!str`. """ - data = self.formatter.write(buffer) - if data: + if data := self.formatter.write(buffer): await self._write(data) async def write_row(self, row: Sequence[Any]) -> None: """Write a record to a table after a :sql:`COPY FROM` operation.""" - data = self.formatter.write_row(row) - if data: + if data := self.formatter.write_row(row): await self._write(data) async def finish(self, exc: BaseException | None) -> None: @@ -140,8 +136,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): using the `Copy` object outside a block. """ if self._direction == COPY_IN: - data = self.formatter.end() - if data: + if data := self.formatter.end(): await self._write(data) await self.writer.finish(exc) self._finished = True @@ -256,8 +251,7 @@ class AsyncQueuedLibpqWriter(AsyncLibpqWriter): """ try: while True: - data = await self._queue.get() - if not data: + if not (data := (await self._queue.get())): break await self.connection.wait( copy_to(self._pgconn, data, flush=PREFER_FLUSH) diff --git a/psycopg/psycopg/_copy_base.py b/psycopg/psycopg/_copy_base.py index 6acc77719..17d740c1c 100644 --- a/psycopg/psycopg/_copy_base.py +++ b/psycopg/psycopg/_copy_base.py @@ -73,17 +73,13 @@ class BaseCopy(Generic[ConnectionType]): formatter: Formatter def __init__( - self, - cursor: BaseCursor[ConnectionType, Any], - *, - binary: bool | None = None, + self, cursor: BaseCursor[ConnectionType, Any], *, binary: bool | None = None ): self.cursor = cursor self.connection = cursor.connection self._pgconn = self.connection.pgconn - result = cursor.pgresult - if result: + if result := cursor.pgresult: self._direction = result.status if self._direction != COPY_IN and self._direction != COPY_OUT: raise e.ProgrammingError( @@ -162,12 +158,10 @@ class BaseCopy(Generic[ConnectionType]): return memoryview(b"") def _read_row_gen(self) -> PQGen[tuple[Any, ...] | None]: - data = yield from self._read_gen() - if not data: + if not (data := (yield from self._read_gen())): return None - row = self.formatter.parse_row(data) - if row is None: + if (row := self.formatter.parse_row(data)) is None: # Get the final result to finish the copy operation yield from self._read_gen() self._finished = True diff --git a/psycopg/psycopg/_cursor_base.py b/psycopg/psycopg/_cursor_base.py index fe71a1451..3a1d6d45b 100644 --- a/psycopg/psycopg/_cursor_base.py +++ b/psycopg/psycopg/_cursor_base.py @@ -395,8 +395,7 @@ class BaseCursor(Generic[ConnectionType, Row]): query = self._convert_query(statement) self._execute_send(query, binary=False) - results = yield from execute(self._pgconn) - if len(results) != 1: + if len(results := (yield from execute(self._pgconn))) != 1: raise e.ProgrammingError("COPY cannot be mixed with other operations") self._check_copy_result(results[0]) @@ -473,8 +472,7 @@ class BaseCursor(Generic[ConnectionType, Row]): """ Raise an appropriate error message for an unexpected database result """ - status = result.status - if status == FATAL_ERROR: + if (status := result.status) == FATAL_ERROR: raise e.error_from_result(result, encoding=self._encoding) elif status == PIPELINE_ABORTED: raise e.PipelineAborted("pipeline aborted") @@ -518,19 +516,17 @@ class BaseCursor(Generic[ConnectionType, Row]): # Received from execute() self._results[:] = results self._select_current_result(0) - + # Received from executemany() + elif self._execmany_returning: + first_batch = not self._results + self._results.extend(results) + if first_batch: + self._select_current_result(0) else: - # Received from executemany() - if self._execmany_returning: - first_batch = not self._results - self._results.extend(results) - if first_batch: - self._select_current_result(0) - else: - # In non-returning case, set rowcount to the cumulated number of - # rows of executed queries. - for res in results: - self._rowcount += res.command_tuples or 0 + # In non-returning case, set rowcount to the cumulated number of + # rows of executed queries. + for res in results: + self._rowcount += res.command_tuples or 0 def _send_prepare(self, name: bytes, query: PostgresQuery) -> None: if self._conn._pipeline: @@ -572,12 +568,11 @@ class BaseCursor(Generic[ConnectionType, Row]): def _check_result_for_fetch(self) -> None: if self.closed: raise e.InterfaceError("the cursor is closed") - res = self.pgresult - if not res: + + if not (res := self.pgresult): raise e.ProgrammingError("no result available") - status = res.status - if status == TUPLES_OK: + if (status := res.status) == TUPLES_OK: return elif status == FATAL_ERROR: raise e.error_from_result(res, encoding=self._encoding) diff --git a/psycopg/psycopg/_dns.py b/psycopg/psycopg/_dns.py index b642b2bca..4d5094280 100644 --- a/psycopg/psycopg/_dns.py +++ b/psycopg/psycopg/_dns.py @@ -62,11 +62,9 @@ async def resolve_hostaddr_async(params: dict[str, Any]) -> dict[str, Any]: ports.append(str(attempt["port"])) out = params.copy() - shosts = ",".join(hosts) - if shosts: + if shosts := ",".join(hosts): out["host"] = shosts - shostaddrs = ",".join(hostaddrs) - if shostaddrs: + if shostaddrs := ",".join(hostaddrs): out["hostaddr"] = shostaddrs sports = ",".join(ports) if ports: @@ -103,8 +101,7 @@ class Rfc2782Resolver: def resolve(self, params: dict[str, Any]) -> dict[str, Any]: """Update the parameters host and port after SRV lookup.""" - attempts = self._get_attempts(params) - if not attempts: + if not (attempts := self._get_attempts(params)): return params hps = [] @@ -118,8 +115,7 @@ class Rfc2782Resolver: async def resolve_async(self, params: dict[str, Any]) -> dict[str, Any]: """Update the parameters host and port after SRV lookup.""" - attempts = self._get_attempts(params) - if not attempts: + if not (attempts := self._get_attempts(params)): return params hps = [] @@ -144,9 +140,7 @@ class Rfc2782Resolver: host_arg: str = params.get("host", os.environ.get("PGHOST", "")) hosts_in = host_arg.split(",") port_arg: str = str(params.get("port", os.environ.get("PGPORT", ""))) - ports_in = port_arg.split(",") - - if len(ports_in) == 1: + if len((ports_in := port_arg.split(","))) == 1: # If only one port is specified, it applies to all the hosts. ports_in *= len(hosts_in) if len(ports_in) != len(hosts_in): @@ -159,8 +153,7 @@ class Rfc2782Resolver: out = [] srv_found = False for host, port in zip(hosts_in, ports_in): - m = self.re_srv_rr.match(host) - if m or port.lower() == "srv": + if (m := self.re_srv_rr.match(host)) or port.lower() == "srv": srv_found = True target = m.group("target") if m else None hp = HostPort(host=host, port=port, totry=True, target=target) diff --git a/psycopg/psycopg/_encodings.py b/psycopg/psycopg/_encodings.py index d1ef6dd2d..6bb859d8b 100644 --- a/psycopg/psycopg/_encodings.py +++ b/psycopg/psycopg/_encodings.py @@ -98,8 +98,7 @@ def conninfo_encoding(conninfo: str) -> str: from .conninfo import conninfo_to_dict params = conninfo_to_dict(conninfo) - pgenc = params.get("client_encoding") - if pgenc: + if pgenc := params.get("client_encoding"): try: return pg2pyenc(str(pgenc).encode()) except NotSupportedError: diff --git a/psycopg/psycopg/_pipeline.py b/psycopg/psycopg/_pipeline.py index f4752829a..8ae21b66f 100644 --- a/psycopg/psycopg/_pipeline.py +++ b/psycopg/psycopg/_pipeline.py @@ -142,8 +142,7 @@ class BasePipeline: exception = None while self.result_queue: - results = yield from fetch_many(self.pgconn) - if not results: + if not (results := (yield from fetch_many(self.pgconn))): # No more results to fetch, but there may still be pending # commands. break diff --git a/psycopg/psycopg/_preparing.py b/psycopg/psycopg/_preparing.py index 10a487494..7a1b0f5d2 100644 --- a/psycopg/psycopg/_preparing.py +++ b/psycopg/psycopg/_preparing.py @@ -64,9 +64,7 @@ class PrepareManager: # The user doesn't want this query to be prepared return Prepare.NO, b"" - key = self.key(query) - name = self._names.get(key) - if name: + if name := self._names.get(key := self.key(query)): # The query was already prepared in this session return Prepare.YES, name @@ -101,8 +99,7 @@ class PrepareManager: # We cannot prepare a multiple statement return False - status = results[0].status - if COMMAND_OK != status != TUPLES_OK: + if COMMAND_OK != results[0].status != TUPLES_OK: # We don't prepare failed queries or other weird results return False @@ -133,8 +130,7 @@ class PrepareManager: if self.prepare_threshold is None: return None - key = self.key(query) - if key in self._counts: + if (key := self.key(query)) in self._counts: if prep is Prepare.SHOULD: del self._counts[key] self._names[key] = name @@ -155,11 +151,7 @@ class PrepareManager: return key def validate( - self, - key: Key, - prep: Prepare, - name: bytes, - results: Sequence[PGresult], + self, key: Key, prep: Prepare, name: bytes, results: Sequence[PGresult] ) -> None: """Validate cached entry with 'key' by checking query 'results'. diff --git a/psycopg/psycopg/_py_transformer.py b/psycopg/psycopg/_py_transformer.py index 0af6884cb..820c567d6 100644 --- a/psycopg/psycopg/_py_transformer.py +++ b/psycopg/psycopg/_py_transformer.py @@ -179,8 +179,7 @@ class Transformer(AdaptContext): # right size. if self._row_dumpers: for i in range(nparams): - param = params[i] - if param is not None: + if (param := params[i]) is not None: out[i] = self._row_dumpers[i].dump(param) return out @@ -188,8 +187,7 @@ class Transformer(AdaptContext): pqformats = [TEXT] * nparams for i in range(nparams): - param = params[i] - if param is None: + if (param := params[i]) is None: continue dumper = self.get_dumper(param, formats[i]) out[i] = dumper.dump(param) @@ -212,8 +210,7 @@ class Transformer(AdaptContext): try: type_sql = self._oid_types[oid] except KeyError: - ti = self.adapters.types.get(oid) - if ti: + if ti := self.adapters.types.get(oid): if oid < 8192: # builtin: prefer "timestamptz" to "timestamp with time zone" type_sql = ti.name.encode(self.encoding) @@ -254,8 +251,7 @@ class Transformer(AdaptContext): cache[key] = dumper = dcls(key, self) # Check if the dumper requires an upgrade to handle this specific value - key1 = dumper.get_key(obj, format) - if key1 is key: + if (key1 := dumper.get_key(obj, format)) is key: return dumper # If it does, ask the dumper to create its own upgraded version @@ -298,8 +294,7 @@ class Transformer(AdaptContext): return dumper def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]: - res = self._pgresult - if not res: + if not (res := self._pgresult): raise e.InterfaceError("result not set") if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples): @@ -311,16 +306,14 @@ class Transformer(AdaptContext): for row in range(row0, row1): record: list[Any] = [None] * self._nfields for col in range(self._nfields): - val = res.get_value(row, col) - if val is not None: + if (val := res.get_value(row, col)) is not None: record[col] = self._row_loaders[col](val) records.append(make_row(record)) return records def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None: - res = self._pgresult - if not res: + if not (res := self._pgresult): return None if not 0 <= row < self._ntuples: @@ -328,8 +321,7 @@ class Transformer(AdaptContext): record: list[Any] = [None] * self._nfields for col in range(self._nfields): - val = res.get_value(row, col) - if val is not None: + if (val := res.get_value(row, col)) is not None: record[col] = self._row_loaders[col](val) return make_row(record) @@ -352,10 +344,8 @@ class Transformer(AdaptContext): except KeyError: pass - loader_cls = self._adapters.get_loader(oid, format) - if not loader_cls: - loader_cls = self._adapters.get_loader(INVALID_OID, format) - if not loader_cls: + if not (loader_cls := self._adapters.get_loader(oid, format)): + if not (loader_cls := self._adapters.get_loader(INVALID_OID, format)): raise e.InterfaceError("unknown oid loader not found") loader = self._loaders[format][oid] = loader_cls(oid, self) return loader diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 22e5b4bf6..98fa0a7ff 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -372,8 +372,7 @@ def _split_query( rv.append(QueryPart(pre, 0, PyFormat.AUTO)) break - ph = m.group(0) - if ph == b"%%": + if (ph := m.group(0)) == b"%%": # unescape '%%' to '%' if necessary, then merge the parts if collapse_double_percent: ph = b"%" diff --git a/psycopg/psycopg/_tpc.py b/psycopg/psycopg/_tpc.py index e3719010c..8c730d644 100644 --- a/psycopg/psycopg/_tpc.py +++ b/psycopg/psycopg/_tpc.py @@ -52,8 +52,7 @@ class Xid: @classmethod def _parse_string(cls, s: str) -> Xid: - m = _re_xid.match(s) - if not m: + if not (m := _re_xid.match(s)): raise ValueError("bad Xid format") format_id = int(m.group(1)) diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index c9126364e..47d6835d9 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -171,8 +171,7 @@ ORDER BY t.oid @classmethod def _has_to_regtype_function(cls, conn: BaseConnection[Any]) -> bool: # to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2 - info = conn.info - if info.vendor == "PostgreSQL": + if (info := conn.info).vendor == "PostgreSQL": return info.server_version >= 90400 elif info.vendor == "CockroachDB": return info.server_version >= 220200 @@ -195,10 +194,8 @@ ORDER BY t.oid pass def get_type_display(self, oid: int | None = None, fmod: int | None = None) -> str: - parts = [] - parts.append(self.name) - mod = self.typemod.get_modifier(fmod) if fmod is not None else () - if mod: + parts = [self.name] + if mod := (self.typemod.get_modifier(fmod) if fmod is not None else ()): parts.append(f"({','.join(map(str, mod))})") if oid == self.array_oid: diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index 2918f555b..58db36400 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -56,8 +56,7 @@ class Dumper(abc.Dumper, ABC): for most types and you won't likely have to implement this method in a subclass. """ - value = self.dump(obj) - if value is None: + if (value := self.dump(obj)) is None: return b"NULL" if self.connection: diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 70d77467d..bfff3fc1e 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -382,8 +382,7 @@ class Connection(BaseConnection[Row]): with self.lock: self._check_connection_ok() - pipeline = self._pipeline - if pipeline is None: + if (pipeline := self._pipeline) is None: # WARNING: reference loop, broken ahead. pipeline = self._pipeline = Pipeline(self) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 3c6615a24..9f4ecb464 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -407,8 +407,7 @@ class AsyncConnection(BaseConnection[Row]): async with self.lock: self._check_connection_ok() - pipeline = self._pipeline - if pipeline is None: + if (pipeline := self._pipeline) is None: # WARNING: reference loop, broken ahead. pipeline = self._pipeline = AsyncPipeline(self) diff --git a/psycopg/psycopg/conninfo.py b/psycopg/psycopg/conninfo.py index e8c33876b..f72ec34e0 100644 --- a/psycopg/psycopg/conninfo.py +++ b/psycopg/psycopg/conninfo.py @@ -114,8 +114,7 @@ def _param_escape(s: str) -> str: if not s: return "''" - s = re_escape.sub(r"\\\1", s) - if re_space.search(s): + if re_space.search(s := re_escape.sub(r"\\\1", s)): s = "'" + s + "'" return s diff --git a/psycopg/psycopg/crdb/connection.py b/psycopg/psycopg/crdb/connection.py index 60db9a876..411434a75 100644 --- a/psycopg/psycopg/crdb/connection.py +++ b/psycopg/psycopg/crdb/connection.py @@ -86,20 +86,17 @@ class CrdbConnectionInfo(ConnectionInfo): Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210). """ - sver = self.parameter_status("crdb_version") - if not sver: + if not (sver := self.parameter_status("crdb_version")): raise e.InternalError("'crdb_version' parameter status not set") - ver = self.parse_crdb_version(sver) - if ver is None: + if (ver := self.parse_crdb_version(sver)) is None: raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}") return ver @classmethod def parse_crdb_version(self, sver: str) -> int | None: - m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver) - if not m: + if not (m := re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)): return None return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3)) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 52451dc5f..ccde99863 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -108,8 +108,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): # If there is already a pipeline, ride it, in order to avoid # sending unnecessary Sync. with self._conn.lock: - p = self._conn._pipeline - if p: + if p := self._conn._pipeline: self._conn.wait( self._executemany_gen_pipeline(query, params_seq, returning) ) @@ -153,8 +152,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): first = True while self._conn.wait(self._stream_fetchone_gen(first)): for pos in range(size): - rec = self._tx.load_row(pos, self._make_row) - if rec is None: + if (rec := self._tx.load_row(pos, self._make_row)) is None: break yield rec first = False @@ -188,8 +186,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): """ self._fetch_pipeline() self._check_result_for_fetch() - record = self._tx.load_row(self._pos, self._make_row) - if record is not None: + if (record := self._tx.load_row(self._pos, self._make_row)) is not None: self._pos += 1 return record @@ -234,8 +231,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]): return self._tx.load_row(pos, self._make_row) while True: - row = load(self._pos) - if row is None: + if (row := load(self._pos)) is None: break self._pos += 1 yield row diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 0296c5179..1edb726f1 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -98,11 +98,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return self async def executemany( - self, - query: Query, - params_seq: Iterable[Params], - *, - returning: bool = False, + self, query: Query, params_seq: Iterable[Params], *, returning: bool = False ) -> None: """ Execute the same command with a sequence of input data. @@ -112,8 +108,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): # If there is already a pipeline, ride it, in order to avoid # sending unnecessary Sync. async with self._conn.lock: - p = self._conn._pipeline - if p: + if p := self._conn._pipeline: await self._conn.wait( self._executemany_gen_pipeline(query, params_seq, returning) ) @@ -157,8 +152,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): first = True while await self._conn.wait(self._stream_fetchone_gen(first)): for pos in range(size): - rec = self._tx.load_row(pos, self._make_row) - if rec is None: + if (rec := self._tx.load_row(pos, self._make_row)) is None: break yield rec first = False @@ -196,8 +190,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): """ await self._fetch_pipeline() self._check_result_for_fetch() - record = self._tx.load_row(self._pos, self._make_row) - if record is not None: + if (record := self._tx.load_row(self._pos, self._make_row)) is not None: self._pos += 1 return record @@ -216,9 +209,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): if not size: size = self.arraysize records = self._tx.load_rows( - self._pos, - min(self._pos + size, self.pgresult.ntuples), - self._make_row, + self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row ) self._pos += len(records) return records @@ -244,8 +235,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): return self._tx.load_row(pos, self._make_row) while True: - row = load(self._pos) - if row is None: + if (row := load(self._pos)) is None: break self._pos += 1 yield row diff --git a/psycopg/psycopg/dbapi20.py b/psycopg/psycopg/dbapi20.py index cdcf8655f..a6f0d75c4 100644 --- a/psycopg/psycopg/dbapi20.py +++ b/psycopg/psycopg/dbapi20.py @@ -71,8 +71,7 @@ class Binary: self.obj = obj def __repr__(self) -> str: - sobj = repr(self.obj) - if len(sobj) > 40: + if len((sobj := repr(self.obj))) > 40: sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)" return f"{self.__class__.__name__}({sobj})" diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py index f1827be2b..f558cc6e3 100644 --- a/psycopg/psycopg/generators.py +++ b/psycopg/psycopg/generators.py @@ -70,8 +70,7 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: if conn.status == BAD: encoding = conninfo_encoding(conninfo) raise e.OperationalError( - f"connection is bad: {conn.get_error_message(encoding)}", - pgconn=conn, + f"connection is bad: {conn.get_error_message(encoding)}", pgconn=conn ) status = conn.connect_poll() @@ -107,8 +106,8 @@ def _cancel(cancel_conn: PGcancelConn, *, timeout: float = 0.0) -> PQGenConn[Non while True: if deadline and monotonic() > deadline: raise e.CancellationTimeout("cancellation timeout expired") - status = cancel_conn.poll() - if status == POLL_OK: + + if (status := cancel_conn.poll()) == POLL_OK: break elif status == POLL_READING: yield cancel_conn.socket, WAIT_R @@ -150,13 +149,11 @@ def _send(pgconn: PGconn) -> PQGen[None]: to retrieve the results available. """ while True: - f = pgconn.flush() - if f == 0: + if pgconn.flush() == 0: break while True: - ready = yield WAIT_RW - if ready: + if ready := (yield WAIT_RW): break if ready & READY_R: @@ -220,8 +217,7 @@ def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]: """ if pgconn.is_busy(): while True: - ready = yield WAIT_R - if ready: + if (yield WAIT_R): break while True: @@ -229,8 +225,7 @@ def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]: if not pgconn.is_busy(): break while True: - ready = yield WAIT_R - if ready: + if (yield WAIT_R): break _consume_notifies(pgconn) @@ -250,8 +245,7 @@ def _pipeline_communicate( while True: while True: - ready = yield WAIT_RW - if ready: + if ready := (yield WAIT_RW): break if ready & READY_R: @@ -260,28 +254,23 @@ def _pipeline_communicate( res: list[PGresult] = [] while not pgconn.is_busy(): - r = pgconn.get_result() - if r is None: + if (r := pgconn.get_result()) is None: if not res: break results.append(res) res = [] + elif (status := r.status) == PIPELINE_SYNC: + assert not res + results.append([r]) + elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: + # This shouldn't happen, but insisting hard enough, it will. + # For instance, in test_executemany_badquery(), with the COPY + # statement and the AsyncClientCursor, which disables + # prepared statements). + # Bail out from the resulting infinite loop. + raise e.NotSupportedError("COPY cannot be used in pipeline mode") else: - status = r.status - if status == PIPELINE_SYNC: - assert not res - results.append([r]) - elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH: - # This shouldn't happen, but insisting hard enough, it will. - # For instance, in test_executemany_badquery(), with the COPY - # statement and the AsyncClientCursor, which disables - # prepared statements). - # Bail out from the resulting infinite loop. - raise e.NotSupportedError( - "COPY cannot be used in pipeline mode" - ) - else: - res.append(r) + res.append(r) if ready & READY_W: pgconn.flush() @@ -295,8 +284,7 @@ def _pipeline_communicate( def _consume_notifies(pgconn: PGconn) -> None: # Consume notifies while True: - n = pgconn.notifies() - if not n: + if not (n := pgconn.notifies()): break if pgconn.notify_handler: pgconn.notify_handler(n) @@ -308,8 +296,7 @@ def notifies(pgconn: PGconn) -> PQGen[list[pq.PGnotify]]: ns = [] while True: - n = pgconn.notifies() - if n: + if n := pgconn.notifies(): ns.append(n) if pgconn.notify_handler: pgconn.notify_handler(n) @@ -327,8 +314,7 @@ def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]: # would block while True: - ready = yield WAIT_R - if ready: + if (yield WAIT_R): break pgconn.consume_input() @@ -337,12 +323,12 @@ def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]: return data # Retrieve the final result of copy - results = yield from _fetch_many(pgconn) - if len(results) > 1: + + if len(results := (yield from _fetch_many(pgconn))) > 1: # TODO: too brutal? Copy worked. raise e.ProgrammingError("you cannot mix COPY with other operations") - result = results[0] - if result.status != COMMAND_OK: + + if (result := results[0]).status != COMMAND_OK: raise e.error_from_result(result, encoding=pgconn._encoding) return result @@ -357,8 +343,7 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]: # do it upstream the queue decoupling the writer task from the producer one. while pgconn.put_copy_data(buffer) == 0: while True: - ready = yield WAIT_W - if ready: + if (yield WAIT_W): break # Flushing often has a good effect on macOS because memcpy operations @@ -368,11 +353,10 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]: # Repeat until it the message is flushed to the server while True: while True: - ready = yield WAIT_W - if ready: + if (yield WAIT_W): break - f = pgconn.flush() - if f == 0: + + if pgconn.flush() == 0: break @@ -380,18 +364,16 @@ def copy_end(pgconn: PGconn, error: bytes | None) -> PQGen[PGresult]: # Retry enqueuing end copy message until successful while pgconn.put_copy_end(error) == 0: while True: - ready = yield WAIT_W - if ready: + if (yield WAIT_W): break # Repeat until it the message is flushed to the server while True: while True: - ready = yield WAIT_W - if ready: + if (yield WAIT_W): break - f = pgconn.flush() - if f == 0: + + if pgconn.flush() == 0: break # Retrieve the final result of copy diff --git a/psycopg/psycopg/pq/_debug.py b/psycopg/psycopg/pq/_debug.py index d55be281d..6fb389112 100644 --- a/psycopg/psycopg/pq/_debug.py +++ b/psycopg/psycopg/pq/_debug.py @@ -56,8 +56,7 @@ class PGconnDebug: return f"<{cls} {info} at 0x{id(self):x}>" def __getattr__(self, attr: str) -> Any: - value = getattr(self._pgconn, attr) - if callable(value): + if callable((value := getattr(self._pgconn, attr))): return debugging(value) else: logger.info("PGconn.%s -> %s", attr, value) diff --git a/psycopg/psycopg/pq/_pq_ctypes.py b/psycopg/psycopg/pq/_pq_ctypes.py index 9128f1532..99e49357c 100644 --- a/psycopg/psycopg/pq/_pq_ctypes.py +++ b/psycopg/psycopg/pq/_pq_ctypes.py @@ -16,8 +16,7 @@ from typing import Any, NoReturn from .misc import find_libpq_full_path, version_pretty from ..errors import NotSupportedError -libname = find_libpq_full_path() -if not libname: +if not (libname := find_libpq_full_path()): raise ImportError("libpq library not found") pq = ctypes.cdll.LoadLibrary(libname) @@ -30,8 +29,7 @@ class FILE(Structure): FILE_ptr = POINTER(FILE) if sys.platform == "linux": - libcname = ctypes.util.find_library("c") - if not libcname: + if not (libcname := ctypes.util.find_library("c")): # Likely this is a system using musl libc, see the following bug: # https://github.com/python/cpython/issues/65821 libcname = "libc.so" @@ -41,7 +39,6 @@ if sys.platform == "linux": fdopen.argtypes = (c_int, c_char_p) fdopen.restype = FILE_ptr - # Get the libpq version to define what functions are available. PQlibVersion = pq.PQlibVersion @@ -680,23 +677,14 @@ PQfreemem.restype = None if libpq_version >= 100000: PQencryptPasswordConn = pq.PQencryptPasswordConn - PQencryptPasswordConn.argtypes = [ - PGconn_ptr, - c_char_p, - c_char_p, - c_char_p, - ] + PQencryptPasswordConn.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_char_p] PQencryptPasswordConn.restype = POINTER(c_char) else: PQencryptPasswordConn = not_supported_before("PQencryptPasswordConn", 100000) if libpq_version >= 170000: PQchangePassword = pq.PQchangePassword - PQchangePassword.argtypes = [ - PGconn_ptr, - c_char_p, - c_char_p, - ] + PQchangePassword.argtypes = [PGconn_ptr, c_char_p, c_char_p] PQchangePassword.restype = PGresult_ptr else: PQchangePassword = not_supported_before("PQchangePassword", 170000) diff --git a/psycopg/psycopg/pq/misc.py b/psycopg/psycopg/pq/misc.py index 31be494a0..81d932ca4 100644 --- a/psycopg/psycopg/pq/misc.py +++ b/psycopg/psycopg/pq/misc.py @@ -52,11 +52,9 @@ class PGresAttDesc(NamedTuple): @cache def find_libpq_full_path() -> str | None: if sys.platform == "win32": - libname = ctypes.util.find_library("libpq.dll") - if libname is None: + if (libname := ctypes.util.find_library("libpq.dll")) is None: return None libname = str(Path(libname).resolve()) - elif sys.platform == "darwin": libname = ctypes.util.find_library("libpq.dylib") # (hopefully) temporary hack: libpq not in a standard place @@ -67,8 +65,7 @@ def find_libpq_full_path() -> str | None: import subprocess as sp libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode() - libname = os.path.join(libdir, "libpq.dylib") - if not os.path.exists(libname): + if not os.path.exists((libname := os.path.join(libdir, "libpq.dylib"))): libname = None except Exception as ex: logger.debug("couldn't use pg_config to find libpq: %s", ex) @@ -129,16 +126,14 @@ PREFIXES = re.compile( def strip_severity(msg: str) -> str: """Strip severity and whitespaces from error message.""" - m = PREFIXES.match(msg) - if m: + if m := PREFIXES.match(msg): msg = msg[m.span()[1] :] return msg.strip() def _clean_error_message(msg: bytes, encoding: str) -> str: - smsg = msg.decode(encoding, "replace") - if smsg: + if smsg := msg.decode(encoding, "replace"): return strip_severity(smsg) else: return "no error details available" @@ -169,8 +164,7 @@ def connection_summary(pgconn: abc.PGconn) -> str: else: status = ConnStatus(pgconn.status).name - sparts = " ".join("%s=%s" % part for part in parts) - if sparts: + if sparts := " ".join(("%s=%s" % part for part in parts)): sparts = f" ({sparts})" return f"[{status}]{sparts}" diff --git a/psycopg/psycopg/pq/pq_ctypes.py b/psycopg/psycopg/pq/pq_ctypes.py index e35a0486f..6229bb02d 100644 --- a/psycopg/psycopg/pq/pq_ctypes.py +++ b/psycopg/psycopg/pq/pq_ctypes.py @@ -109,9 +109,7 @@ class PGconn: def connect(cls, conninfo: bytes) -> PGconn: if not isinstance(conninfo, bytes): raise TypeError(f"bytes expected, got {type(conninfo)} instead") - - pgconn_ptr = impl.PQconnectdb(conninfo) - if not pgconn_ptr: + if not (pgconn_ptr := impl.PQconnectdb(conninfo)): raise MemoryError("couldn't allocate PGconn") return cls(pgconn_ptr) @@ -119,9 +117,7 @@ class PGconn: def connect_start(cls, conninfo: bytes) -> PGconn: if not isinstance(conninfo, bytes): raise TypeError(f"bytes expected, got {type(conninfo)} instead") - - pgconn_ptr = impl.PQconnectStart(conninfo) - if not pgconn_ptr: + if not (pgconn_ptr := impl.PQconnectStart(conninfo)): raise MemoryError("couldn't allocate PGconn") return cls(pgconn_ptr) @@ -151,8 +147,7 @@ class PGconn: @property def info(self) -> list[ConninfoOption]: self._ensure_pgconn() - opts = impl.PQconninfo(self._pgconn_ptr) - if not opts: + if not (opts := impl.PQconninfo(self._pgconn_ptr)): raise MemoryError("couldn't allocate connection info") try: return Conninfo._options_from_array(opts) @@ -246,8 +241,7 @@ class PGconn: @property def socket(self) -> int: - rv = self._call_int(impl.PQsocket) - if rv == -1: + if (rv := self._call_int(impl.PQsocket)) == -1: raise e.OperationalError("the connection is lost") return rv @@ -280,8 +274,7 @@ class PGconn: if not isinstance(command, bytes): raise TypeError(f"bytes expected, got {type(command)} instead") self._ensure_pgconn() - rv = impl.PQexec(self._pgconn_ptr, command) - if not rv: + if not (rv := impl.PQexec(self._pgconn_ptr, command)): raise e.OperationalError( f"executing query failed: {self.get_error_message()}" ) @@ -308,8 +301,7 @@ class PGconn: command, param_values, param_types, param_formats, result_format ) self._ensure_pgconn() - rv = impl.PQexecParams(*args) - if not rv: + if not (rv := impl.PQexecParams(*args)): raise e.OperationalError( f"executing query failed: {self.get_error_message()}" ) @@ -333,10 +325,7 @@ class PGconn: ) def send_prepare( - self, - name: bytes, - command: bytes, - param_types: Sequence[int] | None = None, + self, name: bytes, command: bytes, param_types: Sequence[int] | None = None ) -> None: atypes: Array[impl.Oid] | None if not param_types: @@ -432,10 +421,7 @@ class PGconn: ) def prepare( - self, - name: bytes, - command: bytes, - param_types: Sequence[int] | None = None, + self, name: bytes, command: bytes, param_types: Sequence[int] | None = None ) -> PGresult: if not isinstance(name, bytes): raise TypeError(f"'name' must be bytes, got {type(name)} instead") @@ -451,8 +437,7 @@ class PGconn: atypes = (impl.Oid * nparams)(*param_types) self._ensure_pgconn() - rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes) - if not rv: + if not (rv := impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)): raise e.OperationalError( f"preparing query failed: {self.get_error_message()}" ) @@ -514,8 +499,7 @@ class PGconn: if not isinstance(name, bytes): raise TypeError(f"'name' must be bytes, got {type(name)} instead") self._ensure_pgconn() - rv = impl.PQdescribePrepared(self._pgconn_ptr, name) - if not rv: + if not (rv := impl.PQdescribePrepared(self._pgconn_ptr, name)): raise e.OperationalError( f"describe prepared failed: {self.get_error_message()}" ) @@ -534,8 +518,7 @@ class PGconn: if not isinstance(name, bytes): raise TypeError(f"'name' must be bytes, got {type(name)} instead") self._ensure_pgconn() - rv = impl.PQdescribePortal(self._pgconn_ptr, name) - if not rv: + if not (rv := impl.PQdescribePortal(self._pgconn_ptr, name)): raise e.OperationalError( f"describe portal failed: {self.get_error_message()}" ) @@ -554,8 +537,7 @@ class PGconn: if not isinstance(name, bytes): raise TypeError(f"'name' must be bytes, got {type(name)} instead") self._ensure_pgconn() - rv = impl.PQclosePrepared(self._pgconn_ptr, name) - if not rv: + if not (rv := impl.PQclosePrepared(self._pgconn_ptr, name)): raise e.OperationalError( f"close prepared failed: {self.get_error_message()}" ) @@ -574,8 +556,7 @@ class PGconn: if not isinstance(name, bytes): raise TypeError(f"'name' must be bytes, got {type(name)} instead") self._ensure_pgconn() - rv = impl.PQclosePortal(self._pgconn_ptr, name) - if not rv: + if not (rv := impl.PQclosePortal(self._pgconn_ptr, name)): raise e.OperationalError(f"close portal failed: {self.get_error_message()}") return PGresult(rv) @@ -635,8 +616,7 @@ class PGconn: See :pq:`PQcancelCreate` for details. """ - rv = impl.PQcancelCreate(self._pgconn_ptr) - if not rv: + if not (rv := impl.PQcancelCreate(self._pgconn_ptr)): raise e.OperationalError("couldn't create cancelConn object") return PGcancelConn(rv) @@ -646,14 +626,12 @@ class PGconn: See :pq:`PQgetCancel` for details. """ - rv = impl.PQgetCancel(self._pgconn_ptr) - if not rv: + if not (rv := impl.PQgetCancel(self._pgconn_ptr)): raise e.OperationalError("couldn't create cancel object") return PGcancel(rv) def notifies(self) -> PGnotify | None: - ptr = impl.PQnotifies(self._pgconn_ptr) - if ptr: + if ptr := impl.PQnotifies(self._pgconn_ptr): c = ptr.contents rv = PGnotify(c.relname, c.be_pid, c.extra) impl.PQfreemem(ptr) @@ -664,16 +642,14 @@ class PGconn: def put_copy_data(self, buffer: abc.Buffer) -> int: if not isinstance(buffer, bytes): buffer = bytes(buffer) - rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer)) - if rv < 0: + if (rv := impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))) < 0: raise e.OperationalError( f"sending copy data failed: {self.get_error_message()}" ) return rv def put_copy_end(self, error: bytes | None = None) -> int: - rv = impl.PQputCopyEnd(self._pgconn_ptr, error) - if rv < 0: + if (rv := impl.PQputCopyEnd(self._pgconn_ptr, error)) < 0: raise e.OperationalError( f"sending copy end failed: {self.get_error_message()}" ) @@ -756,8 +732,7 @@ class PGconn: ) def make_empty_result(self, exec_status: int) -> PGresult: - rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status) - if not rv: + if not (rv := impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)): raise MemoryError("couldn't allocate empty PGresult") return PGresult(rv) @@ -791,8 +766,8 @@ class PGconn: :raises ~e.OperationalError: if the connection is not in pipeline mode or if sync failed. """ - rv = impl.PQpipelineSync(self._pgconn_ptr) - if rv == 0: + + if (rv := impl.PQpipelineSync(self._pgconn_ptr)) == 0: raise e.OperationalError("connection not in pipeline mode") if rv != 1: raise e.OperationalError("failed to sync pipeline") @@ -928,11 +903,10 @@ class PGresult: if length: v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number) return string_at(v, length) + elif impl.PQgetisnull(self._pgresult_ptr, row_number, column_number): + return None else: - if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number): - return None - else: - return b"" + return b"" @property def nparams(self) -> int: @@ -959,8 +933,8 @@ class PGresult: impl.PGresAttDesc_struct(*desc) for desc in descriptions # type: ignore ] array = (impl.PGresAttDesc_struct * len(structs))(*structs) # type: ignore - rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) - if rv == 0: + + if impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) == 0: raise e.OperationalError("PQsetResultAttrs failed") @@ -1011,8 +985,7 @@ class PGcancelConn: @property def socket(self) -> int: - rv = impl.PQcancelSocket(self.pgcancelconn_ptr) - if rv == -1: + if (rv := impl.PQcancelSocket(self.pgcancelconn_ptr)) == -1: raise e.OperationalError("cancel connection not opened") return rv @@ -1095,8 +1068,7 @@ class Conninfo: @classmethod def get_defaults(cls) -> list[ConninfoOption]: - opts = impl.PQconndefaults() - if not opts: + if not (opts := impl.PQconndefaults()): raise MemoryError("couldn't allocate connection defaults") try: return cls._options_from_array(opts) @@ -1157,8 +1129,8 @@ class Escaping: # TODO: might be done without copy (however C does that) if not isinstance(data, bytes): data = bytes(data) - out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data)) - if not out: + + if not (out := impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))): raise e.OperationalError( f"escape_literal failed: {self.conn.get_error_message()} bytes" ) @@ -1174,8 +1146,8 @@ class Escaping: if not isinstance(data, bytes): data = bytes(data) - out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data)) - if not out: + + if not (out := impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))): raise e.OperationalError( f"escape_identifier failed: {self.conn.get_error_message()} bytes" ) diff --git a/psycopg/psycopg/rows.py b/psycopg/psycopg/rows.py index 1053fc80c..a29594067 100644 --- a/psycopg/psycopg/rows.py +++ b/psycopg/psycopg/rows.py @@ -116,14 +116,15 @@ def dict_row(cursor: BaseCursor[Any, Any]) -> RowMaker[DictRow]: The dictionary keys are taken from the column names of the returned columns. """ - names = _get_names(cursor) - if names is None: - return no_result + if (names := _get_names(cursor)) is not None: + + def dict_row_(values: Sequence[Any]) -> dict[str, Any]: + return dict(zip(names, values)) - def dict_row_(values: Sequence[Any]) -> dict[str, Any]: - return dict(zip(names, values)) + return dict_row_ - return dict_row_ + else: + return no_result def namedtuple_row(cursor: BaseCursor[Any, Any]) -> RowMaker[NamedTuple]: @@ -132,17 +133,12 @@ def namedtuple_row(cursor: BaseCursor[Any, Any]) -> RowMaker[NamedTuple]: The field names are taken from the column names of the returned columns, with some mangling to deal with invalid names. """ - res = cursor.pgresult - if not res: - return no_result - - nfields = _get_nfields(res) - if nfields is None: + if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None: + nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields))) + return nt._make + else: return no_result - nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields))) - return nt._make - @functools.lru_cache(512) def _make_nt(enc: str, *names: bytes) -> type[NamedTuple]: @@ -161,14 +157,15 @@ def class_row(cls: type[T]) -> BaseRowFactory[T]: """ def class_row_(cursor: BaseCursor[Any, Any]) -> RowMaker[T]: - names = _get_names(cursor) - if names is None: - return no_result + if (names := _get_names(cursor)) is not None: - def class_row__(values: Sequence[Any]) -> T: - return cls(**dict(zip(names, values))) + def class_row__(values: Sequence[Any]) -> T: + return cls(**dict(zip(names, values))) - return class_row__ + return class_row__ + + else: + return no_result return class_row_ @@ -197,14 +194,15 @@ def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]: """ def kwargs_row_(cursor: BaseCursor[Any, T]) -> RowMaker[T]: - names = _get_names(cursor) - if names is None: - return no_result + if (names := _get_names(cursor)) is not None: - def kwargs_row__(values: Sequence[Any]) -> T: - return func(**dict(zip(names, values))) + def kwargs_row__(values: Sequence[Any]) -> T: + return func(**dict(zip(names, values))) - return kwargs_row__ + return kwargs_row__ + + else: + return no_result return kwargs_row_ @@ -214,21 +212,17 @@ def scalar_row(cursor: BaseCursor[Any, Any]) -> RowMaker[Any]: Generate a row factory returning the first column as a scalar value. """ - res = cursor.pgresult - if not res: - return no_result + if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None: + if nfields < 1: + raise e.ProgrammingError("at least one column expected") - nfields = _get_nfields(res) - if nfields is None: - return no_result - - if nfields < 1: - raise e.ProgrammingError("at least one column expected") + def scalar_row_(values: Sequence[Any]) -> Any: + return values[0] - def scalar_row_(values: Sequence[Any]) -> Any: - return values[0] + return scalar_row_ - return scalar_row_ + else: + return no_result def no_result(values: Sequence[Any]) -> NoReturn: @@ -242,19 +236,14 @@ def no_result(values: Sequence[Any]) -> NoReturn: def _get_names(cursor: BaseCursor[Any, Any]) -> list[str] | None: - res = cursor.pgresult - if not res: - return None - - nfields = _get_nfields(res) - if nfields is None: + if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None: + enc = cursor._encoding + return [ + res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr] + ] + else: return None - enc = cursor._encoding - return [ - res.fname(i).decode(enc) for i in range(nfields) # type: ignore[union-attr] - ] - def _get_nfields(res: PGresult) -> int | None: """ diff --git a/psycopg/psycopg/server_cursor.py b/psycopg/psycopg/server_cursor.py index f375846e0..bd30d9c87 100644 --- a/psycopg/psycopg/server_cursor.py +++ b/psycopg/psycopg/server_cursor.py @@ -40,12 +40,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): __slots__ = "_name _scrollable _withhold _described itersize _format".split() - def __init__( - self, - name: str, - scrollable: bool | None, - withhold: bool, - ): + def __init__(self, name: str, scrollable: bool | None, withhold: bool): self._name = name self._scrollable = scrollable self._withhold = withhold @@ -95,10 +90,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): return self._pos if tuples else None def _declare_gen( - self, - query: Query, - params: Params | None = None, - binary: bool | None = None, + self, query: Query, params: Params | None = None, binary: bool | None = None ) -> PQGen[None]: """Generator implementing `ServerCursor.execute()`.""" @@ -195,10 +187,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]): if not isinstance(query, sql.Composable): query = sql.SQL(query) - parts = [ - sql.SQL("DECLARE"), - sql.Identifier(self._name), - ] + parts = [sql.SQL("DECLARE"), sql.Identifier(self._name)] if self._scrollable is not None: parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL")) parts.append(sql.SQL("CURSOR")) @@ -295,11 +284,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]): return self def executemany( - self, - query: Query, - params_seq: Iterable[Params], - *, - returning: bool = True, + self, query: Query, params_seq: Iterable[Params], *, returning: bool = True ) -> None: """Method not implemented for server-side cursors.""" raise e.NotSupportedError("executemany not supported on server-side cursors") @@ -428,11 +413,7 @@ class AsyncServerCursor( return self async def executemany( - self, - query: Query, - params_seq: Iterable[Params], - *, - returning: bool = True, + self, query: Query, params_seq: Iterable[Params], *, returning: bool = True ) -> None: raise e.NotSupportedError("executemany not supported on server-side cursors") diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 81a96a098..172110afe 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -79,10 +79,8 @@ class Composable(ABC): :type context: `connection` or `cursor` """ - conn = context.connection if context else None - enc = conn_encoding(conn) - b = self.as_bytes(context) - if isinstance(b, bytes): + enc = conn_encoding(context.connection if context else None) + if isinstance((b := self.as_bytes(context)), bytes): return b.decode(enc) else: # buffer object @@ -373,8 +371,7 @@ class Identifier(Composable): return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" def as_bytes(self, context: AdaptContext | None = None) -> bytes: - conn = context.connection if context else None - if conn: + if conn := (context.connection if context else None): esc = Escaping(conn.pgconn) enc = conn_encoding(conn) escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] diff --git a/psycopg/psycopg/types/array.py b/psycopg/psycopg/types/array.py index 156ee890f..5c80444fa 100644 --- a/psycopg/psycopg/types/array.py +++ b/psycopg/psycopg/types/array.py @@ -52,8 +52,7 @@ class BaseListDumper(RecursiveDumper): Find the first non-null element of an eventually nested list """ items = list(self._flatiter(L, set())) - types = {type(item): item for item in items} - if not types: + if not (types := {type(item): item for item in items}): return None if len(types) == 1: @@ -62,8 +61,7 @@ class BaseListDumper(RecursiveDumper): # More than one type in the list. It might be still good, as long # as they dump with the same oid (e.g. IPv4Network, IPv6Network). dumpers = [self._tx.get_dumper(item, format) for item in types.values()] - oids = {d.oid for d in dumpers} - if len(oids) == 1: + if len({d.oid for d in dumpers}) == 1: t, v = types.popitem() else: raise e.DataError( @@ -106,8 +104,7 @@ class BaseListDumper(RecursiveDumper): Return text info as fallback. """ if base_oid: - info = self._tx.adapters.types.get(base_oid) - if info: + if info := self._tx.adapters.types.get(base_oid): return info return self._tx.adapters.types["text"] @@ -120,8 +117,7 @@ class ListDumper(BaseListDumper): if self.oid: return self.cls - item = self._find_list_element(obj, format) - if item is None: + if (item := self._find_list_element(obj, format)) is None: return self.cls sd = self._tx.get_dumper(item, format) @@ -132,8 +128,7 @@ class ListDumper(BaseListDumper): if self.oid: return self - item = self._find_list_element(obj, format) - if item is None: + if (item := self._find_list_element(obj, format)) is None: # Empty lists can only be dumped as text if the type is unknown. return self @@ -170,8 +165,7 @@ class ListDumper(BaseListDumper): if isinstance(item, list): dump_list(item) elif item is not None: - ad = self._dump_item(item) - if ad is None: + if (ad := self._dump_item(item)) is None: tokens.append(b"NULL") else: if needs_quotes(ad): @@ -224,8 +218,7 @@ class ListBinaryDumper(BaseListDumper): if self.oid: return self.cls - item = self._find_list_element(obj, format) - if item is None: + if (item := self._find_list_element(obj, format)) is None: return (self.cls,) sd = self._tx.get_dumper(item, format) @@ -236,8 +229,7 @@ class ListBinaryDumper(BaseListDumper): if self.oid: return self - item = self._find_list_element(obj, format) - if item is None: + if (item := self._find_list_element(obj, format)) is None: return ListDumper(self.cls, self._tx) sd = self._tx.get_dumper(item, format.from_pq(self.format)) @@ -396,15 +388,14 @@ def _load_text( if data and data[0] == b"["[0]: if isinstance(data, memoryview): data = bytes(data) - idx = data.find(b"=") - if idx == -1: + + if (idx := data.find(b"=")) == -1: raise e.DataError("malformed array: no '=' after dimension information") data = data[idx + 1 :] re_parse = _get_array_parse_regexp(delimiter) for m in re_parse.finditer(data): - t = m.group(1) - if t == b"{": + if (t := m.group(1)) == b"{": if stack: stack[-1].append(a) stack.append(a) diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index 1c0f747da..ad4931239 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -94,8 +94,7 @@ class SequenceDumper(RecursiveDumper): continue dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format)) - ad = dumper.dump(item) - if ad is None: + if (ad := dumper.dump(item)) is None: ad = b"" elif not ad: ad = b'""' diff --git a/psycopg/psycopg/types/datetime.py b/psycopg/psycopg/types/datetime.py index 016e216bc..1fcab4b97 100644 --- a/psycopg/psycopg/types/datetime.py +++ b/psycopg/psycopg/types/datetime.py @@ -70,8 +70,7 @@ class _BaseTimeDumper(Dumper): raise NotImplementedError def _get_offset(self, obj: time) -> timedelta: - offset = obj.utcoffset() - if offset is None: + if (offset := obj.utcoffset()) is None: raise DataError( f"cannot calculate the offset of tzinfo '{obj.tzinfo}' without a date" ) @@ -236,8 +235,7 @@ class DateLoader(Loader): def __init__(self, oid: int, context: AdaptContext | None = None): super().__init__(oid, context) - ds = _get_datestyle(self.connection) - if ds.startswith(b"I"): # ISO + if (ds := _get_datestyle(self.connection)).startswith(b"I"): # ISO self._order = self._ORDER_YMD elif ds.startswith(b"G"): # German self._order = self._ORDER_DMY @@ -290,8 +288,7 @@ class TimeLoader(Loader): _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?") def load(self, data: Buffer) -> time: - m = self._re_format.match(data) - if not m: + if not (m := self._re_format.match(data)): s = bytes(data).decode("utf8", "replace") raise DataError(f"can't parse time {s!r}") @@ -337,8 +334,7 @@ class TimetzLoader(Loader): ) def load(self, data: Buffer) -> time: - m = self._re_format.match(data) - if not m: + if not (m := self._re_format.match(data)): s = bytes(data).decode("utf8", "replace") raise DataError(f"can't parse timetz {s!r}") @@ -415,9 +411,7 @@ class TimestampLoader(Loader): def __init__(self, oid: int, context: AdaptContext | None = None): super().__init__(oid, context) - - ds = _get_datestyle(self.connection) - if ds.startswith(b"I"): # ISO + if (ds := _get_datestyle(self.connection)).startswith(b"I"): # ISO self._order = self._ORDER_YMD elif ds.startswith(b"G"): # German self._order = self._ORDER_DMY @@ -430,8 +424,7 @@ class TimestampLoader(Loader): raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}") def load(self, data: Buffer) -> datetime: - m = self._re_format.match(data) - if not m: + if not (m := self._re_format.match(data)): raise _get_timestamp_load_error(self.connection, data) from None if self._order == self._ORDER_YMD: @@ -499,8 +492,7 @@ class TimestamptzLoader(Loader): super().__init__(oid, context) self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None) - ds = _get_datestyle(self.connection) - if ds.startswith(b"I"): # ISO + if _get_datestyle(self.connection).startswith(b"I"): # ISO self._load_method = self._load_iso else: self._load_method = self._load_notimpl @@ -510,8 +502,7 @@ class TimestamptzLoader(Loader): @staticmethod def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime: - m = self._re_format.match(data) - if not m: + if not (m := self._re_format.match(data)): raise _get_timestamp_load_error(self.connection, data) from None ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups() @@ -582,10 +573,9 @@ class TimestamptzBinaryLoader(Loader): # timezone) we can still save the day by shifting the value by the # timezone offset and then replacing the timezone. if self._timezone: - utcoff = self._timezone.utcoffset( + if utcoff := self._timezone.utcoffset( datetime.min if micros < 0 else datetime.max - ) - if utcoff: + ): usoff = 1_000_000 * int(utcoff.total_seconds()) try: ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff) @@ -624,8 +614,7 @@ class IntervalLoader(Loader): @staticmethod def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta: - m = self._re_interval.match(data) - if not m: + if not (m := self._re_interval.match(data)): s = bytes(data).decode("utf8", "replace") raise DataError(f"can't parse interval {s!r}") @@ -679,19 +668,15 @@ class IntervalBinaryLoader(Loader): def _get_datestyle(conn: BaseConnection[Any] | None) -> bytes: - if conn: - ds = conn.pgconn.parameter_status(b"DateStyle") - if ds: - return ds + if conn and (ds := conn.pgconn.parameter_status(b"DateStyle")): + return ds return b"ISO, DMY" def _get_intervalstyle(conn: BaseConnection[Any] | None) -> bytes: - if conn: - ints = conn.pgconn.parameter_status(b"IntervalStyle") - if ints: - return ints + if conn and (ints := conn.pgconn.parameter_status(b"IntervalStyle")): + return ints return b"unknown" @@ -705,8 +690,7 @@ def _get_timestamp_load_error( if not s: return False - ds = _get_datestyle(conn) - if not ds.startswith(b"P"): # Postgres + if not _get_datestyle(conn).startswith(b"P"): # Postgres return len(s.split()[0]) > 10 # date is first token else: return len(s.split()[-1]) > 4 # year is last token diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py index 7f1da6e67..e87f9d66b 100644 --- a/psycopg/psycopg/types/hstore.py +++ b/psycopg/psycopg/types/hstore.py @@ -84,8 +84,7 @@ class HstoreLoader(RecursiveLoader): if m is None or m.start() != start: raise e.DataError(f"error parsing hstore pair at char {start}") k = _re_unescape.sub(r"\1", m.group(1)) - v = m.group(2) - if v is not None: + if (v := m.group(2)) is not None: v = _re_unescape.sub(r"\1", v) rv[k] = v diff --git a/psycopg/psycopg/types/json.py b/psycopg/psycopg/types/json.py index 51bb22e03..e4fbb25c6 100644 --- a/psycopg/psycopg/types/json.py +++ b/psycopg/psycopg/types/json.py @@ -98,16 +98,14 @@ def set_json_loads( @cache def _make_dumper(base: type[abc.Dumper], dumps: JsonDumpsFunction) -> type[abc.Dumper]: - name = base.__name__ - if not name.startswith("Custom"): + if not (name := base.__name__).startswith("Custom"): name = f"Custom{name}" return type(name, (base,), {"_dumps": dumps}) @cache def _make_loader(base: type[Loader], loads: JsonLoadsFunction) -> type[Loader]: - name = base.__name__ - if not name.startswith("Custom"): + if not (name := base.__name__).startswith("Custom"): name = f"Custom{name}" return type(name, (base,), {"_loads": loads}) @@ -120,8 +118,7 @@ class _JsonWrapper: self.dumps = dumps def __repr__(self) -> str: - sobj = repr(self.obj) - if len(sobj) > 40: + if len((sobj := repr(self.obj))) > 40: sobj = f"{sobj[:35]} ... ({len(sobj)} chars)" return f"{self.__class__.__name__}({sobj})" @@ -149,8 +146,7 @@ class _JsonDumper(Dumper): obj = obj.obj else: dumps = self.dumps - data = dumps(obj) - if isinstance(data, str): + if isinstance((data := dumps(obj)), str): return data.encode() return data @@ -173,8 +169,7 @@ class JsonbBinaryDumper(_JsonDumper): oid = _oids.JSONB_OID def dump(self, obj: Any) -> Buffer | None: - obj_bytes = super().dump(obj) - if obj_bytes is not None: + if (obj_bytes := super().dump(obj)) is not None: return b"\x01" + obj_bytes else: return None @@ -214,8 +209,7 @@ class JsonbBinaryLoader(_JsonLoader): def load(self, data: Buffer) -> Any: if data and data[0] != 1: raise DataError("unknown jsonb binary format: {data[0]}") - data = data[1:] - if not isinstance(data, bytes): + if not isinstance((data := data[1:]), bytes): data = bytes(data) return self.loads(data) diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index 8c1f75ec1..cbfecd690 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -190,8 +190,7 @@ class BaseMultirangeDumper(RecursiveDumper): if self.cls is not Multirange: return self.cls - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: sd = self._tx.get_dumper(item, self._adapt_format) return (self.cls, sd.get_key(item, format)) else: @@ -202,8 +201,7 @@ class BaseMultirangeDumper(RecursiveDumper): if self.cls is not Multirange: return self - item = self._get_item(obj) - if item is None: + if (item := self._get_item(obj)) is None: return self dumper: BaseMultirangeDumper @@ -257,8 +255,7 @@ class MultirangeDumper(BaseMultirangeDumper): if not obj: return b"{}" - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump @@ -275,8 +272,7 @@ class MultirangeBinaryDumper(BaseMultirangeDumper): format = Format.BINARY def dump(self, obj: Multirange[Any]) -> Buffer | None: - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump diff --git a/psycopg/psycopg/types/numeric.py b/psycopg/psycopg/types/numeric.py index 8c0e64ed1..28fbc9aa1 100644 --- a/psycopg/psycopg/types/numeric.py +++ b/psycopg/psycopg/types/numeric.py @@ -40,8 +40,7 @@ class _IntDumper(Dumper): return str(obj).encode() def quote(self, obj: Any) -> Buffer: - value = self.dump(obj) - if value is None: + if (value := self.dump(obj)) is None: return b"NULL" return value if obj >= 0 else b" " + value @@ -63,9 +62,7 @@ class _SpecialValuesDumper(Dumper): return str(obj).encode() def quote(self, obj: Any) -> Buffer: - value = self.dump(obj) - - if value is None: + if (value := self.dump(obj)) is None: return b"NULL" if not isinstance(value, bytes): value = bytes(value) @@ -158,11 +155,10 @@ class IntDumper(Dumper): return self._int2_dumper else: return self._int4_dumper + elif -(2**63) <= obj < 2**63: + return self._int8_dumper else: - if -(2**63) <= obj < 2**63: - return self._int8_dumper - else: - return self._int_numeric_dumper + return self._int_numeric_dumper class Int2BinaryDumper(Int2Dumper): @@ -453,8 +449,7 @@ def dump_decimal_to_numeric_binary(obj: Decimal) -> bytearray | bytes: # Equivalent of 0-padding left to align the py digits to the pg digits # but without changing the digits tuple. - mod = (ndigits - dscale) % DEC_DIGITS - if mod: + if mod := ((ndigits - dscale) % DEC_DIGITS): wi = DEC_DIGITS - mod ndigits += wi diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index 6488d3d76..f5074829e 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -185,17 +185,15 @@ class Range(Generic[T]): # It doesn't seem that Python has an ABC for ordered types. if x < self._lower: # type: ignore[operator] return False - else: - if x <= self._lower: # type: ignore[operator] - return False + elif x <= self._lower: # type: ignore[operator] + return False if self._upper is not None: if self._bounds[1] == "]": if x > self._upper: # type: ignore[operator] return False - else: - if x >= self._upper: # type: ignore[operator] - return False + elif x >= self._upper: # type: ignore[operator] + return False return True @@ -223,8 +221,7 @@ class Range(Generic[T]): return NotImplemented for attr in ("_lower", "_upper", "_bounds"): self_value = getattr(self, attr) - other_value = getattr(other, attr) - if self_value == other_value: + if self_value == (other_value := getattr(other, attr)): pass elif self_value is None: return True @@ -294,9 +291,7 @@ class BaseRangeDumper(RecursiveDumper): # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Range: return self.cls - - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: sd = self._tx.get_dumper(item, self._adapt_format) return (self.cls, sd.get_key(item, format)) else: @@ -306,9 +301,7 @@ class BaseRangeDumper(RecursiveDumper): # If we are a subclass whose oid is specified we don't need upgrade if self.cls is not Range: return self - - item = self._get_item(obj) - if item is None: + if (item := self._get_item(obj)) is None: return self dumper: BaseRangeDumper @@ -355,8 +348,7 @@ class RangeDumper(BaseRangeDumper): """ def dump(self, obj: Range[Any]) -> Buffer | None: - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump @@ -371,8 +363,7 @@ def dump_range_text(obj: Range[Any], dump: DumpFunc) -> Buffer: parts: list[Buffer] = [b"[" if obj.lower_inc else b"("] def dump_item(item: Any) -> Buffer: - ad = dump(item) - if ad is None: + if (ad := dump(item)) is None: return b"" elif not ad: return b'""' @@ -402,8 +393,7 @@ class RangeBinaryDumper(BaseRangeDumper): format = Format.BINARY def dump(self, obj: Range[Any]) -> Buffer | None: - item = self._get_item(obj) - if item is not None: + if (item := self._get_item(obj)) is not None: dump = self._tx.get_dumper(item, self._adapt_format).dump else: dump = fail_dump @@ -424,8 +414,7 @@ def dump_range_binary(obj: Range[Any], dump: DumpFunc) -> Buffer: head |= RANGE_UB_INC if obj.lower is not None: - data = dump(obj.lower) - if data is not None: + if (data := dump(obj.lower)) is not None: out += pack_len(len(data)) out += data else: @@ -434,8 +423,7 @@ def dump_range_binary(obj: Range[Any], dump: DumpFunc) -> Buffer: head |= RANGE_LB_INF if obj.upper is not None: - data = dump(obj.upper) - if data is not None: + if (data := dump(obj.upper)) is not None: out += pack_len(len(data)) out += data else: @@ -472,27 +460,21 @@ class RangeLoader(BaseRangeLoader[T]): def load_range_text(data: Buffer, load: LoadFunc) -> tuple[Range[Any], int]: if data == b"empty": return Range(empty=True), 5 - - m = _re_range.match(data) - if m is None: + if (m := _re_range.match(data)) is None: raise e.DataError( f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'" ) lower = None - item = m.group(3) - if item is None: - item = m.group(2) - if item is not None: + if (item := m.group(3)) is None: + if (item := m.group(2)) is not None: lower = load(_re_undouble.sub(rb"\1", item)) else: lower = load(item) upper = None - item = m.group(5) - if item is None: - item = m.group(4) - if item is not None: + if (item := m.group(5)) is None: + if (item := m.group(4)) is not None: upper = load(_re_undouble.sub(rb"\1", item)) else: upper = load(item) @@ -530,8 +512,7 @@ class RangeBinaryLoader(BaseRangeLoader[T]): def load_range_binary(data: Buffer, load: LoadFunc) -> Range[Any]: - head = data[0] - if head & RANGE_EMPTY: + if (head := data[0]) & RANGE_EMPTY: return Range(empty=True) lb = "[" if head & RANGE_LB_INC else "(" diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index 262ac658c..ef4fceb28 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -138,8 +138,7 @@ class BytesDumper(Dumper): return self._esc.escape_bytea(obj) def quote(self, obj: Buffer) -> Buffer: - escaped = self.dump(obj) - if escaped is None: + if (escaped := self.dump(obj)) is None: return b"NULL" # We cannot use the base quoting because escape_bytea already returns diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py index 05994b5d5..7df0b960d 100644 --- a/psycopg/psycopg/waiting.py +++ b/psycopg/psycopg/waiting.py @@ -54,8 +54,7 @@ def wait_selector(gen: PQGen[RV], fileno: int, interval: float | None = None) -> with DefaultSelector() as sel: sel.register(fileno, s) while True: - rlist = sel.select(timeout=interval) - if not rlist: + if not (rlist := sel.select(timeout=interval)): gen.send(READY_NONE) continue @@ -89,8 +88,7 @@ def wait_conn(gen: PQGenConn[RV], interval: float | None = None) -> RV: with DefaultSelector() as sel: sel.register(fileno, s) while True: - rlist = sel.select(timeout=interval) - if not rlist: + if not (rlist := sel.select(timeout=interval)): gen.send(READY_NONE) continue @@ -134,7 +132,7 @@ async def wait_async(gen: PQGen[RV], fileno: int, interval: float | None = None) while True: reader = s & WAIT_R writer = s & WAIT_W - if not reader and not writer: + if not (reader or writer): raise e.InternalError(f"bad poll status: {s}") ev.clear() ready = 0 @@ -195,7 +193,7 @@ async def wait_conn_async(gen: PQGenConn[RV], interval: float | None = None) -> while True: reader = s & WAIT_R writer = s & WAIT_W - if not reader and not writer: + if not (reader or writer): raise e.InternalError(f"bad poll status: {s}") ev.clear() ready = 0 # type: ignore[assignment] @@ -297,8 +295,7 @@ def wait_epoll(gen: PQGen[RV], fileno: int, interval: float | None = None) -> RV evmask = _epoll_evmasks[s] epoll.register(fileno, evmask) while True: - fileevs = epoll.poll(interval) - if not fileevs: + if not (fileevs := epoll.poll(interval)): gen.send(READY_NONE) continue ev = fileevs[0][1] @@ -344,8 +341,7 @@ def wait_poll(gen: PQGen[RV], fileno: int, interval: float | None = None) -> RV: evmask = _poll_evmasks[s] poll.register(fileno, evmask) while True: - fileevs = poll.poll(interval) - if not fileevs: + if not (fileevs := poll.poll(interval)): gen.send(READY_NONE) continue @@ -374,8 +370,7 @@ def _is_select_patched() -> bool: Currently supported: gevent. """ # If not imported, don't import it. - m = sys.modules.get("gevent.monkey") - if m: + if m := sys.modules.get("gevent.monkey"): try: if m.is_module_patched("select"): return True diff --git a/psycopg_pool/psycopg_pool/base.py b/psycopg_pool/psycopg_pool/base.py index 9f9984f22..963c8bd3e 100644 --- a/psycopg_pool/psycopg_pool/base.py +++ b/psycopg_pool/psycopg_pool/base.py @@ -145,8 +145,7 @@ class BasePool: raise PoolClosed(f"the pool {self.name!r} is not open yet") def _check_pool_putconn(self, conn: BaseConnection[Any]) -> None: - pool = getattr(conn, "_pool", None) - if pool is self: + if (pool := getattr(conn, "_pool", None)) is self: return if pool: diff --git a/psycopg_pool/psycopg_pool/null_pool.py b/psycopg_pool/psycopg_pool/null_pool.py index 37b7bc707..0370702ef 100644 --- a/psycopg_pool/psycopg_pool/null_pool.py +++ b/psycopg_pool/psycopg_pool/null_pool.py @@ -162,8 +162,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]): while self._waiting: # If there is a client waiting (which is still waiting and # hasn't timed out), give it the connection and notify it. - pos = self._waiting.popleft() - if pos.set(conn): + if self._waiting.popleft().set(conn): break else: # No client waiting for a connection: close the connection diff --git a/psycopg_pool/psycopg_pool/null_pool_async.py b/psycopg_pool/psycopg_pool/null_pool_async.py index b73504824..53a2201c9 100644 --- a/psycopg_pool/psycopg_pool/null_pool_async.py +++ b/psycopg_pool/psycopg_pool/null_pool_async.py @@ -158,13 +158,11 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT]) while self._waiting: # If there is a client waiting (which is still waiting and # hasn't timed out), give it the connection and notify it. - pos = self._waiting.popleft() - if await pos.set(conn): + if await self._waiting.popleft().set(conn): break else: # No client waiting for a connection: close the connection await conn.close() - # If we have been asked to wait for pool init, notify the # waiter if the pool is ready. if self._pool_full_event: diff --git a/psycopg_pool/psycopg_pool/pool.py b/psycopg_pool/psycopg_pool/pool.py index 115894486..24090ca97 100644 --- a/psycopg_pool/psycopg_pool/pool.py +++ b/psycopg_pool/psycopg_pool/pool.py @@ -232,8 +232,7 @@ class ConnectionPool(Generic[CT], BasePool): # Critical section: decide here if there's a connection ready # or if the client needs to wait. with self._lock: - conn = self._get_ready_connection(timeout) - if not conn: + if not (conn := self._get_ready_connection(timeout)): # No connection available: put the client in the waiting queue t0 = monotonic() pos: WaitingClient[CT] = WaitingClient() @@ -571,9 +570,7 @@ class ConnectionPool(Generic[CT], BasePool): StopWorker is received. """ while True: - task = q.get() - - if isinstance(task, StopWorker): + if isinstance((task := q.get()), StopWorker): logger.debug("terminating working task %s", current_thread_name()) return @@ -606,8 +603,7 @@ class ConnectionPool(Generic[CT], BasePool): if self._configure: self._configure(conn) - status = conn.pgconn.transaction_status - if status != TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE: sname = TransactionStatus(status).name raise e.ProgrammingError( f"connection left in status {sname} by configure function {self._configure}: discarded" @@ -734,8 +730,8 @@ class ConnectionPool(Generic[CT], BasePool): while self._waiting: # If there is a client waiting (which is still waiting and # hasn't timed out), give it the connection and notify it. - pos = self._waiting.popleft() - if pos.set(conn): + + if self._waiting.popleft().set(conn): break else: # No client waiting for a connection: put it back into the pool @@ -749,8 +745,7 @@ class ConnectionPool(Generic[CT], BasePool): """ Bring a connection to IDLE state or close it. """ - status = conn.pgconn.transaction_status - if status == TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE: pass elif status == TransactionStatus.UNKNOWN: # Connection closed @@ -776,8 +771,7 @@ class ConnectionPool(Generic[CT], BasePool): if self._reset: try: self._reset(conn) - status = conn.pgconn.transaction_status - if status != TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE: sname = TransactionStatus(status).name raise e.ProgrammingError( f"connection left in status {sname} by reset function {self._reset}: discarded" diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index fc52a6fe4..af725a8e4 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -260,8 +260,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): # Critical section: decide here if there's a connection ready # or if the client needs to wait. async with self._lock: - conn = await self._get_ready_connection(timeout) - if not conn: + if not (conn := (await self._get_ready_connection(timeout))): # No connection available: put the client in the waiting queue t0 = monotonic() pos: WaitingClient[ACT] = WaitingClient() @@ -623,9 +622,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): StopWorker is received. """ while True: - task = await q.get() - - if isinstance(task, StopWorker): + if isinstance((task := (await q.get())), StopWorker): logger.debug("terminating working task %s", current_task_name()) return @@ -658,8 +655,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): if self._configure: await self._configure(conn) - status = conn.pgconn.transaction_status - if status != TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE: sname = TransactionStatus(status).name raise e.ProgrammingError( f"connection left in status {sname} by configure function" @@ -789,13 +785,12 @@ class AsyncConnectionPool(Generic[ACT], BasePool): while self._waiting: # If there is a client waiting (which is still waiting and # hasn't timed out), give it the connection and notify it. - pos = self._waiting.popleft() - if await pos.set(conn): + + if await self._waiting.popleft().set(conn): break else: # No client waiting for a connection: put it back into the pool self._pool.append(conn) - # If we have been asked to wait for pool init, notify the # waiter if the pool is full. if self._pool_full_event and len(self._pool) >= self._min_size: @@ -805,10 +800,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool): """ Bring a connection to IDLE state or close it. """ - status = conn.pgconn.transaction_status - if status == TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE: pass - elif status == TransactionStatus.UNKNOWN: # Connection closed return @@ -835,8 +828,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): if self._reset: try: await self._reset(conn) - status = conn.pgconn.transaction_status - if status != TransactionStatus.IDLE: + if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE: sname = TransactionStatus(status).name raise e.ProgrammingError( f"connection left in status {sname} by reset function" diff --git a/psycopg_pool/psycopg_pool/sched.py b/psycopg_pool/psycopg_pool/sched.py index f41179858..6fc4e2b91 100644 --- a/psycopg_pool/psycopg_pool/sched.py +++ b/psycopg_pool/psycopg_pool/sched.py @@ -66,8 +66,7 @@ class Scheduler: while True: with self._lock: now = monotonic() - task = q[0] if q else None - if task: + if task := (q[0] if q else None): if task.time <= now: heappop(q) else: diff --git a/psycopg_pool/psycopg_pool/sched_async.py b/psycopg_pool/psycopg_pool/sched_async.py index 86f59d036..3db6533d6 100644 --- a/psycopg_pool/psycopg_pool/sched_async.py +++ b/psycopg_pool/psycopg_pool/sched_async.py @@ -62,8 +62,7 @@ class AsyncScheduler: while True: async with self._lock: now = monotonic() - task = q[0] if q else None - if task: + if task := (q[0] if q else None): if task.time <= now: heappop(q) else: diff --git a/tests/conftest.py b/tests/conftest.py index 4a1d4de41..e25ce5218 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,11 +49,8 @@ def pytest_addoption(parser): def pytest_report_header(config): - rv = [] - - rv.append(f"default selector: {selectors.DefaultSelector.__name__}") - loop = config.getoption("--loop") - if loop != "default": + rv = [f"default selector: {selectors.DefaultSelector.__name__}"] + if (loop := config.getoption("--loop")) != "default": rv.append(f"asyncio loop: {loop}") return rv @@ -65,8 +62,7 @@ def pytest_sessionstart(session): # In case of segfault, pytest doesn't get a chance to write failed tests # in the cache. As a consequence, retries would find no test failed and # assume that all tests passed in the previous run, making the whole test pass. - cache = session.config.cache - if cache.get("segfault", False): + if (cache := session.config.cache).get("segfault", False): session.warn(Warning("Previous run resulted in segfault! Not running any test")) session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)")) raise session.Failed diff --git a/tests/crdb/test_cursor.py b/tests/crdb/test_cursor.py index b41de3023..0739a27c8 100644 --- a/tests/crdb/test_cursor.py +++ b/tests/crdb/test_cursor.py @@ -69,8 +69,7 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out): # We often find the record with {"after": null} at least another time # in the queue. Let's tolerate an extra one. for i in range(2): - row = q.get() - if row is None: + if (row := q.get()) is None: break assert json.loads(row.value)["after"] is None, json else: diff --git a/tests/crdb/test_cursor_async.py b/tests/crdb/test_cursor_async.py index 4c803fa87..c5a1986ae 100644 --- a/tests/crdb/test_cursor_async.py +++ b/tests/crdb/test_cursor_async.py @@ -70,8 +70,7 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out): # We often find the record with {"after": null} at least another time # in the queue. Let's tolerate an extra one. for i in range(2): - row = await q.get() - if row is None: + if (row := (await q.get())) is None: break assert json.loads(row.value)["after"] is None, json else: diff --git a/tests/fix_crdb.py b/tests/fix_crdb.py index f1679f6f0..60d721c21 100644 --- a/tests/fix_crdb.py +++ b/tests/fix_crdb.py @@ -38,12 +38,10 @@ def check_crdb_version(got, mark): pred = VersionCheck.parse(spec) pred.whose = "CockroachDB" - msg = pred.get_skip_message(got) - if not msg: + if not (msg := pred.get_skip_message(got)): return None - reason = crdb_skip_message(reason) - if reason: + if reason := crdb_skip_message(reason): msg = f"{msg}: {reason}" return msg diff --git a/tests/fix_db.py b/tests/fix_db.py index 342e51392..7b1879495 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -46,8 +46,7 @@ def pytest_addoption(parser): def pytest_report_header(config): - dsn = config.getoption("--test-dsn") - if dsn is None: + if (dsn := config.getoption("--test-dsn")) is None: return [] try: @@ -56,9 +55,7 @@ def pytest_report_header(config): except Exception as ex: server_version = f"unknown ({ex})" - return [ - f"Server version: {server_version}", - ] + return [f"Server version: {server_version}"] def pytest_collection_modifyitems(items): @@ -93,8 +90,7 @@ def session_dsn(request): """ Return the dsn used to connect to the `--test-dsn` database (session-wide). """ - dsn = request.config.getoption("--test-dsn") - if dsn is None: + if (dsn := request.config.getoption("--test-dsn")) is None: pytest.skip("skipping test as no --test-dsn") warm_up_database(dsn) @@ -130,8 +126,7 @@ def tracefile(request): """Open and yield a file for libpq client/server communication traces if --pq-tracefile option is set. """ - tracefile = request.config.getoption("--pq-trace") - if not tracefile: + if not (tracefile := request.config.getoption("--pq-trace")): yield None return @@ -191,9 +186,7 @@ def pgconn_debug(request): def pgconn(dsn, request, tracefile): """Return a PGconn connection open to `--test-dsn`.""" check_connection_version(request.node) - - conn = pq.PGconn.connect(dsn.encode()) - if conn.status != pq.ConnStatus.OK: + if (conn := pq.PGconn.connect(dsn.encode())).status != pq.ConnStatus.OK: pytest.fail(f"bad connection: {conn.get_error_message()}") with maybe_trace(conn, tracefile, request.function): @@ -298,8 +291,7 @@ def patch_exec(conn, monkeypatch): L = ListPopAll() def _exec_command(command, *args, **kwargs): - cmdcopy = command - if isinstance(cmdcopy, bytes): + if isinstance((cmdcopy := command), bytes): cmdcopy = cmdcopy.decode(conn.info.encoding) elif isinstance(cmdcopy, sql.Composable): cmdcopy = cmdcopy.as_string(conn) @@ -330,15 +322,12 @@ def check_connection_version(node): for mark in node.iter_markers(): if mark.name == "pg": assert len(mark.args) == 1 - msg = check_postgres_version(pg_version, mark.args[0]) - if msg: + if msg := check_postgres_version(pg_version, mark.args[0]): pytest.skip(msg) - elif mark.name in ("crdb", "crdb_skip"): from .fix_crdb import check_crdb_version - msg = check_crdb_version(crdb_version, mark) - if msg: + if msg := check_crdb_version(crdb_version, mark): pytest.skip(msg) @@ -366,12 +355,9 @@ def warm_up_database(dsn: str) -> None: try: with psycopg.connect(dsn, connect_timeout=10) as conn: conn.execute("select 1") - pg_version = conn.info.server_version - crdb_version = None - param = conn.info.parameter_status("crdb_version") - if param: + if param := conn.info.parameter_status("crdb_version"): from psycopg.crdb import CrdbConnectionInfo crdb_version = CrdbConnectionInfo.parse_crdb_version(param) diff --git a/tests/fix_faker.py b/tests/fix_faker.py index d4eee3875..9f6a8c847 100644 --- a/tests/fix_faker.py +++ b/tests/fix_faker.py @@ -157,8 +157,7 @@ class Faker: try: cur.execute(self._insert_field_stmt(j), (val,)) except psycopg.DatabaseError as e: - r = repr(val) - if len(r) > 200: + if len((r := repr(val))) > 200: r = f"{r[:200]}... ({len(r)} chars)" raise Exception( f"value {r!r} at record {i} column0 {j} failed insert: {e}" @@ -182,8 +181,7 @@ class Faker: try: await acur.execute(self._insert_field_stmt(j), (val,)) except psycopg.DatabaseError as e: - r = repr(val) - if len(r) > 200: + if len((r := repr(val))) > 200: r = f"{r[:200]}... ({len(r)} chars)" raise Exception( f"value {r!r} at record {i} column0 {j} failed insert: {e}" @@ -201,8 +199,7 @@ class Faker: def choose_schema(self, ncols=20): schema: list[tuple[type, ...] | type] = [] while len(schema) < ncols: - s = self.make_schema(choice(self.types)) - if s is not None: + if (s := self.make_schema(choice(self.types))) is not None: schema.append(s) self.schema = schema return schema @@ -273,8 +270,7 @@ class Faker: except KeyError: pass - meth = self._get_method("make", cls) - if meth: + if meth := self._get_method("make", cls): self._makers[cls] = meth return meth else: @@ -292,9 +288,7 @@ class Faker: parts = name.split(".") for i in range(len(parts) - 1, -1, -1): - mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}" - meth = getattr(self, mname, None) - if meth: + if meth := getattr(self, f"{prefix}_{'_'.join(parts[-(i + 1):])}", None): return meth return None @@ -306,8 +300,7 @@ class Faker: def example(self, spec): # A good representative of the object - no degenerate case cls = spec if isinstance(spec, type) else spec[0] - meth = self._get_method("example", cls) - if meth: + if meth := self._get_method("example", cls): return meth(spec) else: return self.make(spec) @@ -422,11 +415,10 @@ class Faker: def match_float(self, spec, got, want, rel=None): if got is not None and isnan(got): assert isnan(want) + elif rel or self._server_rounds(): + assert got == pytest.approx(want, rel=rel) else: - if rel or self._server_rounds(): - assert got == pytest.approx(want, rel=rel) - else: - assert got == want + assert got == want def _server_rounds(self): """Return True if the connected server perform float rounding""" @@ -509,21 +501,16 @@ class Faker: def schema_list(self, cls): while True: - scls = choice(self.types) - if scls is cls: - continue - if scls is float: - # TODO: float lists are currently adapted as decimal. - # There may be rounding errors or problems with inf. + # TODO: float lists are currently adapted as decimal. + # There may be rounding errors or problems with inf. + if (scls := choice(self.types)) is cls or scls is float: continue # CRDB doesn't support arrays of json # https://github.com/cockroachdb/cockroach/issues/23468 if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb): continue - - schema = self.make_schema(scls) - if schema is not None: + if (schema := self.make_schema(scls)) is not None: break return (cls, schema) @@ -583,8 +570,7 @@ class Faker: out: list[Range[Any]] = [] for i in range(length): - r = self.make_Range((Range, spec[1]), **kwargs) - if r.isempty: + if (r := self.make_Range((Range, spec[1]), **kwargs)).isempty: continue for r2 in out: if overlap(r, r2): @@ -839,8 +825,7 @@ class Faker: rec_types = [list, dict] scal_types = [type(None), int, JsonFloat, bool, str] if random() < container_chance: - cls = choice(rec_types) - if cls is list: + if (cls := choice(rec_types)) is list: return [ self._make_json(container_chance=container_chance / 2.0) for i in range(randrange(self.json_max_length)) diff --git a/tests/fix_pq.py b/tests/fix_pq.py index 1cff7e18b..9fbf3ce82 100644 --- a/tests/fix_pq.py +++ b/tests/fix_pq.py @@ -41,8 +41,8 @@ def pytest_configure(config): def pytest_runtest_setup(item): for m in item.iter_markers(name="libpq"): assert len(m.args) == 1 - msg = check_libpq_version(pq.version(), m.args[0]) - if msg: + + if msg := check_libpq_version(pq.version(), m.args[0]): pytest.skip(msg) @@ -81,8 +81,7 @@ def setpgenv(monkeypatch): @pytest.fixture def trace(libpq): - pqver = pq.__build_version__ - if pqver < 140000: + if (pqver := pq.__build_version__) < 140000: pytest.skip(f"trace not available on libpq {pqver}") if sys.platform != "linux": pytest.skip(f"trace not available on {sys.platform}") diff --git a/tests/fix_proxy.py b/tests/fix_proxy.py index dbc03a9f4..ba8575e99 100644 --- a/tests/fix_proxy.py +++ b/tests/fix_proxy.py @@ -73,8 +73,8 @@ class Proxy: return logging.info("starting proxy") - pproxy = which("pproxy") - if not pproxy: + + if not (pproxy := which("pproxy")): raise ValueError("pproxy program not found") cmdline = [pproxy, "--reuse"] cmdline.extend(["-l", f"tunnel://:{self.client_port}"]) diff --git a/tests/pq/test_async.py b/tests/pq/test_async.py index aff6ccdf9..404239fb4 100644 --- a/tests/pq/test_async.py +++ b/tests/pq/test_async.py @@ -24,8 +24,7 @@ def test_send_query(pgconn): # send loop waited_on_send = 0 while True: - f = pgconn.flush() - if f == 0: + if pgconn.flush() == 0: break waited_on_send += 1 @@ -48,8 +47,8 @@ def test_send_query(pgconn): if pgconn.is_busy(): select([pgconn.socket], [], []) continue - res = pgconn.get_result() - if res is None: + + if (res := pgconn.get_result()) is None: break assert res.status == pq.ExecStatus.TUPLES_OK results.append(res) diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 02953b4de..657cab9b8 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -30,8 +30,8 @@ def wait( poll = getattr(conn, poll_method) while True: assert conn.status != pq.ConnStatus.BAD, conn.error_message - rv = poll() - if rv == return_on: + + if (rv := poll()) == return_on: return elif rv == pq.PollingStatus.READING: select([conn.socket], [], [], timeout) @@ -362,11 +362,10 @@ def test_used_password(pgconn, dsn, monkeypatch): # so it may be that has_password is false but still a password was # requested by the server and passed by libpq. info = pq.Conninfo.parse(dsn.encode()) - has_password = ( - "PGPASSWORD" in os.environ - or [i for i in info if i.keyword == b"password"][0].val is not None - ) - if has_password: + + if "PGPASSWORD" in os.environ: + assert pgconn.used_password + if [i for i in info if i.keyword == b"password"][0].val is not None: assert pgconn.used_password pgconn.finish() diff --git a/tests/scripts/copytest.py b/tests/scripts/copytest.py index 7303fa967..8157a3293 100755 --- a/tests/scripts/copytest.py +++ b/tests/scripts/copytest.py @@ -136,9 +136,7 @@ def parse_cmdline() -> Namespace: default=logging.INFO, ) - args = parser.parse_args() - - if args.writer: + if (args := parser.parse_args()).writer: try: getattr(psycopg.copy, args.writer) except AttributeError: diff --git a/tests/scripts/pipeline-demo.py b/tests/scripts/pipeline-demo.py index 5af239ddd..8b26dc72e 100644 --- a/tests/scripts/pipeline-demo.py +++ b/tests/scripts/pipeline-demo.py @@ -93,17 +93,13 @@ class LoggingPGconn: self._logger.info("sent prepared '%s' with %s", name.decode(), param_values) def send_prepare( - self, - name: bytes, - command: bytes, - param_types: Sequence[int] | None = None, + self, name: bytes, command: bytes, param_types: Sequence[int] | None = None ) -> None: self._pgconn.send_prepare(name, command, param_types) self._logger.info("prepare %s as '%s'", command.decode(), name.decode()) def get_result(self) -> pq.abc.PGresult | None: - r = self._pgconn.get_result() - if r is not None: + if (r := self._pgconn.get_result()) is not None: self._logger.info("got %s result", pq.ExecStatus(r.status).name) return r @@ -190,8 +186,7 @@ def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None: ): while results_queue: fetched = waiting.wait( - pipeline_communicate(pgconn, commands), - pgconn.socket, + pipeline_communicate(pgconn, commands), pgconn.socket ) assert not commands, commands for results in fetched: @@ -213,8 +208,7 @@ async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> N ): while results_queue: fetched = await waiting.wait_async( - pipeline_communicate(pgconn, commands), - pgconn.socket, + pipeline_communicate(pgconn, commands), pgconn.socket ) assert not commands, commands for results in fetched: diff --git a/tests/scripts/spiketest.py b/tests/scripts/spiketest.py index 334433e57..82d274fd8 100644 --- a/tests/scripts/spiketest.py +++ b/tests/scripts/spiketest.py @@ -21,8 +21,7 @@ from psycopg.rows import Row def main() -> None: - opt = parse_cmdline() - if opt.loglevel: + if (opt := parse_cmdline()).loglevel: loglevel = getattr(logging, opt.loglevel.upper()) logging.basicConfig( level=loglevel, format="%(asctime)s %(levelname)s %(message)s" @@ -110,8 +109,8 @@ class DelayedConnection(psycopg.Connection[Row]): t0 = time.time() conn = super().connect(conninfo, **kwargs) t1 = time.time() - wait = max(0.0, conn_delay - (t1 - t0)) - if wait: + + if wait := max(0.0, conn_delay - (t1 - t0)): time.sleep(wait) return conn diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 8115ceb22..764375715 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -51,13 +51,7 @@ def test_quote(data, result): assert dumper.quote(data) == result -@pytest.mark.parametrize( - "data, result", - [ - ("hello", b"'hello'"), - ("", b"NULL"), - ], -) +@pytest.mark.parametrize("data, result", [("hello", b"'hello'"), ("", b"NULL")]) def test_quote_none(data, result, global_adapters): psycopg.adapters.register_dumper(str, StrNoneDumper) t = Transformer() @@ -424,8 +418,8 @@ def test_optimised_adapters(): for n in dir(_psycopg): if n.startswith("_") or n in ("CDumper", "CLoader"): continue - obj = getattr(_psycopg, n) - if not isinstance(obj, type): + + if not isinstance((obj := getattr(_psycopg, n)), type): continue if not issubclass(obj, (_psycopg.CDumper, _psycopg.CLoader)): continue @@ -449,12 +443,10 @@ def test_optimised_adapters(): # Check that every optimised adapter is the optimised version of a Py one for n in dir(psycopg.types): - mod = getattr(psycopg.types, n) - if not isinstance(mod, ModuleType): + if not isinstance((mod := getattr(psycopg.types, n)), ModuleType): continue for n1 in dir(mod): - obj = getattr(mod, n1) - if not isinstance(obj, type): + if not isinstance((obj := getattr(mod, n1)), type): continue if not issubclass(obj, (Dumper, Loader)): continue diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 28889333a..030965580 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -323,8 +323,7 @@ with psycopg.connect({dsn!r}, application_name={APPNAME!r}) as conn: def run_process(): nonlocal proc proc = sp.Popen( - [sys.executable, "-s", "-c", script], - creationflags=creationflags, + [sys.executable, "-s", "-c", script], creationflags=creationflags ) proc.communicate() @@ -335,8 +334,8 @@ with psycopg.connect({dsn!r}, application_name={APPNAME!r}) as conn: cur = conn.execute( "select pid from pg_stat_activity where application_name = %s", (APPNAME,) ) - rec = cur.fetchone() - if rec: + + if rec := cur.fetchone(): pid = rec[0] break time.sleep(0.1) diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 65bc2539b..63b7b3f81 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -248,8 +248,8 @@ asyncio.run(main()) cur = conn.execute( "select pid from pg_stat_activity where application_name = %s", (APPNAME,) ) - rec = cur.fetchone() - if rec: + + if rec := cur.fetchone(): pid = rec[0] break time.sleep(0.1) diff --git a/tests/test_copy.py b/tests/test_copy.py index c2e6598fc..6ae5d8e05 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -145,8 +145,7 @@ def test_copy_out_allchars(conn, format): with cur.copy(query) as copy: copy.set_types(["text"]) while True: - row = copy.read_row() - if not row: + if not (row := copy.read_row()): break assert len(row) == 1 rows.append(row[0]) @@ -160,8 +159,7 @@ def test_read_row_notypes(conn, format): with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy: rows = [] while True: - row = copy.read_row() - if not row: + if not (row := copy.read_row()): break rows.append(row) @@ -736,15 +734,13 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method, gc): if method == "read": while True: - tmp = copy.read() - if not tmp: + if not copy.read(): break elif method == "iter": list(copy) elif method == "row": while True: - tmp = copy.read_row() - if tmp is None: + if copy.read_row() is None: break elif method == "rows": list(copy.rows()) @@ -864,8 +860,7 @@ class DataGenerator: def blocks(self): f = self.file() while True: - block = f.read(self.block_size) - if not block: + if not (block := f.read(self.block_size)): break yield block @@ -880,8 +875,7 @@ class DataGenerator: def sha(self, f): m = hashlib.sha256() while True: - block = f.read() - if not block: + if not (block := f.read()): break if isinstance(block, str): block = block.encode() diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 1168bc31b..72a838e5e 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -149,8 +149,7 @@ async def test_copy_out_allchars(aconn, format): async with cur.copy(query) as copy: copy.set_types(["text"]) while True: - row = await copy.read_row() - if not row: + if not (row := (await copy.read_row())): break assert len(row) == 1 rows.append(row[0]) @@ -166,8 +165,7 @@ async def test_read_row_notypes(aconn, format): ) as copy: rows = [] while True: - row = await copy.read_row() - if not row: + if not (row := (await copy.read_row())): break rows.append(row) @@ -752,15 +750,13 @@ async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method, gc): if method == "read": while True: - tmp = await copy.read() - if not tmp: + if not (await copy.read()): break elif method == "iter": await alist(copy) elif method == "row": while True: - tmp = await copy.read_row() - if tmp is None: + if (await copy.read_row()) is None: break elif method == "rows": await alist(copy.rows()) @@ -879,8 +875,7 @@ class DataGenerator: def blocks(self): f = self.file() while True: - block = f.read(self.block_size) - if not block: + if not (block := f.read(self.block_size)): break yield block @@ -895,8 +890,7 @@ class DataGenerator: def sha(self, f): m = hashlib.sha256() while True: - block = f.read() - if not block: + if not (block := f.read()): break if isinstance(block, str): block = block.encode() diff --git a/tests/test_cursor.py b/tests/test_cursor.py index e31791e96..cf5551ca0 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -86,13 +86,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc): if fetch == "one": while True: - tmp = cur.fetchone() - if tmp is None: + if cur.fetchone() is None: break elif fetch == "many": while True: - tmp = cur.fetchmany(3) - if not tmp: + if not cur.fetchmany(3): break elif fetch == "all": cur.fetchall() diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index f54b31aa7..fa6b70175 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -87,13 +87,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc) if fetch == "one": while True: - tmp = await cur.fetchone() - if tmp is None: + if (await cur.fetchone()) is None: break elif fetch == "many": while True: - tmp = await cur.fetchmany(3) - if not tmp: + if not (await cur.fetchmany(3)): break elif fetch == "all": await cur.fetchall() diff --git a/tests/test_cursor_client.py b/tests/test_cursor_client.py index 70a4d55db..fd87cf551 100644 --- a/tests/test_cursor_client.py +++ b/tests/test_cursor_client.py @@ -96,13 +96,11 @@ def test_leak(conn_cls, dsn, faker, fetch, row_factory, gc): if fetch == "one": while True: - tmp = cur.fetchone() - if tmp is None: + if cur.fetchone() is None: break elif fetch == "many": while True: - tmp = cur.fetchmany(3) - if not tmp: + if not cur.fetchmany(3): break elif fetch == "all": cur.fetchall() diff --git a/tests/test_cursor_client_async.py b/tests/test_cursor_client_async.py index 753c95434..a580f0d66 100644 --- a/tests/test_cursor_client_async.py +++ b/tests/test_cursor_client_async.py @@ -97,13 +97,11 @@ async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc): if fetch == "one": while True: - tmp = await cur.fetchone() - if tmp is None: + if (await cur.fetchone()) is None: break elif fetch == "many": while True: - tmp = await cur.fetchmany(3) - if not tmp: + if not (await cur.fetchmany(3)): break elif fetch == "all": await cur.fetchall() diff --git a/tests/test_cursor_raw.py b/tests/test_cursor_raw.py index 64ef05125..0f2155e3f 100644 --- a/tests/test_cursor_raw.py +++ b/tests/test_cursor_raw.py @@ -94,13 +94,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc): if fetch == "one": while True: - tmp = cur.fetchone() - if tmp is None: + if cur.fetchone() is None: break elif fetch == "many": while True: - tmp = cur.fetchmany(3) - if not tmp: + if not cur.fetchmany(3): break elif fetch == "all": cur.fetchall() diff --git a/tests/test_cursor_raw_async.py b/tests/test_cursor_raw_async.py index 42a241e4a..f2f9553ad 100644 --- a/tests/test_cursor_raw_async.py +++ b/tests/test_cursor_raw_async.py @@ -91,13 +91,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc) if fetch == "one": while True: - tmp = await cur.fetchone() - if tmp is None: + if (await cur.fetchone()) is None: break elif fetch == "many": while True: - tmp = await cur.fetchmany(3) - if not tmp: + if not (await cur.fetchmany(3)): break elif fetch == "all": await cur.fetchall() diff --git a/tests/test_cursor_server.py b/tests/test_cursor_server.py index 25df3a9c2..365f0b762 100644 --- a/tests/test_cursor_server.py +++ b/tests/test_cursor_server.py @@ -385,8 +385,7 @@ def test_row_factory(conn): recs = cur.fetchall() cur.scroll(0, "absolute") while True: - rec = cur.fetchone() - if not rec: + if not (rec := cur.fetchone()): break recs.append(rec) assert recs == [[1, -1], [1, -2], [1, -3]] * 2 diff --git a/tests/test_cursor_server_async.py b/tests/test_cursor_server_async.py index febd01be1..cede008f7 100644 --- a/tests/test_cursor_server_async.py +++ b/tests/test_cursor_server_async.py @@ -391,8 +391,7 @@ async def test_row_factory(aconn): recs = await cur.fetchall() await cur.scroll(0, "absolute") while True: - rec = await cur.fetchone() - if not rec: + if not (rec := (await cur.fetchone())): break recs.append(rec) assert recs == [[1, -1], [1, -2], [1, -3]] * 2 diff --git a/tests/test_generators.py b/tests/test_generators.py index 2975af43a..80c8ae0df 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -14,8 +14,8 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch): OperationalError has a pgconn attribute set with needs_password. """ gen = generators.connect(dsn) - pgconn = waiting.wait_conn(gen) - if not pgconn.used_password: + + if not (pgconn := waiting.wait_conn(gen)).used_password: pytest.skip("test connection needs no password") with monkeypatch.context() as m: diff --git a/tests/test_rows.py b/tests/test_rows.py index 93240b5eb..d60bda16f 100644 --- a/tests/test_rows.py +++ b/tests/test_rows.py @@ -113,8 +113,7 @@ def test_scalar_row(conn): @pytest.mark.parametrize( - "factory", - "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(), + "factory", "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split() ) def test_no_result(factory, conn): cur = conn.cursor(row_factory=factory_from_name(factory)) @@ -151,8 +150,7 @@ def test_no_column_class_row(conn): def factory_from_name(name): - factory = getattr(rows, name) - if factory is rows.class_row: + if (factory := getattr(rows, name)) is rows.class_row: factory = factory(Person) if factory is rows.args_row: factory = factory(argf) diff --git a/tests/test_tpc.py b/tests/test_tpc.py index 90ba2f09d..85c172f07 100644 --- a/tests/test_tpc.py +++ b/tests/test_tpc.py @@ -11,8 +11,8 @@ pytestmark = pytest.mark.crdb_skip("2-phase commit") def test_tpc_disabled(conn, pipeline): cur = conn.execute("show max_prepared_transactions") - val = int(cur.fetchone()[0]) - if val: + + if int(cur.fetchone()[0]): pytest.skip("prepared transactions enabled") conn.rollback() diff --git a/tests/test_tpc_async.py b/tests/test_tpc_async.py index 8a448c714..b3d1ed2d3 100644 --- a/tests/test_tpc_async.py +++ b/tests/test_tpc_async.py @@ -8,8 +8,8 @@ pytestmark = pytest.mark.crdb_skip("2-phase commit") async def test_tpc_disabled(aconn, apipeline): cur = await aconn.execute("show max_prepared_transactions") - val = int((await cur.fetchone())[0]) - if val: + + if int((await cur.fetchone())[0]): pytest.skip("prepared transactions enabled") await aconn.rollback() diff --git a/tests/test_typeinfo.py b/tests/test_typeinfo.py index 94a704eb5..3df6fb510 100644 --- a/tests/test_typeinfo.py +++ b/tests/test_typeinfo.py @@ -22,8 +22,7 @@ def test_fetch(conn, name, status, encoding): conn.execute("select set_config('client_encoding', %s, false)", [encoding]) if status: - status = getattr(TransactionStatus, status) - if status == TransactionStatus.INTRANS: + if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS: conn.execute("select 1") else: conn.autocommit = True @@ -54,8 +53,7 @@ async def test_fetch_async(aconn, name, status, encoding): ) if status: - status = getattr(TransactionStatus, status) - if status == TransactionStatus.INTRANS: + if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS: await aconn.execute("select 1") else: await aconn.set_autocommit(True) @@ -100,8 +98,8 @@ def test_fetch_not_found(conn, name, status, info_cls, monkeypatch): return exit_orig(self, exc_type, exc_val, exc_tb) monkeypatch.setattr(psycopg.Transaction, "__exit__", exit) - status = getattr(TransactionStatus, status) - if status == TransactionStatus.INTRANS: + + if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS: conn.execute("select 1") assert conn.info.transaction_status == status @@ -122,8 +120,8 @@ async def test_fetch_not_found_async(aconn, name, status, info_cls, monkeypatch) return await exit_orig(self, exc_type, exc_val, exc_tb) monkeypatch.setattr(psycopg.AsyncTransaction, "__aexit__", aexit) - status = getattr(TransactionStatus, status) - if status == TransactionStatus.INTRANS: + + if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS: await aconn.execute("select 1") assert aconn.info.transaction_status == status diff --git a/tests/types/test_array.py b/tests/types/test_array.py index fdb9c888f..57d5b4034 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -204,8 +204,7 @@ def test_numbers_array(num, type, fmt_in): @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out): - wrapper = getattr(psycopg.types.numeric, wrapper) - if wrapper is Decimal: + if (wrapper := getattr(psycopg.types.numeric, wrapper)) is Decimal: want_cls = Decimal else: assert wrapper.__mro__[1] in (int, float) diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index da5fe9066..84f70b737 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -458,13 +458,7 @@ def test_dump_numeric_exhaustive(conn, fmt_in): @pytest.mark.pg(">= 14") -@pytest.mark.parametrize( - "val, expr", - [ - ("inf", "Infinity"), - ("-inf", "-Infinity"), - ], -) +@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")]) def test_dump_numeric_binary_inf(conn, val, expr): cur = conn.cursor() val = Decimal(val) @@ -492,8 +486,8 @@ def test_dump_numeric_binary_inf(conn, val, expr): def test_load_numeric_binary(conn, expr): cur = conn.cursor(binary=1) res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] - val = Decimal(expr) - if val.is_nan(): + + if (val := Decimal(expr)).is_nan(): assert res.is_nan() else: assert res == val @@ -531,13 +525,7 @@ def test_load_numeric_exhaustive(conn, fmt_out): @pytest.mark.pg(">= 14") -@pytest.mark.parametrize( - "val, expr", - [ - ("inf", "Infinity"), - ("-inf", "-Infinity"), - ], -) +@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")]) def test_load_numeric_binary_inf(conn, val, expr): cur = conn.cursor(binary=1) res = cur.execute(f"select '{expr}'::numeric").fetchone()[0] @@ -546,14 +534,7 @@ def test_load_numeric_binary_inf(conn, val, expr): @pytest.mark.parametrize( - "val", - [ - "0", - "0.0", - "0.000000000000000000001", - "-0.000000000000000000001", - "nan", - ], + "val", ["0", "0.0", "0.000000000000000000001", "-0.000000000000000000001", "nan"] ) def test_numeric_as_float(conn, val): cur = conn.cursor() diff --git a/tests/types/test_numpy.py b/tests/types/test_numpy.py index 23ce25f35..2561be575 100644 --- a/tests/types/test_numpy.py +++ b/tests/types/test_numpy.py @@ -19,8 +19,7 @@ skip_numpy2 = pytest.mark.skipif( def _get_arch_size() -> int: - psize = struct.calcsize("P") * 8 - if psize not in (32, 64): + if (psize := (struct.calcsize("P") * 8)) not in (32, 64): msg = f"the pointer size {psize} is unusual" raise ValueError(msg) return psize @@ -198,9 +197,7 @@ def test_copy_by_oid(conn, val, nptype, pgtypes, fmt): fnames = [f"f{t}" for t in pgtypes] fields = [f"f{t} {t}" for fname, t in zip(fnames, pgtypes)] - cur.execute( - f"create table numpyoid (id serial primary key, {', '.join(fields)})", - ) + cur.execute(f"create table numpyoid (id serial primary key, {', '.join(fields)})") with cur.copy( f"copy numpyoid ({', '.join(fnames)}) from stdin (format {fmt.name})" ) as copy: diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index d714f5095..b3dc18695 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -66,14 +66,12 @@ logger = logging.getLogger() def main() -> int: - opt = parse_cmdline() - if opt.container: + if (opt := parse_cmdline()).container: return run_in_container(opt.container) logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s") - current_ver = ".".join(map(str, sys.version_info[:2])) - if current_ver != PYVER: + if (current_ver := ".".join(map(str, sys.version_info[:2]))) != PYVER: logger.warning( "Expecting output generated by Python %s; you are running %s instead.", PYVER, @@ -488,8 +486,7 @@ class BlanksInserter(ast.NodeTransformer): # type: ignore new_body.append(before) for i in range(1, len(body)): after = body[i] - nblanks = after.lineno - before.end_lineno - 1 - if nblanks > 0: + if after.lineno - before.end_lineno - 1 > 0: # Inserting one blank is enough. blank = ast.Comment( value="", @@ -617,8 +614,7 @@ def parse_cmdline() -> Namespace: help="the files to process (process all files if not specified)", ) - opt = parser.parse_args() - if not opt.inputs: + if not (opt := parser.parse_args()).inputs: opt.inputs = [PROJECT_DIR / Path(fn) for fn in ALL_INPUTS] fp: Path diff --git a/tools/build/copy_to_binary.py b/tools/build/copy_to_binary.py index 4834d6b80..f4c1e8d24 100755 --- a/tools/build/copy_to_binary.py +++ b/tools/build/copy_to_binary.py @@ -11,17 +11,16 @@ from pathlib import Path curdir = Path(__file__).parent pdir = curdir / "../.." -target = pdir / "psycopg_binary" -if target.exists(): +if (target := (pdir / "psycopg_binary")).exists(): raise Exception(f"path {target} already exists") def sed_i(pattern: str, repl: str, filename: str | Path) -> None: with open(filename, "rb") as f: data = f.read() - newdata = re.sub(pattern.encode("utf8"), repl.encode("utf8"), data) - if newdata != data: + + if (newdata := re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)) != data: with open(filename, "wb") as f: f.write(newdata) diff --git a/tools/bump_version.py b/tools/bump_version.py index d1f030fe3..9fbc3aba0 100755 --- a/tools/bump_version.py +++ b/tools/bump_version.py @@ -168,8 +168,7 @@ chore: bump {self.package.name} package version to {self.want_version} matches = [] with fp.open() as f: for line in f: - m = self._ini_regex.match(line) - if m: + if m := self._ini_regex.match(line): matches.append(m) if not matches: @@ -240,8 +239,7 @@ chore: bump {self.package.name} package version to {self.want_version} with fp.open() as f: lines = f.readlines() - lns = self._find_lines(r"^[^\s]+ " + re.escape(str(version)), lines) - if not lns: + if not (lns := self._find_lines("^[^\\s]+ " + re.escape(str(version)), lines)): logger.warning("no change log line found") return [] diff --git a/tools/update_backer.py b/tools/update_backer.py index a33c5c4ef..fa7e94f3a 100755 --- a/tools/update_backer.py +++ b/tools/update_backer.py @@ -33,8 +33,7 @@ def get_user_data(data): "name": data["name"], } if data["blog"]: - website = data["blog"] - if not website.startswith("http"): + if not (website := data["blog"]).startswith("http"): website = "http://" + website out["website"] = website @@ -54,8 +53,8 @@ def update_entry(opt, filedata, entry): # entry is an username or an user entry daat if isinstance(entry, str): username = entry - entry = [e for e in filedata if e["username"] == username] - if not entry: + + if not (entry := [e for e in filedata if e["username"] == username]): raise Exception(f"{username} not found") entry = entry[0] else: diff --git a/tools/update_error_prefixes.py b/tools/update_error_prefixes.py index 52ba2cb35..c84c48b2e 100755 --- a/tools/update_error_prefixes.py +++ b/tools/update_error_prefixes.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Find the error prefixes in various l10n used for precise prefixstripping.""" +"""Find the error prefixes in various l10n used for precise prefix stripping.""" import re import logging @@ -90,8 +90,7 @@ def parse_cmdline() -> Namespace: help="the file to change [default: %(default)s]", ) - opt = parser.parse_args() - if not opt.pgroot.is_dir(): + if not (opt := parser.parse_args()).pgroot.is_dir(): parser.error("not a valid directory: {opt.pgroot}") return opt diff --git a/tools/update_errors.py b/tools/update_errors.py index b9406c34e..bc1264492 100755 --- a/tools/update_errors.py +++ b/tools/update_errors.py @@ -38,20 +38,18 @@ def parse_errors_txt(url): page = urlopen(url) for line in page.read().decode("ascii").splitlines(): # Strip comments and skip blanks - line = line.split("#")[0].strip() - if not line: + + if not (line := line.split("#")[0].strip()): continue # Parse a section - m = re.match(r"Section: (Class (..) - .+)", line) - if m: + if m := re.match("Section: (Class (..) - .+)", line): label, class_ = m.groups() classes[class_] = label continue # Parse an error - m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line) - if m: + if m := re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line): sqlstate, macro, spec = m.groups() # skip sqlstates without specs as they are not publicly visible if not spec: diff --git a/tools/update_oids.py b/tools/update_oids.py index 96b0df9ab..8e087a8a7 100755 --- a/tools/update_oids.py +++ b/tools/update_oids.py @@ -34,9 +34,8 @@ ROOT = Path(__file__).parent.parent def main() -> None: opt = parse_cmdline() - conn = psycopg.connect(opt.dsn, autocommit=True) - if CrdbConnection.is_crdb(conn): + if CrdbConnection.is_crdb((conn := psycopg.connect(opt.dsn, autocommit=True))): conn = CrdbConnection.connect(opt.dsn, autocommit=True) update_crdb_python_oids(conn) else: