]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: use the assignment operator in assignments followed by an if
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 27 Mar 2025 00:25:49 +0000 (01:25 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 3 Apr 2025 11:00:46 +0000 (12:00 +0100)
97 files changed:
docs/conf.py
docs/lib/libpq_docs.py
docs/lib/pg3_docs.py
docs/lib/ticket_role.py
psycopg/psycopg/__init__.py
psycopg/psycopg/_adapters_map.py
psycopg/psycopg/_column.py
psycopg/psycopg/_connection_base.py
psycopg/psycopg/_conninfo_attempts.py
psycopg/psycopg/_conninfo_attempts_async.py
psycopg/psycopg/_conninfo_utils.py
psycopg/psycopg/_copy.py
psycopg/psycopg/_copy_async.py
psycopg/psycopg/_copy_base.py
psycopg/psycopg/_cursor_base.py
psycopg/psycopg/_dns.py
psycopg/psycopg/_encodings.py
psycopg/psycopg/_pipeline.py
psycopg/psycopg/_preparing.py
psycopg/psycopg/_py_transformer.py
psycopg/psycopg/_queries.py
psycopg/psycopg/_tpc.py
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/adapt.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
psycopg/psycopg/crdb/connection.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/dbapi20.py
psycopg/psycopg/generators.py
psycopg/psycopg/pq/_debug.py
psycopg/psycopg/pq/_pq_ctypes.py
psycopg/psycopg/pq/misc.py
psycopg/psycopg/pq/pq_ctypes.py
psycopg/psycopg/rows.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/sql.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/hstore.py
psycopg/psycopg/types/json.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/range.py
psycopg/psycopg/types/string.py
psycopg/psycopg/waiting.py
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
psycopg_pool/psycopg_pool/sched.py
psycopg_pool/psycopg_pool/sched_async.py
tests/conftest.py
tests/crdb/test_cursor.py
tests/crdb/test_cursor_async.py
tests/fix_crdb.py
tests/fix_db.py
tests/fix_faker.py
tests/fix_pq.py
tests/fix_proxy.py
tests/pq/test_async.py
tests/pq/test_pgconn.py
tests/scripts/copytest.py
tests/scripts/pipeline-demo.py
tests/scripts/spiketest.py
tests/test_adapt.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_cursor_client.py
tests/test_cursor_client_async.py
tests/test_cursor_raw.py
tests/test_cursor_raw_async.py
tests/test_cursor_server.py
tests/test_cursor_server_async.py
tests/test_generators.py
tests/test_rows.py
tests/test_tpc.py
tests/test_tpc_async.py
tests/test_typeinfo.py
tests/types/test_array.py
tests/types/test_numeric.py
tests/types/test_numpy.py
tools/async_to_sync.py
tools/build/copy_to_binary.py
tools/bump_version.py
tools/update_backer.py
tools/update_error_prefixes.py
tools/update_errors.py
tools/update_oids.py

index cac973651c63f397e31916a01a61d33259bd3ee3..cf87980d32f32a01eba0359c8a9678e315219adb 100644 (file)
@@ -22,7 +22,6 @@ import psycopg
 docs_dir = Path(__file__).parent
 sys.path.append(str(docs_dir / "lib"))
 
-
 # -- Project information -----------------------------------------------------
 
 project = "psycopg"
@@ -55,12 +54,10 @@ templates_path = ["_templates"]
 # This pattern also affects html_static_path and html_extra_path.
 exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", ".venv"]
 
-
 # -- Options for HTML output -------------------------------------------------
 
 # The announcement may be in the website but not shipped with the docs
-ann_file = docs_dir / "../../templates/docs3-announcement.html"
-if ann_file.exists():
+if (ann_file := (docs_dir / "../../templates/docs3-announcement.html")).exists():
     with ann_file.open() as f:
         announcement = f.read()
 else:
@@ -84,9 +81,7 @@ html_theme_options = {
     "sidebar_hide_name": False,
     "light_logo": "psycopg.svg",
     "dark_logo": "psycopg.svg",
-    "light_css_variables": {
-        "admonition-font-size": "1rem",
-    },
+    "light_css_variables": {"admonition-font-size": "1rem"},
 }
 
 # Add any paths that contain custom static files (such as style sheets) here,
index a5488d1515ba0ebbd055c4309b623dc59ea60005..8bd098e82a2287a39a70d12d197f8d8e7bcc359c 100644 (file)
@@ -61,13 +61,11 @@ class LibpqParser(HTMLParser):
         self.add_function(data)
 
     def handle_sect1(self, tag, attrs):
-        attrs = dict(attrs)
-        if "id" in attrs:
+        if "id" in (attrs := dict(attrs)):
             self.section_id = attrs["id"]
 
     def handle_varlistentry(self, tag, attrs):
-        attrs = dict(attrs)
-        if "id" in attrs:
+        if "id" in (attrs := dict(attrs)):
             self.varlist_id = attrs["id"]
 
     def add_function(self, func_name):
@@ -89,10 +87,8 @@ class LibpqReader:
     # must be set before using the rest of the class.
     app = None
 
-    _url_pattern = (
-        "https://raw.githubusercontent.com/postgres/postgres/{branch}"
-        "/doc/src/sgml/libpq.sgml"
-    )
+    _url_pattern = "https://raw.githubusercontent.com/postgres/postgres/"
+    _url_pattern += "{branch}/doc/src/sgml/libpq.sgml"
 
     data = None
 
@@ -113,8 +109,7 @@ class LibpqReader:
             parser.feed(f.read())
 
     def download(self):
-        filename = os.environ.get("LIBPQ_DOCS_FILE")
-        if filename:
+        if filename := os.environ.get("LIBPQ_DOCS_FILE"):
             logger.info("reading postgres libpq docs from %s", filename)
             with open(filename, "rb") as f:
                 data = f.read()
@@ -156,7 +151,6 @@ def pq_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
     if "(" in text:
         func, noise = text.split("(", 1)
         noise = "(" + noise
-
     else:
         func = text
         noise = ""
index 4388cc9d4cd2087f653521f488cbfff8c75f555a..8bd771f2d323ec4a000ef316678eb1736a739feb 100644 (file)
@@ -19,8 +19,7 @@ def process_docstring(app, what, name, obj, options, lines):
 
 
 def before_process_signature(app, obj, bound_method):
-    ann = getattr(obj, "__annotations__", {})
-    if "return" in ann:
+    if "return" in (ann := getattr(obj, "__annotations__", {})):
         # Drop "return: None" from the function signatures
         if ann["return"] is None:
             del ann["return"]
@@ -68,14 +67,12 @@ def recover_defined_module(m, skip_modules=()):
     for fn in walk_modules(mdir):
         assert fn.startswith(mdir)
         modname = os.path.splitext(fn[len(mdir) + 1 :])[0].replace("/", ".")
-        modname = f"{m.__name__}.{modname}"
-        if modname in skip_modules:
+        if (modname := f"{m.__name__}.{modname}") in skip_modules:
             continue
         with open(fn) as f:
             classnames = re.findall(r"^class\s+([^(:]+)", f.read(), re.M)
             for cls in classnames:
-                cls = deep_import(f"{modname}.{cls}")
-                if cls.__module__ != modname:
+                if (cls := deep_import(f"{modname}.{cls}")).__module__ != modname:
                     recovered_classes[cls] = modname
 
 
index f8f935bf52fb412d21620edb77403257a145a02d..107e67d6df62e587075db4fff5660b27f79fd5ae 100644 (file)
@@ -16,8 +16,7 @@ from docutils.parsers.rst import roles
 
 
 def ticket_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
-    cfg = inliner.document.settings.env.app.config
-    if cfg.ticket_url is None:
+    if (cfg := inliner.document.settings.env.app.config).ticket_url is None:
         msg = inliner.reporter.warning(
             "ticket not configured: please configure ticket_url in conf.py"
         )
index cd1ad261ea67a55783fa4ef01adefe5f71b3d958..581d9b23be230c15b5abb394e8301938a009560a 100644 (file)
@@ -32,8 +32,7 @@ from ._connection_info import ConnectionInfo
 from .connection_async import AsyncConnection
 
 # Set the logger to a quiet default, can be enabled if needed
-logger = logging.getLogger("psycopg")
-if logger.level == logging.NOTSET:
+if (logger := logging.getLogger("psycopg")).level == logging.NOTSET:
     logger.setLevel(logging.WARNING)
 
 # DBAPI compliance
index d8fedfa12199c08e87d58e73fdbedc30da627397..f3142c02abee7a484021ade8248360bd42d336de 100644 (file)
@@ -69,9 +69,7 @@ class AdaptersMap:
     _optimised: dict[type, type] = {}
 
     def __init__(
-        self,
-        template: AdaptersMap | None = None,
-        types: TypesRegistry | None = None,
+        self, template: AdaptersMap | None = None, types: TypesRegistry | None = None
     ):
         if template:
             self._dumpers = template._dumpers.copy()
@@ -179,8 +177,7 @@ class AdaptersMap:
         if _psycopg:
             loader = self._get_optimised(loader)
 
-        fmt = loader.format
-        if not self._own_loaders[fmt]:
+        if not self._own_loaders[(fmt := loader.format)]:
             self._loaders[fmt] = self._loaders[fmt].copy()
             self._own_loaders[fmt] = True
 
@@ -282,8 +279,7 @@ class AdaptersMap:
         from psycopg import types
 
         if cls.__module__.startswith(types.__name__):
-            new = cast("type[RV]", getattr(_psycopg, cls.__name__, None))
-            if new:
+            if new := cast("type[RV]", getattr(_psycopg, cls.__name__, None)):
                 self._optimised[cls] = new
                 return new
 
index 372775cf8c297528dacde05079818a93775da8e5..8a7c806ca5eb14f8092a871048ddd5afc8e6704d 100644 (file)
@@ -20,8 +20,7 @@ class Column(Sequence[Any]):
         res = cursor.pgresult
         assert res
 
-        fname = res.fname(index)
-        if fname:
+        if fname := res.fname(index):
             self._name = fname.decode(cursor._encoding)
         else:
             # COPY_OUT results have columns but no name
index 65440ce2fc6f6041bd59f9462d6015a53d3e8f50..7ac6af1a5ccb233ff7393a7225947c744e033905 100644 (file)
@@ -247,8 +247,7 @@ class BaseConnection(Generic[Row]):
 
     def _check_intrans_gen(self, attribute: str) -> PQGen[None]:
         # Raise an exception if we are in a transaction
-        status = self.pgconn.transaction_status
-        if status == IDLE and self._pipeline:
+        if (status := self.pgconn.transaction_status) == IDLE and self._pipeline:
             yield from self._pipeline._sync_gen()
             status = self.pgconn.transaction_status
         if status != IDLE:
@@ -503,10 +502,7 @@ class BaseConnection(Generic[Row]):
         self._check_connection_ok()
 
         if self._pipeline:
-            cmd = partial(
-                self.pgconn.send_close_prepared,
-                name,
-            )
+            cmd = partial(self.pgconn.send_close_prepared, name)
             self._pipeline.command_queue.append(cmd)
             self._pipeline.result_queue.append(None)
             return
index 7bc96dd9976b0e2e659d2cde11c9242bf1b3822c..f853e0537c502511eb6baa74901090765c1c0466 100644 (file)
@@ -85,8 +85,7 @@ def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
         # Local path, or no host to resolve
         return [params]
 
-    hostaddr = get_param(params, "hostaddr")
-    if hostaddr:
+    if get_param(params, "hostaddr"):
         # Already resolved
         return [params]
 
@@ -94,8 +93,7 @@ def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
         # If the host is already an ip address don't try to resolve it
         return [{**params, "hostaddr": host}]
 
-    port = get_param(params, "port")
-    if not port:
+    if not (port := get_param(params, "port")):
         port_def = get_param_def("port")
         port = port_def and port_def.compiled or "5432"
 
index e50e4f95efe534f4a6cb1b42d8ff4c9e47267e21..08f363d1ac7235ee5dcd5ca6658979d708135a3e 100644 (file)
@@ -83,8 +83,7 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
         # Local path, or no host to resolve
         return [params]
 
-    hostaddr = get_param(params, "hostaddr")
-    if hostaddr:
+    if get_param(params, "hostaddr"):
         # Already resolved
         return [params]
 
@@ -92,8 +91,7 @@ async def _resolve_hostnames(params: ConnDict) -> list[ConnDict]:
         # If the host is already an ip address don't try to resolve it
         return [{**params, "hostaddr": host}]
 
-    port = get_param(params, "port")
-    if not port:
+    if not (port := get_param(params, "port")):
         port_def = get_param_def("port")
         port = port_def and port_def.compiled or "5432"
 
index 844e71abd01c74fa75bc43441e7f209b2a0b4aa2..e959bbf155b32a3fe7aefbd01e9bba872146f0ca 100644 (file)
@@ -75,12 +75,10 @@ def get_param(params: ConnMapping, name: str) -> str | None:
 
     # TODO: check if in service
 
-    paramdef = get_param_def(name)
-    if not paramdef:
+    if not (paramdef := get_param_def(name)):
         return None
 
-    env = os.environ.get(paramdef.envvar)
-    if env is not None:
+    if (env := os.environ.get(paramdef.envvar)) is not None:
         return env
 
     return None
index 32f8fc8573e149f0e7888b4851b3847c7854ee9c..b218ecd21574b2855606e58448d3027a140918c8 100644 (file)
@@ -81,8 +81,7 @@ class Copy(BaseCopy["Connection[Any]"]):
     def __iter__(self) -> Iterator[Buffer]:
         """Implement block-by-block iteration on :sql:`COPY TO`."""
         while True:
-            data = self.read()
-            if not data:
+            if not (data := self.read()):
                 break
             yield data
 
@@ -102,8 +101,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         bytes, unless data types are specified using `set_types()`.
         """
         while True:
-            record = self.read_row()
-            if record is None:
+            if (record := self.read_row()) is None:
                 break
             yield record
 
@@ -125,14 +123,12 @@ class Copy(BaseCopy["Connection[Any]"]):
         If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
         text mode it can be either `!bytes` or `!str`.
         """
-        data = self.formatter.write(buffer)
-        if data:
+        if data := self.formatter.write(buffer):
             self._write(data)
 
     def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
-        data = self.formatter.write_row(row)
-        if data:
+        if data := self.formatter.write_row(row):
             self._write(data)
 
     def finish(self, exc: BaseException | None) -> None:
@@ -143,8 +139,7 @@ class Copy(BaseCopy["Connection[Any]"]):
         using the `Copy` object outside a block.
         """
         if self._direction == COPY_IN:
-            data = self.formatter.end()
-            if data:
+            if data := self.formatter.end():
                 self._write(data)
             self.writer.finish(exc)
             self._finished = True
@@ -257,8 +252,7 @@ class QueuedLibpqWriter(LibpqWriter):
         """
         try:
             while True:
-                data = self._queue.get()
-                if not data:
+                if not (data := self._queue.get()):
                     break
                 self.connection.wait(copy_to(self._pgconn, data, flush=PREFER_FLUSH))
         except BaseException as ex:
index 22ef3b197adab598b16c35957b93b2caab34d38a..2d36353a44eaa90fa04a1e242c3c20138cbb938c 100644 (file)
@@ -78,8 +78,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
     async def __aiter__(self) -> AsyncIterator[Buffer]:
         """Implement block-by-block iteration on :sql:`COPY TO`."""
         while True:
-            data = await self.read()
-            if not data:
+            if not (data := (await self.read())):
                 break
             yield data
 
@@ -99,8 +98,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         bytes, unless data types are specified using `set_types()`.
         """
         while True:
-            record = await self.read_row()
-            if record is None:
+            if (record := (await self.read_row())) is None:
                 break
             yield record
 
@@ -122,14 +120,12 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         If the :sql:`COPY` is in binary format `!buffer` must be `!bytes`. In
         text mode it can be either `!bytes` or `!str`.
         """
-        data = self.formatter.write(buffer)
-        if data:
+        if data := self.formatter.write(buffer):
             await self._write(data)
 
     async def write_row(self, row: Sequence[Any]) -> None:
         """Write a record to a table after a :sql:`COPY FROM` operation."""
-        data = self.formatter.write_row(row)
-        if data:
+        if data := self.formatter.write_row(row):
             await self._write(data)
 
     async def finish(self, exc: BaseException | None) -> None:
@@ -140,8 +136,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
         using the `Copy` object outside a block.
         """
         if self._direction == COPY_IN:
-            data = self.formatter.end()
-            if data:
+            if data := self.formatter.end():
                 await self._write(data)
             await self.writer.finish(exc)
             self._finished = True
@@ -256,8 +251,7 @@ class AsyncQueuedLibpqWriter(AsyncLibpqWriter):
         """
         try:
             while True:
-                data = await self._queue.get()
-                if not data:
+                if not (data := (await self._queue.get())):
                     break
                 await self.connection.wait(
                     copy_to(self._pgconn, data, flush=PREFER_FLUSH)
index 6acc77719a77809ebff829e499fdde40f85baddc..17d740c1cd247384f6b9ebf92548210065d01b22 100644 (file)
@@ -73,17 +73,13 @@ class BaseCopy(Generic[ConnectionType]):
     formatter: Formatter
 
     def __init__(
-        self,
-        cursor: BaseCursor[ConnectionType, Any],
-        *,
-        binary: bool | None = None,
+        self, cursor: BaseCursor[ConnectionType, Any], *, binary: bool | None = None
     ):
         self.cursor = cursor
         self.connection = cursor.connection
         self._pgconn = self.connection.pgconn
 
-        result = cursor.pgresult
-        if result:
+        if result := cursor.pgresult:
             self._direction = result.status
             if self._direction != COPY_IN and self._direction != COPY_OUT:
                 raise e.ProgrammingError(
@@ -162,12 +158,10 @@ class BaseCopy(Generic[ConnectionType]):
         return memoryview(b"")
 
     def _read_row_gen(self) -> PQGen[tuple[Any, ...] | None]:
-        data = yield from self._read_gen()
-        if not data:
+        if not (data := (yield from self._read_gen())):
             return None
 
-        row = self.formatter.parse_row(data)
-        if row is None:
+        if (row := self.formatter.parse_row(data)) is None:
             # Get the final result to finish the copy operation
             yield from self._read_gen()
             self._finished = True
index fe71a1451b140f3c39d7d4248a913e996139ed6f..3a1d6d45b361dc20c27facd47d138b3ee816ab9c 100644 (file)
@@ -395,8 +395,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         query = self._convert_query(statement)
 
         self._execute_send(query, binary=False)
-        results = yield from execute(self._pgconn)
-        if len(results) != 1:
+        if len(results := (yield from execute(self._pgconn))) != 1:
             raise e.ProgrammingError("COPY cannot be mixed with other operations")
 
         self._check_copy_result(results[0])
@@ -473,8 +472,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         """
         Raise an appropriate error message for an unexpected database result
         """
-        status = result.status
-        if status == FATAL_ERROR:
+        if (status := result.status) == FATAL_ERROR:
             raise e.error_from_result(result, encoding=self._encoding)
         elif status == PIPELINE_ABORTED:
             raise e.PipelineAborted("pipeline aborted")
@@ -518,19 +516,17 @@ class BaseCursor(Generic[ConnectionType, Row]):
             # Received from execute()
             self._results[:] = results
             self._select_current_result(0)
-
+        # Received from executemany()
+        elif self._execmany_returning:
+            first_batch = not self._results
+            self._results.extend(results)
+            if first_batch:
+                self._select_current_result(0)
         else:
-            # Received from executemany()
-            if self._execmany_returning:
-                first_batch = not self._results
-                self._results.extend(results)
-                if first_batch:
-                    self._select_current_result(0)
-            else:
-                # In non-returning case, set rowcount to the cumulated number of
-                # rows of executed queries.
-                for res in results:
-                    self._rowcount += res.command_tuples or 0
+            # In non-returning case, set rowcount to the cumulated number of
+            # rows of executed queries.
+            for res in results:
+                self._rowcount += res.command_tuples or 0
 
     def _send_prepare(self, name: bytes, query: PostgresQuery) -> None:
         if self._conn._pipeline:
@@ -572,12 +568,11 @@ class BaseCursor(Generic[ConnectionType, Row]):
     def _check_result_for_fetch(self) -> None:
         if self.closed:
             raise e.InterfaceError("the cursor is closed")
-        res = self.pgresult
-        if not res:
+
+        if not (res := self.pgresult):
             raise e.ProgrammingError("no result available")
 
-        status = res.status
-        if status == TUPLES_OK:
+        if (status := res.status) == TUPLES_OK:
             return
         elif status == FATAL_ERROR:
             raise e.error_from_result(res, encoding=self._encoding)
index b642b2bca343200b767dac6e5030fbdb9a55da07..4d50942802a5ddd979a8b8a7b8ce4789dd27e7d8 100644 (file)
@@ -62,11 +62,9 @@ async def resolve_hostaddr_async(params: dict[str, Any]) -> dict[str, Any]:
             ports.append(str(attempt["port"]))
 
     out = params.copy()
-    shosts = ",".join(hosts)
-    if shosts:
+    if shosts := ",".join(hosts):
         out["host"] = shosts
-    shostaddrs = ",".join(hostaddrs)
-    if shostaddrs:
+    if shostaddrs := ",".join(hostaddrs):
         out["hostaddr"] = shostaddrs
     sports = ",".join(ports)
     if ports:
@@ -103,8 +101,7 @@ class Rfc2782Resolver:
 
     def resolve(self, params: dict[str, Any]) -> dict[str, Any]:
         """Update the parameters host and port after SRV lookup."""
-        attempts = self._get_attempts(params)
-        if not attempts:
+        if not (attempts := self._get_attempts(params)):
             return params
 
         hps = []
@@ -118,8 +115,7 @@ class Rfc2782Resolver:
 
     async def resolve_async(self, params: dict[str, Any]) -> dict[str, Any]:
         """Update the parameters host and port after SRV lookup."""
-        attempts = self._get_attempts(params)
-        if not attempts:
+        if not (attempts := self._get_attempts(params)):
             return params
 
         hps = []
@@ -144,9 +140,7 @@ class Rfc2782Resolver:
         host_arg: str = params.get("host", os.environ.get("PGHOST", ""))
         hosts_in = host_arg.split(",")
         port_arg: str = str(params.get("port", os.environ.get("PGPORT", "")))
-        ports_in = port_arg.split(",")
-
-        if len(ports_in) == 1:
+        if len((ports_in := port_arg.split(","))) == 1:
             # If only one port is specified, it applies to all the hosts.
             ports_in *= len(hosts_in)
         if len(ports_in) != len(hosts_in):
@@ -159,8 +153,7 @@ class Rfc2782Resolver:
         out = []
         srv_found = False
         for host, port in zip(hosts_in, ports_in):
-            m = self.re_srv_rr.match(host)
-            if m or port.lower() == "srv":
+            if (m := self.re_srv_rr.match(host)) or port.lower() == "srv":
                 srv_found = True
                 target = m.group("target") if m else None
                 hp = HostPort(host=host, port=port, totry=True, target=target)
index d1ef6dd2d880c8401d80eee4534170a334057331..6bb859d8b173f65f817a6889d0630320447db54d 100644 (file)
@@ -98,8 +98,7 @@ def conninfo_encoding(conninfo: str) -> str:
     from .conninfo import conninfo_to_dict
 
     params = conninfo_to_dict(conninfo)
-    pgenc = params.get("client_encoding")
-    if pgenc:
+    if pgenc := params.get("client_encoding"):
         try:
             return pg2pyenc(str(pgenc).encode())
         except NotSupportedError:
index f4752829a006f04516d0bbe378af84a5401b589a..8ae21b66fab69d0eca8ef22fcbe23c030bedf389 100644 (file)
@@ -142,8 +142,7 @@ class BasePipeline:
 
         exception = None
         while self.result_queue:
-            results = yield from fetch_many(self.pgconn)
-            if not results:
+            if not (results := (yield from fetch_many(self.pgconn))):
                 # No more results to fetch, but there may still be pending
                 # commands.
                 break
index 10a4874940eb6eff887740fd42a1dfad6b4a2d42..7a1b0f5d2efb8a90134a9efe1b8e68f5eaacd0d4 100644 (file)
@@ -64,9 +64,7 @@ class PrepareManager:
             # The user doesn't want this query to be prepared
             return Prepare.NO, b""
 
-        key = self.key(query)
-        name = self._names.get(key)
-        if name:
+        if name := self._names.get(key := self.key(query)):
             # The query was already prepared in this session
             return Prepare.YES, name
 
@@ -101,8 +99,7 @@ class PrepareManager:
             # We cannot prepare a multiple statement
             return False
 
-        status = results[0].status
-        if COMMAND_OK != status != TUPLES_OK:
+        if COMMAND_OK != results[0].status != TUPLES_OK:
             # We don't prepare failed queries or other weird results
             return False
 
@@ -133,8 +130,7 @@ class PrepareManager:
         if self.prepare_threshold is None:
             return None
 
-        key = self.key(query)
-        if key in self._counts:
+        if (key := self.key(query)) in self._counts:
             if prep is Prepare.SHOULD:
                 del self._counts[key]
                 self._names[key] = name
@@ -155,11 +151,7 @@ class PrepareManager:
             return key
 
     def validate(
-        self,
-        key: Key,
-        prep: Prepare,
-        name: bytes,
-        results: Sequence[PGresult],
+        self, key: Key, prep: Prepare, name: bytes, results: Sequence[PGresult]
     ) -> None:
         """Validate cached entry with 'key' by checking query 'results'.
 
index 0af6884cbea5872bb0ba7fa5e87458a85dec94f5..820c567d677c73aff71251a4276f999bb922fca6 100644 (file)
@@ -179,8 +179,7 @@ class Transformer(AdaptContext):
         # right size.
         if self._row_dumpers:
             for i in range(nparams):
-                param = params[i]
-                if param is not None:
+                if (param := params[i]) is not None:
                     out[i] = self._row_dumpers[i].dump(param)
             return out
 
@@ -188,8 +187,7 @@ class Transformer(AdaptContext):
         pqformats = [TEXT] * nparams
 
         for i in range(nparams):
-            param = params[i]
-            if param is None:
+            if (param := params[i]) is None:
                 continue
             dumper = self.get_dumper(param, formats[i])
             out[i] = dumper.dump(param)
@@ -212,8 +210,7 @@ class Transformer(AdaptContext):
             try:
                 type_sql = self._oid_types[oid]
             except KeyError:
-                ti = self.adapters.types.get(oid)
-                if ti:
+                if ti := self.adapters.types.get(oid):
                     if oid < 8192:
                         # builtin: prefer "timestamptz" to "timestamp with time zone"
                         type_sql = ti.name.encode(self.encoding)
@@ -254,8 +251,7 @@ class Transformer(AdaptContext):
                 cache[key] = dumper = dcls(key, self)
 
         # Check if the dumper requires an upgrade to handle this specific value
-        key1 = dumper.get_key(obj, format)
-        if key1 is key:
+        if (key1 := dumper.get_key(obj, format)) is key:
             return dumper
 
         # If it does, ask the dumper to create its own upgraded version
@@ -298,8 +294,7 @@ class Transformer(AdaptContext):
         return dumper
 
     def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> list[Row]:
-        res = self._pgresult
-        if not res:
+        if not (res := self._pgresult):
             raise e.InterfaceError("result not set")
 
         if not (0 <= row0 <= self._ntuples and 0 <= row1 <= self._ntuples):
@@ -311,16 +306,14 @@ class Transformer(AdaptContext):
         for row in range(row0, row1):
             record: list[Any] = [None] * self._nfields
             for col in range(self._nfields):
-                val = res.get_value(row, col)
-                if val is not None:
+                if (val := res.get_value(row, col)) is not None:
                     record[col] = self._row_loaders[col](val)
             records.append(make_row(record))
 
         return records
 
     def load_row(self, row: int, make_row: RowMaker[Row]) -> Row | None:
-        res = self._pgresult
-        if not res:
+        if not (res := self._pgresult):
             return None
 
         if not 0 <= row < self._ntuples:
@@ -328,8 +321,7 @@ class Transformer(AdaptContext):
 
         record: list[Any] = [None] * self._nfields
         for col in range(self._nfields):
-            val = res.get_value(row, col)
-            if val is not None:
+            if (val := res.get_value(row, col)) is not None:
                 record[col] = self._row_loaders[col](val)
 
         return make_row(record)
@@ -352,10 +344,8 @@ class Transformer(AdaptContext):
         except KeyError:
             pass
 
-        loader_cls = self._adapters.get_loader(oid, format)
-        if not loader_cls:
-            loader_cls = self._adapters.get_loader(INVALID_OID, format)
-            if not loader_cls:
+        if not (loader_cls := self._adapters.get_loader(oid, format)):
+            if not (loader_cls := self._adapters.get_loader(INVALID_OID, format)):
                 raise e.InterfaceError("unknown oid loader not found")
         loader = self._loaders[format][oid] = loader_cls(oid, self)
         return loader
index 22e5b4bf66823b0b3ead20351a8c07285ab6bcf3..98fa0a7ff9c6309134e3a012e51d69acbc047373 100644 (file)
@@ -372,8 +372,7 @@ def _split_query(
             rv.append(QueryPart(pre, 0, PyFormat.AUTO))
             break
 
-        ph = m.group(0)
-        if ph == b"%%":
+        if (ph := m.group(0)) == b"%%":
             # unescape '%%' to '%' if necessary, then merge the parts
             if collapse_double_percent:
                 ph = b"%"
index e3719010cd318ff431397e7d55b203b053f35b21..8c730d64412de870dcd0983edd70a1036e3f430b 100644 (file)
@@ -52,8 +52,7 @@ class Xid:
 
     @classmethod
     def _parse_string(cls, s: str) -> Xid:
-        m = _re_xid.match(s)
-        if not m:
+        if not (m := _re_xid.match(s)):
             raise ValueError("bad Xid format")
 
         format_id = int(m.group(1))
index c9126364ebbe2265bc41e4919ad13b6e8b686ab9..47d6835d91d9953dcc8fd144cbc85dbe44c1b3d4 100644 (file)
@@ -171,8 +171,7 @@ ORDER BY t.oid
     @classmethod
     def _has_to_regtype_function(cls, conn: BaseConnection[Any]) -> bool:
         # to_regtype() introduced in PostgreSQL 9.4 and CockroachDB 22.2
-        info = conn.info
-        if info.vendor == "PostgreSQL":
+        if (info := conn.info).vendor == "PostgreSQL":
             return info.server_version >= 90400
         elif info.vendor == "CockroachDB":
             return info.server_version >= 220200
@@ -195,10 +194,8 @@ ORDER BY t.oid
         pass
 
     def get_type_display(self, oid: int | None = None, fmod: int | None = None) -> str:
-        parts = []
-        parts.append(self.name)
-        mod = self.typemod.get_modifier(fmod) if fmod is not None else ()
-        if mod:
+        parts = [self.name]
+        if mod := (self.typemod.get_modifier(fmod) if fmod is not None else ()):
             parts.append(f"({','.join(map(str, mod))})")
 
         if oid == self.array_oid:
index 2918f555b52bfa483243400433765d7e1a206dd3..58db3640066b2f8bcb90319e92f8451ce2b060cb 100644 (file)
@@ -56,8 +56,7 @@ class Dumper(abc.Dumper, ABC):
         for most types and you won't likely have to implement this method in a
         subclass.
         """
-        value = self.dump(obj)
-        if value is None:
+        if (value := self.dump(obj)) is None:
             return b"NULL"
 
         if self.connection:
index 70d77467d313d084b4558d85ed319ed9564065de..bfff3fc1e0b15c0ec0e45e8a404a9ebc0869695c 100644 (file)
@@ -382,8 +382,7 @@ class Connection(BaseConnection[Row]):
         with self.lock:
             self._check_connection_ok()
 
-            pipeline = self._pipeline
-            if pipeline is None:
+            if (pipeline := self._pipeline) is None:
                 # WARNING: reference loop, broken ahead.
                 pipeline = self._pipeline = Pipeline(self)
 
index 3c6615a24529ae31fc0cfe36812e32e4346e5b17..9f4ecb4640b38a1059ffeae2800cb7910834226a 100644 (file)
@@ -407,8 +407,7 @@ class AsyncConnection(BaseConnection[Row]):
         async with self.lock:
             self._check_connection_ok()
 
-            pipeline = self._pipeline
-            if pipeline is None:
+            if (pipeline := self._pipeline) is None:
                 # WARNING: reference loop, broken ahead.
                 pipeline = self._pipeline = AsyncPipeline(self)
 
index e8c33876b4d52f3785d45353dbe8421c67e4a3ea..f72ec34e01129cb43db78a3f3684ec917a3ddfd9 100644 (file)
@@ -114,8 +114,7 @@ def _param_escape(s: str) -> str:
     if not s:
         return "''"
 
-    s = re_escape.sub(r"\\\1", s)
-    if re_space.search(s):
+    if re_space.search(s := re_escape.sub(r"\\\1", s)):
         s = "'" + s + "'"
 
     return s
index 60db9a876ec6ab2e0ab518b4c17f4bc57409719c..411434a75d9be39e6d638db6fc682507f455f3df 100644 (file)
@@ -86,20 +86,17 @@ class CrdbConnectionInfo(ConnectionInfo):
 
         Return a number in the PostgreSQL format (e.g. 21.2.10 -> 210210).
         """
-        sver = self.parameter_status("crdb_version")
-        if not sver:
+        if not (sver := self.parameter_status("crdb_version")):
             raise e.InternalError("'crdb_version' parameter status not set")
 
-        ver = self.parse_crdb_version(sver)
-        if ver is None:
+        if (ver := self.parse_crdb_version(sver)) is None:
             raise e.InterfaceError(f"couldn't parse CockroachDB version from: {sver!r}")
 
         return ver
 
     @classmethod
     def parse_crdb_version(self, sver: str) -> int | None:
-        m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
-        if not m:
+        if not (m := re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)):
             return None
 
         return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
index 52451dc5f207b72a8104c6ae3c21968117636d10..ccde998637a5e163fda90bb89a97ae2b4d9ac2bd 100644 (file)
@@ -108,8 +108,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                 # If there is already a pipeline, ride it, in order to avoid
                 # sending unnecessary Sync.
                 with self._conn.lock:
-                    p = self._conn._pipeline
-                    if p:
+                    if p := self._conn._pipeline:
                         self._conn.wait(
                             self._executemany_gen_pipeline(query, params_seq, returning)
                         )
@@ -153,8 +152,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                 first = True
                 while self._conn.wait(self._stream_fetchone_gen(first)):
                     for pos in range(size):
-                        rec = self._tx.load_row(pos, self._make_row)
-                        if rec is None:
+                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
                             break
                         yield rec
                     first = False
@@ -188,8 +186,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         """
         self._fetch_pipeline()
         self._check_result_for_fetch()
-        record = self._tx.load_row(self._pos, self._make_row)
-        if record is not None:
+        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
             self._pos += 1
         return record
 
@@ -234,8 +231,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
             return self._tx.load_row(pos, self._make_row)
 
         while True:
-            row = load(self._pos)
-            if row is None:
+            if (row := load(self._pos)) is None:
                 break
             self._pos += 1
             yield row
index 0296c51791f8ea767cbe1414a67cce98f01a8b10..1edb726f1f4eb7fef8eb6437790b249ec98ea30a 100644 (file)
@@ -98,11 +98,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         return self
 
     async def executemany(
-        self,
-        query: Query,
-        params_seq: Iterable[Params],
-        *,
-        returning: bool = False,
+        self, query: Query, params_seq: Iterable[Params], *, returning: bool = False
     ) -> None:
         """
         Execute the same command with a sequence of input data.
@@ -112,8 +108,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                 # If there is already a pipeline, ride it, in order to avoid
                 # sending unnecessary Sync.
                 async with self._conn.lock:
-                    p = self._conn._pipeline
-                    if p:
+                    if p := self._conn._pipeline:
                         await self._conn.wait(
                             self._executemany_gen_pipeline(query, params_seq, returning)
                         )
@@ -157,8 +152,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                 first = True
                 while await self._conn.wait(self._stream_fetchone_gen(first)):
                     for pos in range(size):
-                        rec = self._tx.load_row(pos, self._make_row)
-                        if rec is None:
+                        if (rec := self._tx.load_row(pos, self._make_row)) is None:
                             break
                         yield rec
                     first = False
@@ -196,8 +190,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         """
         await self._fetch_pipeline()
         self._check_result_for_fetch()
-        record = self._tx.load_row(self._pos, self._make_row)
-        if record is not None:
+        if (record := self._tx.load_row(self._pos, self._make_row)) is not None:
             self._pos += 1
         return record
 
@@ -216,9 +209,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         if not size:
             size = self.arraysize
         records = self._tx.load_rows(
-            self._pos,
-            min(self._pos + size, self.pgresult.ntuples),
-            self._make_row,
+            self._pos, min(self._pos + size, self.pgresult.ntuples), self._make_row
         )
         self._pos += len(records)
         return records
@@ -244,8 +235,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
             return self._tx.load_row(pos, self._make_row)
 
         while True:
-            row = load(self._pos)
-            if row is None:
+            if (row := load(self._pos)) is None:
                 break
             self._pos += 1
             yield row
index cdcf8655f796c3af83bc77235772588a2d8e8b45..a6f0d75c4521cd70386330b64bca88bafcd961b0 100644 (file)
@@ -71,8 +71,7 @@ class Binary:
         self.obj = obj
 
     def __repr__(self) -> str:
-        sobj = repr(self.obj)
-        if len(sobj) > 40:
+        if len((sobj := repr(self.obj))) > 40:
             sobj = f"{sobj[:35]} ... ({len(sobj)} byteschars)"
         return f"{self.__class__.__name__}({sobj})"
 
index f1827be2befb633e45fe4fb615b43f2a60a3c8a6..f558cc6e3891e12997cd8c0c81bd50814f6b0233 100644 (file)
@@ -70,8 +70,7 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]:
         if conn.status == BAD:
             encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
-                f"connection is bad: {conn.get_error_message(encoding)}",
-                pgconn=conn,
+                f"connection is bad: {conn.get_error_message(encoding)}", pgconn=conn
             )
 
         status = conn.connect_poll()
@@ -107,8 +106,8 @@ def _cancel(cancel_conn: PGcancelConn, *, timeout: float = 0.0) -> PQGenConn[Non
     while True:
         if deadline and monotonic() > deadline:
             raise e.CancellationTimeout("cancellation timeout expired")
-        status = cancel_conn.poll()
-        if status == POLL_OK:
+
+        if (status := cancel_conn.poll()) == POLL_OK:
             break
         elif status == POLL_READING:
             yield cancel_conn.socket, WAIT_R
@@ -150,13 +149,11 @@ def _send(pgconn: PGconn) -> PQGen[None]:
     to retrieve the results available.
     """
     while True:
-        f = pgconn.flush()
-        if f == 0:
+        if pgconn.flush() == 0:
             break
 
         while True:
-            ready = yield WAIT_RW
-            if ready:
+            if ready := (yield WAIT_RW):
                 break
 
         if ready & READY_R:
@@ -220,8 +217,7 @@ def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]:
     """
     if pgconn.is_busy():
         while True:
-            ready = yield WAIT_R
-            if ready:
+            if (yield WAIT_R):
                 break
 
         while True:
@@ -229,8 +225,7 @@ def _fetch(pgconn: PGconn) -> PQGen[PGresult | None]:
             if not pgconn.is_busy():
                 break
             while True:
-                ready = yield WAIT_R
-                if ready:
+                if (yield WAIT_R):
                     break
 
     _consume_notifies(pgconn)
@@ -250,8 +245,7 @@ def _pipeline_communicate(
 
     while True:
         while True:
-            ready = yield WAIT_RW
-            if ready:
+            if ready := (yield WAIT_RW):
                 break
 
         if ready & READY_R:
@@ -260,28 +254,23 @@ def _pipeline_communicate(
 
             res: list[PGresult] = []
             while not pgconn.is_busy():
-                r = pgconn.get_result()
-                if r is None:
+                if (r := pgconn.get_result()) is None:
                     if not res:
                         break
                     results.append(res)
                     res = []
+                elif (status := r.status) == PIPELINE_SYNC:
+                    assert not res
+                    results.append([r])
+                elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
+                    # This shouldn't happen, but insisting hard enough, it will.
+                    # For instance, in test_executemany_badquery(), with the COPY
+                    # statement and the AsyncClientCursor, which disables
+                    # prepared statements).
+                    # Bail out from the resulting infinite loop.
+                    raise e.NotSupportedError("COPY cannot be used in pipeline mode")
                 else:
-                    status = r.status
-                    if status == PIPELINE_SYNC:
-                        assert not res
-                        results.append([r])
-                    elif status == COPY_IN or status == COPY_OUT or status == COPY_BOTH:
-                        # This shouldn't happen, but insisting hard enough, it will.
-                        # For instance, in test_executemany_badquery(), with the COPY
-                        # statement and the AsyncClientCursor, which disables
-                        # prepared statements).
-                        # Bail out from the resulting infinite loop.
-                        raise e.NotSupportedError(
-                            "COPY cannot be used in pipeline mode"
-                        )
-                    else:
-                        res.append(r)
+                    res.append(r)
 
         if ready & READY_W:
             pgconn.flush()
@@ -295,8 +284,7 @@ def _pipeline_communicate(
 def _consume_notifies(pgconn: PGconn) -> None:
     # Consume notifies
     while True:
-        n = pgconn.notifies()
-        if not n:
+        if not (n := pgconn.notifies()):
             break
         if pgconn.notify_handler:
             pgconn.notify_handler(n)
@@ -308,8 +296,7 @@ def notifies(pgconn: PGconn) -> PQGen[list[pq.PGnotify]]:
 
     ns = []
     while True:
-        n = pgconn.notifies()
-        if n:
+        if n := pgconn.notifies():
             ns.append(n)
             if pgconn.notify_handler:
                 pgconn.notify_handler(n)
@@ -327,8 +314,7 @@ def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]:
 
         # would block
         while True:
-            ready = yield WAIT_R
-            if ready:
+            if (yield WAIT_R):
                 break
         pgconn.consume_input()
 
@@ -337,12 +323,12 @@ def copy_from(pgconn: PGconn) -> PQGen[memoryview | PGresult]:
         return data
 
     # Retrieve the final result of copy
-    results = yield from _fetch_many(pgconn)
-    if len(results) > 1:
+
+    if len(results := (yield from _fetch_many(pgconn))) > 1:
         # TODO: too brutal? Copy worked.
         raise e.ProgrammingError("you cannot mix COPY with other operations")
-    result = results[0]
-    if result.status != COMMAND_OK:
+
+    if (result := results[0]).status != COMMAND_OK:
         raise e.error_from_result(result, encoding=pgconn._encoding)
 
     return result
@@ -357,8 +343,7 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
     # do it upstream the queue decoupling the writer task from the producer one.
     while pgconn.put_copy_data(buffer) == 0:
         while True:
-            ready = yield WAIT_W
-            if ready:
+            if (yield WAIT_W):
                 break
 
     # Flushing often has a good effect on macOS because memcpy operations
@@ -368,11 +353,10 @@ def copy_to(pgconn: PGconn, buffer: Buffer, flush: bool = True) -> PQGen[None]:
         # Repeat until it the message is flushed to the server
         while True:
             while True:
-                ready = yield WAIT_W
-                if ready:
+                if (yield WAIT_W):
                     break
-            f = pgconn.flush()
-            if f == 0:
+
+            if pgconn.flush() == 0:
                 break
 
 
@@ -380,18 +364,16 @@ def copy_end(pgconn: PGconn, error: bytes | None) -> PQGen[PGresult]:
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
         while True:
-            ready = yield WAIT_W
-            if ready:
+            if (yield WAIT_W):
                 break
 
     # Repeat until it the message is flushed to the server
     while True:
         while True:
-            ready = yield WAIT_W
-            if ready:
+            if (yield WAIT_W):
                 break
-        f = pgconn.flush()
-        if f == 0:
+
+        if pgconn.flush() == 0:
             break
 
     # Retrieve the final result of copy
index d55be281d113b55dfa3f53fb3ffa622b8ca160a6..6fb389112be2b57d2ffc87c26e749c656368b152 100644 (file)
@@ -56,8 +56,7 @@ class PGconnDebug:
         return f"<{cls} {info} at 0x{id(self):x}>"
 
     def __getattr__(self, attr: str) -> Any:
-        value = getattr(self._pgconn, attr)
-        if callable(value):
+        if callable((value := getattr(self._pgconn, attr))):
             return debugging(value)
         else:
             logger.info("PGconn.%s -> %s", attr, value)
index 9128f1532d6812a9d2470b7d0e0d547f8d61975d..99e49357c7c59d66e0958167754350c0fc07e492 100644 (file)
@@ -16,8 +16,7 @@ from typing import Any, NoReturn
 from .misc import find_libpq_full_path, version_pretty
 from ..errors import NotSupportedError
 
-libname = find_libpq_full_path()
-if not libname:
+if not (libname := find_libpq_full_path()):
     raise ImportError("libpq library not found")
 
 pq = ctypes.cdll.LoadLibrary(libname)
@@ -30,8 +29,7 @@ class FILE(Structure):
 FILE_ptr = POINTER(FILE)
 
 if sys.platform == "linux":
-    libcname = ctypes.util.find_library("c")
-    if not libcname:
+    if not (libcname := ctypes.util.find_library("c")):
         # Likely this is a system using musl libc, see the following bug:
         # https://github.com/python/cpython/issues/65821
         libcname = "libc.so"
@@ -41,7 +39,6 @@ if sys.platform == "linux":
     fdopen.argtypes = (c_int, c_char_p)
     fdopen.restype = FILE_ptr
 
-
 # Get the libpq version to define what functions are available.
 
 PQlibVersion = pq.PQlibVersion
@@ -680,23 +677,14 @@ PQfreemem.restype = None
 
 if libpq_version >= 100000:
     PQencryptPasswordConn = pq.PQencryptPasswordConn
-    PQencryptPasswordConn.argtypes = [
-        PGconn_ptr,
-        c_char_p,
-        c_char_p,
-        c_char_p,
-    ]
+    PQencryptPasswordConn.argtypes = [PGconn_ptr, c_char_p, c_char_p, c_char_p]
     PQencryptPasswordConn.restype = POINTER(c_char)
 else:
     PQencryptPasswordConn = not_supported_before("PQencryptPasswordConn", 100000)
 
 if libpq_version >= 170000:
     PQchangePassword = pq.PQchangePassword
-    PQchangePassword.argtypes = [
-        PGconn_ptr,
-        c_char_p,
-        c_char_p,
-    ]
+    PQchangePassword.argtypes = [PGconn_ptr, c_char_p, c_char_p]
     PQchangePassword.restype = PGresult_ptr
 else:
     PQchangePassword = not_supported_before("PQchangePassword", 170000)
index 31be494a0fa0d982c877b18c99cf51d3637a3726..81d932ca4d1d36a3699d128d51c390c214ea2007 100644 (file)
@@ -52,11 +52,9 @@ class PGresAttDesc(NamedTuple):
 @cache
 def find_libpq_full_path() -> str | None:
     if sys.platform == "win32":
-        libname = ctypes.util.find_library("libpq.dll")
-        if libname is None:
+        if (libname := ctypes.util.find_library("libpq.dll")) is None:
             return None
         libname = str(Path(libname).resolve())
-
     elif sys.platform == "darwin":
         libname = ctypes.util.find_library("libpq.dylib")
         # (hopefully) temporary hack: libpq not in a standard place
@@ -67,8 +65,7 @@ def find_libpq_full_path() -> str | None:
                 import subprocess as sp
 
                 libdir = sp.check_output(["pg_config", "--libdir"]).strip().decode()
-                libname = os.path.join(libdir, "libpq.dylib")
-                if not os.path.exists(libname):
+                if not os.path.exists((libname := os.path.join(libdir, "libpq.dylib"))):
                     libname = None
             except Exception as ex:
                 logger.debug("couldn't use pg_config to find libpq: %s", ex)
@@ -129,16 +126,14 @@ PREFIXES = re.compile(
 
 def strip_severity(msg: str) -> str:
     """Strip severity and whitespaces from error message."""
-    m = PREFIXES.match(msg)
-    if m:
+    if m := PREFIXES.match(msg):
         msg = msg[m.span()[1] :]
 
     return msg.strip()
 
 
 def _clean_error_message(msg: bytes, encoding: str) -> str:
-    smsg = msg.decode(encoding, "replace")
-    if smsg:
+    if smsg := msg.decode(encoding, "replace"):
         return strip_severity(smsg)
     else:
         return "no error details available"
@@ -169,8 +164,7 @@ def connection_summary(pgconn: abc.PGconn) -> str:
     else:
         status = ConnStatus(pgconn.status).name
 
-    sparts = " ".join("%s=%s" % part for part in parts)
-    if sparts:
+    if sparts := " ".join(("%s=%s" % part for part in parts)):
         sparts = f" ({sparts})"
     return f"[{status}]{sparts}"
 
index e35a0486fd7f360660b2efb0b579317ab8f36dd0..6229bb02d59b930b49301144064b1e0613dca24d 100644 (file)
@@ -109,9 +109,7 @@ class PGconn:
     def connect(cls, conninfo: bytes) -> PGconn:
         if not isinstance(conninfo, bytes):
             raise TypeError(f"bytes expected, got {type(conninfo)} instead")
-
-        pgconn_ptr = impl.PQconnectdb(conninfo)
-        if not pgconn_ptr:
+        if not (pgconn_ptr := impl.PQconnectdb(conninfo)):
             raise MemoryError("couldn't allocate PGconn")
         return cls(pgconn_ptr)
 
@@ -119,9 +117,7 @@ class PGconn:
     def connect_start(cls, conninfo: bytes) -> PGconn:
         if not isinstance(conninfo, bytes):
             raise TypeError(f"bytes expected, got {type(conninfo)} instead")
-
-        pgconn_ptr = impl.PQconnectStart(conninfo)
-        if not pgconn_ptr:
+        if not (pgconn_ptr := impl.PQconnectStart(conninfo)):
             raise MemoryError("couldn't allocate PGconn")
         return cls(pgconn_ptr)
 
@@ -151,8 +147,7 @@ class PGconn:
     @property
     def info(self) -> list[ConninfoOption]:
         self._ensure_pgconn()
-        opts = impl.PQconninfo(self._pgconn_ptr)
-        if not opts:
+        if not (opts := impl.PQconninfo(self._pgconn_ptr)):
             raise MemoryError("couldn't allocate connection info")
         try:
             return Conninfo._options_from_array(opts)
@@ -246,8 +241,7 @@ class PGconn:
 
     @property
     def socket(self) -> int:
-        rv = self._call_int(impl.PQsocket)
-        if rv == -1:
+        if (rv := self._call_int(impl.PQsocket)) == -1:
             raise e.OperationalError("the connection is lost")
         return rv
 
@@ -280,8 +274,7 @@ class PGconn:
         if not isinstance(command, bytes):
             raise TypeError(f"bytes expected, got {type(command)} instead")
         self._ensure_pgconn()
-        rv = impl.PQexec(self._pgconn_ptr, command)
-        if not rv:
+        if not (rv := impl.PQexec(self._pgconn_ptr, command)):
             raise e.OperationalError(
                 f"executing query failed: {self.get_error_message()}"
             )
@@ -308,8 +301,7 @@ class PGconn:
             command, param_values, param_types, param_formats, result_format
         )
         self._ensure_pgconn()
-        rv = impl.PQexecParams(*args)
-        if not rv:
+        if not (rv := impl.PQexecParams(*args)):
             raise e.OperationalError(
                 f"executing query failed: {self.get_error_message()}"
             )
@@ -333,10 +325,7 @@ class PGconn:
             )
 
     def send_prepare(
-        self,
-        name: bytes,
-        command: bytes,
-        param_types: Sequence[int] | None = None,
+        self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
     ) -> None:
         atypes: Array[impl.Oid] | None
         if not param_types:
@@ -432,10 +421,7 @@ class PGconn:
         )
 
     def prepare(
-        self,
-        name: bytes,
-        command: bytes,
-        param_types: Sequence[int] | None = None,
+        self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
     ) -> PGresult:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
@@ -451,8 +437,7 @@ class PGconn:
             atypes = (impl.Oid * nparams)(*param_types)
 
         self._ensure_pgconn()
-        rv = impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)
-        if not rv:
+        if not (rv := impl.PQprepare(self._pgconn_ptr, name, command, nparams, atypes)):
             raise e.OperationalError(
                 f"preparing query failed: {self.get_error_message()}"
             )
@@ -514,8 +499,7 @@ class PGconn:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         self._ensure_pgconn()
-        rv = impl.PQdescribePrepared(self._pgconn_ptr, name)
-        if not rv:
+        if not (rv := impl.PQdescribePrepared(self._pgconn_ptr, name)):
             raise e.OperationalError(
                 f"describe prepared failed: {self.get_error_message()}"
             )
@@ -534,8 +518,7 @@ class PGconn:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         self._ensure_pgconn()
-        rv = impl.PQdescribePortal(self._pgconn_ptr, name)
-        if not rv:
+        if not (rv := impl.PQdescribePortal(self._pgconn_ptr, name)):
             raise e.OperationalError(
                 f"describe portal failed: {self.get_error_message()}"
             )
@@ -554,8 +537,7 @@ class PGconn:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         self._ensure_pgconn()
-        rv = impl.PQclosePrepared(self._pgconn_ptr, name)
-        if not rv:
+        if not (rv := impl.PQclosePrepared(self._pgconn_ptr, name)):
             raise e.OperationalError(
                 f"close prepared failed: {self.get_error_message()}"
             )
@@ -574,8 +556,7 @@ class PGconn:
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
         self._ensure_pgconn()
-        rv = impl.PQclosePortal(self._pgconn_ptr, name)
-        if not rv:
+        if not (rv := impl.PQclosePortal(self._pgconn_ptr, name)):
             raise e.OperationalError(f"close portal failed: {self.get_error_message()}")
         return PGresult(rv)
 
@@ -635,8 +616,7 @@ class PGconn:
 
         See :pq:`PQcancelCreate` for details.
         """
-        rv = impl.PQcancelCreate(self._pgconn_ptr)
-        if not rv:
+        if not (rv := impl.PQcancelCreate(self._pgconn_ptr)):
             raise e.OperationalError("couldn't create cancelConn object")
         return PGcancelConn(rv)
 
@@ -646,14 +626,12 @@ class PGconn:
 
         See :pq:`PQgetCancel` for details.
         """
-        rv = impl.PQgetCancel(self._pgconn_ptr)
-        if not rv:
+        if not (rv := impl.PQgetCancel(self._pgconn_ptr)):
             raise e.OperationalError("couldn't create cancel object")
         return PGcancel(rv)
 
     def notifies(self) -> PGnotify | None:
-        ptr = impl.PQnotifies(self._pgconn_ptr)
-        if ptr:
+        if ptr := impl.PQnotifies(self._pgconn_ptr):
             c = ptr.contents
             rv = PGnotify(c.relname, c.be_pid, c.extra)
             impl.PQfreemem(ptr)
@@ -664,16 +642,14 @@ class PGconn:
     def put_copy_data(self, buffer: abc.Buffer) -> int:
         if not isinstance(buffer, bytes):
             buffer = bytes(buffer)
-        rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
-        if rv < 0:
+        if (rv := impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))) < 0:
             raise e.OperationalError(
                 f"sending copy data failed: {self.get_error_message()}"
             )
         return rv
 
     def put_copy_end(self, error: bytes | None = None) -> int:
-        rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
-        if rv < 0:
+        if (rv := impl.PQputCopyEnd(self._pgconn_ptr, error)) < 0:
             raise e.OperationalError(
                 f"sending copy end failed: {self.get_error_message()}"
             )
@@ -756,8 +732,7 @@ class PGconn:
             )
 
     def make_empty_result(self, exec_status: int) -> PGresult:
-        rv = impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)
-        if not rv:
+        if not (rv := impl.PQmakeEmptyPGresult(self._pgconn_ptr, exec_status)):
             raise MemoryError("couldn't allocate empty PGresult")
         return PGresult(rv)
 
@@ -791,8 +766,8 @@ class PGconn:
         :raises ~e.OperationalError: if the connection is not in pipeline mode
             or if sync failed.
         """
-        rv = impl.PQpipelineSync(self._pgconn_ptr)
-        if rv == 0:
+
+        if (rv := impl.PQpipelineSync(self._pgconn_ptr)) == 0:
             raise e.OperationalError("connection not in pipeline mode")
         if rv != 1:
             raise e.OperationalError("failed to sync pipeline")
@@ -928,11 +903,10 @@ class PGresult:
         if length:
             v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
             return string_at(v, length)
+        elif impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
+            return None
         else:
-            if impl.PQgetisnull(self._pgresult_ptr, row_number, column_number):
-                return None
-            else:
-                return b""
+            return b""
 
     @property
     def nparams(self) -> int:
@@ -959,8 +933,8 @@ class PGresult:
             impl.PGresAttDesc_struct(*desc) for desc in descriptions  # type: ignore
         ]
         array = (impl.PGresAttDesc_struct * len(structs))(*structs)  # type: ignore
-        rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
-        if rv == 0:
+
+        if impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array) == 0:
             raise e.OperationalError("PQsetResultAttrs failed")
 
 
@@ -1011,8 +985,7 @@ class PGcancelConn:
 
     @property
     def socket(self) -> int:
-        rv = impl.PQcancelSocket(self.pgcancelconn_ptr)
-        if rv == -1:
+        if (rv := impl.PQcancelSocket(self.pgcancelconn_ptr)) == -1:
             raise e.OperationalError("cancel connection not opened")
         return rv
 
@@ -1095,8 +1068,7 @@ class Conninfo:
 
     @classmethod
     def get_defaults(cls) -> list[ConninfoOption]:
-        opts = impl.PQconndefaults()
-        if not opts:
+        if not (opts := impl.PQconndefaults()):
             raise MemoryError("couldn't allocate connection defaults")
         try:
             return cls._options_from_array(opts)
@@ -1157,8 +1129,8 @@ class Escaping:
         # TODO: might be done without copy (however C does that)
         if not isinstance(data, bytes):
             data = bytes(data)
-        out = impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))
-        if not out:
+
+        if not (out := impl.PQescapeLiteral(self.conn._pgconn_ptr, data, len(data))):
             raise e.OperationalError(
                 f"escape_literal failed: {self.conn.get_error_message()} bytes"
             )
@@ -1174,8 +1146,8 @@ class Escaping:
 
         if not isinstance(data, bytes):
             data = bytes(data)
-        out = impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))
-        if not out:
+
+        if not (out := impl.PQescapeIdentifier(self.conn._pgconn_ptr, data, len(data))):
             raise e.OperationalError(
                 f"escape_identifier failed: {self.conn.get_error_message()} bytes"
             )
index 1053fc80cbaddf918594afb730c25eee4b36c327..a295940675319123f5d76b713df4a8c0c18c25dd 100644 (file)
@@ -116,14 +116,15 @@ def dict_row(cursor: BaseCursor[Any, Any]) -> RowMaker[DictRow]:
 
     The dictionary keys are taken from the column names of the returned columns.
     """
-    names = _get_names(cursor)
-    if names is None:
-        return no_result
+    if (names := _get_names(cursor)) is not None:
+
+        def dict_row_(values: Sequence[Any]) -> dict[str, Any]:
+            return dict(zip(names, values))
 
-    def dict_row_(values: Sequence[Any]) -> dict[str, Any]:
-        return dict(zip(names, values))
+        return dict_row_
 
-    return dict_row_
+    else:
+        return no_result
 
 
 def namedtuple_row(cursor: BaseCursor[Any, Any]) -> RowMaker[NamedTuple]:
@@ -132,17 +133,12 @@ def namedtuple_row(cursor: BaseCursor[Any, Any]) -> RowMaker[NamedTuple]:
     The field names are taken from the column names of the returned columns,
     with some mangling to deal with invalid names.
     """
-    res = cursor.pgresult
-    if not res:
-        return no_result
-
-    nfields = _get_nfields(res)
-    if nfields is None:
+    if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+        nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
+        return nt._make
+    else:
         return no_result
 
-    nt = _make_nt(cursor._encoding, *(res.fname(i) for i in range(nfields)))
-    return nt._make
-
 
 @functools.lru_cache(512)
 def _make_nt(enc: str, *names: bytes) -> type[NamedTuple]:
@@ -161,14 +157,15 @@ def class_row(cls: type[T]) -> BaseRowFactory[T]:
     """
 
     def class_row_(cursor: BaseCursor[Any, Any]) -> RowMaker[T]:
-        names = _get_names(cursor)
-        if names is None:
-            return no_result
+        if (names := _get_names(cursor)) is not None:
 
-        def class_row__(values: Sequence[Any]) -> T:
-            return cls(**dict(zip(names, values)))
+            def class_row__(values: Sequence[Any]) -> T:
+                return cls(**dict(zip(names, values)))
 
-        return class_row__
+            return class_row__
+
+        else:
+            return no_result
 
     return class_row_
 
@@ -197,14 +194,15 @@ def kwargs_row(func: Callable[..., T]) -> BaseRowFactory[T]:
     """
 
     def kwargs_row_(cursor: BaseCursor[Any, T]) -> RowMaker[T]:
-        names = _get_names(cursor)
-        if names is None:
-            return no_result
+        if (names := _get_names(cursor)) is not None:
 
-        def kwargs_row__(values: Sequence[Any]) -> T:
-            return func(**dict(zip(names, values)))
+            def kwargs_row__(values: Sequence[Any]) -> T:
+                return func(**dict(zip(names, values)))
 
-        return kwargs_row__
+            return kwargs_row__
+
+        else:
+            return no_result
 
     return kwargs_row_
 
@@ -214,21 +212,17 @@ def scalar_row(cursor: BaseCursor[Any, Any]) -> RowMaker[Any]:
     Generate a row factory returning the first column
     as a scalar value.
     """
-    res = cursor.pgresult
-    if not res:
-        return no_result
+    if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+        if nfields < 1:
+            raise e.ProgrammingError("at least one column expected")
 
-    nfields = _get_nfields(res)
-    if nfields is None:
-        return no_result
-
-    if nfields < 1:
-        raise e.ProgrammingError("at least one column expected")
+        def scalar_row_(values: Sequence[Any]) -> Any:
+            return values[0]
 
-    def scalar_row_(values: Sequence[Any]) -> Any:
-        return values[0]
+        return scalar_row_
 
-    return scalar_row_
+    else:
+        return no_result
 
 
 def no_result(values: Sequence[Any]) -> NoReturn:
@@ -242,19 +236,14 @@ def no_result(values: Sequence[Any]) -> NoReturn:
 
 
 def _get_names(cursor: BaseCursor[Any, Any]) -> list[str] | None:
-    res = cursor.pgresult
-    if not res:
-        return None
-
-    nfields = _get_nfields(res)
-    if nfields is None:
+    if (res := cursor.pgresult) and (nfields := _get_nfields(res)) is not None:
+        enc = cursor._encoding
+        return [
+            res.fname(i).decode(enc) for i in range(nfields)  # type: ignore[union-attr]
+        ]
+    else:
         return None
 
-    enc = cursor._encoding
-    return [
-        res.fname(i).decode(enc) for i in range(nfields)  # type: ignore[union-attr]
-    ]
-
 
 def _get_nfields(res: PGresult) -> int | None:
     """
index f375846e0682fe2a5712140a404abd7c018e7995..bd30d9c872c3d6ea32dcd16a2623f39842a7338b 100644 (file)
@@ -40,12 +40,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
 
     __slots__ = "_name _scrollable _withhold _described itersize _format".split()
 
-    def __init__(
-        self,
-        name: str,
-        scrollable: bool | None,
-        withhold: bool,
-    ):
+    def __init__(self, name: str, scrollable: bool | None, withhold: bool):
         self._name = name
         self._scrollable = scrollable
         self._withhold = withhold
@@ -95,10 +90,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         return self._pos if tuples else None
 
     def _declare_gen(
-        self,
-        query: Query,
-        params: Params | None = None,
-        binary: bool | None = None,
+        self, query: Query, params: Params | None = None, binary: bool | None = None
     ) -> PQGen[None]:
         """Generator implementing `ServerCursor.execute()`."""
 
@@ -195,10 +187,7 @@ class ServerCursorMixin(BaseCursor[ConnectionType, Row]):
         if not isinstance(query, sql.Composable):
             query = sql.SQL(query)
 
-        parts = [
-            sql.SQL("DECLARE"),
-            sql.Identifier(self._name),
-        ]
+        parts = [sql.SQL("DECLARE"), sql.Identifier(self._name)]
         if self._scrollable is not None:
             parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
         parts.append(sql.SQL("CURSOR"))
@@ -295,11 +284,7 @@ class ServerCursor(ServerCursorMixin["Connection[Any]", Row], Cursor[Row]):
         return self
 
     def executemany(
-        self,
-        query: Query,
-        params_seq: Iterable[Params],
-        *,
-        returning: bool = True,
+        self, query: Query, params_seq: Iterable[Params], *, returning: bool = True
     ) -> None:
         """Method not implemented for server-side cursors."""
         raise e.NotSupportedError("executemany not supported on server-side cursors")
@@ -428,11 +413,7 @@ class AsyncServerCursor(
         return self
 
     async def executemany(
-        self,
-        query: Query,
-        params_seq: Iterable[Params],
-        *,
-        returning: bool = True,
+        self, query: Query, params_seq: Iterable[Params], *, returning: bool = True
     ) -> None:
         raise e.NotSupportedError("executemany not supported on server-side cursors")
 
index 81a96a0981b27243156624af5d1e8299ab0a18df..172110afec099ec2aa58caaac2e70038ee37122f 100644 (file)
@@ -79,10 +79,8 @@ class Composable(ABC):
         :type context: `connection` or `cursor`
 
         """
-        conn = context.connection if context else None
-        enc = conn_encoding(conn)
-        b = self.as_bytes(context)
-        if isinstance(b, bytes):
+        enc = conn_encoding(context.connection if context else None)
+        if isinstance((b := self.as_bytes(context)), bytes):
             return b.decode(enc)
         else:
             # buffer object
@@ -373,8 +371,7 @@ class Identifier(Composable):
         return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
 
     def as_bytes(self, context: AdaptContext | None = None) -> bytes:
-        conn = context.connection if context else None
-        if conn:
+        if conn := (context.connection if context else None):
             esc = Escaping(conn.pgconn)
             enc = conn_encoding(conn)
             escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
index 156ee890fbb3ab1670992bab253fc1b036f1caba..5c80444fa49481104700e9d39241925afa319384 100644 (file)
@@ -52,8 +52,7 @@ class BaseListDumper(RecursiveDumper):
         Find the first non-null element of an eventually nested list
         """
         items = list(self._flatiter(L, set()))
-        types = {type(item): item for item in items}
-        if not types:
+        if not (types := {type(item): item for item in items}):
             return None
 
         if len(types) == 1:
@@ -62,8 +61,7 @@ class BaseListDumper(RecursiveDumper):
             # More than one type in the list. It might be still good, as long
             # as they dump with the same oid (e.g. IPv4Network, IPv6Network).
             dumpers = [self._tx.get_dumper(item, format) for item in types.values()]
-            oids = {d.oid for d in dumpers}
-            if len(oids) == 1:
+            if len({d.oid for d in dumpers}) == 1:
                 t, v = types.popitem()
             else:
                 raise e.DataError(
@@ -106,8 +104,7 @@ class BaseListDumper(RecursiveDumper):
         Return text info as fallback.
         """
         if base_oid:
-            info = self._tx.adapters.types.get(base_oid)
-            if info:
+            if info := self._tx.adapters.types.get(base_oid):
                 return info
 
         return self._tx.adapters.types["text"]
@@ -120,8 +117,7 @@ class ListDumper(BaseListDumper):
         if self.oid:
             return self.cls
 
-        item = self._find_list_element(obj, format)
-        if item is None:
+        if (item := self._find_list_element(obj, format)) is None:
             return self.cls
 
         sd = self._tx.get_dumper(item, format)
@@ -132,8 +128,7 @@ class ListDumper(BaseListDumper):
         if self.oid:
             return self
 
-        item = self._find_list_element(obj, format)
-        if item is None:
+        if (item := self._find_list_element(obj, format)) is None:
             # Empty lists can only be dumped as text if the type is unknown.
             return self
 
@@ -170,8 +165,7 @@ class ListDumper(BaseListDumper):
                 if isinstance(item, list):
                     dump_list(item)
                 elif item is not None:
-                    ad = self._dump_item(item)
-                    if ad is None:
+                    if (ad := self._dump_item(item)) is None:
                         tokens.append(b"NULL")
                     else:
                         if needs_quotes(ad):
@@ -224,8 +218,7 @@ class ListBinaryDumper(BaseListDumper):
         if self.oid:
             return self.cls
 
-        item = self._find_list_element(obj, format)
-        if item is None:
+        if (item := self._find_list_element(obj, format)) is None:
             return (self.cls,)
 
         sd = self._tx.get_dumper(item, format)
@@ -236,8 +229,7 @@ class ListBinaryDumper(BaseListDumper):
         if self.oid:
             return self
 
-        item = self._find_list_element(obj, format)
-        if item is None:
+        if (item := self._find_list_element(obj, format)) is None:
             return ListDumper(self.cls, self._tx)
 
         sd = self._tx.get_dumper(item, format.from_pq(self.format))
@@ -396,15 +388,14 @@ def _load_text(
     if data and data[0] == b"["[0]:
         if isinstance(data, memoryview):
             data = bytes(data)
-        idx = data.find(b"=")
-        if idx == -1:
+
+        if (idx := data.find(b"=")) == -1:
             raise e.DataError("malformed array: no '=' after dimension information")
         data = data[idx + 1 :]
 
     re_parse = _get_array_parse_regexp(delimiter)
     for m in re_parse.finditer(data):
-        t = m.group(1)
-        if t == b"{":
+        if (t := m.group(1)) == b"{":
             if stack:
                 stack[-1].append(a)
             stack.append(a)
index 1c0f747dae7750c0b3755eaa4884756428824234..ad4931239f80605bd65a8924a05ca320565dfd80 100644 (file)
@@ -94,8 +94,7 @@ class SequenceDumper(RecursiveDumper):
                 continue
 
             dumper = self._tx.get_dumper(item, PyFormat.from_pq(self.format))
-            ad = dumper.dump(item)
-            if ad is None:
+            if (ad := dumper.dump(item)) is None:
                 ad = b""
             elif not ad:
                 ad = b'""'
index 016e216bc3d4c63b922f1d19394fb5e55626ce3d..1fcab4b9711cec676ae92e76d3385d5e67af71e6 100644 (file)
@@ -70,8 +70,7 @@ class _BaseTimeDumper(Dumper):
         raise NotImplementedError
 
     def _get_offset(self, obj: time) -> timedelta:
-        offset = obj.utcoffset()
-        if offset is None:
+        if (offset := obj.utcoffset()) is None:
             raise DataError(
                 f"cannot calculate the offset of tzinfo '{obj.tzinfo}' without a date"
             )
@@ -236,8 +235,7 @@ class DateLoader(Loader):
 
     def __init__(self, oid: int, context: AdaptContext | None = None):
         super().__init__(oid, context)
-        ds = _get_datestyle(self.connection)
-        if ds.startswith(b"I"):  # ISO
+        if (ds := _get_datestyle(self.connection)).startswith(b"I"):  # ISO
             self._order = self._ORDER_YMD
         elif ds.startswith(b"G"):  # German
             self._order = self._ORDER_DMY
@@ -290,8 +288,7 @@ class TimeLoader(Loader):
     _re_format = re.compile(rb"^(\d+):(\d+):(\d+)(?:\.(\d+))?")
 
     def load(self, data: Buffer) -> time:
-        m = self._re_format.match(data)
-        if not m:
+        if not (m := self._re_format.match(data)):
             s = bytes(data).decode("utf8", "replace")
             raise DataError(f"can't parse time {s!r}")
 
@@ -337,8 +334,7 @@ class TimetzLoader(Loader):
     )
 
     def load(self, data: Buffer) -> time:
-        m = self._re_format.match(data)
-        if not m:
+        if not (m := self._re_format.match(data)):
             s = bytes(data).decode("utf8", "replace")
             raise DataError(f"can't parse timetz {s!r}")
 
@@ -415,9 +411,7 @@ class TimestampLoader(Loader):
 
     def __init__(self, oid: int, context: AdaptContext | None = None):
         super().__init__(oid, context)
-
-        ds = _get_datestyle(self.connection)
-        if ds.startswith(b"I"):  # ISO
+        if (ds := _get_datestyle(self.connection)).startswith(b"I"):  # ISO
             self._order = self._ORDER_YMD
         elif ds.startswith(b"G"):  # German
             self._order = self._ORDER_DMY
@@ -430,8 +424,7 @@ class TimestampLoader(Loader):
             raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
 
     def load(self, data: Buffer) -> datetime:
-        m = self._re_format.match(data)
-        if not m:
+        if not (m := self._re_format.match(data)):
             raise _get_timestamp_load_error(self.connection, data) from None
 
         if self._order == self._ORDER_YMD:
@@ -499,8 +492,7 @@ class TimestamptzLoader(Loader):
         super().__init__(oid, context)
         self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
 
-        ds = _get_datestyle(self.connection)
-        if ds.startswith(b"I"):  # ISO
+        if _get_datestyle(self.connection).startswith(b"I"):  # ISO
             self._load_method = self._load_iso
         else:
             self._load_method = self._load_notimpl
@@ -510,8 +502,7 @@ class TimestamptzLoader(Loader):
 
     @staticmethod
     def _load_iso(self: TimestamptzLoader, data: Buffer) -> datetime:
-        m = self._re_format.match(data)
-        if not m:
+        if not (m := self._re_format.match(data)):
             raise _get_timestamp_load_error(self.connection, data) from None
 
         ye, mo, da, ho, mi, se, fr, sgn, oh, om, os = m.groups()
@@ -582,10 +573,9 @@ class TimestamptzBinaryLoader(Loader):
             # timezone) we can still save the day by shifting the value by the
             # timezone offset and then replacing the timezone.
             if self._timezone:
-                utcoff = self._timezone.utcoffset(
+                if utcoff := self._timezone.utcoffset(
                     datetime.min if micros < 0 else datetime.max
-                )
-                if utcoff:
+                ):
                     usoff = 1_000_000 * int(utcoff.total_seconds())
                     try:
                         ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
@@ -624,8 +614,7 @@ class IntervalLoader(Loader):
 
     @staticmethod
     def _load_postgres(self: IntervalLoader, data: Buffer) -> timedelta:
-        m = self._re_interval.match(data)
-        if not m:
+        if not (m := self._re_interval.match(data)):
             s = bytes(data).decode("utf8", "replace")
             raise DataError(f"can't parse interval {s!r}")
 
@@ -679,19 +668,15 @@ class IntervalBinaryLoader(Loader):
 
 
 def _get_datestyle(conn: BaseConnection[Any] | None) -> bytes:
-    if conn:
-        ds = conn.pgconn.parameter_status(b"DateStyle")
-        if ds:
-            return ds
+    if conn and (ds := conn.pgconn.parameter_status(b"DateStyle")):
+        return ds
 
     return b"ISO, DMY"
 
 
 def _get_intervalstyle(conn: BaseConnection[Any] | None) -> bytes:
-    if conn:
-        ints = conn.pgconn.parameter_status(b"IntervalStyle")
-        if ints:
-            return ints
+    if conn and (ints := conn.pgconn.parameter_status(b"IntervalStyle")):
+        return ints
 
     return b"unknown"
 
@@ -705,8 +690,7 @@ def _get_timestamp_load_error(
         if not s:
             return False
 
-        ds = _get_datestyle(conn)
-        if not ds.startswith(b"P"):  # Postgres
+        if not _get_datestyle(conn).startswith(b"P"):  # Postgres
             return len(s.split()[0]) > 10  # date is first token
         else:
             return len(s.split()[-1]) > 4  # year is last token
index 7f1da6e67d3fdcc3c2fe5c2d55974880d830d4a4..e87f9d66bf87d6ac765e983e831df3ce04233e1b 100644 (file)
@@ -84,8 +84,7 @@ class HstoreLoader(RecursiveLoader):
             if m is None or m.start() != start:
                 raise e.DataError(f"error parsing hstore pair at char {start}")
             k = _re_unescape.sub(r"\1", m.group(1))
-            v = m.group(2)
-            if v is not None:
+            if (v := m.group(2)) is not None:
                 v = _re_unescape.sub(r"\1", v)
 
             rv[k] = v
index 51bb22e034fcee6444c2912fb86d6298a6d03d6f..e4fbb25c634907ec5d8889f49af4b9787449a0a7 100644 (file)
@@ -98,16 +98,14 @@ def set_json_loads(
 
 @cache
 def _make_dumper(base: type[abc.Dumper], dumps: JsonDumpsFunction) -> type[abc.Dumper]:
-    name = base.__name__
-    if not name.startswith("Custom"):
+    if not (name := base.__name__).startswith("Custom"):
         name = f"Custom{name}"
     return type(name, (base,), {"_dumps": dumps})
 
 
 @cache
 def _make_loader(base: type[Loader], loads: JsonLoadsFunction) -> type[Loader]:
-    name = base.__name__
-    if not name.startswith("Custom"):
+    if not (name := base.__name__).startswith("Custom"):
         name = f"Custom{name}"
     return type(name, (base,), {"_loads": loads})
 
@@ -120,8 +118,7 @@ class _JsonWrapper:
         self.dumps = dumps
 
     def __repr__(self) -> str:
-        sobj = repr(self.obj)
-        if len(sobj) > 40:
+        if len((sobj := repr(self.obj))) > 40:
             sobj = f"{sobj[:35]} ... ({len(sobj)} chars)"
         return f"{self.__class__.__name__}({sobj})"
 
@@ -149,8 +146,7 @@ class _JsonDumper(Dumper):
             obj = obj.obj
         else:
             dumps = self.dumps
-        data = dumps(obj)
-        if isinstance(data, str):
+        if isinstance((data := dumps(obj)), str):
             return data.encode()
         return data
 
@@ -173,8 +169,7 @@ class JsonbBinaryDumper(_JsonDumper):
     oid = _oids.JSONB_OID
 
     def dump(self, obj: Any) -> Buffer | None:
-        obj_bytes = super().dump(obj)
-        if obj_bytes is not None:
+        if (obj_bytes := super().dump(obj)) is not None:
             return b"\x01" + obj_bytes
         else:
             return None
@@ -214,8 +209,7 @@ class JsonbBinaryLoader(_JsonLoader):
     def load(self, data: Buffer) -> Any:
         if data and data[0] != 1:
             raise DataError("unknown jsonb binary format: {data[0]}")
-        data = data[1:]
-        if not isinstance(data, bytes):
+        if not isinstance((data := data[1:]), bytes):
             data = bytes(data)
         return self.loads(data)
 
index 8c1f75ec116e3682aa312bc2c0c81d0112529d29..cbfecd690e39af5ddbcd45ff8c1eb1eb72c83c82 100644 (file)
@@ -190,8 +190,7 @@ class BaseMultirangeDumper(RecursiveDumper):
         if self.cls is not Multirange:
             return self.cls
 
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             sd = self._tx.get_dumper(item, self._adapt_format)
             return (self.cls, sd.get_key(item, format))
         else:
@@ -202,8 +201,7 @@ class BaseMultirangeDumper(RecursiveDumper):
         if self.cls is not Multirange:
             return self
 
-        item = self._get_item(obj)
-        if item is None:
+        if (item := self._get_item(obj)) is None:
             return self
 
         dumper: BaseMultirangeDumper
@@ -257,8 +255,7 @@ class MultirangeDumper(BaseMultirangeDumper):
         if not obj:
             return b"{}"
 
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
         else:
             dump = fail_dump
@@ -275,8 +272,7 @@ class MultirangeBinaryDumper(BaseMultirangeDumper):
     format = Format.BINARY
 
     def dump(self, obj: Multirange[Any]) -> Buffer | None:
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
         else:
             dump = fail_dump
index 8c0e64ed1559952fbbbf22d9b37da528fb6bcb84..28fbc9aa1d4c6f2fab938360c2f82d380be14beb 100644 (file)
@@ -40,8 +40,7 @@ class _IntDumper(Dumper):
         return str(obj).encode()
 
     def quote(self, obj: Any) -> Buffer:
-        value = self.dump(obj)
-        if value is None:
+        if (value := self.dump(obj)) is None:
             return b"NULL"
         return value if obj >= 0 else b" " + value
 
@@ -63,9 +62,7 @@ class _SpecialValuesDumper(Dumper):
         return str(obj).encode()
 
     def quote(self, obj: Any) -> Buffer:
-        value = self.dump(obj)
-
-        if value is None:
+        if (value := self.dump(obj)) is None:
             return b"NULL"
         if not isinstance(value, bytes):
             value = bytes(value)
@@ -158,11 +155,10 @@ class IntDumper(Dumper):
                 return self._int2_dumper
             else:
                 return self._int4_dumper
+        elif -(2**63) <= obj < 2**63:
+            return self._int8_dumper
         else:
-            if -(2**63) <= obj < 2**63:
-                return self._int8_dumper
-            else:
-                return self._int_numeric_dumper
+            return self._int_numeric_dumper
 
 
 class Int2BinaryDumper(Int2Dumper):
@@ -453,8 +449,7 @@ def dump_decimal_to_numeric_binary(obj: Decimal) -> bytearray | bytes:
 
     # Equivalent of 0-padding left to align the py digits to the pg digits
     # but without changing the digits tuple.
-    mod = (ndigits - dscale) % DEC_DIGITS
-    if mod:
+    if mod := ((ndigits - dscale) % DEC_DIGITS):
         wi = DEC_DIGITS - mod
         ndigits += wi
 
index 6488d3d768af8b361e511f95a3d96e54473d141a..f5074829ea46e4dd298f8f0b9ca1c6db7ee07e7b 100644 (file)
@@ -185,17 +185,15 @@ class Range(Generic[T]):
                 # It doesn't seem that Python has an ABC for ordered types.
                 if x < self._lower:  # type: ignore[operator]
                     return False
-            else:
-                if x <= self._lower:  # type: ignore[operator]
-                    return False
+            elif x <= self._lower:  # type: ignore[operator]
+                return False
 
         if self._upper is not None:
             if self._bounds[1] == "]":
                 if x > self._upper:  # type: ignore[operator]
                     return False
-            else:
-                if x >= self._upper:  # type: ignore[operator]
-                    return False
+            elif x >= self._upper:  # type: ignore[operator]
+                return False
 
         return True
 
@@ -223,8 +221,7 @@ class Range(Generic[T]):
             return NotImplemented
         for attr in ("_lower", "_upper", "_bounds"):
             self_value = getattr(self, attr)
-            other_value = getattr(other, attr)
-            if self_value == other_value:
+            if self_value == (other_value := getattr(other, attr)):
                 pass
             elif self_value is None:
                 return True
@@ -294,9 +291,7 @@ class BaseRangeDumper(RecursiveDumper):
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.cls is not Range:
             return self.cls
-
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             sd = self._tx.get_dumper(item, self._adapt_format)
             return (self.cls, sd.get_key(item, format))
         else:
@@ -306,9 +301,7 @@ class BaseRangeDumper(RecursiveDumper):
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.cls is not Range:
             return self
-
-        item = self._get_item(obj)
-        if item is None:
+        if (item := self._get_item(obj)) is None:
             return self
 
         dumper: BaseRangeDumper
@@ -355,8 +348,7 @@ class RangeDumper(BaseRangeDumper):
     """
 
     def dump(self, obj: Range[Any]) -> Buffer | None:
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
         else:
             dump = fail_dump
@@ -371,8 +363,7 @@ def dump_range_text(obj: Range[Any], dump: DumpFunc) -> Buffer:
     parts: list[Buffer] = [b"[" if obj.lower_inc else b"("]
 
     def dump_item(item: Any) -> Buffer:
-        ad = dump(item)
-        if ad is None:
+        if (ad := dump(item)) is None:
             return b""
         elif not ad:
             return b'""'
@@ -402,8 +393,7 @@ class RangeBinaryDumper(BaseRangeDumper):
     format = Format.BINARY
 
     def dump(self, obj: Range[Any]) -> Buffer | None:
-        item = self._get_item(obj)
-        if item is not None:
+        if (item := self._get_item(obj)) is not None:
             dump = self._tx.get_dumper(item, self._adapt_format).dump
         else:
             dump = fail_dump
@@ -424,8 +414,7 @@ def dump_range_binary(obj: Range[Any], dump: DumpFunc) -> Buffer:
         head |= RANGE_UB_INC
 
     if obj.lower is not None:
-        data = dump(obj.lower)
-        if data is not None:
+        if (data := dump(obj.lower)) is not None:
             out += pack_len(len(data))
             out += data
         else:
@@ -434,8 +423,7 @@ def dump_range_binary(obj: Range[Any], dump: DumpFunc) -> Buffer:
         head |= RANGE_LB_INF
 
     if obj.upper is not None:
-        data = dump(obj.upper)
-        if data is not None:
+        if (data := dump(obj.upper)) is not None:
             out += pack_len(len(data))
             out += data
         else:
@@ -472,27 +460,21 @@ class RangeLoader(BaseRangeLoader[T]):
 def load_range_text(data: Buffer, load: LoadFunc) -> tuple[Range[Any], int]:
     if data == b"empty":
         return Range(empty=True), 5
-
-    m = _re_range.match(data)
-    if m is None:
+    if (m := _re_range.match(data)) is None:
         raise e.DataError(
             f"failed to parse range: '{bytes(data).decode('utf8', 'replace')}'"
         )
 
     lower = None
-    item = m.group(3)
-    if item is None:
-        item = m.group(2)
-        if item is not None:
+    if (item := m.group(3)) is None:
+        if (item := m.group(2)) is not None:
             lower = load(_re_undouble.sub(rb"\1", item))
     else:
         lower = load(item)
 
     upper = None
-    item = m.group(5)
-    if item is None:
-        item = m.group(4)
-        if item is not None:
+    if (item := m.group(5)) is None:
+        if (item := m.group(4)) is not None:
             upper = load(_re_undouble.sub(rb"\1", item))
     else:
         upper = load(item)
@@ -530,8 +512,7 @@ class RangeBinaryLoader(BaseRangeLoader[T]):
 
 
 def load_range_binary(data: Buffer, load: LoadFunc) -> Range[Any]:
-    head = data[0]
-    if head & RANGE_EMPTY:
+    if (head := data[0]) & RANGE_EMPTY:
         return Range(empty=True)
 
     lb = "[" if head & RANGE_LB_INC else "("
index 262ac658c42f54b0c7a44b6a36084b127fd80375..ef4fceb284380b75cacb0f0f077f448f29d3712d 100644 (file)
@@ -138,8 +138,7 @@ class BytesDumper(Dumper):
         return self._esc.escape_bytea(obj)
 
     def quote(self, obj: Buffer) -> Buffer:
-        escaped = self.dump(obj)
-        if escaped is None:
+        if (escaped := self.dump(obj)) is None:
             return b"NULL"
 
         # We cannot use the base quoting because escape_bytea already returns
index 05994b5d57276288fa67344a72353713183b42e5..7df0b960d6a1021331f697c676d4f555e23cc6e5 100644 (file)
@@ -54,8 +54,7 @@ def wait_selector(gen: PQGen[RV], fileno: int, interval: float | None = None) ->
         with DefaultSelector() as sel:
             sel.register(fileno, s)
             while True:
-                rlist = sel.select(timeout=interval)
-                if not rlist:
+                if not (rlist := sel.select(timeout=interval)):
                     gen.send(READY_NONE)
                     continue
 
@@ -89,8 +88,7 @@ def wait_conn(gen: PQGenConn[RV], interval: float | None = None) -> RV:
         with DefaultSelector() as sel:
             sel.register(fileno, s)
             while True:
-                rlist = sel.select(timeout=interval)
-                if not rlist:
+                if not (rlist := sel.select(timeout=interval)):
                     gen.send(READY_NONE)
                     continue
 
@@ -134,7 +132,7 @@ async def wait_async(gen: PQGen[RV], fileno: int, interval: float | None = None)
         while True:
             reader = s & WAIT_R
             writer = s & WAIT_W
-            if not reader and not writer:
+            if not (reader or writer):
                 raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
             ready = 0
@@ -195,7 +193,7 @@ async def wait_conn_async(gen: PQGenConn[RV], interval: float | None = None) ->
         while True:
             reader = s & WAIT_R
             writer = s & WAIT_W
-            if not reader and not writer:
+            if not (reader or writer):
                 raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
             ready = 0  # type: ignore[assignment]
@@ -297,8 +295,7 @@ def wait_epoll(gen: PQGen[RV], fileno: int, interval: float | None = None) -> RV
             evmask = _epoll_evmasks[s]
             epoll.register(fileno, evmask)
             while True:
-                fileevs = epoll.poll(interval)
-                if not fileevs:
+                if not (fileevs := epoll.poll(interval)):
                     gen.send(READY_NONE)
                     continue
                 ev = fileevs[0][1]
@@ -344,8 +341,7 @@ def wait_poll(gen: PQGen[RV], fileno: int, interval: float | None = None) -> RV:
         evmask = _poll_evmasks[s]
         poll.register(fileno, evmask)
         while True:
-            fileevs = poll.poll(interval)
-            if not fileevs:
+            if not (fileevs := poll.poll(interval)):
                 gen.send(READY_NONE)
                 continue
 
@@ -374,8 +370,7 @@ def _is_select_patched() -> bool:
     Currently supported: gevent.
     """
     # If not imported, don't import it.
-    m = sys.modules.get("gevent.monkey")
-    if m:
+    if m := sys.modules.get("gevent.monkey"):
         try:
             if m.is_module_patched("select"):
                 return True
index 9f9984f227accfd47d2fa1b09d72e7d158ba4768..963c8bd3eed9a9789a78424efc397cc34cc22094 100644 (file)
@@ -145,8 +145,7 @@ class BasePool:
                 raise PoolClosed(f"the pool {self.name!r} is not open yet")
 
     def _check_pool_putconn(self, conn: BaseConnection[Any]) -> None:
-        pool = getattr(conn, "_pool", None)
-        if pool is self:
+        if (pool := getattr(conn, "_pool", None)) is self:
             return
 
         if pool:
index 37b7bc707d95447c3b2555538aa2a63b74c4e411..0370702ef62555bef397868443405087413d9c25 100644 (file)
@@ -162,8 +162,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool[CT]):
             while self._waiting:
                 # If there is a client waiting (which is still waiting and
                 # hasn't timed out), give it the connection and notify it.
-                pos = self._waiting.popleft()
-                if pos.set(conn):
+                if self._waiting.popleft().set(conn):
                     break
             else:
                 # No client waiting for a connection: close the connection
index b73504824b3b3aad707bc090fd09417d1a78f77a..53a2201c9c17ed337d13c7b5304bf2588888c93c 100644 (file)
@@ -158,13 +158,11 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool[ACT])
             while self._waiting:
                 # If there is a client waiting (which is still waiting and
                 # hasn't timed out), give it the connection and notify it.
-                pos = self._waiting.popleft()
-                if await pos.set(conn):
+                if await self._waiting.popleft().set(conn):
                     break
             else:
                 # No client waiting for a connection: close the connection
                 await conn.close()
-
                 # If we have been asked to wait for pool init, notify the
                 # waiter if the pool is ready.
                 if self._pool_full_event:
index 115894486c5ebe1755da33ebc2bebd0e2b0fe60f..24090ca9793d7e4c723e4c463c5264ca240cc077 100644 (file)
@@ -232,8 +232,7 @@ class ConnectionPool(Generic[CT], BasePool):
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         with self._lock:
-            conn = self._get_ready_connection(timeout)
-            if not conn:
+            if not (conn := self._get_ready_connection(timeout)):
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
                 pos: WaitingClient[CT] = WaitingClient()
@@ -571,9 +570,7 @@ class ConnectionPool(Generic[CT], BasePool):
         StopWorker is received.
         """
         while True:
-            task = q.get()
-
-            if isinstance(task, StopWorker):
+            if isinstance((task := q.get()), StopWorker):
                 logger.debug("terminating working task %s", current_thread_name())
                 return
 
@@ -606,8 +603,7 @@ class ConnectionPool(Generic[CT], BasePool):
 
         if self._configure:
             self._configure(conn)
-            status = conn.pgconn.transaction_status
-            if status != TransactionStatus.IDLE:
+            if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
                 sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
                     f"connection left in status {sname} by configure function {self._configure}: discarded"
@@ -734,8 +730,8 @@ class ConnectionPool(Generic[CT], BasePool):
             while self._waiting:
                 # If there is a client waiting (which is still waiting and
                 # hasn't timed out), give it the connection and notify it.
-                pos = self._waiting.popleft()
-                if pos.set(conn):
+
+                if self._waiting.popleft().set(conn):
                     break
             else:
                 # No client waiting for a connection: put it back into the pool
@@ -749,8 +745,7 @@ class ConnectionPool(Generic[CT], BasePool):
         """
         Bring a connection to IDLE state or close it.
         """
-        status = conn.pgconn.transaction_status
-        if status == TransactionStatus.IDLE:
+        if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE:
             pass
         elif status == TransactionStatus.UNKNOWN:
             # Connection closed
@@ -776,8 +771,7 @@ class ConnectionPool(Generic[CT], BasePool):
         if self._reset:
             try:
                 self._reset(conn)
-                status = conn.pgconn.transaction_status
-                if status != TransactionStatus.IDLE:
+                if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
                     sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function {self._reset}: discarded"
index fc52a6fe40053f4b8b2b2ab94f62e1a38f6f1903..af725a8e46f3477e56a581fefacb8ae2a3611197 100644 (file)
@@ -260,8 +260,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         # Critical section: decide here if there's a connection ready
         # or if the client needs to wait.
         async with self._lock:
-            conn = await self._get_ready_connection(timeout)
-            if not conn:
+            if not (conn := (await self._get_ready_connection(timeout))):
                 # No connection available: put the client in the waiting queue
                 t0 = monotonic()
                 pos: WaitingClient[ACT] = WaitingClient()
@@ -623,9 +622,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         StopWorker is received.
         """
         while True:
-            task = await q.get()
-
-            if isinstance(task, StopWorker):
+            if isinstance((task := (await q.get())), StopWorker):
                 logger.debug("terminating working task %s", current_task_name())
                 return
 
@@ -658,8 +655,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
 
         if self._configure:
             await self._configure(conn)
-            status = conn.pgconn.transaction_status
-            if status != TransactionStatus.IDLE:
+            if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
                 sname = TransactionStatus(status).name
                 raise e.ProgrammingError(
                     f"connection left in status {sname} by configure function"
@@ -789,13 +785,12 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
             while self._waiting:
                 # If there is a client waiting (which is still waiting and
                 # hasn't timed out), give it the connection and notify it.
-                pos = self._waiting.popleft()
-                if await pos.set(conn):
+
+                if await self._waiting.popleft().set(conn):
                     break
             else:
                 # No client waiting for a connection: put it back into the pool
                 self._pool.append(conn)
-
                 # If we have been asked to wait for pool init, notify the
                 # waiter if the pool is full.
                 if self._pool_full_event and len(self._pool) >= self._min_size:
@@ -805,10 +800,8 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         """
         Bring a connection to IDLE state or close it.
         """
-        status = conn.pgconn.transaction_status
-        if status == TransactionStatus.IDLE:
+        if (status := conn.pgconn.transaction_status) == TransactionStatus.IDLE:
             pass
-
         elif status == TransactionStatus.UNKNOWN:
             # Connection closed
             return
@@ -835,8 +828,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if self._reset:
             try:
                 await self._reset(conn)
-                status = conn.pgconn.transaction_status
-                if status != TransactionStatus.IDLE:
+                if (status := conn.pgconn.transaction_status) != TransactionStatus.IDLE:
                     sname = TransactionStatus(status).name
                     raise e.ProgrammingError(
                         f"connection left in status {sname} by reset function"
index f41179858f48a5d2cd4fd789b15d49637fc70e72..6fc4e2b91b74ee9e607de2c9569c50435b5a3b6a 100644 (file)
@@ -66,8 +66,7 @@ class Scheduler:
         while True:
             with self._lock:
                 now = monotonic()
-                task = q[0] if q else None
-                if task:
+                if task := (q[0] if q else None):
                     if task.time <= now:
                         heappop(q)
                     else:
index 86f59d03630e4db3b6283e1cd433221830a482cf..3db6533d6c85bf49650444ffc0749bbc77035e5d 100644 (file)
@@ -62,8 +62,7 @@ class AsyncScheduler:
         while True:
             async with self._lock:
                 now = monotonic()
-                task = q[0] if q else None
-                if task:
+                if task := (q[0] if q else None):
                     if task.time <= now:
                         heappop(q)
                     else:
index 4a1d4de4111d8d49e25e79393e5945c978bddedf..e25ce52184ca73c333d9e73e972f812afab29fc4 100644 (file)
@@ -49,11 +49,8 @@ def pytest_addoption(parser):
 
 
 def pytest_report_header(config):
-    rv = []
-
-    rv.append(f"default selector: {selectors.DefaultSelector.__name__}")
-    loop = config.getoption("--loop")
-    if loop != "default":
+    rv = [f"default selector: {selectors.DefaultSelector.__name__}"]
+    if (loop := config.getoption("--loop")) != "default":
         rv.append(f"asyncio loop: {loop}")
 
     return rv
@@ -65,8 +62,7 @@ def pytest_sessionstart(session):
     # In case of segfault, pytest doesn't get a chance to write failed tests
     # in the cache. As a consequence, retries would find no test failed and
     # assume that all tests passed in the previous run, making the whole test pass.
-    cache = session.config.cache
-    if cache.get("segfault", False):
+    if (cache := session.config.cache).get("segfault", False):
         session.warn(Warning("Previous run resulted in segfault! Not running any test"))
         session.warn(Warning("(delete '.pytest_cache/v/segfault' to clear this state)"))
         raise session.Failed
index b41de3023c510fb2e782a01e186633bd14cf9b1e..0739a27c8e8cbc3fdcef8f65ab36c6bde280a2e5 100644 (file)
@@ -69,8 +69,7 @@ def test_changefeed(conn_cls, dsn, conn, testfeed, fmt_out):
     # We often find the record with {"after": null} at least another time
     # in the queue. Let's tolerate an extra one.
     for i in range(2):
-        row = q.get()
-        if row is None:
+        if (row := q.get()) is None:
             break
         assert json.loads(row.value)["after"] is None, json
     else:
index 4c803fa87c17b74d102666f1a10b911f63c5e920..c5a1986aeb7a7f99bb88a78dd200e2e13d086822 100644 (file)
@@ -70,8 +70,7 @@ async def test_changefeed(aconn_cls, dsn, aconn, testfeed, fmt_out):
     # We often find the record with {"after": null} at least another time
     # in the queue. Let's tolerate an extra one.
     for i in range(2):
-        row = await q.get()
-        if row is None:
+        if (row := (await q.get())) is None:
             break
         assert json.loads(row.value)["after"] is None, json
     else:
index f1679f6f08aaeb513d623f5cc0bdec8671473502..60d721c2155a28494f34c2159d23617fa3cca04e 100644 (file)
@@ -38,12 +38,10 @@ def check_crdb_version(got, mark):
     pred = VersionCheck.parse(spec)
     pred.whose = "CockroachDB"
 
-    msg = pred.get_skip_message(got)
-    if not msg:
+    if not (msg := pred.get_skip_message(got)):
         return None
 
-    reason = crdb_skip_message(reason)
-    if reason:
+    if reason := crdb_skip_message(reason):
         msg = f"{msg}: {reason}"
 
     return msg
index 342e5139251152c28121eb16b694f6337855e19d..7b1879495cd7ebf59900e26656bdd286dd4bdfec 100644 (file)
@@ -46,8 +46,7 @@ def pytest_addoption(parser):
 
 
 def pytest_report_header(config):
-    dsn = config.getoption("--test-dsn")
-    if dsn is None:
+    if (dsn := config.getoption("--test-dsn")) is None:
         return []
 
     try:
@@ -56,9 +55,7 @@ def pytest_report_header(config):
     except Exception as ex:
         server_version = f"unknown ({ex})"
 
-    return [
-        f"Server version: {server_version}",
-    ]
+    return [f"Server version: {server_version}"]
 
 
 def pytest_collection_modifyitems(items):
@@ -93,8 +90,7 @@ def session_dsn(request):
     """
     Return the dsn used to connect to the `--test-dsn` database (session-wide).
     """
-    dsn = request.config.getoption("--test-dsn")
-    if dsn is None:
+    if (dsn := request.config.getoption("--test-dsn")) is None:
         pytest.skip("skipping test as no --test-dsn")
 
     warm_up_database(dsn)
@@ -130,8 +126,7 @@ def tracefile(request):
     """Open and yield a file for libpq client/server communication traces if
     --pq-tracefile option is set.
     """
-    tracefile = request.config.getoption("--pq-trace")
-    if not tracefile:
+    if not (tracefile := request.config.getoption("--pq-trace")):
         yield None
         return
 
@@ -191,9 +186,7 @@ def pgconn_debug(request):
 def pgconn(dsn, request, tracefile):
     """Return a PGconn connection open to `--test-dsn`."""
     check_connection_version(request.node)
-
-    conn = pq.PGconn.connect(dsn.encode())
-    if conn.status != pq.ConnStatus.OK:
+    if (conn := pq.PGconn.connect(dsn.encode())).status != pq.ConnStatus.OK:
         pytest.fail(f"bad connection: {conn.get_error_message()}")
 
     with maybe_trace(conn, tracefile, request.function):
@@ -298,8 +291,7 @@ def patch_exec(conn, monkeypatch):
     L = ListPopAll()
 
     def _exec_command(command, *args, **kwargs):
-        cmdcopy = command
-        if isinstance(cmdcopy, bytes):
+        if isinstance((cmdcopy := command), bytes):
             cmdcopy = cmdcopy.decode(conn.info.encoding)
         elif isinstance(cmdcopy, sql.Composable):
             cmdcopy = cmdcopy.as_string(conn)
@@ -330,15 +322,12 @@ def check_connection_version(node):
     for mark in node.iter_markers():
         if mark.name == "pg":
             assert len(mark.args) == 1
-            msg = check_postgres_version(pg_version, mark.args[0])
-            if msg:
+            if msg := check_postgres_version(pg_version, mark.args[0]):
                 pytest.skip(msg)
-
         elif mark.name in ("crdb", "crdb_skip"):
             from .fix_crdb import check_crdb_version
 
-            msg = check_crdb_version(crdb_version, mark)
-            if msg:
+            if msg := check_crdb_version(crdb_version, mark):
                 pytest.skip(msg)
 
 
@@ -366,12 +355,9 @@ def warm_up_database(dsn: str) -> None:
     try:
         with psycopg.connect(dsn, connect_timeout=10) as conn:
             conn.execute("select 1")
-
             pg_version = conn.info.server_version
-
             crdb_version = None
-            param = conn.info.parameter_status("crdb_version")
-            if param:
+            if param := conn.info.parameter_status("crdb_version"):
                 from psycopg.crdb import CrdbConnectionInfo
 
                 crdb_version = CrdbConnectionInfo.parse_crdb_version(param)
index d4eee38753e35e83c860f276b98d71131f18e76f..9f6a8c84790c74e9a773cb8998c1a53539932ffd 100644 (file)
@@ -157,8 +157,7 @@ class Faker:
                     try:
                         cur.execute(self._insert_field_stmt(j), (val,))
                     except psycopg.DatabaseError as e:
-                        r = repr(val)
-                        if len(r) > 200:
+                        if len((r := repr(val))) > 200:
                             r = f"{r[:200]}... ({len(r)} chars)"
                         raise Exception(
                             f"value {r!r} at record {i} column0 {j} failed insert: {e}"
@@ -182,8 +181,7 @@ class Faker:
                     try:
                         await acur.execute(self._insert_field_stmt(j), (val,))
                     except psycopg.DatabaseError as e:
-                        r = repr(val)
-                        if len(r) > 200:
+                        if len((r := repr(val))) > 200:
                             r = f"{r[:200]}... ({len(r)} chars)"
                         raise Exception(
                             f"value {r!r} at record {i} column0 {j} failed insert: {e}"
@@ -201,8 +199,7 @@ class Faker:
     def choose_schema(self, ncols=20):
         schema: list[tuple[type, ...] | type] = []
         while len(schema) < ncols:
-            s = self.make_schema(choice(self.types))
-            if s is not None:
+            if (s := self.make_schema(choice(self.types))) is not None:
                 schema.append(s)
         self.schema = schema
         return schema
@@ -273,8 +270,7 @@ class Faker:
         except KeyError:
             pass
 
-        meth = self._get_method("make", cls)
-        if meth:
+        if meth := self._get_method("make", cls):
             self._makers[cls] = meth
             return meth
         else:
@@ -292,9 +288,7 @@ class Faker:
 
         parts = name.split(".")
         for i in range(len(parts) - 1, -1, -1):
-            mname = f"{prefix}_{'_'.join(parts[-(i + 1) :])}"
-            meth = getattr(self, mname, None)
-            if meth:
+            if meth := getattr(self, f"{prefix}_{'_'.join(parts[-(i + 1):])}", None):
                 return meth
 
         return None
@@ -306,8 +300,7 @@ class Faker:
     def example(self, spec):
         # A good representative of the object - no degenerate case
         cls = spec if isinstance(spec, type) else spec[0]
-        meth = self._get_method("example", cls)
-        if meth:
+        if meth := self._get_method("example", cls):
             return meth(spec)
         else:
             return self.make(spec)
@@ -422,11 +415,10 @@ class Faker:
     def match_float(self, spec, got, want, rel=None):
         if got is not None and isnan(got):
             assert isnan(want)
+        elif rel or self._server_rounds():
+            assert got == pytest.approx(want, rel=rel)
         else:
-            if rel or self._server_rounds():
-                assert got == pytest.approx(want, rel=rel)
-            else:
-                assert got == want
+            assert got == want
 
     def _server_rounds(self):
         """Return True if the connected server perform float rounding"""
@@ -509,21 +501,16 @@ class Faker:
 
     def schema_list(self, cls):
         while True:
-            scls = choice(self.types)
-            if scls is cls:
-                continue
-            if scls is float:
-                # TODO: float lists are currently adapted as decimal.
-                # There may be rounding errors or problems with inf.
+            # TODO: float lists are currently adapted as decimal.
+            # There may be rounding errors or problems with inf.
+            if (scls := choice(self.types)) is cls or scls is float:
                 continue
 
             # CRDB doesn't support arrays of json
             # https://github.com/cockroachdb/cockroach/issues/23468
             if self.conn.info.vendor == "CockroachDB" and scls in (Json, Jsonb):
                 continue
-
-            schema = self.make_schema(scls)
-            if schema is not None:
+            if (schema := self.make_schema(scls)) is not None:
                 break
 
         return (cls, schema)
@@ -583,8 +570,7 @@ class Faker:
 
         out: list[Range[Any]] = []
         for i in range(length):
-            r = self.make_Range((Range, spec[1]), **kwargs)
-            if r.isempty:
+            if (r := self.make_Range((Range, spec[1]), **kwargs)).isempty:
                 continue
             for r2 in out:
                 if overlap(r, r2):
@@ -839,8 +825,7 @@ class Faker:
         rec_types = [list, dict]
         scal_types = [type(None), int, JsonFloat, bool, str]
         if random() < container_chance:
-            cls = choice(rec_types)
-            if cls is list:
+            if (cls := choice(rec_types)) is list:
                 return [
                     self._make_json(container_chance=container_chance / 2.0)
                     for i in range(randrange(self.json_max_length))
index 1cff7e18bd6f146a95b66f6f5708da94fab1e4fb..9fbf3ce82f4342ac31bb2a9223a3b6ac16a54765 100644 (file)
@@ -41,8 +41,8 @@ def pytest_configure(config):
 def pytest_runtest_setup(item):
     for m in item.iter_markers(name="libpq"):
         assert len(m.args) == 1
-        msg = check_libpq_version(pq.version(), m.args[0])
-        if msg:
+
+        if msg := check_libpq_version(pq.version(), m.args[0]):
             pytest.skip(msg)
 
 
@@ -81,8 +81,7 @@ def setpgenv(monkeypatch):
 
 @pytest.fixture
 def trace(libpq):
-    pqver = pq.__build_version__
-    if pqver < 140000:
+    if (pqver := pq.__build_version__) < 140000:
         pytest.skip(f"trace not available on libpq {pqver}")
     if sys.platform != "linux":
         pytest.skip(f"trace not available on {sys.platform}")
index dbc03a9f4231809ac3fe12be7a96a548862b697b..ba8575e994430575e9afa0316f176852f21732c7 100644 (file)
@@ -73,8 +73,8 @@ class Proxy:
             return
 
         logging.info("starting proxy")
-        pproxy = which("pproxy")
-        if not pproxy:
+
+        if not (pproxy := which("pproxy")):
             raise ValueError("pproxy program not found")
         cmdline = [pproxy, "--reuse"]
         cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
index aff6ccdf9fae8966f0fddd46883c0145d88cb715..404239fb4824280a6fa0b1215c9f9b7f7039c3e9 100644 (file)
@@ -24,8 +24,7 @@ def test_send_query(pgconn):
     # send loop
     waited_on_send = 0
     while True:
-        f = pgconn.flush()
-        if f == 0:
+        if pgconn.flush() == 0:
             break
 
         waited_on_send += 1
@@ -48,8 +47,8 @@ def test_send_query(pgconn):
         if pgconn.is_busy():
             select([pgconn.socket], [], [])
             continue
-        res = pgconn.get_result()
-        if res is None:
+
+        if (res := pgconn.get_result()) is None:
             break
         assert res.status == pq.ExecStatus.TUPLES_OK
         results.append(res)
index 02953b4de6b0d8ab52f13ca1daf3c25e05351e7f..657cab9b8a541946e93362e59d56aa50253e5aed 100644 (file)
@@ -30,8 +30,8 @@ def wait(
     poll = getattr(conn, poll_method)
     while True:
         assert conn.status != pq.ConnStatus.BAD, conn.error_message
-        rv = poll()
-        if rv == return_on:
+
+        if (rv := poll()) == return_on:
             return
         elif rv == pq.PollingStatus.READING:
             select([conn.socket], [], [], timeout)
@@ -362,11 +362,10 @@ def test_used_password(pgconn, dsn, monkeypatch):
     # so it may be that has_password is false but still a password was
     # requested by the server and passed by libpq.
     info = pq.Conninfo.parse(dsn.encode())
-    has_password = (
-        "PGPASSWORD" in os.environ
-        or [i for i in info if i.keyword == b"password"][0].val is not None
-    )
-    if has_password:
+
+    if "PGPASSWORD" in os.environ:
+        assert pgconn.used_password
+    if [i for i in info if i.keyword == b"password"][0].val is not None:
         assert pgconn.used_password
 
     pgconn.finish()
index 7303fa9679e2d57bf90a0973d2d85b00bb11a1d9..8157a329368d45613cfe5938be84290e07261438 100755 (executable)
@@ -136,9 +136,7 @@ def parse_cmdline() -> Namespace:
         default=logging.INFO,
     )
 
-    args = parser.parse_args()
-
-    if args.writer:
+    if (args := parser.parse_args()).writer:
         try:
             getattr(psycopg.copy, args.writer)
         except AttributeError:
index 5af239dddc9613fdf4509e4be9bbe38a2a174214..8b26dc72e2f7ae7b02946f3d80950c14d12d49d5 100644 (file)
@@ -93,17 +93,13 @@ class LoggingPGconn:
         self._logger.info("sent prepared '%s' with %s", name.decode(), param_values)
 
     def send_prepare(
-        self,
-        name: bytes,
-        command: bytes,
-        param_types: Sequence[int] | None = None,
+        self, name: bytes, command: bytes, param_types: Sequence[int] | None = None
     ) -> None:
         self._pgconn.send_prepare(name, command, param_types)
         self._logger.info("prepare %s as '%s'", command.decode(), name.decode())
 
     def get_result(self) -> pq.abc.PGresult | None:
-        r = self._pgconn.get_result()
-        if r is not None:
+        if (r := self._pgconn.get_result()) is not None:
             self._logger.info("got %s result", pq.ExecStatus(r.status).name)
         return r
 
@@ -190,8 +186,7 @@ def pipeline_demo_pq(rows_to_send: int, logger: logging.Logger) -> None:
     ):
         while results_queue:
             fetched = waiting.wait(
-                pipeline_communicate(pgconn, commands),
-                pgconn.socket,
+                pipeline_communicate(pgconn, commands), pgconn.socket
             )
             assert not commands, commands
             for results in fetched:
@@ -213,8 +208,7 @@ async def pipeline_demo_pq_async(rows_to_send: int, logger: logging.Logger) -> N
     ):
         while results_queue:
             fetched = await waiting.wait_async(
-                pipeline_communicate(pgconn, commands),
-                pgconn.socket,
+                pipeline_communicate(pgconn, commands), pgconn.socket
             )
             assert not commands, commands
             for results in fetched:
index 334433e57660fb9ce891a9bbbdd891958f328014..82d274fd8eff0ca5956e4497dad6a3317ad43070 100644 (file)
@@ -21,8 +21,7 @@ from psycopg.rows import Row
 
 
 def main() -> None:
-    opt = parse_cmdline()
-    if opt.loglevel:
+    if (opt := parse_cmdline()).loglevel:
         loglevel = getattr(logging, opt.loglevel.upper())
         logging.basicConfig(
             level=loglevel, format="%(asctime)s %(levelname)s %(message)s"
@@ -110,8 +109,8 @@ class DelayedConnection(psycopg.Connection[Row]):
         t0 = time.time()
         conn = super().connect(conninfo, **kwargs)
         t1 = time.time()
-        wait = max(0.0, conn_delay - (t1 - t0))
-        if wait:
+
+        if wait := max(0.0, conn_delay - (t1 - t0)):
             time.sleep(wait)
         return conn
 
index 8115ceb221313f0911de3ad286e4800965aa0475..764375715ba022c36b698ef50bc8bb63b87702b7 100644 (file)
@@ -51,13 +51,7 @@ def test_quote(data, result):
     assert dumper.quote(data) == result
 
 
-@pytest.mark.parametrize(
-    "data, result",
-    [
-        ("hello", b"'hello'"),
-        ("", b"NULL"),
-    ],
-)
+@pytest.mark.parametrize("data, result", [("hello", b"'hello'"), ("", b"NULL")])
 def test_quote_none(data, result, global_adapters):
     psycopg.adapters.register_dumper(str, StrNoneDumper)
     t = Transformer()
@@ -424,8 +418,8 @@ def test_optimised_adapters():
     for n in dir(_psycopg):
         if n.startswith("_") or n in ("CDumper", "CLoader"):
             continue
-        obj = getattr(_psycopg, n)
-        if not isinstance(obj, type):
+
+        if not isinstance((obj := getattr(_psycopg, n)), type):
             continue
         if not issubclass(obj, (_psycopg.CDumper, _psycopg.CLoader)):
             continue
@@ -449,12 +443,10 @@ def test_optimised_adapters():
 
     # Check that every optimised adapter is the optimised version of a Py one
     for n in dir(psycopg.types):
-        mod = getattr(psycopg.types, n)
-        if not isinstance(mod, ModuleType):
+        if not isinstance((mod := getattr(psycopg.types, n)), ModuleType):
             continue
         for n1 in dir(mod):
-            obj = getattr(mod, n1)
-            if not isinstance(obj, type):
+            if not isinstance((obj := getattr(mod, n1)), type):
                 continue
             if not issubclass(obj, (Dumper, Loader)):
                 continue
index 28889333a9637fd808b071fa25c5cc6baad245ec..030965580fc244837c806ced6d8c8f989add96e0 100644 (file)
@@ -323,8 +323,7 @@ with psycopg.connect({dsn!r}, application_name={APPNAME!r}) as conn:
     def run_process():
         nonlocal proc
         proc = sp.Popen(
-            [sys.executable, "-s", "-c", script],
-            creationflags=creationflags,
+            [sys.executable, "-s", "-c", script], creationflags=creationflags
         )
         proc.communicate()
 
@@ -335,8 +334,8 @@ with psycopg.connect({dsn!r}, application_name={APPNAME!r}) as conn:
         cur = conn.execute(
             "select pid from pg_stat_activity where application_name = %s", (APPNAME,)
         )
-        rec = cur.fetchone()
-        if rec:
+
+        if rec := cur.fetchone():
             pid = rec[0]
             break
         time.sleep(0.1)
index 65bc2539b71e55644128fa904bce042b9e386ee9..63b7b3f81d06cd064767ceaf21db792a570619f0 100644 (file)
@@ -248,8 +248,8 @@ asyncio.run(main())
         cur = conn.execute(
             "select pid from pg_stat_activity where application_name = %s", (APPNAME,)
         )
-        rec = cur.fetchone()
-        if rec:
+
+        if rec := cur.fetchone():
             pid = rec[0]
             break
         time.sleep(0.1)
index c2e6598fcf6d58ec91acabc29a86150358e0ac37..6ae5d8e05839deae8348e4cf759537b005efdc68 100644 (file)
@@ -145,8 +145,7 @@ def test_copy_out_allchars(conn, format):
     with cur.copy(query) as copy:
         copy.set_types(["text"])
         while True:
-            row = copy.read_row()
-            if not row:
+            if not (row := copy.read_row()):
                 break
             assert len(row) == 1
             rows.append(row[0])
@@ -160,8 +159,7 @@ def test_read_row_notypes(conn, format):
     with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         rows = []
         while True:
-            row = copy.read_row()
-            if not row:
+            if not (row := copy.read_row()):
                 break
             rows.append(row)
 
@@ -736,15 +734,13 @@ def test_copy_to_leaks(conn_cls, dsn, faker, fmt, set_types, method, gc):
 
                     if method == "read":
                         while True:
-                            tmp = copy.read()
-                            if not tmp:
+                            if not copy.read():
                                 break
                     elif method == "iter":
                         list(copy)
                     elif method == "row":
                         while True:
-                            tmp = copy.read_row()
-                            if tmp is None:
+                            if copy.read_row() is None:
                                 break
                     elif method == "rows":
                         list(copy.rows())
@@ -864,8 +860,7 @@ class DataGenerator:
     def blocks(self):
         f = self.file()
         while True:
-            block = f.read(self.block_size)
-            if not block:
+            if not (block := f.read(self.block_size)):
                 break
             yield block
 
@@ -880,8 +875,7 @@ class DataGenerator:
     def sha(self, f):
         m = hashlib.sha256()
         while True:
-            block = f.read()
-            if not block:
+            if not (block := f.read()):
                 break
             if isinstance(block, str):
                 block = block.encode()
index 1168bc31b3b478af1c11eb52a5b28babc235d9db..72a838e5e23654fb7b337e578862bb5e686fdbb7 100644 (file)
@@ -149,8 +149,7 @@ async def test_copy_out_allchars(aconn, format):
     async with cur.copy(query) as copy:
         copy.set_types(["text"])
         while True:
-            row = await copy.read_row()
-            if not row:
+            if not (row := (await copy.read_row())):
                 break
             assert len(row) == 1
             rows.append(row[0])
@@ -166,8 +165,7 @@ async def test_read_row_notypes(aconn, format):
     ) as copy:
         rows = []
         while True:
-            row = await copy.read_row()
-            if not row:
+            if not (row := (await copy.read_row())):
                 break
             rows.append(row)
 
@@ -752,15 +750,13 @@ async def test_copy_to_leaks(aconn_cls, dsn, faker, fmt, set_types, method, gc):
 
                     if method == "read":
                         while True:
-                            tmp = await copy.read()
-                            if not tmp:
+                            if not (await copy.read()):
                                 break
                     elif method == "iter":
                         await alist(copy)
                     elif method == "row":
                         while True:
-                            tmp = await copy.read_row()
-                            if tmp is None:
+                            if (await copy.read_row()) is None:
                                 break
                     elif method == "rows":
                         await alist(copy.rows())
@@ -879,8 +875,7 @@ class DataGenerator:
     def blocks(self):
         f = self.file()
         while True:
-            block = f.read(self.block_size)
-            if not block:
+            if not (block := f.read(self.block_size)):
                 break
             yield block
 
@@ -895,8 +890,7 @@ class DataGenerator:
     def sha(self, f):
         m = hashlib.sha256()
         while True:
-            block = f.read()
-            if not block:
+            if not (block := f.read()):
                 break
             if isinstance(block, str):
                 block = block.encode()
index e31791e96d3612444520e229f00b55ddafc189dd..cf5551ca0f95551d76c2172be5960d0802e73c73 100644 (file)
@@ -86,13 +86,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
 
                 if fetch == "one":
                     while True:
-                        tmp = cur.fetchone()
-                        if tmp is None:
+                        if cur.fetchone() is None:
                             break
                 elif fetch == "many":
                     while True:
-                        tmp = cur.fetchmany(3)
-                        if not tmp:
+                        if not cur.fetchmany(3):
                             break
                 elif fetch == "all":
                     cur.fetchall()
index f54b31aa75d03b5bea5305d82242b9f6c0abbdf6..fa6b70175da19a23b72a6feb6e5efc7c4b4be417 100644 (file)
@@ -87,13 +87,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc)
 
                 if fetch == "one":
                     while True:
-                        tmp = await cur.fetchone()
-                        if tmp is None:
+                        if (await cur.fetchone()) is None:
                             break
                 elif fetch == "many":
                     while True:
-                        tmp = await cur.fetchmany(3)
-                        if not tmp:
+                        if not (await cur.fetchmany(3)):
                             break
                 elif fetch == "all":
                     await cur.fetchall()
index 70a4d55db4ea3fe9d2aa5f8139f7b1eba1db9943..fd87cf551979989f6675c7b8620b682dbde63d33 100644 (file)
@@ -96,13 +96,11 @@ def test_leak(conn_cls, dsn, faker, fetch, row_factory, gc):
 
                 if fetch == "one":
                     while True:
-                        tmp = cur.fetchone()
-                        if tmp is None:
+                        if cur.fetchone() is None:
                             break
                 elif fetch == "many":
                     while True:
-                        tmp = cur.fetchmany(3)
-                        if not tmp:
+                        if not cur.fetchmany(3):
                             break
                 elif fetch == "all":
                     cur.fetchall()
index 753c954348ce4a77ff35087843e703b3a0cdccf6..a580f0d6662326fb4a7a2e293dc4fd77424e8cb7 100644 (file)
@@ -97,13 +97,11 @@ async def test_leak(aconn_cls, dsn, faker, fetch, row_factory, gc):
 
                 if fetch == "one":
                     while True:
-                        tmp = await cur.fetchone()
-                        if tmp is None:
+                        if (await cur.fetchone()) is None:
                             break
                 elif fetch == "many":
                     while True:
-                        tmp = await cur.fetchmany(3)
-                        if not tmp:
+                        if not (await cur.fetchmany(3)):
                             break
                 elif fetch == "all":
                     await cur.fetchall()
index 64ef05125a4eb0350deb8ee3def96b635fa8c4c5..0f2155e3ff6d36b9a230fdc4cc1c49b95d757021 100644 (file)
@@ -94,13 +94,11 @@ def test_leak(conn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc):
 
                     if fetch == "one":
                         while True:
-                            tmp = cur.fetchone()
-                            if tmp is None:
+                            if cur.fetchone() is None:
                                 break
                     elif fetch == "many":
                         while True:
-                            tmp = cur.fetchmany(3)
-                            if not tmp:
+                            if not cur.fetchmany(3):
                                 break
                     elif fetch == "all":
                         cur.fetchall()
index 42a241e4a8c05dba38692c1daaa1cffee57aabb1..f2f9553ad7576f09d8949b47f1577ca6373c3da9 100644 (file)
@@ -91,13 +91,11 @@ async def test_leak(aconn_cls, dsn, faker, fmt, fmt_out, fetch, row_factory, gc)
 
                     if fetch == "one":
                         while True:
-                            tmp = await cur.fetchone()
-                            if tmp is None:
+                            if (await cur.fetchone()) is None:
                                 break
                     elif fetch == "many":
                         while True:
-                            tmp = await cur.fetchmany(3)
-                            if not tmp:
+                            if not (await cur.fetchmany(3)):
                                 break
                     elif fetch == "all":
                         await cur.fetchall()
index 25df3a9c2b485573f6733738cd79c7232483ad22..365f0b76235e5c31b18c123cb7edd94b18ee0bec 100644 (file)
@@ -385,8 +385,7 @@ def test_row_factory(conn):
     recs = cur.fetchall()
     cur.scroll(0, "absolute")
     while True:
-        rec = cur.fetchone()
-        if not rec:
+        if not (rec := cur.fetchone()):
             break
         recs.append(rec)
     assert recs == [[1, -1], [1, -2], [1, -3]] * 2
index febd01be190ef01044510abf92051e7a8da6f09b..cede008f76162cb1e0bcbb8da108fc34f84e74fa 100644 (file)
@@ -391,8 +391,7 @@ async def test_row_factory(aconn):
     recs = await cur.fetchall()
     await cur.scroll(0, "absolute")
     while True:
-        rec = await cur.fetchone()
-        if not rec:
+        if not (rec := (await cur.fetchone())):
             break
         recs.append(rec)
     assert recs == [[1, -1], [1, -2], [1, -3]] * 2
index 2975af43ab3d06a32ffc6b43631e83bd207c3215..80c8ae0df8c89f2a41e7c2568725fd16734b6966 100644 (file)
@@ -14,8 +14,8 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch):
     OperationalError has a pgconn attribute set with needs_password.
     """
     gen = generators.connect(dsn)
-    pgconn = waiting.wait_conn(gen)
-    if not pgconn.used_password:
+
+    if not (pgconn := waiting.wait_conn(gen)).used_password:
         pytest.skip("test connection needs no password")
 
     with monkeypatch.context() as m:
index 93240b5eba92c6c8ab830befbf226a9cf3dd6661..d60bda16fa31c712500d3ae34f84319b54cbcb2c 100644 (file)
@@ -113,8 +113,7 @@ def test_scalar_row(conn):
 
 
 @pytest.mark.parametrize(
-    "factory",
-    "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split(),
+    "factory", "tuple_row dict_row namedtuple_row class_row args_row kwargs_row".split()
 )
 def test_no_result(factory, conn):
     cur = conn.cursor(row_factory=factory_from_name(factory))
@@ -151,8 +150,7 @@ def test_no_column_class_row(conn):
 
 
 def factory_from_name(name):
-    factory = getattr(rows, name)
-    if factory is rows.class_row:
+    if (factory := getattr(rows, name)) is rows.class_row:
         factory = factory(Person)
     if factory is rows.args_row:
         factory = factory(argf)
index 90ba2f09d2223756d236caf78669a7c19378cb55..85c172f07c6b82e1795f562a590f8b154bb81e27 100644 (file)
@@ -11,8 +11,8 @@ pytestmark = pytest.mark.crdb_skip("2-phase commit")
 
 def test_tpc_disabled(conn, pipeline):
     cur = conn.execute("show max_prepared_transactions")
-    val = int(cur.fetchone()[0])
-    if val:
+
+    if int(cur.fetchone()[0]):
         pytest.skip("prepared transactions enabled")
 
     conn.rollback()
index 8a448c7147ed1cab9862b2c61be81bfb5b5dd006..b3d1ed2d3a3a42a2ec61b9239b5b2a291be127ab 100644 (file)
@@ -8,8 +8,8 @@ pytestmark = pytest.mark.crdb_skip("2-phase commit")
 
 async def test_tpc_disabled(aconn, apipeline):
     cur = await aconn.execute("show max_prepared_transactions")
-    val = int((await cur.fetchone())[0])
-    if val:
+
+    if int((await cur.fetchone())[0]):
         pytest.skip("prepared transactions enabled")
 
     await aconn.rollback()
index 94a704eb55945c68d2356ca27130f6a03613e7c6..3df6fb5102d38cfd5cab560192c3a8776fe6d2bc 100644 (file)
@@ -22,8 +22,7 @@ def test_fetch(conn, name, status, encoding):
         conn.execute("select set_config('client_encoding', %s, false)", [encoding])
 
     if status:
-        status = getattr(TransactionStatus, status)
-        if status == TransactionStatus.INTRANS:
+        if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
             conn.execute("select 1")
     else:
         conn.autocommit = True
@@ -54,8 +53,7 @@ async def test_fetch_async(aconn, name, status, encoding):
         )
 
     if status:
-        status = getattr(TransactionStatus, status)
-        if status == TransactionStatus.INTRANS:
+        if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
             await aconn.execute("select 1")
     else:
         await aconn.set_autocommit(True)
@@ -100,8 +98,8 @@ def test_fetch_not_found(conn, name, status, info_cls, monkeypatch):
             return exit_orig(self, exc_type, exc_val, exc_tb)
 
         monkeypatch.setattr(psycopg.Transaction, "__exit__", exit)
-    status = getattr(TransactionStatus, status)
-    if status == TransactionStatus.INTRANS:
+
+    if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
         conn.execute("select 1")
 
     assert conn.info.transaction_status == status
@@ -122,8 +120,8 @@ async def test_fetch_not_found_async(aconn, name, status, info_cls, monkeypatch)
             return await exit_orig(self, exc_type, exc_val, exc_tb)
 
         monkeypatch.setattr(psycopg.AsyncTransaction, "__aexit__", aexit)
-    status = getattr(TransactionStatus, status)
-    if status == TransactionStatus.INTRANS:
+
+    if (status := getattr(TransactionStatus, status)) == TransactionStatus.INTRANS:
         await aconn.execute("select 1")
 
     assert aconn.info.transaction_status == status
index fdb9c888f2b82b9505cd2bc03d43a21673c39ae3..57d5b40344fa4deeea3470d2a975585be2adc342 100644 (file)
@@ -204,8 +204,7 @@ def test_numbers_array(num, type, fmt_in):
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
-    wrapper = getattr(psycopg.types.numeric, wrapper)
-    if wrapper is Decimal:
+    if (wrapper := getattr(psycopg.types.numeric, wrapper)) is Decimal:
         want_cls = Decimal
     else:
         assert wrapper.__mro__[1] in (int, float)
index da5fe9066bbafd3a28d5ca00178e6953b45b2e00..84f70b73711a704a6fdcb4d4ff3113f52c1ac877 100644 (file)
@@ -458,13 +458,7 @@ def test_dump_numeric_exhaustive(conn, fmt_in):
 
 
 @pytest.mark.pg(">= 14")
-@pytest.mark.parametrize(
-    "val, expr",
-    [
-        ("inf", "Infinity"),
-        ("-inf", "-Infinity"),
-    ],
-)
+@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")])
 def test_dump_numeric_binary_inf(conn, val, expr):
     cur = conn.cursor()
     val = Decimal(val)
@@ -492,8 +486,8 @@ def test_dump_numeric_binary_inf(conn, val, expr):
 def test_load_numeric_binary(conn, expr):
     cur = conn.cursor(binary=1)
     res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
-    val = Decimal(expr)
-    if val.is_nan():
+
+    if (val := Decimal(expr)).is_nan():
         assert res.is_nan()
     else:
         assert res == val
@@ -531,13 +525,7 @@ def test_load_numeric_exhaustive(conn, fmt_out):
 
 
 @pytest.mark.pg(">= 14")
-@pytest.mark.parametrize(
-    "val, expr",
-    [
-        ("inf", "Infinity"),
-        ("-inf", "-Infinity"),
-    ],
-)
+@pytest.mark.parametrize("val, expr", [("inf", "Infinity"), ("-inf", "-Infinity")])
 def test_load_numeric_binary_inf(conn, val, expr):
     cur = conn.cursor(binary=1)
     res = cur.execute(f"select '{expr}'::numeric").fetchone()[0]
@@ -546,14 +534,7 @@ def test_load_numeric_binary_inf(conn, val, expr):
 
 
 @pytest.mark.parametrize(
-    "val",
-    [
-        "0",
-        "0.0",
-        "0.000000000000000000001",
-        "-0.000000000000000000001",
-        "nan",
-    ],
+    "val", ["0", "0.0", "0.000000000000000000001", "-0.000000000000000000001", "nan"]
 )
 def test_numeric_as_float(conn, val):
     cur = conn.cursor()
index 23ce25f357a9906e46cabed82b2e10c6b9594181..2561be575e12ba6336b84ac968cfeb5bd2903c1c 100644 (file)
@@ -19,8 +19,7 @@ skip_numpy2 = pytest.mark.skipif(
 
 
 def _get_arch_size() -> int:
-    psize = struct.calcsize("P") * 8
-    if psize not in (32, 64):
+    if (psize := (struct.calcsize("P") * 8)) not in (32, 64):
         msg = f"the pointer size {psize} is unusual"
         raise ValueError(msg)
     return psize
@@ -198,9 +197,7 @@ def test_copy_by_oid(conn, val, nptype, pgtypes, fmt):
 
     fnames = [f"f{t}" for t in pgtypes]
     fields = [f"f{t} {t}" for fname, t in zip(fnames, pgtypes)]
-    cur.execute(
-        f"create table numpyoid (id serial primary key, {', '.join(fields)})",
-    )
+    cur.execute(f"create table numpyoid (id serial primary key, {', '.join(fields)})")
     with cur.copy(
         f"copy numpyoid ({', '.join(fnames)}) from stdin (format {fmt.name})"
     ) as copy:
index d714f5095d88a8728c5455bdb95bd0e98993e4e2..b3dc18695f6a6217854db4a06bf77be9af9e6d17 100755 (executable)
@@ -66,14 +66,12 @@ logger = logging.getLogger()
 
 
 def main() -> int:
-    opt = parse_cmdline()
-    if opt.container:
+    if (opt := parse_cmdline()).container:
         return run_in_container(opt.container)
 
     logging.basicConfig(level=opt.log_level, format="%(levelname)s %(message)s")
 
-    current_ver = ".".join(map(str, sys.version_info[:2]))
-    if current_ver != PYVER:
+    if (current_ver := ".".join(map(str, sys.version_info[:2]))) != PYVER:
         logger.warning(
             "Expecting output generated by Python %s; you are running %s instead.",
             PYVER,
@@ -488,8 +486,7 @@ class BlanksInserter(ast.NodeTransformer):  # type: ignore
         new_body.append(before)
         for i in range(1, len(body)):
             after = body[i]
-            nblanks = after.lineno - before.end_lineno - 1
-            if nblanks > 0:
+            if after.lineno - before.end_lineno - 1 > 0:
                 # Inserting one blank is enough.
                 blank = ast.Comment(
                     value="",
@@ -617,8 +614,7 @@ def parse_cmdline() -> Namespace:
         help="the files to process (process all files if not specified)",
     )
 
-    opt = parser.parse_args()
-    if not opt.inputs:
+    if not (opt := parser.parse_args()).inputs:
         opt.inputs = [PROJECT_DIR / Path(fn) for fn in ALL_INPUTS]
 
     fp: Path
index 4834d6b8045721ab868c6892513856bf25eb6b0f..f4c1e8d24a72fdd67ea5e874c0086cd855899aae 100755 (executable)
@@ -11,17 +11,16 @@ from pathlib import Path
 
 curdir = Path(__file__).parent
 pdir = curdir / "../.."
-target = pdir / "psycopg_binary"
 
-if target.exists():
+if (target := (pdir / "psycopg_binary")).exists():
     raise Exception(f"path {target} already exists")
 
 
 def sed_i(pattern: str, repl: str, filename: str | Path) -> None:
     with open(filename, "rb") as f:
         data = f.read()
-    newdata = re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)
-    if newdata != data:
+
+    if (newdata := re.sub(pattern.encode("utf8"), repl.encode("utf8"), data)) != data:
         with open(filename, "wb") as f:
             f.write(newdata)
 
index d1f030fe30cbaa4032f5f45d428f0b163aaa68bb..9fbc3aba0d8b596d6875570cc13f56a99f5898d9 100755 (executable)
@@ -168,8 +168,7 @@ chore: bump {self.package.name} package version to {self.want_version}
         matches = []
         with fp.open() as f:
             for line in f:
-                m = self._ini_regex.match(line)
-                if m:
+                if m := self._ini_regex.match(line):
                     matches.append(m)
 
         if not matches:
@@ -240,8 +239,7 @@ chore: bump {self.package.name} package version to {self.want_version}
         with fp.open() as f:
             lines = f.readlines()
 
-        lns = self._find_lines(r"^[^\s]+ " + re.escape(str(version)), lines)
-        if not lns:
+        if not (lns := self._find_lines("^[^\\s]+ " + re.escape(str(version)), lines)):
             logger.warning("no change log line found")
             return []
 
index a33c5c4efb08496bded2733586de9b1b327d4521..fa7e94f3aba924597315c40a0fd3fbb1a0df2fe2 100755 (executable)
@@ -33,8 +33,7 @@ def get_user_data(data):
         "name": data["name"],
     }
     if data["blog"]:
-        website = data["blog"]
-        if not website.startswith("http"):
+        if not (website := data["blog"]).startswith("http"):
             website = "http://" + website
 
         out["website"] = website
@@ -54,8 +53,8 @@ def update_entry(opt, filedata, entry):
     # entry is an username or an user entry daat
     if isinstance(entry, str):
         username = entry
-        entry = [e for e in filedata if e["username"] == username]
-        if not entry:
+
+        if not (entry := [e for e in filedata if e["username"] == username]):
             raise Exception(f"{username} not found")
         entry = entry[0]
     else:
index 52ba2cb351016099e8166c1e257b0ab1d4f44da9..c84c48b2e0987d42f1bb580b17e1a1b953328aaf 100755 (executable)
@@ -1,5 +1,5 @@
 #!/usr/bin/env python
-"""Find the error prefixes in various l10n used for precise prefixstripping."""
+"""Find the error prefixes in various l10n used for precise prefix stripping."""
 
 import re
 import logging
@@ -90,8 +90,7 @@ def parse_cmdline() -> Namespace:
         help="the file to change [default: %(default)s]",
     )
 
-    opt = parser.parse_args()
-    if not opt.pgroot.is_dir():
+    if not (opt := parser.parse_args()).pgroot.is_dir():
         parser.error("not a valid directory: {opt.pgroot}")
 
     return opt
index b9406c34e52a618286040cbec0b8b43cb48d82fd..bc1264492ab1157d9ffe7c7b1ddcf03cd5c9c9c2 100755 (executable)
@@ -38,20 +38,18 @@ def parse_errors_txt(url):
     page = urlopen(url)
     for line in page.read().decode("ascii").splitlines():
         # Strip comments and skip blanks
-        line = line.split("#")[0].strip()
-        if not line:
+
+        if not (line := line.split("#")[0].strip()):
             continue
 
         # Parse a section
-        m = re.match(r"Section: (Class (..) - .+)", line)
-        if m:
+        if m := re.match("Section: (Class (..) - .+)", line):
             label, class_ = m.groups()
             classes[class_] = label
             continue
 
         # Parse an error
-        m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line)
-        if m:
+        if m := re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line):
             sqlstate, macro, spec = m.groups()
             # skip sqlstates without specs as they are not publicly visible
             if not spec:
index 96b0df9ab67cc61183570eea7d5d0ddbbe8c1472..8e087a8a7d887ba11695fde2faebf5c66d64e82a 100755 (executable)
@@ -34,9 +34,8 @@ ROOT = Path(__file__).parent.parent
 
 def main() -> None:
     opt = parse_cmdline()
-    conn = psycopg.connect(opt.dsn, autocommit=True)
 
-    if CrdbConnection.is_crdb(conn):
+    if CrdbConnection.is_crdb((conn := psycopg.connect(opt.dsn, autocommit=True))):
         conn = CrdbConnection.connect(opt.dsn, autocommit=True)
         update_crdb_python_oids(conn)
     else: