]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Restore black max line length to its default
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 30 Jan 2022 16:40:19 +0000 (16:40 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 30 Jan 2022 16:43:48 +0000 (16:43 +0000)
Trying black --preview, it performs several aggressive changes that are
made worse by the stricter margin, so relax it to the default.

112 files changed:
docs/lib/libpq_docs.py
docs/lib/pg3_docs.py
psycopg/psycopg/_adapters_map.py
psycopg/psycopg/_cmodule.py
psycopg/psycopg/_dns.py
psycopg/psycopg/_preparing.py
psycopg/psycopg/_queries.py
psycopg/psycopg/_tpc.py
psycopg/psycopg/_transform.py
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/_tz.py
psycopg/psycopg/abc.py
psycopg/psycopg/adapt.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/conninfo.py
psycopg/psycopg/copy.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/dbapi20.py
psycopg/psycopg/errors.py
psycopg/psycopg/postgres.py
psycopg/psycopg/pq/_pq_ctypes.py
psycopg/psycopg/pq/_pq_ctypes.pyi
psycopg/psycopg/pq/abc.py
psycopg/psycopg/pq/misc.py
psycopg/psycopg/pq/pq_ctypes.py
psycopg/psycopg/rows.py
psycopg/psycopg/server_cursor.py
psycopg/psycopg/sql.py
psycopg/psycopg/transaction.py
psycopg/psycopg/types/array.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/datetime.py
psycopg/psycopg/types/hstore.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/numeric.py
psycopg/psycopg/types/range.py
psycopg/psycopg/types/shapely.py
psycopg/psycopg/types/string.py
psycopg/psycopg/waiting.py
psycopg/pyproject.toml
psycopg/tox.ini
psycopg_c/psycopg_c/__init__.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/pyproject.toml
psycopg_c/tox.ini
psycopg_pool/psycopg_pool/base.py
psycopg_pool/psycopg_pool/null_pool.py
psycopg_pool/psycopg_pool/null_pool_async.py
psycopg_pool/psycopg_pool/pool.py
psycopg_pool/psycopg_pool/pool_async.py
psycopg_pool/psycopg_pool/sched.py
psycopg_pool/tox.ini
pyproject.toml
tests/fix_db.py
tests/fix_faker.py
tests/fix_proxy.py
tests/fix_psycopg.py
tests/pool/fix_pool.py
tests/pool/test_null_pool.py
tests/pool/test_null_pool_async.py
tests/pool/test_pool.py
tests/pool/test_pool_async.py
tests/pq/test_async.py
tests/pq/test_copy.py
tests/pq/test_escaping.py
tests/pq/test_exec.py
tests/pq/test_pgconn.py
tests/pq/test_pgresult.py
tests/pq/test_pipeline.py
tests/scripts/dectest.py
tests/scripts/spiketest.py
tests/test_adapt.py
tests/test_concurrency.py
tests/test_concurrency_async.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_copy.py
tests/test_copy_async.py
tests/test_cursor.py
tests/test_cursor_async.py
tests/test_dns.py
tests/test_dns_srv.py
tests/test_errors.py
tests/test_prepared.py
tests/test_prepared_async.py
tests/test_psycopg_dbapi20.py
tests/test_server_cursor.py
tests/test_server_cursor_async.py
tests/test_sql.py
tests/test_tpc.py
tests/test_tpc_async.py
tests/test_transaction.py
tests/test_transaction_async.py
tests/test_typing.py
tests/types/test_array.py
tests/types/test_bool.py
tests/types/test_composite.py
tests/types/test_datetime.py
tests/types/test_multirange.py
tests/types/test_net.py
tests/types/test_numeric.py
tests/types/test_range.py
tests/types/test_shapely.py
tests/types/test_string.py
tests/utils.py
tools/update_backer.py
tools/update_errors.py
tools/update_oids.py
tox.ini

index b7d6d57b6321d6bad0feafb8551cae21aa388d09..e8c82caa536732eca8a66c3f82d49382d19b57ee 100644 (file)
@@ -80,9 +80,7 @@ class LibpqParser(HTMLParser):
             func_id=self.varlist_id.upper(),
         )
 
-    _url_pattern = (
-        "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}"
-    )
+    _url_pattern = "https://www.postgresql.org/docs/{version}/{section}.html#{func_id}"
 
 
 class LibpqReader:
index 2ebd54c73e36a04587cca5fe780ece3108497717..4352afbcf070ca4d41f8c79cb7e8faa52ef757ed 100644 (file)
@@ -24,9 +24,7 @@ def before_process_signature(app, obj, bound_method):
             del ann["return"]
 
 
-def process_signature(
-    app, what, name, obj, options, signature, return_annotation
-):
+def process_signature(app, what, name, obj, options, signature, return_annotation):
     pass
 
 
index d6efb8c761cb9dd99c4068dccb8b7548f48e1abc..102b2bf39d283f1ce271d3f359922178d76a3905 100644 (file)
@@ -160,9 +160,7 @@ class AdaptersMap:
 
             self._dumpers_by_oid[dumper.format][dumper.oid] = dumper
 
-    def register_loader(
-        self, oid: Union[int, str], loader: Type["Loader"]
-    ) -> None:
+    def register_loader(self, oid: Union[int, str], loader: Type["Loader"]) -> None:
         """
         Configure the context to use *loader* to convert data of oid *oid*.
 
@@ -176,9 +174,7 @@ class AdaptersMap:
         if isinstance(oid, str):
             oid = self.types[oid].oid
         if not isinstance(oid, int):
-            raise TypeError(
-                f"loaders should be registered on oid, got {oid} instead"
-            )
+            raise TypeError(f"loaders should be registered on oid, got {oid} instead")
 
         if _psycopg:
             loader = self._get_optimised(loader)
@@ -252,9 +248,7 @@ class AdaptersMap:
                 )
             raise e.ProgrammingError(msg)
 
-    def get_loader(
-        self, oid: int, format: pq.Format
-    ) -> Optional[Type["Loader"]]:
+    def get_loader(self, oid: int, format: pq.Format) -> Optional[Type["Loader"]]:
         """
         Return the loader class for the given oid and format.
 
index a553fcc4176bb5e7165724c166aa0ad5c0b751b1..710564cd8514483d9f71923c7d5bff16015227b6 100644 (file)
@@ -21,6 +21,4 @@ elif pq.__impl__ == "binary":
 elif pq.__impl__ == "python":
     _psycopg = None  # type: ignore
 else:
-    raise ImportError(
-        f"can't find _psycopg optimised module in {pq.__impl__!r}"
-    )
+    raise ImportError(f"can't find _psycopg optimised module in {pq.__impl__!r}")
index 0e83d29e2508cbc210187eba9bd6ad917faecd85..e8860f49e5c8d2c62a953ec45844e1d763486a79 100644 (file)
@@ -187,9 +187,7 @@ class Rfc2782Resolver:
     the async paths.
     """
 
-    re_srv_rr = re.compile(
-        r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)"
-    )
+    re_srv_rr = re.compile(r"^(?P<service>_[^\.]+)\.(?P<proto>_[^\.]+)\.(?P<target>.+)")
 
     def resolve(self, params: Dict[str, Any]) -> Dict[str, Any]:
         """Update the parameters host and port after SRV lookup."""
index d7d71d1c0f38848351f852a8f08cde610e8c1c1b..3e409074e80cf0641fb0651fa88048554783a8ab 100644 (file)
@@ -74,9 +74,7 @@ class PrepareManager:
             # The query is not to be prepared yet
             return Prepare.NO, b""
 
-    def _should_discard(
-        self, prep: Prepare, results: Sequence["PGresult"]
-    ) -> bool:
+    def _should_discard(self, prep: Prepare, results: Sequence["PGresult"]) -> bool:
         """Check if we need to discard our entire state: it should happen on
         rollback or on dropping objects, because the same object may get
         recreated and postgres would fail internal lookups.
@@ -86,9 +84,7 @@ class PrepareManager:
                 if result.status != ExecStatus.COMMAND_OK:
                     continue
                 cmdstat = result.command_status
-                if cmdstat and (
-                    cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"
-                ):
+                if cmdstat and (cmdstat.startswith(b"DROP ") or cmdstat == b"ROLLBACK"):
                     return self.clear()
         return False
 
index 2f2ee4238f036b17b01c0580db1d7ac146c282b4..b19c0115967b0e7a9dcedb495b9280f195fbdd2f 100644 (file)
@@ -90,9 +90,7 @@ class PostgresQuery:
         This method updates `params` and `types`.
         """
         if vars is not None:
-            params = _validate_and_reorder_params(
-                self._parts, vars, self._order
-            )
+            params = _validate_and_reorder_params(self._parts, vars, self._order)
             assert self._want_formats is not None
             self.params = self._tx.dump_sequence(params, self._want_formats)
             self.types = self._tx.types or ()
@@ -144,8 +142,7 @@ def _query2pg(
             else:
                 if seen[part.item][1] != part.format:
                     raise e.ProgrammingError(
-                        f"placeholder '{part.item}' cannot have"
-                        f" different formats"
+                        f"placeholder '{part.item}' cannot have different formats"
                     )
                 chunks.append(seen[part.item][0])
 
@@ -184,9 +181,7 @@ def _validate_and_reorder_params(
                 f" {len(vars)} parameters were passed"
             )
         if vars and not isinstance(parts[0].item, int):
-            raise TypeError(
-                "named placeholders require a mapping of parameters"
-            )
+            raise TypeError("named placeholders require a mapping of parameters")
         return vars  # type: ignore[return-value]
 
     else:
@@ -195,9 +190,7 @@ def _validate_and_reorder_params(
                 "positional placeholders (%s) require a sequence of parameters"
             )
         try:
-            return [
-                vars[item] for item in order or ()  # type: ignore[call-overload]
-            ]
+            return [vars[item] for item in order or ()]  # type: ignore[call-overload]
         except KeyError:
             raise e.ProgrammingError(
                 f"query parameter missing:"
index 43e3df62c710f3308062d47a355093ede9c88ae3..35281881ce71f3608cc8d647d2b21edd39d1f61d 100644 (file)
@@ -68,9 +68,7 @@ class Xid:
             if bqual is None:
                 raise TypeError("if format_id is specified, bqual must be too")
             if not 0 <= format_id < 0x80000000:
-                raise ValueError(
-                    "format_id must be a non-negative 32-bit integer"
-                )
+                raise ValueError("format_id must be a non-negative 32-bit integer")
             if len(bqual) > 64:
                 raise ValueError("bqual must be not longer than 64 chars")
             if len(gtrid) > 64:
index b3c30d3b8957b067f8cf38fa2d73485735d864d4..4be169dbffb304f6d37d3dded6f292744625a4bc 100644 (file)
@@ -117,21 +117,13 @@ class Transformer(AdaptContext):
             self.get_loader(result.ftype(i), fmt).load for i in range(nf)
         ]
 
-    def set_dumper_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None:
-        self._row_dumpers = [
-            self.get_dumper_by_oid(oid, format) for oid in types
-        ]
+    def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
+        self._row_dumpers = [self.get_dumper_by_oid(oid, format) for oid in types]
         self.types = tuple(types)
         self.formats = [format] * len(types)
 
-    def set_loader_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None:
-        self._row_loaders = [
-            self.get_loader(oid, format).load for oid in types
-        ]
+    def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
+        self._row_loaders = [self.get_loader(oid, format).load for oid in types]
 
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
@@ -214,9 +206,7 @@ class Transformer(AdaptContext):
 
         return dumper
 
-    def load_rows(
-        self, row0: int, row1: int, make_row: RowMaker[Row]
-    ) -> List[Row]:
+    def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]:
         res = self._pgresult
         if not res:
             raise e.InterfaceError("result not set")
@@ -253,9 +243,7 @@ class Transformer(AdaptContext):
 
         return make_row(record)
 
-    def load_sequence(
-        self, record: Sequence[Optional[bytes]]
-    ) -> Tuple[Any, ...]:
+    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
         if len(self._row_loaders) != len(record):
             raise e.ProgrammingError(
                 f"cannot load sequence of {len(record)} items:"
index 5ac37245b002bb6c36377066eaaba9b6e232020c..853247183840abaaec3686623cb6379944d4a65e 100644 (file)
@@ -113,12 +113,8 @@ class TypeInfo:
 
         try:
             async with conn.transaction():
-                async with conn.cursor(
-                    binary=True, row_factory=dict_row
-                ) as cur:
-                    await cur.execute(
-                        cls._get_info_query(conn), {"name": name}
-                    )
+                async with conn.cursor(binary=True, row_factory=dict_row) as cur:
+                    await cur.execute(cls._get_info_query(conn), {"name": name})
                     recs = await cur.fetchall()
         except e.UndefinedObject:
             return None
@@ -134,9 +130,7 @@ class TypeInfo:
         elif not recs:
             return None
         else:
-            raise e.ProgrammingError(
-                f"found {len(recs)} different types named {name}"
-            )
+            raise e.ProgrammingError(f"found {len(recs)} different types named {name}")
 
     def register(self, context: Optional[AdaptContext] = None) -> None:
         """
@@ -356,15 +350,11 @@ class TypesRegistry:
             if key.endswith("[]"):
                 key = key[:-2]
         elif not isinstance(key, (int, tuple)):
-            raise TypeError(
-                f"the key must be an oid or a name, got {type(key)}"
-            )
+            raise TypeError(f"the key must be an oid or a name, got {type(key)}")
         try:
             return self._registry[key]
         except KeyError:
-            raise KeyError(
-                f"couldn't find the type {key!r} in the types registry"
-            )
+            raise KeyError(f"couldn't find the type {key!r} in the types registry")
 
     @overload
     def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
@@ -403,9 +393,7 @@ class TypesRegistry:
         else:
             return t.oid
 
-    def get_by_subtype(
-        self, cls: Type[T], subtype: Union[int, str]
-    ) -> Optional[T]:
+    def get_by_subtype(self, cls: Type[T], subtype: Union[int, str]) -> Optional[T]:
         """
         Return info about a `TypeInfo` subclass by its element name or oid.
 
index 66110ab5d93bd759760ca5b831c108df1a766480..30f846f1bef8dd409fb24d30ae6c337d02c7ca4b 100644 (file)
@@ -29,9 +29,7 @@ def get_tzinfo(pgconn: Optional[PGconn]) -> tzinfo:
         try:
             zi: tzinfo = ZoneInfo(sname)
         except KeyError:
-            logger.warning(
-                "unknown PostgreSQL timezone: %r; will use UTC", sname
-            )
+            logger.warning("unknown PostgreSQL timezone: %r; will use UTC", sname)
             zi = timezone.utc
 
         _timezones[tzname] = zi
index c9904600fd89c1c579cc9f07799a48c78422441a..9bdd3e1734d43501cfff569e21ee47fef463b446 100644 (file)
@@ -219,14 +219,10 @@ class Transformer(Protocol):
     ) -> None:
         ...
 
-    def set_dumper_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None:
+    def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None:
         ...
 
-    def set_loader_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None:
+    def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None:
         ...
 
     def dump_sequence(
@@ -237,17 +233,13 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
         ...
 
-    def load_rows(
-        self, row0: int, row1: int, make_row: "RowMaker[Row]"
-    ) -> List["Row"]:
+    def load_rows(self, row0: int, row1: int, make_row: "RowMaker[Row]") -> List["Row"]:
         ...
 
     def load_row(self, row: int, make_row: "RowMaker[Row]") -> Optional["Row"]:
         ...
 
-    def load_sequence(
-        self, record: Sequence[Optional[bytes]]
-    ) -> Tuple[Any, ...]:
+    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]:
         ...
 
     def get_loader(self, oid: int, format: pq.Format) -> Loader:
index bfea8c47a5dc31b65a9b45db673636ac4182a54d..8341a5a29981c14a9a6327ecba453c245dad440d 100644 (file)
@@ -91,9 +91,7 @@ class Dumper(abc.Dumper, ABC):
             rv = rv.replace(b"\\", b"\\\\")
         return rv
 
-    def get_key(
-        self, obj: Any, format: PyFormat
-    ) -> Union[type, Tuple[type, ...]]:
+    def get_key(self, obj: Any, format: PyFormat) -> Union[type, Tuple[type, ...]]:
         """
         Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
         `~psycopg.abc.Dumper` protocol. Look at its definition for details.
index 3a739e9469b8ca251d0ecd1155e6e7c23fca87ca..88498e75e4c1a20218d9a01610af53edc5a6756e 100644 (file)
@@ -202,9 +202,7 @@ class BaseConnection(Generic[Row]):
         # Base implementation, not thread safe.
         # Subclasses must call it holding a lock
         self._check_intrans("isolation_level")
-        self._isolation_level = (
-            IsolationLevel(value) if value is not None else None
-        )
+        self._isolation_level = IsolationLevel(value) if value is not None else None
         self._begin_statement = b""
 
     @property
@@ -317,9 +315,7 @@ class BaseConnection(Generic[Row]):
             try:
                 cb(diag)
             except Exception as ex:
-                logger.exception(
-                    "error processing notice callback '%s': %s", cb, ex
-                )
+                logger.exception("error processing notice callback '%s': %s", cb, ex)
 
     def add_notify_handler(self, callback: NotifyHandler) -> None:
         """
@@ -419,16 +415,12 @@ class BaseConnection(Generic[Row]):
         if result_format == Format.TEXT:
             self.pgconn.send_query(command)
         else:
-            self.pgconn.send_query_params(
-                command, None, result_format=result_format
-            )
+            self.pgconn.send_query_params(command, None, result_format=result_format)
 
         result = (yield from execute(self.pgconn))[-1]
         if result.status not in (ExecStatus.COMMAND_OK, ExecStatus.TUPLES_OK):
             if result.status == ExecStatus.FATAL_ERROR:
-                raise e.error_from_result(
-                    result, encoding=pgconn_encoding(self.pgconn)
-                )
+                raise e.error_from_result(result, encoding=pgconn_encoding(self.pgconn))
             else:
                 raise e.InterfaceError(
                     f"unexpected result {ExecStatus(result.status).name}"
@@ -472,9 +464,7 @@ class BaseConnection(Generic[Row]):
             parts.append(b"READ ONLY" if self.read_only else b"READ WRITE")
 
         if self.deferrable is not None:
-            parts.append(
-                b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE"
-            )
+            parts.append(b"DEFERRABLE" if self.deferrable else b"NOT DEFERRABLE")
 
         self._begin_statement = b" ".join(parts)
         return self._begin_statement
@@ -558,13 +548,9 @@ class BaseConnection(Generic[Row]):
             )
         xid = self._tpc[0]
         self._tpc = (xid, True)
-        yield from self._exec_command(
-            SQL("PREPARE TRANSACTION {}").format(str(xid))
-        )
+        yield from self._exec_command(SQL("PREPARE TRANSACTION {}").format(str(xid)))
 
-    def _tpc_finish_gen(
-        self, action: str, xid: Union[Xid, str, None]
-    ) -> PQGen[None]:
+    def _tpc_finish_gen(self, action: str, xid: Union[Xid, str, None]) -> PQGen[None]:
         fname = f"tpc_{action}()"
         if xid is None:
             if not self._tpc:
@@ -605,9 +591,7 @@ class Connection(BaseConnection[Row]):
     server_cursor_factory: Type[ServerCursor[Row]]
     row_factory: RowFactory[Row]
 
-    def __init__(
-        self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None
-    ):
+    def __init__(self, pgconn: "PGconn", row_factory: Optional[RowFactory[Row]] = None):
         super().__init__(pgconn)
         self.row_factory = row_factory or cast(RowFactory[Row], tuple_row)
         self.lock = threading.Lock()
@@ -704,9 +688,7 @@ class Connection(BaseConnection[Row]):
             self.close()
 
     @classmethod
-    def _get_connection_params(
-        cls, conninfo: str, **kwargs: Any
-    ) -> Dict[str, Any]:
+    def _get_connection_params(cls, conninfo: str, **kwargs: Any) -> Dict[str, Any]:
         """Manipulate connection parameters before connecting.
 
         :param conninfo: Connection string as received by `~Connection.connect()`.
@@ -854,9 +836,7 @@ class Connection(BaseConnection[Row]):
                 ns = self.wait(notifies(self.pgconn))
             enc = pgconn_encoding(self.pgconn)
             for pgn in ns:
-                n = Notify(
-                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
-                )
+                n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
 
     def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
index 3382c5d0f863bafb84a661096d03520a4bf7a57a..8bd2ca73f6069c0cfb4c183c496f9301b5700436 100644 (file)
@@ -289,18 +289,14 @@ class AsyncConnection(BaseConnection[Row]):
                 ns = await self.wait(notifies(self.pgconn))
             enc = pgconn_encoding(self.pgconn)
             for pgn in ns:
-                n = Notify(
-                    pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid
-                )
+                n = Notify(pgn.relname.decode(enc), pgn.extra.decode(enc), pgn.be_pid)
                 yield n
 
     async def wait(self, gen: PQGen[RV]) -> RV:
         return await waiting.wait_async(gen, self.pgconn.socket)
 
     @classmethod
-    async def _wait_conn(
-        cls, gen: PQGenConn[RV], timeout: Optional[int]
-    ) -> RV:
+    async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
         return await waiting.wait_conn_async(gen, timeout)
 
     def _set_autocommit(self, value: bool) -> None:
@@ -314,9 +310,7 @@ class AsyncConnection(BaseConnection[Row]):
     def _set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         self._no_set_async("isolation_level")
 
-    async def set_isolation_level(
-        self, value: Optional[IsolationLevel]
-    ) -> None:
+    async def set_isolation_level(self, value: Optional[IsolationLevel]) -> None:
         """Async version of the `~Connection.isolation_level` setter."""
         async with self.lock:
             super()._set_isolation_level(value)
index 6d966edf5b3b5bd2ee6e7add9cec27a835d11745..11606b56789d40559c8bc4b3fc0dd49e89bbf752 100644 (file)
@@ -48,9 +48,7 @@ def make_conninfo(conninfo: str = "", **kwargs: Any) -> str:
         tmp.update(kwargs)
         kwargs = tmp
 
-    conninfo = " ".join(
-        f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items()
-    )
+    conninfo = " ".join(f"{k}={_param_escape(str(v))}" for (k, v) in kwargs.items())
 
     # Verify the result is valid
     _parse_conninfo(conninfo)
@@ -74,11 +72,7 @@ def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
            #LIBPQ-CONNSTRING
     """
     opts = _parse_conninfo(conninfo)
-    rv = {
-        opt.keyword.decode(): opt.val.decode()
-        for opt in opts
-        if opt.val is not None
-    }
+    rv = {opt.keyword.decode(): opt.val.decode() for opt in opts if opt.val is not None}
     for k, v in kwargs.items():
         if v is not None:
             rv[k] = v
index 6ddadd7060e03ac191ed189f42fbfbb54ed9c2a2..e8b45adae70de596c0c1e2ad4488f3d3c7f6cd4c 100644 (file)
@@ -66,9 +66,7 @@ class BaseCopy(Generic[ConnectionType]):
         self._pgresult: "PGresult" = tx.pgresult
 
         if self._pgresult.binary_tuples == pq.Format.TEXT:
-            self.formatter = TextFormatter(
-                tx, encoding=pgconn_encoding(self._pgconn)
-            )
+            self.formatter = TextFormatter(tx, encoding=pgconn_encoding(self._pgconn))
         else:
             self.formatter = BinaryFormatter(tx)
 
@@ -103,18 +101,12 @@ class BaseCopy(Generic[ConnectionType]):
 
         """
         registry = self.cursor.adapters.types
-        oids = [
-            t if isinstance(t, int) else registry.get_oid(t) for t in types
-        ]
+        oids = [t if isinstance(t, int) else registry.get_oid(t) for t in types]
 
         if self._pgresult.status == ExecStatus.COPY_IN:
-            self.formatter.transformer.set_dumper_types(
-                oids, self.formatter.format
-            )
+            self.formatter.transformer.set_dumper_types(oids, self.formatter.format)
         else:
-            self.formatter.transformer.set_loader_types(
-                oids, self.formatter.format
-            )
+            self.formatter.transformer.set_loader_types(oids, self.formatter.format)
 
     # High level copy protocol generators (state change of the Copy object)
 
@@ -164,10 +156,7 @@ class BaseCopy(Generic[ConnectionType]):
         if not exc:
             return
 
-        if (
-            self.connection.pgconn.transaction_status
-            != pq.TransactionStatus.ACTIVE
-        ):
+        if self.connection.pgconn.transaction_status != pq.TransactionStatus.ACTIVE:
             # The server has already finished to send copy data. The connection
             # is already in a good state.
             return
@@ -320,9 +309,7 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
 
     def __init__(self, cursor: "AsyncCursor[Any]"):
         super().__init__(cursor)
-        self._queue: asyncio.Queue[bytes] = asyncio.Queue(
-            maxsize=self.QUEUE_SIZE
-        )
+        self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE)
         self._worker: Optional[asyncio.Future[None]] = None
 
     async def __aenter__(self) -> "AsyncCopy":
@@ -551,9 +538,7 @@ class BinaryFormatter(Formatter):
             return data
 
         elif isinstance(data, str):
-            raise TypeError(
-                "cannot copy str data in binary mode: use bytes instead"
-            )
+            raise TypeError("cannot copy str data in binary mode: use bytes instead")
 
         else:
             raise TypeError(f"can't write {type(data).__name__}")
@@ -653,9 +638,7 @@ _dump_repl = {
 }
 
 
-def _dump_sub(
-    m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl
-) -> bytes:
+def _dump_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _dump_repl) -> bytes:
     return __map[m.group(0)]
 
 
@@ -663,9 +646,7 @@ _load_re = re.compile(b"\\\\[btnvfr\\\\]")
 _load_repl = {v: k for k, v in _dump_repl.items()}
 
 
-def _load_sub(
-    m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl
-) -> bytes:
+def _load_sub(m: Match[bytes], __map: Dict[bytes, bytes] = _load_repl) -> bytes:
     return __map[m.group(0)]
 
 
index 423f15732a56a1243c128af34786bff44bd7ed10..2b172a0b939cdace66d0c3815dbf485cb1657a1e 100644 (file)
@@ -340,9 +340,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         self._execute_send(query, binary=False)
         results = yield from execute(self._pgconn)
         if len(results) != 1:
-            raise e.ProgrammingError(
-                "COPY cannot be mixed with other operations"
-            )
+            raise e.ProgrammingError("COPY cannot be mixed with other operations")
 
         result = results[0]
         self._check_copy_result(result)
@@ -428,9 +426,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
                 f" {ExecStatus(result.status).name}"
             )
 
-    def _set_current_result(
-        self, i: int, format: Optional[Format] = None
-    ) -> None:
+    def _set_current_result(self, i: int, format: Optional[Format] = None) -> None:
         """
         Select one of the results in the cursor as the active one.
         """
@@ -469,9 +465,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         if not res:
             raise e.ProgrammingError("no result available")
         elif res.status != ExecStatus.TUPLES_OK:
-            raise e.ProgrammingError(
-                "the last operation didn't produce a result"
-            )
+            raise e.ProgrammingError("the last operation didn't produce a result")
 
     def _check_copy_result(self, result: "PGresult") -> None:
         """
@@ -496,9 +490,7 @@ class BaseCursor(Generic[ConnectionType, Row]):
         elif mode == "absolute":
             newpos = value
         else:
-            raise ValueError(
-                f"bad mode: {mode}. It should be 'relative' or 'absolute'"
-            )
+            raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
         if not 0 <= newpos < self.pgresult.ntuples:
             raise IndexError("position out of bound")
         self._pos = newpos
@@ -515,9 +507,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     __module__ = "psycopg"
     __slots__ = ()
 
-    def __init__(
-        self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]
-    ):
+    def __init__(self, connection: "Connection[Any]", *, row_factory: RowFactory[Row]):
         super().__init__(connection)
         self._row_factory = row_factory
 
@@ -566,9 +556,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         try:
             with self._conn.lock:
                 self._conn.wait(
-                    self._execute_gen(
-                        query, params, prepare=prepare, binary=binary
-                    )
+                    self._execute_gen(query, params, prepare=prepare, binary=binary)
                 )
         except e.Error as ex:
             raise ex.with_traceback(None)
@@ -586,9 +574,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         """
         try:
             with self._conn.lock:
-                self._conn.wait(
-                    self._executemany_gen(query, params_seq, returning)
-                )
+                self._conn.wait(self._executemany_gen(query, params_seq, returning))
         except e.Error as ex:
             raise ex.with_traceback(None)
 
@@ -603,9 +589,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         Iterate row-by-row on a result from the database.
         """
         with self._conn.lock:
-            self._conn.wait(
-                self._stream_send_gen(query, params, binary=binary)
-            )
+            self._conn.wait(self._stream_send_gen(query, params, binary=binary))
             first = True
             while self._conn.wait(self._stream_fetchone_gen(first)):
                 rec = self._tx.load_row(0, self._make_row)
@@ -656,9 +640,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
         """
         self._check_result_for_fetch()
         assert self.pgresult
-        records = self._tx.load_rows(
-            self._pos, self.pgresult.ntuples, self._make_row
-        )
+        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
         self._pos = self.pgresult.ntuples
         return records
 
index 323b77a4de1611db19c66ed216e83b0225eba6da..1fef161fa4fca23729f136f2eb90c2f136bef9f5 100644 (file)
@@ -73,9 +73,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         try:
             async with self._conn.lock:
                 await self._conn.wait(
-                    self._execute_gen(
-                        query, params, prepare=prepare, binary=binary
-                    )
+                    self._execute_gen(query, params, prepare=prepare, binary=binary)
                 )
         except e.Error as ex:
             raise ex.with_traceback(None)
@@ -104,9 +102,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
         binary: Optional[bool] = None,
     ) -> AsyncIterator[Row]:
         async with self._conn.lock:
-            await self._conn.wait(
-                self._stream_send_gen(query, params, binary=binary)
-            )
+            await self._conn.wait(self._stream_send_gen(query, params, binary=binary))
             first = True
             while await self._conn.wait(self._stream_fetchone_gen(first)):
                 rec = self._tx.load_row(0, self._make_row)
@@ -138,9 +134,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     async def fetchall(self) -> List[Row]:
         self._check_result_for_fetch()
         assert self.pgresult
-        records = self._tx.load_rows(
-            self._pos, self.pgresult.ntuples, self._make_row
-        )
+        records = self._tx.load_rows(self._pos, self.pgresult.ntuples, self._make_row)
         self._pos = self.pgresult.ntuples
         return records
 
index 200c6017c174c88d8024a22852a4f90fac9ac361..3c3d8b72f9747dcd1aa1fd112c918e9dc0b8ebca 100644 (file)
@@ -39,9 +39,7 @@ BINARY = DBAPITypeObject("BINARY", ("bytea",))
 DATETIME = DBAPITypeObject(
     "DATETIME", "timestamp timestamptz date time timetz interval".split()
 )
-NUMBER = DBAPITypeObject(
-    "NUMBER", "int2 int4 int8 float4 float8 numeric".split()
-)
+NUMBER = DBAPITypeObject("NUMBER", "int2 int4 int8 float4 float8 numeric".split())
 ROWID = DBAPITypeObject("ROWID", ("oid",))
 STRING = DBAPITypeObject("STRING", "text varchar bpchar".split())
 
index b6e15edb4f8eef3b20a571a6e721e1f04e3753b6..0e0db5204ab5270facc91f17cd5fbfc4d57ad76d 100644 (file)
@@ -53,10 +53,7 @@ class Error(Exception):
     sqlstate: Optional[str] = None
 
     def __init__(
-        self,
-        *args: Sequence[Any],
-        info: ErrorInfo = None,
-        encoding: str = "utf-8"
+        self, *args: Sequence[Any], info: ErrorInfo = None, encoding: str = "utf-8"
     ):
         super().__init__(*args)
         self._info = info
@@ -347,9 +344,7 @@ _base_exc_map = {
 }
 
 
-def sqlcode(
-    const_name: str, code: str
-) -> Callable[[Type[Error]], Type[Error]]:
+def sqlcode(const_name: str, code: str) -> Callable[[Type[Error]], Type[Error]]:
     """
     Decorator to associate an exception class to a sqlstate.
     """
index 37b130421a21c0fdfea7b0b4442d46bcd90cef16..50ab959481a3870c2775769fc07a64f5a4fbc037 100644 (file)
@@ -88,24 +88,12 @@ for t in [
     RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
     RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
     RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
-    MultirangeInfo(
-        "datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082
-    ),
-    MultirangeInfo(
-        "int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23
-    ),
-    MultirangeInfo(
-        "int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20
-    ),
-    MultirangeInfo(
-        "nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700
-    ),
-    MultirangeInfo(
-        "tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114
-    ),
-    MultirangeInfo(
-        "tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184
-    ),
+    MultirangeInfo("datemultirange", 4535, 6155, range_oid=3912, subtype_oid=1082),
+    MultirangeInfo("int4multirange", 4451, 6150, range_oid=3904, subtype_oid=23),
+    MultirangeInfo("int8multirange", 4536, 6157, range_oid=3926, subtype_oid=20),
+    MultirangeInfo("nummultirange", 4532, 6151, range_oid=3906, subtype_oid=1700),
+    MultirangeInfo("tsmultirange", 4533, 6152, range_oid=3908, subtype_oid=1114),
+    MultirangeInfo("tstzmultirange", 4534, 6153, range_oid=3910, subtype_oid=1184),
     # autogenerated: end
 ]:
     types.add(t)
index bdf26d203d312004930f2ceca479c8c924caa81c..605c86f0a39befb0acd2684de98e82af3ee2a57a 100644 (file)
@@ -778,9 +778,7 @@ def generate_stub() -> None:
     )
 
     known = {
-        line[4:].split("(", 1)[0]
-        for line in lines[:istart]
-        if line.startswith("def ")
+        line[4:].split("(", 1)[0] for line in lines[:istart] if line.startswith("def ")
     }
 
     signatures = []
index c3eaa279be2b4208be781b0dce1333a6a9c59710..9f7e7cdac94110d87131edb5344a53a5d5ac6662 100644 (file)
@@ -88,9 +88,7 @@ def PQsendQueryPrepared(
     arg6: Optional[Array[c_int]],
     arg7: int,
 ) -> int: ...
-def PQcancel(
-    arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int
-) -> int: ...
+def PQcancel(arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int) -> int: ...
 def PQsetNoticeReceiver(
     arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
 ) -> Callable[[Any], PGresult_struct]: ...
@@ -101,14 +99,10 @@ def PQsetNoticeReceiver(
 def PQnotifies(
     arg1: Optional[PGconn_struct],
 ) -> Optional[pointer[PGnotify_struct]]: ...  # type: ignore
-def PQputCopyEnd(
-    arg1: Optional[PGconn_struct], arg2: Optional[bytes]
-) -> int: ...
+def PQputCopyEnd(arg1: Optional[PGconn_struct], arg2: Optional[bytes]) -> int: ...
 
 # Arg 2 is a pointer, reported as _CArgObject by mypy
-def PQgetCopyData(
-    arg1: Optional[PGconn_struct], arg2: Any, arg3: int
-) -> int: ...
+def PQgetCopyData(arg1: Optional[PGconn_struct], arg2: Any, arg3: int) -> int: ...
 def PQsetResultAttrs(
     arg1: Optional[PGresult_struct],
     arg2: int,
index db57478fff521f60e82d3c76d1f5cb43abb21b70..20da0a3a042c83378b6ecbcb434c13d09d030ae6 100644 (file)
@@ -316,9 +316,7 @@ class PGresult(Protocol):
     def binary_tuples(self) -> int:
         ...
 
-    def get_value(
-        self, row_number: int, column_number: int
-    ) -> Optional[bytes]:
+    def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
         ...
 
     @property
@@ -362,9 +360,7 @@ class Conninfo(Protocol):
         ...
 
     @classmethod
-    def _options_from_array(
-        cls, opts: Sequence[Any]
-    ) -> List["ConninfoOption"]:
+    def _options_from_array(cls, opts: Sequence[Any]) -> List["ConninfoOption"]:
         ...
 
 
index 217000cdb29b645bad147c5d244fdf912708c808..1822d604c3f29e456a33c6c7b2d5819acb8307f4 100644 (file)
@@ -70,9 +70,7 @@ def error_message(obj: Union[PGconn, PGresult], encoding: str = "utf8") -> str:
             bmsg = bmsg.split(b":", 1)[-1].strip()
 
     else:
-        raise TypeError(
-            f"PGconn or PGresult expected, got {type(obj).__name__}"
-        )
+        raise TypeError(f"PGconn or PGresult expected, got {type(obj).__name__}")
 
     if bmsg:
         msg = bmsg.decode(encoding, "replace")
index 8aa30b1c7da545a9dc89fce65d0d13e166c4b6b2..f286b7ec82b789da7355fc1ac99a3060993b19d5 100644 (file)
@@ -264,9 +264,7 @@ class PGconn:
             raise TypeError(f"bytes expected, got {type(command)} instead")
         self._ensure_pgconn()
         if not impl.PQsendQuery(self._pgconn_ptr, command):
-            raise e.OperationalError(
-                f"sending query failed: {error_message(self)}"
-            )
+            raise e.OperationalError(f"sending query failed: {error_message(self)}")
 
     def exec_params(
         self,
@@ -317,9 +315,7 @@ class PGconn:
             atypes = (impl.Oid * nparams)(*param_types)
 
         self._ensure_pgconn()
-        if not impl.PQsendPrepare(
-            self._pgconn_ptr, name, command, nparams, atypes
-        ):
+        if not impl.PQsendPrepare(self._pgconn_ptr, name, command, nparams, atypes):
             raise e.OperationalError(
                 f"sending query and params failed: {error_message(self)}"
             )
@@ -369,9 +365,7 @@ class PGconn:
                     for b in param_values
                 )
             )
-            alenghts = (c_int * nparams)(
-                *(len(p) if p else 0 for p in param_values)
-            )
+            alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
         else:
             nparams = 0
             aparams = alenghts = None
@@ -418,9 +412,7 @@ class PGconn:
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
         if not isinstance(command, bytes):
-            raise TypeError(
-                f"'command' must be bytes, got {type(command)} instead"
-            )
+            raise TypeError(f"'command' must be bytes, got {type(command)} instead")
 
         if not param_types:
             nparams = 0
@@ -450,9 +442,7 @@ class PGconn:
         if param_values:
             nparams = len(param_values)
             aparams = (c_char_p * nparams)(*param_values)
-            alenghts = (c_int * nparams)(
-                *(len(p) if p else 0 for p in param_values)
-            )
+            alenghts = (c_int * nparams)(*(len(p) if p else 0 for p in param_values))
         else:
             nparams = 0
             aparams = alenghts = None
@@ -523,9 +513,7 @@ class PGconn:
 
     def consume_input(self) -> None:
         if 1 != impl.PQconsumeInput(self._pgconn_ptr):
-            raise e.OperationalError(
-                f"consuming input failed: {error_message(self)}"
-            )
+            raise e.OperationalError(f"consuming input failed: {error_message(self)}")
 
     def is_busy(self) -> int:
         return impl.PQisBusy(self._pgconn_ptr)
@@ -543,9 +531,7 @@ class PGconn:
 
     def flush(self) -> int:
         if not self._pgconn_ptr:
-            raise e.OperationalError(
-                "flushing failed: the connection is closed"
-            )
+            raise e.OperationalError("flushing failed: the connection is closed")
         rv: int = impl.PQflush(self._pgconn_ptr)
         if rv < 0:
             raise e.OperationalError(f"flushing failed: {error_message(self)}")
@@ -581,24 +567,18 @@ class PGconn:
             buffer = bytes(buffer)
         rv = impl.PQputCopyData(self._pgconn_ptr, buffer, len(buffer))
         if rv < 0:
-            raise e.OperationalError(
-                f"sending copy data failed: {error_message(self)}"
-            )
+            raise e.OperationalError(f"sending copy data failed: {error_message(self)}")
         return rv
 
     def put_copy_end(self, error: Optional[bytes] = None) -> int:
         rv = impl.PQputCopyEnd(self._pgconn_ptr, error)
         if rv < 0:
-            raise e.OperationalError(
-                f"sending copy end failed: {error_message(self)}"
-            )
+            raise e.OperationalError(f"sending copy end failed: {error_message(self)}")
         return rv
 
     def get_copy_data(self, async_: int) -> Tuple[int, memoryview]:
         buffer_ptr = c_char_p()
-        nbytes = impl.PQgetCopyData(
-            self._pgconn_ptr, byref(buffer_ptr), async_
-        )
+        nbytes = impl.PQgetCopyData(self._pgconn_ptr, byref(buffer_ptr), async_)
         if nbytes == -2:
             raise e.OperationalError(
                 f"receiving copy data failed: {error_message(self)}"
@@ -626,9 +606,7 @@ class PGconn:
     def encrypt_password(
         self, passwd: bytes, user: bytes, algorithm: Optional[bytes] = None
     ) -> bytes:
-        out = impl.PQencryptPasswordConn(
-            self._pgconn_ptr, passwd, user, algorithm
-        )
+        out = impl.PQencryptPasswordConn(self._pgconn_ptr, passwd, user, algorithm)
         if not out:
             raise e.OperationalError(
                 f"password encryption failed: {error_message(self)}"
@@ -686,9 +664,7 @@ class PGconn:
         :raises ~e.OperationalError: if the flush request failed.
         """
         if impl.PQsendFlushRequest(self._pgconn_ptr) == 0:
-            raise e.OperationalError(
-                f"flush request failed: {error_message(self)}"
-            )
+            raise e.OperationalError(f"flush request failed: {error_message(self)}")
 
     def _call_bytes(
         self, func: Callable[[impl.PGconn_struct], Optional[bytes]]
@@ -805,12 +781,8 @@ class PGresult:
     def binary_tuples(self) -> int:
         return impl.PQbinaryTuples(self._pgresult_ptr)
 
-    def get_value(
-        self, row_number: int, column_number: int
-    ) -> Optional[bytes]:
-        length: int = impl.PQgetlength(
-            self._pgresult_ptr, row_number, column_number
-        )
+    def get_value(self, row_number: int, column_number: int) -> Optional[bytes]:
+        length: int = impl.PQgetlength(self._pgresult_ptr, row_number, column_number)
         if length:
             v = impl.PQgetvalue(self._pgresult_ptr, row_number, column_number)
             return string_at(v, length)
@@ -842,8 +814,7 @@ class PGresult:
 
     def set_attributes(self, descriptions: List[PGresAttDesc]) -> None:
         structs = [
-            impl.PGresAttDesc_struct(*desc)  # type: ignore
-            for desc in descriptions
+            impl.PGresAttDesc_struct(*desc) for desc in descriptions  # type: ignore
         ]
         array = (impl.PGresAttDesc_struct * len(structs))(*structs)  # type: ignore
         rv = impl.PQsetResultAttrs(self._pgresult_ptr, len(structs), array)
@@ -884,9 +855,7 @@ class PGcancel:
         See :pq:`PQcancel()` for details.
         """
         buf = create_string_buffer(256)
-        res = impl.PQcancel(
-            self.pgcancel_ptr, pointer(buf), len(buf)  # type: ignore
-        )
+        res = impl.PQcancel(self.pgcancel_ptr, pointer(buf), len(buf))  # type: ignore
         if not res:
             raise e.OperationalError(
                 f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
@@ -956,9 +925,7 @@ class Escaping:
 
     def escape_literal(self, data: "abc.Buffer") -> memoryview:
         if not self.conn:
-            raise e.OperationalError(
-                "escape_literal failed: no connection provided"
-            )
+            raise e.OperationalError("escape_literal failed: no connection provided")
 
         self.conn._ensure_pgconn()
         # TODO: might be done without copy (however C does that)
@@ -975,9 +942,7 @@ class Escaping:
 
     def escape_identifier(self, data: "abc.Buffer") -> memoryview:
         if not self.conn:
-            raise e.OperationalError(
-                "escape_identifier failed: no connection provided"
-            )
+            raise e.OperationalError("escape_identifier failed: no connection provided")
 
         self.conn._ensure_pgconn()
 
@@ -1039,9 +1004,7 @@ class Escaping:
                 pointer(t_cast(c_ulong, len_out)),
             )
         else:
-            out = impl.PQescapeBytea(
-                data, len(data), pointer(t_cast(c_ulong, len_out))
-            )
+            out = impl.PQescapeBytea(data, len(data), pointer(t_cast(c_ulong, len_out)))
         if not out:
             raise MemoryError(
                 f"couldn't allocate for escape_bytea of {len(data)} bytes"
index 8032fb2d5c758e4a8470875893150d60817a17ab..d7acac4703e3e02605225b9bb5babb0f18b26050 100644 (file)
@@ -139,9 +139,7 @@ def namedtuple_row(
 
 
 # ascii except alnum and underscore
-_re_clean = re.compile(
-    "[" + re.escape(" !\"#$%&'()*+,-./:;<=>?@[\\]^`{|}~") + "]"
-)
+_re_clean = re.compile("[" + re.escape(" !\"#$%&'()*+,-./:;<=>?@[\\]^`{|}~") + "]")
 
 
 @functools.lru_cache(512)
index 7b6a4ec69d6de9245f4cb0fe29d32cd5102970d8..b10f066c564ffc6e49d3d7c0cb263fa495b28cad 100644 (file)
@@ -81,9 +81,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         # The above result only returned COMMAND_OK. Get the cursor shape
         yield from self._describe_gen(cur)
 
-    def _describe_gen(
-        self, cur: BaseCursor[ConnectionType, Row]
-    ) -> PQGen[None]:
+    def _describe_gen(self, cur: BaseCursor[ConnectionType, Row]) -> PQGen[None]:
         conn = cur._conn
         conn.pgconn.send_describe_portal(self.name.encode(cur._encoding))
         results = yield from execute(conn.pgconn)
@@ -130,9 +128,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
             sql.SQL("ALL") if num is None else sql.Literal(num),
             sql.Identifier(self.name),
         )
-        res = yield from cur._conn._exec_command(
-            query, result_format=self._format
-        )
+        res = yield from cur._conn._exec_command(query, result_format=self._format)
 
         cur.pgresult = res
         cur._tx.set_pgresult(res, set_loaders=False)
@@ -142,9 +138,7 @@ class ServerCursorHelper(Generic[ConnectionType, Row]):
         self, cur: BaseCursor[ConnectionType, Row], value: int, mode: str
     ) -> PQGen[None]:
         if mode not in ("relative", "absolute"):
-            raise ValueError(
-                f"bad mode: {mode}. It should be 'relative' or 'absolute'"
-            )
+            raise ValueError(f"bad mode: {mode}. It should be 'relative' or 'absolute'")
         query = sql.SQL("MOVE{} {} FROM {}").format(
             sql.SQL(" ABSOLUTE" if mode == "absolute" else ""),
             sql.Literal(value),
@@ -258,9 +252,7 @@ class ServerCursor(Cursor[Row]):
 
         try:
             with self._conn.lock:
-                self._conn.wait(
-                    self._helper._declare_gen(self, query, params, binary)
-                )
+                self._conn.wait(self._helper._declare_gen(self, query, params, binary))
         except e.Error as ex:
             raise ex.with_traceback(None)
 
@@ -274,9 +266,7 @@ class ServerCursor(Cursor[Row]):
         returning: bool = True,
     ) -> None:
         """Method not implemented for server-side cursors."""
-        raise e.NotSupportedError(
-            "executemany not supported on server-side cursors"
-        )
+        raise e.NotSupportedError("executemany not supported on server-side cursors")
 
     def fetchone(self) -> Optional[Row]:
         with self._conn.lock:
@@ -304,9 +294,7 @@ class ServerCursor(Cursor[Row]):
     def __iter__(self) -> Iterator[Row]:
         while True:
             with self._conn.lock:
-                recs = self._conn.wait(
-                    self._helper._fetch_gen(self, self.itersize)
-                )
+                recs = self._conn.wait(self._helper._fetch_gen(self, self.itersize))
             for rec in recs:
                 self._pos += 1
                 yield rec
@@ -399,9 +387,7 @@ class AsyncServerCursor(AsyncCursor[Row]):
         *,
         returning: bool = True,
     ) -> None:
-        raise e.NotSupportedError(
-            "executemany not supported on server-side cursors"
-        )
+        raise e.NotSupportedError("executemany not supported on server-side cursors")
 
     async def fetchone(self) -> Optional[Row]:
         async with self._conn.lock:
index 2f1a02ac4262bd2f2d00273075922448222eef74..acfeac0bf4fe8faf825ce3410c03b87b5efcf2c3 100644 (file)
@@ -125,9 +125,7 @@ class Composed(Composable):
     _obj: List[Composable]
 
     def __init__(self, seq: Sequence[Any]):
-        seq = [
-            obj if isinstance(obj, Composable) else Literal(obj) for obj in seq
-        ]
+        seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
         super().__init__(seq)
 
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
index 54212f8c91c06f8282a22bf403f4f477b4ba1e27..0c3b82c9a362f1d8730132f3e1338a1cf215ee6c 100644 (file)
@@ -116,9 +116,7 @@ class BaseTransaction(Generic[ConnectionType]):
                 # operational error that might arise in the block.
                 raise
             except Exception as exc2:
-                logger.warning(
-                    "error ignored in rollback of %s: %s", self, exc2
-                )
+                logger.warning("error ignored in rollback of %s: %s", self, exc2)
                 return False
 
     def _commit_gen(self) -> PQGen[PGresult]:
@@ -132,9 +130,7 @@ class BaseTransaction(Generic[ConnectionType]):
 
     def _rollback_gen(self, exc_val: Optional[BaseException]) -> PQGen[bool]:
         if isinstance(exc_val, Rollback):
-            logger.debug(
-                f"{self._conn}: Explicit rollback from: ", exc_info=True
-            )
+            logger.debug(f"{self._conn}: Explicit rollback from: ", exc_info=True)
 
         ex = self._pop_savepoint("rollback")
         self._exited = True
@@ -214,9 +210,7 @@ class BaseTransaction(Generic[ConnectionType]):
         else:
             # inner transaction: it always has a name
             if not self._savepoint_name:
-                self._savepoint_name = (
-                    f"_pg3_{self._conn._num_transactions + 1}"
-                )
+                self._savepoint_name = f"_pg3_{self._conn._num_transactions + 1}"
 
         self._stack_index = self._conn._num_transactions
         self._conn._num_transactions += 1
@@ -261,9 +255,7 @@ class Transaction(BaseTransaction["Connection[Any]"]):
     ) -> bool:
         if self._conn.pgconn.status == ConnStatus.OK:
             with self._conn.lock:
-                return self._conn.wait(
-                    self._exit_gen(exc_type, exc_val, exc_tb)
-                )
+                return self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
             return False
 
@@ -292,8 +284,6 @@ class AsyncTransaction(BaseTransaction["AsyncConnection[Any]"]):
     ) -> bool:
         if self._conn.pgconn.status == ConnStatus.OK:
             async with self._conn.lock:
-                return await self._conn.wait(
-                    self._exit_gen(exc_type, exc_val, exc_tb)
-                )
+                return await self._conn.wait(self._exit_gen(exc_type, exc_val, exc_tb))
         else:
             return False
index f1c6f621ac8d2c229ab138abe4271f0499f541bf..8d863545a4aa4ac9c270ac29d4c40503c45affbf 100644 (file)
@@ -22,14 +22,10 @@ from .._typeinfo import TypeInfo
 
 _struct_head = struct.Struct("!III")  # ndims, hasnull, elem oid
 _pack_head = cast(Callable[[int, int, int], bytes], _struct_head.pack)
-_unpack_head = cast(
-    Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from
-)
+_unpack_head = cast(Callable[[bytes], Tuple[int, int, int]], _struct_head.unpack_from)
 _struct_dim = struct.Struct("!II")  # dim, lower bound
 _pack_dim = cast(Callable[[int, int], bytes], _struct_dim.pack)
-_unpack_dim = cast(
-    Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from
-)
+_unpack_dim = cast(Callable[[bytes, int], Tuple[int, int]], _struct_dim.unpack_from)
 
 TEXT_ARRAY_OID = postgres.types["text"].array_oid
 NoneType: type = type(None)
@@ -45,9 +41,7 @@ class BaseListDumper(RecursiveDumper):
         super().__init__(cls, context)
         self.sub_dumper: Optional[Dumper] = None
         if self.element_oid and context:
-            sdclass = context.adapters.get_dumper_by_oid(
-                self.element_oid, self.format
-            )
+            sdclass = context.adapters.get_dumper_by_oid(self.element_oid, self.format)
             self.sub_dumper = sdclass(NoneType, context)
 
     def _find_list_element(self, L: List[Any]) -> Any:
@@ -295,9 +289,7 @@ class ListBinaryDumper(BaseListDumper):
             else:
                 for item in L:
                     if not isinstance(item, self.cls):
-                        raise e.DataError(
-                            "nested lists have inconsistent depths"
-                        )
+                        raise e.DataError("nested lists have inconsistent depths")
                     dump_list(item, dim + 1)  # type: ignore
 
         dump_list(obj, 0)
@@ -350,9 +342,7 @@ class ArrayLoader(BaseArrayLoader):
             else:
                 if not stack:
                     wat = (
-                        t[:10].decode("utf8", "replace") + "..."
-                        if len(t) > 10
-                        else ""
+                        t[:10].decode("utf8", "replace") + "..." if len(t) > 10 else ""
                     )
                     raise e.DataError(f"malformed array, unexpected '{wat}'")
                 if t == b"NULL":
@@ -422,9 +412,7 @@ class ArrayBinaryLoader(BaseArrayLoader):
         return agg(dims)
 
 
-def register_array(
-    info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_array(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
     if not info.array_oid:
         raise ValueError(f"the type info {info} doesn't describe an array")
 
index 57881e6b710f996fb45be85133e44d401ff1b60c..a5340993e048091b7621678014a4459ee9232e2a 100644 (file)
@@ -227,9 +227,7 @@ def register_composite(
     # A friendly error warning instead of an AttributeError in case fetch()
     # failed and it wasn't noticed.
     if not info:
-        raise TypeError(
-            "no info passed. Is the requested composite available?"
-        )
+        raise TypeError("no info passed. Is the requested composite available?")
 
     # Register arrays and type info
     info.register(context)
@@ -268,9 +266,7 @@ def register_composite(
         adapters.register_dumper(factory, dumper)
 
         # Default to the text dumper because it is more flexible
-        dumper = type(
-            f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid}
-        )
+        dumper = type(f"{info.name.title()}Dumper", (TupleDumper,), {"oid": info.oid})
         adapters.register_dumper(factory, dumper)
 
         info.python_type = factory
index 342bda36158efdab001a55ce6650bfc1c775b0a7..8ab52b0ed661fc5c6b6d2867c35b770fa94a55b7 100644 (file)
@@ -22,9 +22,7 @@ if TYPE_CHECKING:
 
 _struct_timetz = struct.Struct("!qi")  # microseconds, sec tz offset
 _pack_timetz = cast(Callable[[int, int], bytes], _struct_timetz.pack)
-_unpack_timetz = cast(
-    Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack
-)
+_unpack_timetz = cast(Callable[[bytes], Tuple[int, int]], _struct_timetz.unpack)
 
 _struct_interval = struct.Struct("!qii")  # microseconds, days, months
 _pack_interval = cast(Callable[[int, int, int], bytes], _struct_interval.pack)
@@ -168,9 +166,7 @@ class DatetimeBinaryDumper(_BaseDatetimeDumper):
 
     def dump(self, obj: datetime) -> bytes:
         delta = obj - _pg_datetimetz_epoch
-        micros = delta.microseconds + 1_000_000 * (
-            86_400 * delta.days + delta.seconds
-        )
+        micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
         return pack_int8(micros)
 
     def upgrade(self, obj: datetime, format: PyFormat) -> Dumper:
@@ -187,9 +183,7 @@ class DatetimeNoTzBinaryDumper(_BaseDatetimeDumper):
 
     def dump(self, obj: datetime) -> bytes:
         delta = obj - _pg_datetime_epoch
-        micros = delta.microseconds + 1_000_000 * (
-            86_400 * delta.days + delta.seconds
-        )
+        micros = delta.microseconds + 1_000_000 * (86_400 * delta.days + delta.seconds)
         return pack_int8(micros)
 
 
@@ -243,9 +237,7 @@ class DateLoader(Loader):
         elif ds.startswith(b"G"):  # German
             self._order = self._ORDER_DMY
         elif ds.startswith(b"S") or ds.startswith(b"P"):  # SQL or Postgres
-            self._order = (
-                self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
-            )
+            self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
         else:
             raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
 
@@ -326,9 +318,7 @@ class TimeBinaryLoader(Loader):
         try:
             return time(h, m, s, us)
         except ValueError:
-            raise DataError(
-                f"time not supported by Python: hour={h}"
-            ) from None
+            raise DataError(f"time not supported by Python: hour={h}") from None
 
 
 class TimetzLoader(Loader):
@@ -387,9 +377,7 @@ class TimetzBinaryLoader(Loader):
         try:
             return time(h, m, s, us, timezone(timedelta(seconds=-off)))
         except ValueError:
-            raise DataError(
-                f"time not supported by Python: hour={h}"
-            ) from None
+            raise DataError(f"time not supported by Python: hour={h}") from None
 
 
 class TimestampLoader(Loader):
@@ -432,13 +420,9 @@ class TimestampLoader(Loader):
         elif ds.startswith(b"G"):  # German
             self._order = self._ORDER_DMY
         elif ds.startswith(b"S"):  # SQL
-            self._order = (
-                self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
-            )
+            self._order = self._ORDER_DMY if ds.endswith(b"DMY") else self._ORDER_MDY
         elif ds.startswith(b"P"):  # Postgres
-            self._order = (
-                self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
-            )
+            self._order = self._ORDER_PGDM if ds.endswith(b"DMY") else self._ORDER_PGMD
             self._re_format = self._re_format_pg
         else:
             raise InterfaceError(f"unexpected DateStyle: {ds.decode('ascii')}")
@@ -480,9 +464,7 @@ class TimestampLoader(Loader):
             us = 0
 
         try:
-            return datetime(
-                int(ye), imo, int(da), int(ho), int(mi), int(se), us
-            )
+            return datetime(int(ye), imo, int(da), int(ho), int(mi), int(se), us)
         except ValueError as e:
             s = bytes(data).decode("utf8", "replace")
             raise DataError(f"can't parse timestamp {s!r}: {e}") from None
@@ -498,13 +480,9 @@ class TimestampBinaryLoader(Loader):
             return _pg_datetime_epoch + timedelta(microseconds=micros)
         except OverflowError:
             if micros <= 0:
-                raise DataError(
-                    "timestamp too small (before year 1)"
-                ) from None
+                raise DataError("timestamp too small (before year 1)") from None
             else:
-                raise DataError(
-                    "timestamp too large (after year 10K)"
-                ) from None
+                raise DataError("timestamp too large (after year 10K)") from None
 
 
 class TimestamptzLoader(Loader):
@@ -523,9 +501,7 @@ class TimestamptzLoader(Loader):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._timezone = get_tzinfo(
-            self.connection.pgconn if self.connection else None
-        )
+        self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
 
         ds = _get_datestyle(self.connection)
         if not ds.startswith(b"I"):  # not ISO
@@ -565,9 +541,7 @@ class TimestamptzLoader(Loader):
         dt = None
         ex: Exception
         try:
-            dt = datetime(
-                int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc
-            )
+            dt = datetime(int(ye), int(mo), int(da), int(ho), int(mi), int(se), us, utc)
             return (dt - tzoff).astimezone(self._timezone)
         except OverflowError as e:
             # If we have created the temporary 'dt' it means that we have a
@@ -597,9 +571,7 @@ class TimestamptzBinaryLoader(Loader):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._timezone = get_tzinfo(
-            self.connection.pgconn if self.connection else None
-        )
+        self._timezone = get_tzinfo(self.connection.pgconn if self.connection else None)
 
     def load(self, data: Buffer) -> datetime:
         micros = unpack_int8(data)[0]
@@ -618,22 +590,16 @@ class TimestamptzBinaryLoader(Loader):
                 if utcoff:
                     usoff = 1_000_000 * int(utcoff.total_seconds())
                     try:
-                        ts = _pg_datetime_epoch + timedelta(
-                            microseconds=micros + usoff
-                        )
+                        ts = _pg_datetime_epoch + timedelta(microseconds=micros + usoff)
                     except OverflowError:
                         pass  # will raise downstream
                     else:
                         return ts.replace(tzinfo=self._timezone)
 
             if micros <= 0:
-                raise DataError(
-                    "timestamp too small (before year 1)"
-                ) from None
+                raise DataError("timestamp too small (before year 1)") from None
             else:
-                raise DataError(
-                    "timestamp too large (after year 10K)"
-                ) from None
+                raise DataError("timestamp too large (after year 10K)") from None
 
 
 class IntervalLoader(Loader):
@@ -722,9 +688,7 @@ def _get_datestyle(conn: Optional["BaseConnection[Any]"]) -> bytes:
 
 _month_abbr = {
     n: i
-    for i, n in enumerate(
-        b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1
-    )
+    for i, n in enumerate(b"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split(), 1)
 }
 
 # Pad to get microseconds from a fraction of seconds
index 80129123c264fc3281986036789963b4e2c821c8..8c59931b776d265e435128dff3453c62074bee1d 100644 (file)
@@ -91,16 +91,12 @@ class HstoreLoader(RecursiveLoader):
             start = m.end()
 
         if start < len(s):
-            raise e.DataError(
-                f"error parsing hstore: unparsed data after char {start}"
-            )
+            raise e.DataError(f"error parsing hstore: unparsed data after char {start}")
 
         return rv
 
 
-def register_hstore(
-    info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_hstore(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
     """Register the adapters to load and dump hstore.
 
     :param info: The object with the information about the hstore type.
index 9d6763c6faa5903b676bad2239cef37cff578346..5f22a993f768d22e9dbadab00b94c611e4a7dbf1 100644 (file)
@@ -52,9 +52,7 @@ class Multirange(MutableSequence[Range[T]]):
     def __getitem__(self, index: slice) -> "Multirange[T]":
         ...
 
-    def __getitem__(
-        self, index: Union[int, slice]
-    ) -> "Union[Range[T],Multirange[T]]":
+    def __getitem__(self, index: Union[int, slice]) -> "Union[Range[T],Multirange[T]]":
         if isinstance(index, int):
             return self._ranges[index]
         else:
@@ -160,9 +158,7 @@ class BaseMultirangeDumper(RecursiveDumper):
         else:
             return (self.cls,)
 
-    def upgrade(
-        self, obj: Multirange[Any], format: PyFormat
-    ) -> "BaseMultirangeDumper":
+    def upgrade(self, obj: Multirange[Any], format: PyFormat) -> "BaseMultirangeDumper":
         # If we are a subclass whose oid is specified we don't need upgrade
         if self.cls is not Multirange:
             return self
@@ -261,9 +257,7 @@ class BaseMultirangeLoader(RecursiveLoader, Generic[T]):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._load = self._tx.get_loader(
-            self.subtype_oid, format=self.format
-        ).load
+        self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
 
 
 class MultirangeLoader(BaseMultirangeLoader[T]):
@@ -354,9 +348,7 @@ def register_multirange(
     # A friendly error warning instead of an AttributeError in case fetch()
     # failed and it wasn't noticed.
     if not info:
-        raise TypeError(
-            "no info passed. Is the requested multirange available?"
-        )
+        raise TypeError("no info passed. Is the requested multirange available?")
 
     # Register arrays and type info
     info.register(context)
@@ -501,19 +493,13 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_dumper(NumericMultirange, NumericMultirangeDumper)
     adapters.register_dumper(DateMultirange, DateMultirangeDumper)
     adapters.register_dumper(TimestampMultirange, TimestampMultirangeDumper)
-    adapters.register_dumper(
-        TimestamptzMultirange, TimestamptzMultirangeDumper
-    )
+    adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeDumper)
     adapters.register_dumper(Int4Multirange, Int4MultirangeBinaryDumper)
     adapters.register_dumper(Int8Multirange, Int8MultirangeBinaryDumper)
     adapters.register_dumper(NumericMultirange, NumericMultirangeBinaryDumper)
     adapters.register_dumper(DateMultirange, DateMultirangeBinaryDumper)
-    adapters.register_dumper(
-        TimestampMultirange, TimestampMultirangeBinaryDumper
-    )
-    adapters.register_dumper(
-        TimestamptzMultirange, TimestamptzMultirangeBinaryDumper
-    )
+    adapters.register_dumper(TimestampMultirange, TimestampMultirangeBinaryDumper)
+    adapters.register_dumper(TimestamptzMultirange, TimestamptzMultirangeBinaryDumper)
     adapters.register_loader("int4multirange", Int4MultirangeLoader)
     adapters.register_loader("int8multirange", Int8MultirangeLoader)
     adapters.register_loader("nummultirange", NumericMultirangeLoader)
@@ -525,6 +511,4 @@ def register_default_adapters(context: AdaptContext) -> None:
     adapters.register_loader("nummultirange", NumericMultirangeBinaryLoader)
     adapters.register_loader("datemultirange", DateMultirangeBinaryLoader)
     adapters.register_loader("tsmultirange", TimestampMultirangeBinaryLoader)
-    adapters.register_loader(
-        "tstzmultirange", TimestampTZMultirangeBinaryLoader
-    )
+    adapters.register_loader("tstzmultirange", TimestampTZMultirangeBinaryLoader)
index 5045895d8dc53a09af565abf1eff0453545cce38..cc5ed1a7e1928355a7a791781007f487feac7a45 100644 (file)
@@ -341,9 +341,7 @@ class NumericBinaryLoader(Loader):
             try:
                 return _decimal_special[sign]
             except KeyError:
-                raise e.DataError(
-                    f"bad value for numeric sign: 0x{sign:X}"
-                ) from None
+                raise e.DataError(f"bad value for numeric sign: 0x{sign:X}") from None
 
 
 NUMERIC_NAN_BIN = _pack_numeric_head(0, 0, NUMERIC_NAN, 0)
index 85b270f9e266a94899f69fa11b769d2a635c13d6..fb4c6e8082111105a0c6ad92a6197d359e8e51e4 100644 (file)
@@ -209,9 +209,7 @@ class Range(Generic[T]):
 
     def __getstate__(self) -> Dict[str, Any]:
         return {
-            slot: getattr(self, slot)
-            for slot in self.__slots__
-            if hasattr(self, slot)
+            slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
         }
 
     def __setstate__(self, state: Dict[str, Any]) -> None:
@@ -373,9 +371,7 @@ class RangeBinaryDumper(BaseRangeDumper):
         return dump_range_binary(obj, dump)
 
 
-def dump_range_binary(
-    obj: Range[Any], dump: Callable[[Any], Buffer]
-) -> Buffer:
+def dump_range_binary(obj: Range[Any], dump: Callable[[Any], Buffer]) -> Buffer:
     if not obj:
         return _EMPTY_HEAD
 
@@ -419,9 +415,7 @@ class BaseRangeLoader(RecursiveLoader, Generic[T]):
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
-        self._load = self._tx.get_loader(
-            self.subtype_oid, format=self.format
-        ).load
+        self._load = self._tx.get_loader(self.subtype_oid, format=self.format).load
 
 
 class RangeLoader(BaseRangeLoader[T]):
@@ -492,9 +486,7 @@ class RangeBinaryLoader(BaseRangeLoader[T]):
         return load_range_binary(data, self._load)
 
 
-def load_range_binary(
-    data: Buffer, load: Callable[[Buffer], Any]
-) -> Range[Any]:
+def load_range_binary(data: Buffer, load: Callable[[Buffer], Any]) -> Range[Any]:
     head = data[0]
     if head & RANGE_EMPTY:
         return Range(empty=True)
@@ -522,9 +514,7 @@ def load_range_binary(
     return Range(min, max, lb + ub)
 
 
-def register_range(
-    info: RangeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_range(info: RangeInfo, context: Optional[AdaptContext] = None) -> None:
     """Register the adapters to load and dump a range type.
 
     :param info: The object with the information about the range to register.
index 12d5bcb0ef930e66c380749855b5cd45d46b1532..54dde629bf41cf9ecef9004a1039707013f46d17 100644 (file)
@@ -51,9 +51,7 @@ class BaseGeometryDumper(Dumper):
         return dumps(obj, hex=True).encode()  # type: ignore
 
 
-def register_shapely(
-    info: TypeInfo, context: Optional[AdaptContext] = None
-) -> None:
+def register_shapely(info: TypeInfo, context: Optional[AdaptContext] = None) -> None:
     """Register Shapely dumper and loaders.
 
     After invoking this function on an adapter, the queries retrieving
index f389e43c0ba6d5fb8e0973093cbe777383dbcec1..42cd319d14f97824558505c327e4f0a30d237df0 100644 (file)
@@ -44,9 +44,7 @@ class StrBinaryDumper(_BaseStrDumper):
 class _StrDumper(_BaseStrDumper):
     def dump(self, obj: str) -> bytes:
         if "\x00" in obj:
-            raise DataError(
-                "PostgreSQL text fields cannot contain NUL (0x00) bytes"
-            )
+            raise DataError("PostgreSQL text fields cannot contain NUL (0x00) bytes")
         else:
             return obj.encode(self._encoding)
 
@@ -112,9 +110,7 @@ class BytesDumper(Dumper):
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-        self._esc = Escaping(
-            self.connection.pgconn if self.connection else None
-        )
+        self._esc = Escaping(self.connection.pgconn if self.connection else None)
 
     def dump(self, obj: bytes) -> Buffer:
         return self._esc.escape_bytea(obj)
index ed6fc1a598fee82b4561f28ab5ab307adfbf5fe1..e2a494430c82252fa1ed2f1dff8e18446e26350d 100644 (file)
@@ -32,9 +32,7 @@ class Ready(IntEnum):
     RW = EVENT_READ | EVENT_WRITE
 
 
-def wait_selector(
-    gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
-) -> RV:
+def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
     """
     Wait for a generator using the best strategy available.
 
@@ -150,9 +148,7 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
         return rv
 
 
-async def wait_conn_async(
-    gen: PQGenConn[RV], timeout: Optional[float] = None
-) -> RV:
+async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
     """
     Coroutine waiting for a connection generator to complete.
 
@@ -208,9 +204,7 @@ async def wait_conn_async(
         return rv
 
 
-def wait_epoll(
-    gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
-) -> RV:
+def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> RV:
     """
     Wait for a generator using epoll where supported.
 
index 80bad514b57cd50e70a52dd329ceff12b0249a0d..21e410c74b3e8c7fade8709ffad94c2cd1debd2c 100644 (file)
@@ -1,6 +1,3 @@
 [build-system]
 requires = ["setuptools>=49.2.0", "wheel>=0.37"]
 build-backend = "setuptools.build_meta"
-
-[tool.black]
-line-length = 79
index 6561ed7b15539808c6f439d117d4ed1370977a70..77733d4f74e39de342c1f27ac7ce1d7b85245700 100644 (file)
@@ -33,5 +33,5 @@ deps =
     shapely
 
 [flake8]
-max-line-length = 85
+max-line-length = 88
 ignore = W503, E203
index 457bd11bf6eaba0cf0cf896235ceeb1d29335aec..14db92bc24bfc01843607b65e1d7ccab1945e38c 100644 (file)
@@ -9,8 +9,6 @@ import sys
 # This package shouldn't be imported before psycopg itself, or weird things
 # will happen
 if "psycopg" not in sys.modules:
-    raise ImportError(
-        "the psycopg package should be imported before psycopg_c"
-    )
+    raise ImportError("the psycopg package should be imported before psycopg_c")
 
 from .version import __version__ as __version__  # noqa
index 2f5027ef8f600683ed195b6c2b86996fe34493aa..b992cb5e8a2bc1e8b96729f2d48ae596cf8923e4 100644 (file)
@@ -33,23 +33,15 @@ class Transformer(abc.AdaptContext):
         set_loaders: bool = True,
         format: Optional[pq.Format] = None,
     ) -> None: ...
-    def set_dumper_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None: ...
-    def set_loader_types(
-        self, types: Sequence[int], format: pq.Format
-    ) -> None: ...
+    def set_dumper_types(self, types: Sequence[int], format: pq.Format) -> None: ...
+    def set_loader_types(self, types: Sequence[int], format: pq.Format) -> None: ...
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Sequence[Optional[abc.Buffer]]: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
-    def load_rows(
-        self, row0: int, row1: int, make_row: RowMaker[Row]
-    ) -> List[Row]: ...
+    def load_rows(self, row0: int, row1: int, make_row: RowMaker[Row]) -> List[Row]: ...
     def load_row(self, row: int, make_row: RowMaker[Row]) -> Optional[Row]: ...
-    def load_sequence(
-        self, record: Sequence[Optional[bytes]]
-    ) -> Tuple[Any, ...]: ...
+    def load_sequence(self, record: Sequence[Optional[bytes]]) -> Tuple[Any, ...]: ...
     def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
 
 # Generators
index a0aacce0c1af968bc296b7a617002fdbb4cb4600..515aa66c48ef58a1081d4d4674a1ff0b50c81f6e 100644 (file)
@@ -1,6 +1,3 @@
 [build-system]
 requires = ["setuptools>=49.2.0", "wheel>=0.37", "Cython>=3.0a5"]
 build-backend = "setuptools.build_meta"
-
-[tool.black]
-line-length = 79
index 3703e03076d7c6a1b90ac4fbd86c5b4ab435f143..a77cf45e107a0d9a85d2ab84352c03bfe6ca10e8 100644 (file)
@@ -14,5 +14,5 @@ deps =
     -e {toxinidir}/../psycopg_pool
 
 [flake8]
-max-line-length = 85
+max-line-length = 88
 ignore = W503, E203
index a5ffab9295dac7ff82592666edf5e57a6e86bcc6..298ea6837621c3c29ad968304c45b26aa471005d 100644 (file)
@@ -51,9 +51,7 @@ class BasePool(Generic[ConnectionType]):
         max_lifetime: float = 60 * 60.0,
         max_idle: float = 10 * 60.0,
         reconnect_timeout: float = 5 * 60.0,
-        reconnect_failed: Optional[
-            Callable[["BasePool[ConnectionType]"], None]
-        ] = None,
+        reconnect_failed: Optional[Callable[["BasePool[ConnectionType]"], None]] = None,
         num_workers: int = 3,
     ):
         min_size, max_size = self._check_size(min_size, max_size)
@@ -118,9 +116,7 @@ class BasePool(Generic[ConnectionType]):
         """`!True` if the pool is closed."""
         return self._closed
 
-    def _check_size(
-        self, min_size: int, max_size: Optional[int]
-    ) -> Tuple[int, int]:
+    def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
         if max_size is None:
             max_size = min_size
 
@@ -129,9 +125,7 @@ class BasePool(Generic[ConnectionType]):
         if max_size < min_size:
             raise ValueError("max_size must be greater or equal than min_size")
         if min_size == max_size == 0:
-            raise ValueError(
-                "if min_size is 0 max_size must be greater or than 0"
-            )
+            raise ValueError("if min_size is 0 max_size must be greater or than 0")
 
         return min_size, max_size
 
@@ -203,9 +197,7 @@ class BasePool(Generic[ConnectionType]):
 
         Add some randomness to avoid mass reconnection.
         """
-        conn._expire_at = monotonic() + self._jitter(
-            self.max_lifetime, -0.05, 0.0
-        )
+        conn._expire_at = monotonic() + self._jitter(self.max_lifetime, -0.05, 0.0)
 
 
 class ConnectionAttempt:
index 436485d7e8f381f464d4245f956ceb1894631b8b..c0a77c24672dd0830842b30342ef2912fdb2a767 100644 (file)
@@ -26,9 +26,7 @@ class _BaseNullConnectionPool:
             conninfo, *args, min_size=min_size, **kwargs
         )
 
-    def _check_size(
-        self, min_size: int, max_size: Optional[int]
-    ) -> Tuple[int, int]:
+    def _check_size(self, min_size: int, max_size: Optional[int]) -> Tuple[int, int]:
         if max_size is None:
             max_size = min_size
 
@@ -71,9 +69,7 @@ class NullConnectionPool(_BaseNullConnectionPool, ConnectionPool):
         self.run_task(AddConnection(self))
         if not self._pool_full_event.wait(timeout):
             self.close()  # stop all the threads
-            raise PoolTimeout(
-                f"pool initialization incomplete after {timeout} sec"
-            )
+            raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
 
         with self._lock:
             assert self._pool_full_event
index 3fa8a146240e7e78858ecd1ca8df77d112767103..ae9d207bca6ef229e9bf1f0854740d3d30d2b620 100644 (file)
@@ -62,9 +62,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
             )
         return conn
 
-    async def _maybe_close_connection(
-        self, conn: AsyncConnection[Any]
-    ) -> bool:
+    async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
         # Close the connection if no client is waiting for it, or if the pool
         # is closed. For extra refcare remove the pool reference from it.
         # Maintain the stats.
@@ -79,9 +77,7 @@ class AsyncNullConnectionPool(_BaseNullConnectionPool, AsyncConnectionPool):
             self._nconns -= 1
             return True
 
-    async def resize(
-        self, min_size: int, max_size: Optional[int] = None
-    ) -> None:
+    async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
         min_size, max_size = self._check_size(min_size, max_size)
 
         logger.info(
index 664ca123f318418112095c405b469855d786f978..d29092f15d2bd8d4ab2e5fbf47fa64d9d66fb586 100644 (file)
@@ -90,9 +90,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         logger.info("waiting for pool %r initialization", self.name)
         if not self._pool_full_event.wait(timeout):
             self.close()  # stop all the threads
-            raise PoolTimeout(
-                f"pool initialization incomplete after {timeout} sec"
-            )
+            raise PoolTimeout(f"pool initialization incomplete after {timeout} sec")
 
         with self._lock:
             assert self._pool_full_event
@@ -101,9 +99,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         logger.info("pool %r is ready to use", self.name)
 
     @contextmanager
-    def connection(
-        self, timeout: Optional[float] = None
-    ) -> Iterator[Connection[Any]]:
+    def connection(self, timeout: Optional[float] = None) -> Iterator[Connection[Any]]:
         """Context manager to obtain a connection from the pool.
 
         Return the connection immediately if available, otherwise wait up to
@@ -517,9 +513,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
         """
         now = monotonic()
         if not attempt:
-            attempt = ConnectionAttempt(
-                reconnect_timeout=self.reconnect_timeout
-            )
+            attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
 
         try:
             conn = self._connect()
@@ -548,9 +542,7 @@ class ConnectionPool(BasePool[Connection[Any]]):
             with self._lock:
                 if self._nconns < self._max_size and self._waiting:
                     self._nconns += 1
-                    logger.info(
-                        "growing pool %r to %s", self.name, self._nconns
-                    )
+                    logger.info("growing pool %r to %s", self.name, self._nconns)
                     self.run_task(AddConnection(self, growing=True))
                 else:
                     self._growing = False
@@ -748,9 +740,7 @@ class MaintenanceTask(ABC):
 
     def __init__(self, pool: "ConnectionPool"):
         self.pool = ref(pool)
-        logger.debug(
-            "task created in %s: %s", threading.current_thread().name, self
-        )
+        logger.debug("task created in %s: %s", threading.current_thread().name, self)
 
     def __repr__(self) -> str:
         pool = self.pool()
@@ -768,9 +758,7 @@ class MaintenanceTask(ABC):
             # Pool is no more working. Quietly discard the operation.
             return
 
-        logger.debug(
-            "task running in %s: %s", threading.current_thread().name, self
-        )
+        logger.debug("task running in %s: %s", threading.current_thread().name, self)
         self._run(pool)
 
     def tick(self) -> None:
index bb7be930ad6089026b673e909184f95a6cd6e95c..b8b6af93734edd5f1e6cd75d737e9606190ce8a7 100644 (file)
@@ -33,12 +33,8 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         *,
         open: bool = True,
         connection_class: Type[AsyncConnection[Any]] = AsyncConnection,
-        configure: Optional[
-            Callable[[AsyncConnection[Any]], Awaitable[None]]
-        ] = None,
-        reset: Optional[
-            Callable[[AsyncConnection[Any]], Awaitable[None]]
-        ] = None,
+        configure: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
+        reset: Optional[Callable[[AsyncConnection[Any]], Awaitable[None]]] = None,
         **kwargs: Any,
     ):
         self.connection_class = connection_class
@@ -99,9 +95,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             self._stats[self._USAGE_MS] += int(1000.0 * (t1 - t0))
             await self.putconn(conn)
 
-    async def getconn(
-        self, timeout: Optional[float] = None
-    ) -> AsyncConnection[Any]:
+    async def getconn(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
         logger.info("connection requested from %r", self.name)
         self._stats[self._REQUESTS_NUM] += 1
 
@@ -180,9 +174,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         else:
             await self._return_connection(conn)
 
-    async def _maybe_close_connection(
-        self, conn: AsyncConnection[Any]
-    ) -> bool:
+    async def _maybe_close_connection(self, conn: AsyncConnection[Any]) -> bool:
         # If the pool is closed just close the connection instead of returning
         # it to the pool. For extra refcare remove the pool reference from it.
         if not self._closed:
@@ -299,9 +291,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
     ) -> None:
         await self.close()
 
-    async def resize(
-        self, min_size: int, max_size: Optional[int] = None
-    ) -> None:
+    async def resize(self, min_size: int, max_size: Optional[int] = None) -> None:
         min_size, max_size = self._check_size(min_size, max_size)
 
         ngrow = max(0, min_size - self._min_size)
@@ -348,9 +338,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         """Run a maintenance task in a worker."""
         self._tasks.put_nowait(task)
 
-    async def schedule_task(
-        self, task: "MaintenanceTask", delay: float
-    ) -> None:
+    async def schedule_task(self, task: "MaintenanceTask", delay: float) -> None:
         """Run a maintenance task in a worker in the future."""
         await self._sched.enter(delay, task.tick)
 
@@ -381,9 +369,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
                     ex,
                 )
 
-    async def _connect(
-        self, timeout: Optional[float] = None
-    ) -> AsyncConnection[Any]:
+    async def _connect(self, timeout: Optional[float] = None) -> AsyncConnection[Any]:
         self._stats[self._CONNECTIONS_NUM] += 1
         kwargs = self.kwargs
         if timeout:
@@ -428,9 +414,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
         """
         now = monotonic()
         if not attempt:
-            attempt = ConnectionAttempt(
-                reconnect_timeout=self.reconnect_timeout
-            )
+            attempt = ConnectionAttempt(reconnect_timeout=self.reconnect_timeout)
 
         try:
             conn = await self._connect()
@@ -459,9 +443,7 @@ class AsyncConnectionPool(BasePool[AsyncConnection[Any]]):
             async with self._lock:
                 if self._nconns < self._max_size and self._waiting:
                     self._nconns += 1
-                    logger.info(
-                        "growing pool %r to %s", self.name, self._nconns
-                    )
+                    logger.info("growing pool %r to %s", self.name, self._nconns)
                     self.run_task(AddConnection(self, growing=True))
                 else:
                     self._growing = False
@@ -723,9 +705,7 @@ class AddConnection(MaintenanceTask):
 class ReturnConnection(MaintenanceTask):
     """Clean up and return a connection to the pool."""
 
-    def __init__(
-        self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"
-    ):
+    def __init__(self, pool: "AsyncConnectionPool", conn: "AsyncConnection[Any]"):
         super().__init__(pool)
         self.conn = conn
 
index 6c5b4bea52baf4c6d843330a429a38259dcfe18b..ca26007324a50a6267f33f28bcc1b3fb0defcbbf 100644 (file)
@@ -59,9 +59,7 @@ class Scheduler:
         time = monotonic() + delay
         return self.enterabs(time, action)
 
-    def enterabs(
-        self, time: float, action: Optional[Callable[[], Any]]
-    ) -> Task:
+    def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
         """Enter a new task in the queue at an absolute time.
 
         Schedule a `!None` to stop the execution.
@@ -119,9 +117,7 @@ class AsyncScheduler:
 
     EMPTY_QUEUE_TIMEOUT = 600.0
 
-    async def enter(
-        self, delay: float, action: Optional[Callable[[], Any]]
-    ) -> Task:
+    async def enter(self, delay: float, action: Optional[Callable[[], Any]]) -> Task:
         """Enter a new task in the queue delayed in the future.
 
         Schedule a `!None` to stop the execution.
@@ -129,9 +125,7 @@ class AsyncScheduler:
         time = monotonic() + delay
         return await self.enterabs(time, action)
 
-    async def enterabs(
-        self, time: float, action: Optional[Callable[[], Any]]
-    ) -> Task:
+    async def enterabs(self, time: float, action: Optional[Callable[[], Any]]) -> Task:
         """Enter a new task in the queue at an absolute time.
 
         Schedule a `!None` to stop the execution.
index 75809f6a2a82e22ac982434c57aa8e85f9ebe9a4..2ae629c2d4d3d195def647f96d7bce3cdbab83b6 100644 (file)
@@ -1,3 +1,3 @@
 [flake8]
-max-line-length = 85
+max-line-length = 88
 ignore = W503, E203
index 1e47972a0ea325176c0a4d448d94493cc2e9b2be..72bb520dc5b1dc18d2de3a47eb5335ec9172bc81 100644 (file)
@@ -2,9 +2,6 @@
 requires = ["setuptools>=49.2.0", "wheel>=0.37"]
 build-backend = "setuptools.build_meta"
 
-[tool.black]
-line-length = 79
-
 [tool.pytest.ini_options]
 asyncio_mode = "auto"
 filterwarnings = [
index 1797177def6569208e4e033e229cae5dd2c3c071..b5c932ad94c4b3e83e9f770d39dbaf29515268bd 100644 (file)
@@ -89,9 +89,7 @@ def maybe_trace(pgconn, tracefile, function):
 
     pgconn.trace(tracefile.fileno())
     try:
-        pgconn.set_trace_flags(
-            pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE
-        )
+        pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
     except psycopg.NotSupportedError:
         pass
     try:
@@ -107,9 +105,7 @@ def pgconn(dsn, request, tracefile):
 
     conn = pq.PGconn.connect(dsn.encode())
     if conn.status != pq.ConnStatus.OK:
-        pytest.fail(
-            f"bad connection: {conn.error_message.decode('utf8', 'replace')}"
-        )
+        pytest.fail(f"bad connection: {conn.error_message.decode('utf8', 'replace')}")
     msg = check_connection_version(conn.server_version, request.function)
     if msg:
         conn.finish()
@@ -222,9 +218,7 @@ def hstore(svcconn):
         pytest.skip(str(e))
 
 
-def warm_up_database(
-    dsn: str, __first_connection: List[bool] = [True]
-) -> None:
+def warm_up_database(dsn: str, __first_connection: List[bool] = [True]) -> None:
     """Connect to the database before returning a connection.
 
     In the CI sometimes, the first test fails with a timeout, probably because
index aebd553e5f94f30d0d68c0d58b01e24bb1c6a399..7befd5ebd8b4e7f0dd483e4f3efd81e868acdc22 100644 (file)
@@ -87,10 +87,7 @@ class Faker:
 
     @property
     def types_names(self):
-        types = [
-            t.as_string(self.conn).replace('"', "")
-            for t in self.types_names_sql
-        ]
+        types = [t.as_string(self.conn).replace('"', "") for t in self.types_names_sql]
         return types
 
     def _get_type_name(self, tx, schema, value):
@@ -118,16 +115,13 @@ class Faker:
             field_values.append(sql.SQL("{} {}").format(name, type))
 
         fields = sql.SQL(", ").join(field_values)
-        return sql.SQL(
-            "create table {table} (id serial primary key, {fields})"
-        ).format(table=self.table_name, fields=fields)
+        return sql.SQL("create table {table} (id serial primary key, {fields})").format(
+            table=self.table_name, fields=fields
+        )
 
     @property
     def insert_stmt(self):
-        phs = [
-            sql.Placeholder(format=self.format)
-            for i in range(len(self.schema))
-        ]
+        phs = [sql.Placeholder(format=self.format) for i in range(len(self.schema))]
         return sql.SQL("insert into {} ({}) values ({})").format(
             self.table_name,
             sql.SQL(", ").join(self.fields_names),
@@ -137,9 +131,7 @@ class Faker:
     @property
     def select_stmt(self):
         fields = sql.SQL(", ").join(self.fields_names)
-        return sql.SQL("select {} from {} order by id").format(
-            fields, self.table_name
-        )
+        return sql.SQL("select {} from {} order by id").format(fields, self.table_name)
 
     @contextmanager
     def find_insert_problem(self, conn):
@@ -216,8 +208,7 @@ class Faker:
             return tuple(self.example(spec) for spec in self.schema)
         else:
             return tuple(
-                self.make(spec) if random() > nulls else None
-                for spec in self.schema
+                self.make(spec) if random() > nulls else None for spec in self.schema
             )
 
     def assert_record(self, got, want):
@@ -233,10 +224,7 @@ class Faker:
         for cls in dumpers.keys():
             if isinstance(cls, str):
                 cls = deep_import(cls)
-            if (
-                issubclass(cls, Multirange)
-                and self.conn.info.server_version < 140000
-            ):
+            if issubclass(cls, Multirange) and self.conn.info.server_version < 140000:
                 continue
 
             rv.add(cls)
@@ -276,9 +264,7 @@ class Faker:
             self._makers[cls] = meth
             return meth
         else:
-            raise NotImplementedError(
-                f"cannot make fake objects of class {cls}"
-            )
+            raise NotImplementedError(f"cannot make fake objects of class {cls}")
 
     def get_matcher(self, spec):
         cls = spec if isinstance(spec, type) else spec[0]
@@ -393,9 +379,7 @@ class Faker:
                 else f"{choice('-+')}0.{randrange(1 << 22)}e{randrange(-37,38)}"
             )
         else:
-            return choice(
-                (0.0, -0.0, float("-inf"), float("inf"), float("nan"))
-            )
+            return choice((0.0, -0.0, float("-inf"), float("inf"), float("nan")))
 
     def match_float(self, spec, got, want, approx=False, rel=None):
         if got is not None and isnan(got):
@@ -476,9 +460,7 @@ class Faker:
     def make_JsonFloat(self, spec):
         # A float limited to what json accepts
         # this exponent should generate no inf
-        return float(
-            f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}"
-        )
+        return float(f"{choice('-+')}0.{randrange(1 << 20)}e{randrange(-15,15)}")
 
     def schema_list(self, cls):
         while True:
@@ -566,9 +548,7 @@ class Faker:
         return spec[0](sorted(out))
 
     def example_Multirange(self, spec):
-        return self.make_Multirange(
-            spec, length=1, empty_chance=0, no_bound_chance=0
-        )
+        return self.make_Multirange(spec, length=1, empty_chance=0, no_bound_chance=0)
 
     def make_Int4Multirange(self, spec):
         return self.make_Multirange((spec, Int4))
@@ -714,13 +694,9 @@ class Faker:
 
         if unit is not None:
             if want.lower is not None and not want.lower_inc:
-                want = type(want)(
-                    want.lower + unit, want.upper, "[" + want.bounds[1]
-                )
+                want = type(want)(want.lower + unit, want.upper, "[" + want.bounds[1])
             if want.upper_inc:
-                want = type(want)(
-                    want.lower, want.upper + unit, want.bounds[0] + ")"
-                )
+                want = type(want)(want.lower, want.upper + unit, want.bounds[0] + ")")
 
         if spec[1] == (dt.datetime, True) and not want.isempty:
             # work around https://bugs.python.org/issue45347
index 01e384f50e5f0c4007c091f68172d3b35e55f437..7101aac3df3406ddfb3a800bdc1d14870f6213d8 100644 (file)
@@ -87,9 +87,7 @@ class Proxy:
             raise ValueError("pproxy program not found")
         cmdline = [pproxy, "--reuse"]
         cmdline.extend(["-l", f"tunnel://:{self.client_port}"])
-        cmdline.extend(
-            ["-r", f"tunnel://{self.server_host}:{self.server_port}"]
-        )
+        cmdline.extend(["-r", f"tunnel://{self.server_host}:{self.server_port}"])
 
         self.proc = sp.Popen(cmdline, stdout=sp.DEVNULL)
         logging.info("proxy started")
index ff52c11f9eed118aec80bc568788b28c15bcae7d..dc703e5df1f5a0bdbddacde59f3e1eb2d22249af 100644 (file)
@@ -41,9 +41,7 @@ class Tpc:
         self.conn = conn
 
     def check_tpc(self):
-        val = int(
-            self.conn.execute("show max_prepared_transactions").fetchone()[0]
-        )
+        val = int(self.conn.execute("show max_prepared_transactions").fetchone()[0])
         if not val:
             pytest.skip("prepared transactions disabled in the database")
 
index 13768c0926f04b545938a6754d3b9eef3fb8e6a9..12e4f3941bfec22928fa997ecb768fcdf3bab50b 100644 (file)
@@ -2,9 +2,7 @@ import pytest
 
 
 def pytest_configure(config):
-    config.addinivalue_line(
-        "markers", "pool: test related to the psycopg_pool package"
-    )
+    config.addinivalue_line("markers", "pool: test related to the psycopg_pool package")
 
 
 def pytest_collection_modifyitems(items):
index 24d4ca4de79bc2e1fff56fd4d47ae10d0b907c3b..f7fe782cd8a4f461fea3f07e64a2e0a23187a43b 100644 (file)
@@ -33,9 +33,7 @@ def test_min_size_max_size(dsn):
         assert p.max_size == 2
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
 def test_bad_size(dsn, min_size, max_size):
     with pytest.raises(ValueError):
         NullConnectionPool(min_size=min_size, max_size=max_size)
@@ -275,9 +273,7 @@ def test_reset_broken(dsn, caplog):
 @pytest.mark.slow
 @pytest.mark.skipif("ver(psycopg.__version__) < ver('3.0.8')")
 def test_no_queue_timeout(deaf_port):
-    with NullConnectionPool(
-        kwargs={"host": "localhost", "port": deaf_port}
-    ) as p:
+    with NullConnectionPool(kwargs={"host": "localhost", "port": deaf_port}) as p:
         with pytest.raises(PoolTimeout):
             with p.connection(timeout=1):
                 pass
@@ -536,9 +532,7 @@ def test_active_close(dsn, caplog):
         ensure_waiting(p)
 
         pids.append(conn.info.backend_pid)
-        conn.pgconn.exec_(
-            b"copy (select * from generate_series(1, 10)) to stdout"
-        )
+        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
         assert conn.info.transaction_status == TransactionStatus.ACTIVE
         p.putconn(conn)
         t.join()
@@ -764,9 +758,7 @@ def test_reopen(dsn):
         p.open()
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
 def test_bad_resize(dsn, min_size, max_size):
     with NullConnectionPool() as p:
         with pytest.raises(ValueError):
index 824ed21367bfa2db73104b5a7bb305ca89691786..4213accae09d59b0ce16f6c60e81b8d46adc062a 100644 (file)
@@ -42,9 +42,7 @@ async def test_min_size_max_size(dsn):
         assert p.max_size == 2
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
 async def test_bad_size(dsn, min_size, max_size):
     with pytest.raises(ValueError):
         AsyncNullConnectionPool(min_size=min_size, max_size=max_size)
@@ -108,9 +106,7 @@ async def test_wait_closed(dsn):
 @pytest.mark.slow
 async def test_setup_no_timeout(dsn, proxy):
     with pytest.raises(PoolTimeout):
-        async with AsyncNullConnectionPool(
-            proxy.client_dsn, num_workers=1
-        ) as p:
+        async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
             await p.wait(0.2)
 
     async with AsyncNullConnectionPool(proxy.client_dsn, num_workers=1) as p:
@@ -297,9 +293,7 @@ async def test_queue(dsn):
     async def worker(n):
         t0 = time()
         async with p.connection() as conn:
-            cur = await conn.execute(
-                "select pg_backend_pid() from pg_sleep(0.2)"
-            )
+            cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
             (pid,) = await cur.fetchone()  # type: ignore[misc]
         t1 = time()
         results.append((n, t1 - t0, pid))
@@ -358,9 +352,7 @@ async def test_queue_timeout(dsn):
         t0 = time()
         try:
             async with p.connection() as conn:
-                cur = await conn.execute(
-                    "select pg_backend_pid() from pg_sleep(0.2)"
-                )
+                cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
                 (pid,) = await cur.fetchone()  # type: ignore[misc]
         except PoolTimeout as e:
             t1 = time()
@@ -414,9 +406,7 @@ async def test_queue_timeout_override(dsn):
         timeout = 0.25 if n == 3 else None
         try:
             async with p.connection(timeout=timeout) as conn:
-                cur = await conn.execute(
-                    "select pg_backend_pid() from pg_sleep(0.2)"
-                )
+                cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
                 (pid,) = await cur.fetchone()  # type: ignore[misc]
         except PoolTimeout as e:
             t1 = time()
@@ -527,9 +517,7 @@ async def test_active_close(dsn, caplog):
         await ensure_waiting(p)
 
         pids.append(conn.info.backend_pid)
-        conn.pgconn.exec_(
-            b"copy (select * from generate_series(1, 10)) to stdout"
-        )
+        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
         assert conn.info.transaction_status == TransactionStatus.ACTIVE
         await p.putconn(conn)
         await asyncio.gather(t)
@@ -737,9 +725,7 @@ async def test_reopen(dsn):
         await p.open()
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(1, None), (-1, None), (0, -2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(1, None), (-1, None), (0, -2)])
 async def test_bad_resize(dsn, min_size, max_size):
     async with AsyncNullConnectionPool() as p:
         with pytest.raises(ValueError):
index 26525d0c53ff2d269dd9165a7d80ab243df4c15d..b3b67392e97eabb158544d3f34db038d6d8e1414 100644 (file)
@@ -43,9 +43,7 @@ def test_min_size_max_size(dsn, min_size, max_size):
         assert p.max_size == max_size if max_size is not None else min_size
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
 def test_bad_size(dsn, min_size, max_size):
     with pytest.raises(ValueError):
         pool.ConnectionPool(min_size=min_size, max_size=max_size)
@@ -61,9 +59,7 @@ def test_connection_class(dsn):
 
 
 def test_kwargs(dsn):
-    with pool.ConnectionPool(
-        dsn, kwargs={"autocommit": True}, min_size=1
-    ) as p:
+    with pool.ConnectionPool(dsn, kwargs={"autocommit": True}, min_size=1) as p:
         with p.connection() as conn:
             assert conn.autocommit
 
@@ -149,9 +145,7 @@ def test_wait_closed(dsn):
 @pytest.mark.slow
 def test_setup_no_timeout(dsn, proxy):
     with pytest.raises(pool.PoolTimeout):
-        with pool.ConnectionPool(
-            proxy.client_dsn, min_size=1, num_workers=1
-        ) as p:
+        with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
             p.wait(0.2)
 
     with pool.ConnectionPool(proxy.client_dsn, min_size=1, num_workers=1) as p:
@@ -513,9 +507,7 @@ def test_active_close(dsn, caplog):
     with pool.ConnectionPool(dsn, min_size=1) as p:
         conn = p.getconn()
         pid = conn.info.backend_pid
-        conn.pgconn.exec_(
-            b"copy (select * from generate_series(1, 10)) to stdout"
-        )
+        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
         assert conn.info.transaction_status == TransactionStatus.ACTIVE
         p.putconn(conn)
 
@@ -796,9 +788,7 @@ def test_grow(dsn, monkeypatch, min_size, want_times):
         t1 = time()
         results.append((n, t1 - t0))
 
-    with pool.ConnectionPool(
-        dsn, min_size=min_size, max_size=4, num_workers=3
-    ) as p:
+    with pool.ConnectionPool(dsn, min_size=min_size, max_size=4, num_workers=3) as p:
         p.wait(1.0)
         results: List[Tuple[int, float]] = []
 
index 7b663bea6f3ace2d5be7a0e9aac8dc9b5c92f9a1..79316106f36e5c25d7a0c2b9f2cf4455574f47e2 100644 (file)
@@ -29,16 +29,12 @@ async def test_defaults(dsn):
 
 @pytest.mark.parametrize("min_size, max_size", [(2, None), (0, 2), (2, 4)])
 async def test_min_size_max_size(dsn, min_size, max_size):
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=min_size, max_size=max_size
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=min_size, max_size=max_size) as p:
         assert p.min_size == min_size
         assert p.max_size == max_size if max_size is not None else min_size
 
 
-@pytest.mark.parametrize(
-    "min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)]
-)
+@pytest.mark.parametrize("min_size, max_size", [(0, 0), (0, None), (-1, None), (4, 2)])
 async def test_bad_size(dsn, min_size, max_size):
     with pytest.raises(ValueError):
         pool.AsyncConnectionPool(min_size=min_size, max_size=max_size)
@@ -48,9 +44,7 @@ async def test_connection_class(dsn):
     class MyConn(psycopg.AsyncConnection[Any]):
         pass
 
-    async with pool.AsyncConnectionPool(
-        dsn, connection_class=MyConn, min_size=1
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, connection_class=MyConn, min_size=1) as p:
         async with p.connection() as conn:
             assert isinstance(conn, MyConn)
 
@@ -122,9 +116,7 @@ async def test_concurrent_filling(dsn, monkeypatch):
 async def test_wait_ready(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
     with pytest.raises(pool.PoolTimeout):
-        async with pool.AsyncConnectionPool(
-            dsn, min_size=4, num_workers=1
-        ) as p:
+        async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
             await p.wait(0.3)
 
     async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
@@ -171,9 +163,7 @@ async def test_configure(dsn):
         async with conn.transaction():
             await conn.execute("set default_transaction_read_only to on")
 
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=1, configure=configure
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
         await p.wait(timeout=1.0)
         async with p.connection() as conn:
             assert inits == 1
@@ -199,9 +189,7 @@ async def test_configure_badstate(dsn, caplog):
     async def configure(conn):
         await conn.execute("select 1")
 
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=1, configure=configure
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
         with pytest.raises(pool.PoolTimeout):
             await p.wait(timeout=0.5)
 
@@ -217,9 +205,7 @@ async def test_configure_broken(dsn, caplog):
         async with conn.transaction():
             await conn.execute("WAT")
 
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=1, configure=configure
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=1, configure=configure) as p:
         with pytest.raises(pool.PoolTimeout):
             await p.wait(timeout=0.5)
 
@@ -303,9 +289,7 @@ async def test_queue(dsn):
     async def worker(n):
         t0 = time()
         async with p.connection() as conn:
-            cur = await conn.execute(
-                "select pg_backend_pid() from pg_sleep(0.2)"
-            )
+            cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
             (pid,) = await cur.fetchone()  # type: ignore[misc]
         t1 = time()
         results.append((n, t1 - t0, pid))
@@ -364,9 +348,7 @@ async def test_queue_timeout(dsn):
         t0 = time()
         try:
             async with p.connection() as conn:
-                cur = await conn.execute(
-                    "select pg_backend_pid() from pg_sleep(0.2)"
-                )
+                cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
                 (pid,) = await cur.fetchone()  # type: ignore[misc]
         except pool.PoolTimeout as e:
             t1 = time()
@@ -421,9 +403,7 @@ async def test_queue_timeout_override(dsn):
         timeout = 0.25 if n == 3 else None
         try:
             async with p.connection(timeout=timeout) as conn:
-                cur = await conn.execute(
-                    "select pg_backend_pid() from pg_sleep(0.2)"
-                )
+                cur = await conn.execute("select pg_backend_pid() from pg_sleep(0.2)")
                 (pid,) = await cur.fetchone()  # type: ignore[misc]
         except pool.PoolTimeout as e:
             t1 = time()
@@ -506,9 +486,7 @@ async def test_active_close(dsn, caplog):
     async with pool.AsyncConnectionPool(dsn, min_size=1) as p:
         conn = await p.getconn()
         pid = conn.info.backend_pid
-        conn.pgconn.exec_(
-            b"copy (select * from generate_series(1, 10)) to stdout"
-        )
+        conn.pgconn.exec_(b"copy (select * from generate_series(1, 10)) to stdout")
         assert conn.info.transaction_status == TransactionStatus.ACTIVE
         await p.putconn(conn)
 
@@ -706,9 +684,7 @@ async def test_open_no_op(dsn):
 async def test_open_wait(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
     with pytest.raises(pool.PoolTimeout):
-        p = pool.AsyncConnectionPool(
-            dsn, min_size=4, num_workers=1, open=False
-        )
+        p = pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1, open=False)
         try:
             await p.open(wait=True, timeout=0.3)
         finally:
@@ -726,9 +702,7 @@ async def test_open_wait(dsn, monkeypatch):
 async def test_open_as_wait(dsn, monkeypatch):
     delay_connection(monkeypatch, 0.1)
     with pytest.raises(pool.PoolTimeout):
-        async with pool.AsyncConnectionPool(
-            dsn, min_size=4, num_workers=1
-        ) as p:
+        async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
             await p.open(wait=True, timeout=0.3)
 
     async with pool.AsyncConnectionPool(dsn, min_size=4, num_workers=1) as p:
@@ -801,9 +775,7 @@ async def test_shrink(dsn, monkeypatch):
         async with p.connection() as conn:
             await conn.execute("select pg_sleep(0.1)")
 
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=2, max_size=4, max_idle=0.2
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=2, max_size=4, max_idle=0.2) as p:
         await p.wait(5.0)
         assert p.max_idle == 0.2
 
@@ -953,9 +925,7 @@ async def test_bad_resize(dsn, min_size, max_size):
 
 
 def test_jitter():
-    rnds = [
-        pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)
-    ]
+    rnds = [pool.AsyncConnectionPool._jitter(30, -0.1, +0.2) for i in range(100)]
     assert 27 <= min(rnds) <= 28
     assert 35 < max(rnds) < 36
 
@@ -963,9 +933,7 @@ def test_jitter():
 @pytest.mark.slow
 @pytest.mark.timing
 async def test_max_lifetime(dsn):
-    async with pool.AsyncConnectionPool(
-        dsn, min_size=1, max_lifetime=0.2
-    ) as p:
+    async with pool.AsyncConnectionPool(dsn, min_size=1, max_lifetime=0.2) as p:
         await asyncio.sleep(0.1)
         pids = []
         for i in range(5):
index 3f8eb35e89d387086380cbbe50aec21b885b47dc..0d0e4aa7285578161b60b21a83ebd00ff4e4bdfc 100644 (file)
@@ -15,8 +15,7 @@ def test_send_query(pgconn):
 
     # Long query to make sure we have to wait on send
     pgconn.send_query(
-        b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
-        % (b"x" * 1_000_000)
+        b"/* %s */ select pg_sleep(0.01); select 1 as foo;" % (b"x" * 1_000_000)
     )
 
     # send loop
@@ -64,8 +63,7 @@ def test_send_query(pgconn):
 def test_send_query_compact_test(pgconn):
     # Like the above test but use psycopg facilities for compactness
     pgconn.send_query(
-        b"/* %s */ select pg_sleep(0.01); select 1 as foo;"
-        % (b"x" * 1_000_000)
+        b"/* %s */ select pg_sleep(0.01); select 1 as foo;" % (b"x" * 1_000_000)
     )
     results = execute_wait(pgconn)
 
@@ -157,18 +155,14 @@ def test_send_prepared_binary_in(pgconn):
         pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
 
 
-@pytest.mark.parametrize(
-    "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
 def test_send_prepared_binary_out(pgconn, fmt, out):
     val = b"foo\00bar"
     pgconn.send_prepare(b"", b"select $1::bytea")
     (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
-    pgconn.send_query_prepared(
-        b"", [val], param_formats=[1], result_format=fmt
-    )
+    pgconn.send_query_prepared(b"", [val], param_formats=[1], result_format=fmt)
     (res,) = execute_wait(pgconn)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == out
index 6d871247a4f162d56eeb86f377962037083a8545..d45416d27a5d9f184ea86b10b75db511a7e8b8de 100644 (file)
@@ -23,8 +23,7 @@ ff ff
 """
 
 sample_binary_rows = [
-    bytes.fromhex("".join(row.split()))
-    for row in sample_binary_value.split("\n\n")
+    bytes.fromhex("".join(row.split())) for row in sample_binary_value.split("\n\n")
 ]
 
 sample_binary = b"".join(sample_binary_rows)
index 6262d3f7178c1c2d3d38c50c3759380515f774d7..7db6248c69da9ecfd8fd86008baa104b722a23f3 100644 (file)
@@ -21,9 +21,7 @@ def test_escape_literal(pgconn, data, want):
 
 @pytest.mark.parametrize("scs", ["on", "off"])
 def test_escape_literal_1char(pgconn, scs):
-    res = pgconn.exec_(
-        f"set standard_conforming_strings to {scs}".encode("ascii")
-    )
+    res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
     assert res.status == pq.ExecStatus.COMMAND_OK
     esc = pq.Escaping(pgconn)
     special = {b"'": b"''''", b"\\": b" E'\\\\'"}
@@ -62,9 +60,7 @@ def test_escape_identifier(pgconn, data, want):
 
 @pytest.mark.parametrize("scs", ["on", "off"])
 def test_escape_identifier_1char(pgconn, scs):
-    res = pgconn.exec_(
-        f"set standard_conforming_strings to {scs}".encode("ascii")
-    )
+    res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
     assert res.status == pq.ExecStatus.COMMAND_OK
     esc = pq.Escaping(pgconn)
     special = {b'"': b'""""', b"\\": b'"\\"'}
@@ -104,9 +100,7 @@ def test_escape_string(pgconn, data, want):
 @pytest.mark.parametrize("scs", ["on", "off"])
 def test_escape_string_1char(pgconn, scs):
     esc = pq.Escaping(pgconn)
-    res = pgconn.exec_(
-        f"set standard_conforming_strings to {scs}".encode("ascii")
-    )
+    res = pgconn.exec_(f"set standard_conforming_strings to {scs}".encode("ascii"))
     assert res.status == pq.ExecStatus.COMMAND_OK
     special = {b"'": b"''", b"\\": b"\\" if scs == "on" else b"\\\\"}
     for c in range(1, 128):
@@ -167,9 +161,7 @@ def test_escape_noconn(pgconn):
     data = bytes(range(256))
     esc = pq.Escaping()
     escdata = esc.escape_bytea(data)
-    res = pgconn.exec_params(
-        b"select '%s'::bytea" % escdata, [], result_format=1
-    )
+    res = pgconn.exec_params(b"select '%s'::bytea" % escdata, [], result_format=1)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == data
 
index ce21be2f3e87c746fec5eb2e40859b35cadb53b5..136879317defa00f2cf730e0037e6597af27265e 100644 (file)
@@ -47,9 +47,7 @@ def test_exec_params_types(pgconn):
 
 
 def test_exec_params_nulls(pgconn):
-    res = pgconn.exec_params(
-        b"select $1::text, $2::text, $3::text", [b"hi", b"", None]
-    )
+    res = pgconn.exec_params(b"select $1::text, $2::text, $3::text", [b"hi", b"", None])
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == b"hi"
     assert res.get_value(0, 1) == b""
@@ -71,9 +69,7 @@ def test_exec_params_binary_in(pgconn):
         pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
 
 
-@pytest.mark.parametrize(
-    "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
 def test_exec_params_binary_out(pgconn, fmt, out):
     val = b"foo\00bar"
     res = pgconn.exec_params(
@@ -119,17 +115,13 @@ def test_exec_prepared_binary_in(pgconn):
         pgconn.exec_params(b"select $1::bytea", [val], param_formats=[1, 1])
 
 
-@pytest.mark.parametrize(
-    "fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")]
-)
+@pytest.mark.parametrize("fmt, out", [(0, b"\\x666f6f00626172"), (1, b"foo\00bar")])
 def test_exec_prepared_binary_out(pgconn, fmt, out):
     val = b"foo\00bar"
     res = pgconn.prepare(b"", b"select $1::bytea")
     assert res.status == pq.ExecStatus.COMMAND_OK, res.error_message
 
-    res = pgconn.exec_prepared(
-        b"", [val], param_formats=[1], result_format=fmt
-    )
+    res = pgconn.exec_prepared(b"", [val], param_formats=[1], result_format=fmt)
     assert res.status == pq.ExecStatus.TUPLES_OK
     assert res.get_value(0, 0) == out
 
index fefab51bbf52f1f551894df74d25a9ef5f8e8751..dfd0a6a807733e97df393370120109effd83e783 100644 (file)
@@ -52,9 +52,7 @@ def test_connect_async(dsn):
 
 
 def test_connect_async_bad(dsn):
-    parsed_dsn = {
-        e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val
-    }
+    parsed_dsn = {e.keyword: e.val for e in pq.Conninfo.parse(dsn.encode()) if e.val}
     parsed_dsn[b"dbname"] = b"psycopg_test_not_for_real"
     dsn = b" ".join(b"%s='%s'" % item for item in parsed_dsn.items())
     conn = pq.PGconn.connect_start(dsn)
@@ -493,9 +491,7 @@ def test_trace(pgconn, tmp_path):
     tracef = tmp_path / "trace"
     with tracef.open("w") as f:
         pgconn.trace(f.fileno())
-        pgconn.set_trace_flags(
-            pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE
-        )
+        pgconn.set_trace_flags(pq.Trace.SUPPRESS_TIMESTAMPS | pq.Trace.REGRESS_MODE)
         pgconn.exec_(b"select 1")
         pgconn.untrace()
         pgconn.exec_(b"select 2")
index a40b54feb82109d70d35d519ad299616a12addcc..40ca336886e11ed5ea19b36fa5da14b463cb86f2 100644 (file)
@@ -52,9 +52,7 @@ def test_error_message(pgconn):
 
 def test_error_field(pgconn):
     res = pgconn.exec_(b"select wat")
-    assert (
-        res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) == b"ERROR"
-    )
+    assert res.error_field(pq.DiagnosticField.SEVERITY_NONLOCALIZED) == b"ERROR"
     assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"42703"
     assert b"wat" in res.error_field(pq.DiagnosticField.MESSAGE_PRIMARY)
     res.clear()
@@ -63,9 +61,7 @@ def test_error_field(pgconn):
 
 @pytest.mark.parametrize("n", range(4))
 def test_ntuples(pgconn, n):
-    res = pgconn.exec_params(
-        b"select generate_series(1, $1)", [str(n).encode("ascii")]
-    )
+    res = pgconn.exec_params(b"select generate_series(1, $1)", [str(n).encode("ascii")])
     assert res.ntuples == n
     res.clear()
     assert res.ntuples == 0
index c73f67bcc8baddd04fe0553d4d7670036d6dc1a4..855801b770039e2d94bf4048c06e7c80951862f3 100644 (file)
@@ -23,9 +23,7 @@ def test_work_in_progress(pgconn):
     assert pgconn.pipeline_status == pq.PipelineStatus.OFF
     pgconn.enter_pipeline_mode()
     pgconn.send_query_params(b"select $1", [b"1"])
-    with pytest.raises(
-        psycopg.OperationalError, match="cannot exit pipeline mode"
-    ):
+    with pytest.raises(psycopg.OperationalError, match="cannot exit pipeline mode"):
         pgconn.exit_pipeline_mode()
 
 
index 60f676a52ed7608578984161645638c88a81cc02..a49f11685b069c33d9f7958e3513308c312d75e0 100644 (file)
@@ -28,15 +28,10 @@ def main() -> None:
     cur = cnn.cursor()
 
     if test == "copy":
-        with cur.copy(
-            f"copy testdec from stdin (format {format.name})"
-        ) as copy:
+        with cur.copy(f"copy testdec from stdin (format {format.name})") as copy:
             for j in range(nrows):
                 copy.write_row(
-                    [
-                        Decimal(randrange(10000000000)) / 100
-                        for i in range(ncols)
-                    ]
+                    [Decimal(randrange(10000000000)) / 100 for i in range(ncols)]
                 )
 
     elif test == "insert":
index 6de342970a44fc8b63c46263859d96e2224781c7..2c9cc164a4c903a9cb846b077246ce0670767cde 100644 (file)
@@ -44,9 +44,7 @@ def main() -> None:
         # Create and start all the thread: they will get stuck on the event
         ev = threading.Event()
         threads = [
-            threading.Thread(
-                target=worker, args=(pool, 0.002, ev), daemon=True
-            )
+            threading.Thread(target=worker, args=(pool, 0.002, ev), daemon=True)
             for i in range(opt.num_clients)
         ]
         for t in threads:
@@ -90,9 +88,7 @@ class Measurer:
         self.measures = []
 
     def start(self, interval):
-        self.worker = threading.Thread(
-            target=self._run, args=(interval,), daemon=True
-        )
+        self.worker = threading.Thread(target=self._run, args=(interval,), daemon=True)
         self.worker.start()
 
     def stop(self):
@@ -125,9 +121,7 @@ def parse_cmdline():
     from argparse import ArgumentParser
 
     parser = ArgumentParser(description=__doc__)
-    parser.add_argument(
-        "--dsn", default="", help="connection string to the database"
-    )
+    parser.add_argument("--dsn", default="", help="connection string to the database")
     parser.add_argument(
         "--min_size",
         default=5,
index 863ceef59a8d4662e150b077b73d90e0571f082b..c4ade8ac47e6ba6a2cadc5353f15d27585b20491 100644 (file)
@@ -57,9 +57,7 @@ def test_register_dumper_by_class(conn):
 def test_register_dumper_by_class_name(conn):
     dumper = make_dumper("x")
     assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is not dumper
-    conn.adapters.register_dumper(
-        f"{MyStr.__module__}.{MyStr.__qualname__}", dumper
-    )
+    conn.adapters.register_dumper(f"{MyStr.__module__}.{MyStr.__qualname__}", dumper)
     assert conn.adapters.get_dumper(MyStr, PyFormat.TEXT) is dumper
 
 
@@ -118,9 +116,7 @@ def test_dump_subclass(conn):
         pass
 
     cur = conn.cursor()
-    cur.execute(
-        "select %s::text, %b::text", [MyString("hello"), MyString("world")]
-    )
+    cur.execute("select %s::text, %b::text", [MyString("hello"), MyString("world")])
     assert cur.fetchone() == ("hello", "world")
 
 
@@ -352,9 +348,7 @@ def test_last_dumper_registered_ctx(conn):
 def test_none_type_argument(conn, fmt_in):
     cur = conn.cursor()
     cur.execute("create table none_args (id serial primary key, num integer)")
-    cur.execute(
-        "insert into none_args (num) values (%s) returning id", (None,)
-    )
+    cur.execute("insert into none_args (num) values (%s) returning id", (None,))
     assert cur.fetchone()[0]
 
 
@@ -375,9 +369,7 @@ def test_return_untyped(conn, fmt_in):
     else:
         # Binary types cannot be passed as unknown oids.
         with pytest.raises(e.DatatypeMismatch):
-            cur.execute(
-                f"insert into testjson (data) values (%{fmt_in})", ["{}"]
-            )
+            cur.execute(f"insert into testjson (data) values (%{fmt_in})", ["{}"])
 
 
 @pytest.mark.parametrize("fmt_in", PyFormat.AUTO)
index a4b0d4e746236e916043835bd0796811ee1159c0..643c1f026a384a1d898ba4c4884cb91a496617b5 100644 (file)
@@ -183,9 +183,7 @@ def test_cancel(conn):
 def test_identify_closure(dsn):
     def closer():
         time.sleep(0.2)
-        conn2.execute(
-            "select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]
-        )
+        conn2.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
 
     conn = psycopg.connect(dsn)
     conn2 = psycopg.connect(dsn)
index a2710f1d73d5114f0938d63213defebb501d1835..6a1df9684f1d01581b17e536039d31a12765ae08 100644 (file)
@@ -16,9 +16,7 @@ async def test_commit_concurrency(aconn):
     # Because of bad status check, we commit even when a commit is already on
     # its way. We can detect this condition by the warnings.
     notices = Queue()  # type: ignore[var-annotated]
-    aconn.add_notice_handler(
-        lambda diag: notices.put_nowait(diag.message_primary)
-    )
+    aconn.add_notice_handler(lambda diag: notices.put_nowait(diag.message_primary))
     stop = False
 
     async def committer():
index 6fcd62aa4036411312d5691f68990a97c3426306..4013a16a02a4ed39516705584422ce5c54f10953 100644 (file)
@@ -74,9 +74,7 @@ def test_close(conn):
 
 def test_broken(conn):
     with pytest.raises(psycopg.OperationalError):
-        conn.execute(
-            "select pg_terminate_backend(%s)", [conn.pgconn.backend_pid]
-        )
+        conn.execute("select pg_terminate_backend(%s)", [conn.pgconn.backend_pid])
     assert conn.closed
     assert conn.broken
     conn.close()
@@ -186,9 +184,7 @@ def test_context_active_rollback_no_clobber(dsn, caplog):
 
     with pytest.raises(ZeroDivisionError):
         with psycopg.connect(dsn) as conn:
-            conn.pgconn.exec_(
-                b"copy (select generate_series(1, 10)) to stdout"
-            )
+            conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
             status = conn.info.transaction_status
             assert status == conn.TransactionStatus.ACTIVE
             1 / 0
@@ -396,15 +392,11 @@ def test_notice_handlers(conn, caplog):
     conn.add_notice_handler(cb1)
     conn.add_notice_handler(cb2)
     conn.add_notice_handler("the wrong thing")
-    conn.add_notice_handler(
-        lambda diag: severities.append(diag.severity_nonlocalized)
-    )
+    conn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
 
     conn.pgconn.exec_(b"set client_min_messages to notice")
     cur = conn.cursor()
-    cur.execute(
-        "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
-    )
+    cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
     assert messages == ["hello notice"]
     assert severities == ["NOTICE"]
 
@@ -418,9 +410,7 @@ def test_notice_handlers(conn, caplog):
 
     conn.remove_notice_handler(cb1)
     conn.remove_notice_handler("the wrong thing")
-    cur.execute(
-        "do $$begin raise warning 'hello warning'; end$$ language plpgsql"
-    )
+    cur.execute("do $$begin raise warning 'hello warning'; end$$ language plpgsql")
     assert len(caplog.records) == 3
     assert messages == ["hello notice"]
     assert severities == ["NOTICE", "WARNING"]
index a0606af7473f0bf59a97c836054c782dab4bcbad..2d471c7f9fe62b68d9019bf7140713a3e1d375ae 100644 (file)
@@ -185,9 +185,7 @@ async def test_context_active_rollback_no_clobber(dsn, caplog):
 
     with pytest.raises(ZeroDivisionError):
         async with await psycopg.AsyncConnection.connect(dsn) as conn:
-            conn.pgconn.exec_(
-                b"copy (select generate_series(1, 10)) to stdout"
-            )
+            conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
             status = conn.info.transaction_status
             assert status == conn.TransactionStatus.ACTIVE
             1 / 0
@@ -400,15 +398,11 @@ async def test_notice_handlers(aconn, caplog):
     aconn.add_notice_handler(cb1)
     aconn.add_notice_handler(cb2)
     aconn.add_notice_handler("the wrong thing")
-    aconn.add_notice_handler(
-        lambda diag: severities.append(diag.severity_nonlocalized)
-    )
+    aconn.add_notice_handler(lambda diag: severities.append(diag.severity_nonlocalized))
 
     aconn.pgconn.exec_(b"set client_min_messages to notice")
     cur = aconn.cursor()
-    await cur.execute(
-        "do $$begin raise notice 'hello notice'; end$$ language plpgsql"
-    )
+    await cur.execute("do $$begin raise notice 'hello notice'; end$$ language plpgsql")
     assert messages == ["hello notice"]
     assert severities == ["NOTICE"]
 
@@ -641,9 +635,7 @@ async def test_set_transaction_param_all(aconn):
 
     for attr in tx_params:
         guc = tx_params[attr]["guc"]
-        cur = await aconn.execute(
-            "select current_setting(%s)", [f"transaction_{guc}"]
-        )
+        cur = await aconn.execute("select current_setting(%s)", [f"transaction_{guc}"])
         pgval = (await cur.fetchone())[0]
         assert tx_values_map[pgval] == value
 
index 4d56b1fe8281885aaf35bca21b6a7d08e3f1f96b..63573d78b52ac06d2d787545d16e2480b6135b4a 100644 (file)
@@ -149,9 +149,7 @@ class TestConnectionInfo:
 
         monkeypatch.setenv("PGAPPNAME", "hello test")
         with psycopg.connect(**dsn) as conn:
-            assert (
-                conn.info.get_parameters()["application_name"] == "hello test"
-            )
+            assert conn.info.get_parameters()["application_name"] == "hello test"
 
     def test_dsn_env(self, dsn, monkeypatch):
         dsn = conninfo_to_dict(dsn)
@@ -220,9 +218,7 @@ class TestConnectionInfo:
         with pytest.raises(psycopg.OperationalError):
             conn.info.backend_pid
 
-    @pytest.mark.skipif(
-        sys.platform == "win32", reason="no IANA db on Windows"
-    )
+    @pytest.mark.skipif(sys.platform == "win32", reason="no IANA db on Windows")
     def test_timezone(self, conn):
         conn.execute("set timezone to 'Europe/Rome'")
         tz = conn.info.timezone
index 285b598b654ded26449a2be07c8117915093f557..af541a25639db45d960a10cae5d86ba2bb15e9ab 100644 (file)
@@ -42,8 +42,7 @@ ff ff
 """
 
 sample_binary_rows = [
-    bytes.fromhex("".join(row.split()))
-    for row in sample_binary_str.split("\n\n")
+    bytes.fromhex("".join(row.split())) for row in sample_binary_str.split("\n\n")
 ]
 
 sample_binary = b"".join(sample_binary_rows)
@@ -57,15 +56,11 @@ def test_copy_out_read(conn, format):
         want = sample_binary_rows
 
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         for row in want:
             got = copy.read()
             assert got == row
-            assert (
-                conn.info.transaction_status == conn.TransactionStatus.ACTIVE
-            )
+            assert conn.info.transaction_status == conn.TransactionStatus.ACTIVE
 
         assert copy.read() == b""
         assert copy.read() == b""
@@ -82,9 +77,7 @@ def test_copy_out_iter(conn, format):
         want = sample_binary_rows
 
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         assert list(copy) == want
 
     assert conn.info.transaction_status == conn.TransactionStatus.INTRANS
@@ -110,9 +103,7 @@ def test_read_rows(conn, format, typetype):
 @pytest.mark.parametrize("format", Format)
 def test_rows(conn, format):
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         copy.set_types(["int4", "int4", "text"])
         rows = list(copy.rows())
 
@@ -143,9 +134,9 @@ def test_copy_out_allchars(conn, format):
     chars = list(map(chr, range(1, 256))) + [eur]
     conn.execute("set client_encoding to utf8")
     rows = []
-    query = sql.SQL(
-        "copy (select unnest({}::text[])) to stdout (format {})"
-    ).format(chars, sql.SQL(format.name))
+    query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+        chars, sql.SQL(format.name)
+    )
     with cur.copy(query) as copy:
         copy.set_types(["text"])
         while 1:
@@ -161,9 +152,7 @@ def test_copy_out_allchars(conn, format):
 @pytest.mark.parametrize("format", Format)
 def test_read_row_notypes(conn, format):
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         rows = []
         while 1:
             row = copy.read_row()
@@ -171,24 +160,16 @@ def test_read_row_notypes(conn, format):
                 break
             rows.append(row)
 
-    ref = [
-        tuple(py_to_raw(i, format) for i in record)
-        for record in sample_records
-    ]
+    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
     assert rows == ref
 
 
 @pytest.mark.parametrize("format", Format)
 def test_rows_notypes(conn, format):
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         rows = list(copy.rows())
-    ref = [
-        tuple(py_to_raw(i, format) for i in record)
-        for record in sample_records
-    ]
+    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
     assert rows == ref
 
 
@@ -196,9 +177,7 @@ def test_rows_notypes(conn, format):
 @pytest.mark.parametrize("format", Format)
 def test_copy_out_badntypes(conn, format, err):
     cur = conn.cursor()
-    with cur.copy(
-        f"copy ({sample_values}) to stdout (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy ({sample_values}) to stdout (format {format.name})") as copy:
         copy.set_types([0] * (len(sample_records[0]) + err))
         with pytest.raises(e.ProgrammingError):
             copy.read_row()
@@ -303,9 +282,7 @@ def test_subclass_adapter(conn, format):
     cur = conn.cursor()
     ensure_table(cur, sample_tabledef)
 
-    with cur.copy(
-        f"copy copy_in (data) from stdin (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy copy_in (data) from stdin (format {format.name})") as copy:
         copy.write_row(("hello",))
 
     rec = cur.execute("select data from copy_in").fetchone()
@@ -360,9 +337,7 @@ def test_copy_out_error_with_copy_finished(conn):
 def test_copy_out_error_with_copy_not_finished(conn):
     cur = conn.cursor()
     with pytest.raises(ZeroDivisionError):
-        with cur.copy(
-            "copy (select generate_series(1, 1000000)) to stdout"
-        ) as copy:
+        with cur.copy("copy (select generate_series(1, 1000000)) to stdout") as copy:
             copy.read_row()
             1 / 0
 
@@ -628,9 +603,7 @@ def test_copy_to_leaks(dsn, faker, fmt, set_types, method):
         gc_collect()
         n.append(len(gc.get_objects()))
 
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
 
 @pytest.mark.slow
@@ -673,9 +646,7 @@ def test_copy_from_leaks(dsn, faker, fmt, set_types):
         gc_collect()
         n.append(len(gc.get_objects()))
 
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
 
 def py_to_raw(item, fmt):
index 4b78ce59147c724b2399c6c4db8b1422d4e4cf49..a5cfd605ec5c938f00efbfeb5a31c472bb1677c6 100644 (file)
@@ -38,9 +38,7 @@ async def test_copy_out_read(aconn, format):
         for row in want:
             got = await copy.read()
             assert got == row
-            assert (
-                aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
-            )
+            assert aconn.info.transaction_status == aconn.TransactionStatus.ACTIVE
 
         assert await copy.read() == b""
         assert await copy.read() == b""
@@ -118,9 +116,9 @@ async def test_copy_out_allchars(aconn, format):
     chars = list(map(chr, range(1, 256))) + [eur]
     await aconn.execute("set client_encoding to utf8")
     rows = []
-    query = sql.SQL(
-        "copy (select unnest({}::text[])) to stdout (format {})"
-    ).format(chars, sql.SQL(format.name))
+    query = sql.SQL("copy (select unnest({}::text[])) to stdout (format {})").format(
+        chars, sql.SQL(format.name)
+    )
     async with cur.copy(query) as copy:
         copy.set_types(["text"])
         while 1:
@@ -146,10 +144,7 @@ async def test_read_row_notypes(aconn, format):
                 break
             rows.append(row)
 
-    ref = [
-        tuple(py_to_raw(i, format) for i in record)
-        for record in sample_records
-    ]
+    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
     assert rows == ref
 
 
@@ -160,10 +155,7 @@ async def test_rows_notypes(aconn, format):
         f"copy ({sample_values}) to stdout (format {format.name})"
     ) as copy:
         rows = await alist(copy.rows())
-    ref = [
-        tuple(py_to_raw(i, format) for i in record)
-        for record in sample_records
-    ]
+    ref = [tuple(py_to_raw(i, format) for i in record) for record in sample_records]
     assert rows == ref
 
 
@@ -186,9 +178,7 @@ async def test_copy_out_badntypes(aconn, format, err):
 async def test_copy_in_buffers(aconn, format, buffer):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
-    async with cur.copy(
-        f"copy copy_in from stdin (format {format.name})"
-    ) as copy:
+    async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         await copy.write(globals()[buffer])
 
     await cur.execute("select * from copy_in order by 1")
@@ -330,9 +320,7 @@ async def test_copy_in_buffers_with_py_error(aconn):
 async def test_copy_out_error_with_copy_finished(aconn):
     cur = aconn.cursor()
     with pytest.raises(ZeroDivisionError):
-        async with cur.copy(
-            "copy (select generate_series(1, 2)) to stdout"
-        ) as copy:
+        async with cur.copy("copy (select generate_series(1, 2)) to stdout") as copy:
             await copy.read_row()
             1 / 0
 
@@ -368,9 +356,7 @@ async def test_copy_in_records(aconn, format):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
 
-    async with cur.copy(
-        f"copy copy_in from stdin (format {format.name})"
-    ) as copy:
+    async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         for row in sample_records:
             if format == Format.BINARY:
                 row = tuple(
@@ -388,9 +374,7 @@ async def test_copy_in_records_set_types(aconn, format):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
 
-    async with cur.copy(
-        f"copy copy_in from stdin (format {format.name})"
-    ) as copy:
+    async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         copy.set_types(["int4", "int4", "text"])
         for row in sample_records:
             await copy.write_row(row)
@@ -555,9 +539,7 @@ async def test_str(aconn):
 async def test_worker_life(aconn, format, buffer):
     cur = aconn.cursor()
     await ensure_table(cur, sample_tabledef)
-    async with cur.copy(
-        f"copy copy_in from stdin (format {format.name})"
-    ) as copy:
+    async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy:
         assert not copy._worker
         await copy.write(globals()[buffer])
         assert copy._worker
@@ -621,9 +603,7 @@ async def test_copy_to_leaks(dsn, faker, fmt, set_types, method):
         gc_collect()
         n.append(len(gc.get_objects()))
 
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
 
 @pytest.mark.slow
@@ -666,9 +646,7 @@ async def test_copy_from_leaks(dsn, faker, fmt, set_types):
         gc_collect()
         n.append(len(gc.get_objects()))
 
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
 
 async def ensure_table(cur, tabledef, name="copy_in"):
index e202fe6d3b8b1139296d7c59fa266b58a45b796d..3ace93bce534eaf6edfb64f929ec8973fc4e03b2 100644 (file)
@@ -326,9 +326,7 @@ def test_executemany_rowcount_no_hit(conn, execmany):
     assert cur.rowcount == 0
     cur.executemany("delete from execmany where id = %s", [])
     assert cur.rowcount == 0
-    cur.executemany(
-        "delete from execmany where id = %s returning num", [(-1,), (-2,)]
-    )
+    cur.executemany("delete from execmany where id = %s returning num", [(-1,), (-2,)])
     assert cur.rowcount == 0
 
 
@@ -373,9 +371,7 @@ def test_rowcount(conn):
     cur.execute("create table test_rowcount_notuples (id int primary key)")
     assert cur.rowcount == -1
 
-    cur.execute(
-        "insert into test_rowcount_notuples select generate_series(1, 42)"
-    )
+    cur.execute("insert into test_rowcount_notuples select generate_series(1, 42)")
     assert cur.rowcount == 42
 
 
@@ -756,9 +752,7 @@ def test_str(conn):
 @pytest.mark.parametrize("fmt", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
-@pytest.mark.parametrize(
-    "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
-)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
 def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     faker.format = fmt
     faker.choose_schema(ncols=5)
@@ -797,9 +791,7 @@ def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
         work()
         gc_collect()
         n.append(len(gc.get_objects()))
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
 
 
 def my_row_factory(
@@ -809,10 +801,7 @@ def my_row_factory(
         titles = [c.name for c in cursor.description]
 
         def mkrow(values):
-            return [
-                f"{value.upper()}{title}"
-                for title, value in zip(titles, values)
-            ]
+            return [f"{value.upper()}{title}" for title, value in zip(titles, values)]
 
         return mkrow
     else:
index c55130298b72598d38b5900452c3678858802468..96ece0061b92324b00f88855a50f53814661f634 100644 (file)
@@ -134,9 +134,7 @@ async def test_execute_many_results(aconn):
 
 async def test_execute_sequence(aconn):
     cur = aconn.cursor()
-    rv = await cur.execute(
-        "select %s::int, %s::text, %s::text", [1, "foo", None]
-    )
+    rv = await cur.execute("select %s::int, %s::text, %s::text", [1, "foo", None])
     assert rv is cur
     assert len(cur._results) == 1
     assert cur.pgresult.get_value(0, 0) == b"1"
@@ -259,9 +257,7 @@ async def test_executemany_name(aconn, execmany):
 
 async def test_executemany_no_data(aconn, execmany):
     cur = aconn.cursor()
-    await cur.executemany(
-        "insert into execmany(num, data) values (%s, %s)", []
-    )
+    await cur.executemany("insert into execmany(num, data) values (%s, %s)", [])
     assert cur.rowcount == 0
 
 
@@ -362,9 +358,7 @@ async def test_rowcount(aconn):
     await cur.execute("select 1 from generate_series(1, 42)")
     assert cur.rowcount == 42
 
-    await cur.execute(
-        "create table test_rowcount_notuples (id int primary key)"
-    )
+    await cur.execute("create table test_rowcount_notuples (id int primary key)")
     assert cur.rowcount == -1
 
     await cur.execute(
@@ -637,9 +631,7 @@ async def test_str(aconn):
 @pytest.mark.parametrize("fmt", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
 @pytest.mark.parametrize("fetch", ["one", "many", "all", "iter"])
-@pytest.mark.parametrize(
-    "row_factory", ["tuple_row", "dict_row", "namedtuple_row"]
-)
+@pytest.mark.parametrize("row_factory", ["tuple_row", "dict_row", "namedtuple_row"])
 async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
     faker.format = fmt
     faker.choose_schema(ncols=5)
@@ -648,9 +640,7 @@ async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
 
     async def work():
         async with await psycopg.AsyncConnection.connect(dsn) as conn:
-            async with conn.cursor(
-                binary=fmt_out, row_factory=row_factory
-            ) as cur:
+            async with conn.cursor(binary=fmt_out, row_factory=row_factory) as cur:
                 await cur.execute(faker.drop_stmt)
                 await cur.execute(faker.create_stmt)
                 async with faker.find_insert_problem_async(conn):
@@ -680,6 +670,4 @@ async def test_leak(dsn, faker, fmt, fmt_out, fetch, row_factory):
         gc_collect()
         n.append(len(gc.get_objects()))
 
-    assert (
-        n[0] == n[1] == n[2]
-    ), f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
+    assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
index 480aec193068081444fd5dea75206c1905635de8..ca8202c219f68471d843f37d1f40ab696b808713 100644 (file)
@@ -112,17 +112,13 @@ async def test_resolve_hostaddr_async(conninfo, want, env, fake_resolve):
     ],
 )
 @pytest.mark.asyncio
-async def test_resolve_hostaddr_async_bad(
-    monkeypatch, conninfo, env, fake_resolve
-):
+async def test_resolve_hostaddr_async_bad(monkeypatch, conninfo, env, fake_resolve):
     if env:
         for k, v in env.items():
             monkeypatch.setenv(k, v)
     params = conninfo_to_dict(conninfo)
     with pytest.raises(psycopg.Error):
-        await psycopg._dns.resolve_hostaddr_async(  # type: ignore[attr-defined]
-            params
-        )
+        await psycopg._dns.resolve_hostaddr_async(params)  # type: ignore[attr-defined]
 
 
 @pytest.mark.asyncio
@@ -133,9 +129,7 @@ async def test_resolve_hostaddr_conn(monkeypatch, fake_resolve):
         got.append(conninfo)
         1 / 0
 
-    monkeypatch.setattr(
-        psycopg.AsyncConnection, "_connect_gen", fake_connect_gen
-    )
+    monkeypatch.setattr(psycopg.AsyncConnection, "_connect_gen", fake_connect_gen)
 
     # TODO: not enabled by default, but should be usable to make a subclass
     class AsyncDnsConnection(psycopg.AsyncConnection[Row]):
index 7d0856f2887844021a72980688b9d85777dd8559..d269a6d69ca5203de8f74745d9c68fa575347ccf 100644 (file)
@@ -153,9 +153,7 @@ def get_fake_srv_function(monkeypatch):
         else:
             for entry in ans:
                 pri, w, port, target = entry.split()
-                rv.append(
-                    SRV("IN", "SRV", int(pri), int(w), int(port), target)
-                )
+                rv.append(SRV("IN", "SRV", int(pri), int(w), int(port), target))
 
         return rv
 
index 131d660566ec2452b1a73dbeedf82e611203d9fe..aa61546aea997e1d77bf22d15ca185fb71cbdae1 100644 (file)
@@ -78,9 +78,7 @@ def test_diag_encoding(conn, enc):
     conn.add_notice_handler(lambda diag: msgs.append(diag.message_primary))
     conn.execute(f"set client_encoding to {enc}")
     cur = conn.cursor()
-    cur.execute(
-        "do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql"
-    )
+    cur.execute("do $$begin raise notice 'hello %', chr(8364); end$$ language plpgsql")
     assert msgs == [f"hello {eur}"]
 
 
index 42cbc46072f0ad7d35619c1b1b0cfba26eab3fcb..cfb4b0495236d5abc93b7d6aa97bc58119711168 100644 (file)
@@ -80,9 +80,7 @@ def test_prepare_disable(conn):
 def test_no_prepare_multi(conn):
     res = []
     for i in range(10):
-        cur = conn.execute(
-            "select count(*) from pg_prepared_statements; select 1"
-        )
+        cur = conn.execute("select count(*) from pg_prepared_statements; select 1")
         res.append(cur.fetchone()[0])
 
     assert res == [0] * 10
@@ -124,9 +122,7 @@ def test_misc_statement(conn, query):
     conn.execute("create table prepared_test (num int)", prepare=False)
     conn.prepare_threshold = 0
     conn.execute(query)
-    cur = conn.execute(
-        "select count(*) from pg_prepared_statements", prepare=False
-    )
+    cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
     assert cur.fetchone() == (1,)
 
 
@@ -240,9 +236,7 @@ def test_change_type(conn):
         {"enum_col": ["foo"]},
     )
 
-    cur = conn.execute(
-        "select count(*) from pg_prepared_statements", prepare=False
-    )
+    cur = conn.execute("select count(*) from pg_prepared_statements", prepare=False)
     assert cur.fetchone()[0] == 3
 
 
@@ -252,12 +246,8 @@ def test_change_type_savepoint(conn):
         for i in range(3):
             with pytest.raises(ZeroDivisionError):
                 with conn.transaction():
-                    conn.execute(
-                        "CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')"
-                    )
-                    conn.execute(
-                        "CREATE TABLE preptable(id integer, bar prepenum[])"
-                    )
+                    conn.execute("CREATE TYPE prepenum AS ENUM ('foo', 'bar', 'baz')")
+                    conn.execute("CREATE TABLE preptable(id integer, bar prepenum[])")
                     conn.cursor().execute(
                         "INSERT INTO preptable (bar) "
                         "VALUES (%(enum_col)s::prepenum[])",
index 4315ccb46d96e7e042e8a44702b31766379cefa5..330635cb41b0d3bd765ef845e97e38e42ca5210c 100644 (file)
@@ -63,9 +63,7 @@ async def test_do_prepare_conn(aconn):
 async def test_auto_prepare_conn(aconn):
     res = []
     for i in range(10):
-        cur = await aconn.execute(
-            "select count(*) from pg_prepared_statements"
-        )
+        cur = await aconn.execute("select count(*) from pg_prepared_statements")
         res.append((await cur.fetchone())[0])
 
     assert res == [0] * 5 + [1] * 5
@@ -75,9 +73,7 @@ async def test_prepare_disable(aconn):
     aconn.prepare_threshold = None
     res = []
     for i in range(10):
-        cur = await aconn.execute(
-            "select count(*) from pg_prepared_statements"
-        )
+        cur = await aconn.execute("select count(*) from pg_prepared_statements")
         res.append((await cur.fetchone())[0])
 
     assert res == [0] * 10
@@ -134,9 +130,7 @@ async def test_params_types(aconn):
         [dt.date(2020, 12, 10), 42, Decimal(42)],
         prepare=True,
     )
-    cur = await aconn.execute(
-        "select parameter_types from pg_prepared_statements"
-    )
+    cur = await aconn.execute("select parameter_types from pg_prepared_statements")
     (rec,) = await cur.fetchall()
     assert rec[0] == ["date", "smallint", "numeric"]
 
@@ -172,9 +166,7 @@ async def test_evict_lru_deallocate(aconn):
         "select statement from pg_prepared_statements order by prepare_time",
         prepare=False,
     )
-    assert await cur.fetchall() == [
-        (f"select {i}",) for i in ["'a'", 6, 7, 8, 9]
-    ]
+    assert await cur.fetchall() == [(f"select {i}",) for i in ["'a'", 6, 7, 8, 9]]
 
 
 async def test_different_types(aconn):
@@ -197,7 +189,5 @@ async def test_untyped_json(aconn):
     for i in range(2):
         await aconn.execute("insert into testjson (data) values (%s)", ["{}"])
 
-    cur = await aconn.execute(
-        "select parameter_types from pg_prepared_statements"
-    )
+    cur = await aconn.execute("select parameter_types from pg_prepared_statements")
     assert await cur.fetchall() == [(["jsonb"],)]
index 9c900c01568814306677f85cdedf0363c0a56675..fc6bec44ae9c3ca9feccc1abde1f76e6cf96a9fe 100644 (file)
@@ -41,9 +41,7 @@ class PsycopgTPCTests(dbapi20_tpc.TwoPhaseCommitTests):
 
 # Shut up warnings
 PsycopgTests.failUnless = PsycopgTests.assertTrue  # type: ignore[assignment]
-PsycopgTPCTests.assertEquals = (  # type: ignore[assignment]
-    PsycopgTPCTests.assertEqual
-)
+PsycopgTPCTests.assertEquals = PsycopgTPCTests.assertEqual  # type: ignore[assignment]
 
 
 @pytest.mark.parametrize(
index 9182947dfd060ae00e32c2fcf635801cbdde7dc7..75980067676aaff82af5cff86b08b30daebc0365 100644 (file)
@@ -113,9 +113,7 @@ def test_close(conn, recwarn):
     cur.close()
     assert cur.closed
 
-    assert not conn.execute(
-        "select * from pg_cursors where name = 'foo'"
-    ).fetchone()
+    assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
     del cur
     assert not recwarn, [str(w.message) for w in recwarn.list]
 
@@ -211,9 +209,7 @@ def test_context(conn, recwarn):
         cur.execute("select generate_series(1, 10) as bar")
 
     assert cur.closed
-    assert not conn.execute(
-        "select * from pg_cursors where name = 'foo'"
-    ).fetchone()
+    assert not conn.execute("select * from pg_cursors where name = 'foo'").fetchone()
     del cur
     assert not recwarn, [str(w.message) for w in recwarn.list]
 
@@ -237,9 +233,7 @@ def test_execute_reuse(conn):
         cur.execute("select generate_series(1, %s) as foo", (3,))
         assert cur.fetchone() == (1,)
 
-        cur.execute(
-            "select %s::text as bar, %s::text as baz", ("hello", "world")
-        )
+        cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
         assert cur.fetchone() == ("hello", "world")
         assert cur.description[0].name == "bar"
         assert cur.description[0].type_code == cur.adapters.types["text"].oid
index 730a39887626acb768d2327e299798bbf537fa5a..14e98b5cf3a6981db12942195167d1f149447ad3 100644 (file)
@@ -239,9 +239,7 @@ async def test_execute_reuse(aconn):
         await cur.execute("select generate_series(1, %s) as foo", (3,))
         assert await cur.fetchone() == (1,)
 
-        await cur.execute(
-            "select %s::text as bar, %s::text as baz", ("hello", "world")
-        )
+        await cur.execute("select %s::text as bar, %s::text as baz", ("hello", "world"))
         assert await cur.fetchone() == ("hello", "world")
         assert cur.description[0].name == "bar"
         assert cur.description[0].type_code == cur.adapters.types["text"].oid
@@ -303,9 +301,7 @@ async def test_nextset(aconn):
 
 async def test_no_result(aconn):
     async with aconn.cursor("foo") as cur:
-        await cur.execute(
-            "select generate_series(1, %s) as bar where false", (3,)
-        )
+        await cur.execute("select generate_series(1, %s) as bar where false", (3,))
         assert len(cur.description) == 1
         assert (await cur.fetchall()) == []
 
index d305375bbe614377c4ce1bd9a8bb7e4290a0f6cb..6550c3f8e158ba24c663099ef4cf25235e937168 100644 (file)
@@ -160,9 +160,7 @@ class TestSqlFormat:
             s.as_string(conn)
 
     def test_auto_literal(self, conn):
-        s = sql.SQL("select {}, {}, {}").format(
-            "he'lo", 10, dt.date(2020, 1, 1)
-        )
+        s = sql.SQL("select {}, {}, {}").format("he'lo", 10, dt.date(2020, 1, 1))
         assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'"
 
     def test_execute(self, conn):
@@ -177,9 +175,7 @@ class TestSqlFormat:
         cur.execute(
             sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
                 sql.Identifier("test_compose"),
-                sql.SQL(", ").join(
-                    map(sql.Identifier, ["foo", "bar", "ba'z"])
-                ),
+                sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
                 (sql.Placeholder() * 3).join(", "),
             ),
             (10, "a", "b", "c"),
@@ -200,9 +196,7 @@ class TestSqlFormat:
         cur.executemany(
             sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
                 sql.Identifier("test_compose"),
-                sql.SQL(", ").join(
-                    map(sql.Identifier, ["foo", "bar", "ba'z"])
-                ),
+                sql.SQL(", ").join(map(sql.Identifier, ["foo", "bar", "ba'z"])),
                 (sql.Placeholder() * 3).join(", "),
             ),
             [(10, "a", "b", "c"), (20, "d", "e", "f")],
@@ -317,17 +311,13 @@ class TestLiteral:
         assert sql.Literal(None).as_string(conn) == "NULL"
         assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
         assert sql.Literal(42).as_string(conn) == "42"
-        assert (
-            sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
-        )
+        assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
 
     def test_as_bytes(self, conn):
         assert sql.Literal(None).as_bytes(conn) == b"NULL"
         assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
         assert sql.Literal(42).as_bytes(conn) == b"42"
-        assert (
-            sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
-        )
+        assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
 
         conn.execute("set client_encoding to utf8")
         assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode()
@@ -395,9 +385,7 @@ class TestSQL:
         assert obj.as_string(conn) == '"foo", bar, 42'
 
         obj = sql.SQL(", ").join(
-            sql.Composed(
-                [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
-            )
+            sql.Composed([sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)])
         )
         assert isinstance(obj, sql.Composed)
         assert obj.as_string(conn) == '"foo", bar, 42'
@@ -424,9 +412,7 @@ class TestComposed:
 
     def test_repr(self):
         obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
-        assert (
-            repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
-        )
+        assert repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
         assert str(obj) == repr(obj)
 
     def test_eq(self):
index cde02624528dcfb966e4ae7271a6facb2feaee49..f964bb4ba78227866f2aa9a6198215f9cc03b4c7 100644 (file)
@@ -212,9 +212,7 @@ class TestTPC:
         conn.close()
 
         with psycopg.connect(dsn) as conn:
-            xids = [
-                x for x in conn.tpc_recover() if x.database == conn.info.dbname
-            ]
+            xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
 
             assert len(xids) == 1
             xid = xids[0]
@@ -238,9 +236,7 @@ class TestTPC:
         conn.close()
 
         with psycopg.connect(dsn) as conn:
-            xids = [
-                x for x in conn.tpc_recover() if x.database == conn.info.dbname
-            ]
+            xids = [x for x in conn.tpc_recover() if x.database == conn.info.dbname]
 
             assert len(xids) == 1
             xid = xids[0]
@@ -257,9 +253,7 @@ class TestTPC:
         conn.close()
 
         with psycopg.connect(dsn) as conn:
-            xid = [
-                x for x in conn.tpc_recover() if x.database == conn.info.dbname
-            ][0]
+            xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
         assert 10 == xid.format_id
         assert "uni" == xid.gtrid
         assert "code" == xid.bqual
@@ -276,9 +270,7 @@ class TestTPC:
         conn.close()
 
         with psycopg.connect(dsn) as conn:
-            xid = [
-                x for x in conn.tpc_recover() if x.database == conn.info.dbname
-            ][0]
+            xid = [x for x in conn.tpc_recover() if x.database == conn.info.dbname][0]
 
         assert xid.format_id is None
         assert xid.gtrid == "transaction-id"
index 6a80ecebfbfdb8de13b4ccdbffdd4c8332d81fef..e3e5dcf64f62d820b9018dee317e114e2abf3c97 100644 (file)
@@ -66,9 +66,7 @@ class TestTPC:
         assert aconn.info.transaction_status == TransactionStatus.INTRANS
 
         cur = aconn.cursor()
-        await cur.execute(
-            "insert into test_tpc values ('test_tpc_commit_rec')"
-        )
+        await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
         assert tpc.count_xacts() == 0
         assert tpc.count_test_records() == 0
 
@@ -115,9 +113,7 @@ class TestTPC:
         assert aconn.info.transaction_status == TransactionStatus.INTRANS
 
         cur = aconn.cursor()
-        await cur.execute(
-            "insert into test_tpc values ('test_tpc_rollback_1p')"
-        )
+        await cur.execute("insert into test_tpc values ('test_tpc_rollback_1p')")
         assert tpc.count_xacts() == 0
         assert tpc.count_test_records() == 0
 
@@ -134,9 +130,7 @@ class TestTPC:
         assert aconn.info.transaction_status == TransactionStatus.INTRANS
 
         cur = aconn.cursor()
-        await cur.execute(
-            "insert into test_tpc values ('test_tpc_commit_rec')"
-        )
+        await cur.execute("insert into test_tpc values ('test_tpc_commit_rec')")
         assert tpc.count_xacts() == 0
         assert tpc.count_test_records() == 0
 
@@ -222,9 +216,7 @@ class TestTPC:
 
         async with await psycopg.AsyncConnection.connect(dsn) as aconn:
             xids = [
-                x
-                for x in await aconn.tpc_recover()
-                if x.database == aconn.info.dbname
+                x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
             ]
             assert len(xids) == 1
             xid = xids[0]
@@ -249,9 +241,7 @@ class TestTPC:
 
         async with await psycopg.AsyncConnection.connect(dsn) as aconn:
             xids = [
-                x
-                for x in await aconn.tpc_recover()
-                if x.database == aconn.info.dbname
+                x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
             ]
             assert len(xids) == 1
             xid = xids[0]
@@ -269,9 +259,7 @@ class TestTPC:
 
         async with await psycopg.AsyncConnection.connect(dsn) as aconn:
             xid = [
-                x
-                for x in await aconn.tpc_recover()
-                if x.database == aconn.info.dbname
+                x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
             ][0]
 
         assert 10 == xid.format_id
@@ -291,9 +279,7 @@ class TestTPC:
 
         async with await psycopg.AsyncConnection.connect(dsn) as aconn:
             xid = [
-                x
-                for x in await aconn.tpc_recover()
-                if x.database == aconn.info.dbname
+                x for x in await aconn.tpc_recover() if x.database == aconn.info.dbname
             ][0]
 
         assert xid.format_id is None
index 7db3b1dbad82fd3de623a6958b86ef4af9689ec2..68d25df6b7a0bcba2af83b7de6218b0c73400f01 100644 (file)
@@ -158,9 +158,7 @@ def test_context_active_rollback_no_clobber(dsn, caplog):
     try:
         with pytest.raises(ZeroDivisionError):
             with conn.transaction():
-                conn.pgconn.exec_(
-                    b"copy (select generate_series(1, 10)) to stdout"
-                )
+                conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
                 status = conn.info.transaction_status
                 assert status == conn.TransactionStatus.ACTIVE
                 1 / 0
index 51bc023ddbbb29af4a5e1c9132614cab96416872..d6832b98f9dffc9ed53da3fee0383c8bff76f9b5 100644 (file)
@@ -101,9 +101,7 @@ async def test_context_active_rollback_no_clobber(dsn, caplog):
     try:
         with pytest.raises(ZeroDivisionError):
             async with conn.transaction():
-                conn.pgconn.exec_(
-                    b"copy (select generate_series(1, 10)) to stdout"
-                )
+                conn.pgconn.exec_(b"copy (select generate_series(1, 10)) to stdout")
                 status = conn.info.transaction_status
                 assert status == conn.TransactionStatus.ACTIVE
                 1 / 0
@@ -165,9 +163,7 @@ async def test_preserves_autocommit(aconn, autocommit):
     assert aconn.autocommit is autocommit
 
 
-async def test_autocommit_off_but_no_tx_started_successful_exit(
-    aconn, svcconn
-):
+async def test_autocommit_off_but_no_tx_started_successful_exit(aconn, svcconn):
     """
     Scenario:
     * Connection has autocommit off but no transaction has been initiated
@@ -211,9 +207,7 @@ async def test_autocommit_off_but_no_tx_started_exception_exit(aconn, svcconn):
     assert not inserted(svcconn)
 
 
-async def test_autocommit_off_and_tx_in_progress_successful_exit(
-    aconn, svcconn
-):
+async def test_autocommit_off_and_tx_in_progress_successful_exit(aconn, svcconn):
     """
     Scenario:
     * Connection has autocommit off but and a transaction is already in
@@ -236,9 +230,7 @@ async def test_autocommit_off_and_tx_in_progress_successful_exit(
     assert not inserted(svcconn)
 
 
-async def test_autocommit_off_and_tx_in_progress_exception_exit(
-    aconn, svcconn
-):
+async def test_autocommit_off_and_tx_in_progress_exception_exit(aconn, svcconn):
     """
     Scenario:
     * Connection has autocommit off but and a transaction is already in
@@ -308,9 +300,7 @@ async def test_nested_all_changes_discarded_on_inner_exception(aconn, svcconn):
     assert not inserted(svcconn)
 
 
-async def test_nested_inner_scope_exception_handled_in_outer_scope(
-    aconn, svcconn
-):
+async def test_nested_inner_scope_exception_handled_in_outer_scope(aconn, svcconn):
     """
     An exception escaping the inner transaction context causes changes made
     within that inner context to be discarded, but the error can then be
@@ -578,9 +568,7 @@ async def test_explicit_rollback_of_outer_transaction(aconn):
     assert not await inserted(aconn)
 
 
-async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(
-    aconn, svcconn
-):
+async def test_explicit_rollback_of_enclosing_tx_outer_tx_unaffected(aconn, svcconn):
     """
     Rolling-back an enclosing transaction does not impact an outer transaction.
     """
index 9937ff7d1a93b0d13f2fc0498b449d80461537e9..c8c664bb28cff5fb22009a981dccda74c0cf6d5d 100644 (file)
@@ -295,9 +295,7 @@ class MyCursor(psycopg.{cur_base_class}[Row]):
 
 
 def _test_reveal(stmts, type, mypy):
-    ignore = (
-        "" if type.startswith("Optional") else "# type: ignore[assignment]"
-    )
+    ignore = "" if type.startswith("Optional") else "# type: ignore[assignment]"
     stmts = "\n".join(f"    {line}" for line in stmts.splitlines())
 
     src = f"""\
index 38ec59ed8d947ef1218bc9186ae25a3c4b175cc8..855e2855de8b83991d777e945e88214d4787dae5 100644 (file)
@@ -139,9 +139,7 @@ def test_array_of_unknown_builtin(conn):
     assert res[1] == [val]
 
 
-@pytest.mark.parametrize(
-    "array, type", [([1, 32767], "int2"), ([1, 32768], "int4")]
-)
+@pytest.mark.parametrize("array, type", [([1, 32767], "int2"), ([1, 32768], "int4")])
 def test_array_mixed_numbers(array, type):
     tx = Transformer()
     dumper = tx.get_dumper(array, PyFormat.BINARY)
@@ -149,9 +147,7 @@ def test_array_mixed_numbers(array, type):
     assert dumper.oid == builtins[type].array_oid
 
 
-@pytest.mark.parametrize(
-    "wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split()
-)
+@pytest.mark.parametrize("wrapper", "Int2 Int4 Int8 Float4 Float8 Decimal".split())
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_list_number_wrapper(conn, wrapper, fmt_in, fmt_out):
@@ -213,9 +209,7 @@ def test_empty_list(conn, fmt_in):
 def test_empty_list_after_choice(conn, fmt_in):
     cur = conn.cursor()
     cur.execute("create table test (id serial primary key, data float[])")
-    cur.executemany(
-        f"insert into test (data) values (%{fmt_in})", [([1.0],), ([],)]
-    )
+    cur.executemany(f"insert into test (data) values (%{fmt_in})", [([1.0],), ([],)])
     cur.execute("select data from test order by id")
     assert cur.fetchall() == [([1.0],), ([],)]
 
index 85a62b48892b1e485b069d8d1461b9f97dfe60c7..9a6a8e8f5f60142306bbb7db8d6f26bc80d6505e 100644 (file)
@@ -28,9 +28,9 @@ def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
 def test_quote_bool(conn, val):
 
     tx = Transformer()
-    assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(
-        val
-    ).lower().encode("ascii")
+    assert tx.get_dumper(val, PyFormat.TEXT).quote(val) == str(val).lower().encode(
+        "ascii"
+    )
 
     cur = conn.cursor()
     cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
index 93ba6f44dc27ab5688227dae566246b0d795458e..a3f7d5284340f0abe7796b920c1214b32e9ef5bb 100644 (file)
@@ -53,9 +53,7 @@ def test_load_all_chars(conn, fmt_out):
         res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
         assert res == (chr(i),)
 
-    cur.execute(
-        "select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256))
-    )
+    cur.execute("select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256)))
     res = cur.fetchone()[0]
     assert res == tuple(map(chr, range(1, 256)))
 
@@ -98,8 +96,7 @@ def test_dump_builtin_empty_range(conn, fmt_in):
             (b"foo'", b"'foo", b'"bar', b'bar"'),
         ),
         (
-            "10::int, null::text, 20::float,"
-            " null::text, 'foo'::text, 'bar'::bytea ",
+            "10::int, null::text, 20::float, null::text, 'foo'::text, 'bar'::bytea ",
             (10, None, 20.0, None, "foo", b"bar"),
         ),
     ],
@@ -228,9 +225,7 @@ def test_load_composite(conn, testcomp, fmt_out):
     assert res.baz == 20.0
     assert isinstance(res.baz, float)
 
-    res = cur.execute(
-        "select array[row('hello', 10, 30)::testcomp]"
-    ).fetchone()[0]
+    res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
     assert len(res) == 1
     assert res[0].baz == 30.0
     assert isinstance(res[0].baz, float)
@@ -253,9 +248,7 @@ def test_load_composite_factory(conn, testcomp, fmt_out):
     assert res.baz == 20.0
     assert isinstance(res.baz, float)
 
-    res = cur.execute(
-        "select array[row('hello', 10, 30)::testcomp]"
-    ).fetchone()[0]
+    res = cur.execute("select array[row('hello', 10, 30)::testcomp]").fetchone()[0]
     assert len(res) == 1
     assert res[0].baz == 30.0
     assert isinstance(res[0].baz, float)
index 8d50a85aa583f683a2bb967d2cd66ea3d393d7cf..79ceeefb341cdafec61ad7e339be973df47a14b3 100644 (file)
@@ -58,9 +58,7 @@ class TestDate:
         cur.execute(f"select '{expr}'::date")
         assert cur.fetchone()[0] == as_date(val)
 
-    @pytest.mark.parametrize(
-        "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
-    )
+    @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
     def test_load_date_datestyle(self, conn, datestyle_out):
         cur = conn.cursor(binary=False)
         cur.execute(f"set datestyle = {datestyle_out}, YMD")
@@ -68,24 +66,18 @@ class TestDate:
         assert cur.fetchone()[0] == dt.date(2000, 1, 2)
 
     @pytest.mark.parametrize("val", ["min", "max"])
-    @pytest.mark.parametrize(
-        "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
-    )
+    @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
     def test_load_date_overflow(self, conn, val, datestyle_out):
         cur = conn.cursor(binary=False)
         cur.execute(f"set datestyle = {datestyle_out}, YMD")
-        cur.execute(
-            "select %t + %s::int", (as_date(val), -1 if val == "min" else 1)
-        )
+        cur.execute("select %t + %s::int", (as_date(val), -1 if val == "min" else 1))
         with pytest.raises(DataError):
             cur.fetchone()[0]
 
     @pytest.mark.parametrize("val", ["min", "max"])
     def test_load_date_overflow_binary(self, conn, val):
         cur = conn.cursor(binary=True)
-        cur.execute(
-            "select %s + %s::int", (as_date(val), -1 if val == "min" else 1)
-        )
+        cur.execute("select %s + %s::int", (as_date(val), -1 if val == "min" else 1))
         with pytest.raises(DataError):
             cur.fetchone()[0]
 
@@ -148,9 +140,7 @@ class TestDatetime:
     ]
 
     @pytest.mark.parametrize("val, expr", load_datetime_samples)
-    @pytest.mark.parametrize(
-        "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
-    )
+    @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
     @pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
     def test_load_datetime(self, conn, val, expr, datestyle_in, datestyle_out):
         cur = conn.cursor(binary=False)
@@ -167,9 +157,7 @@ class TestDatetime:
         assert cur.fetchone()[0] == as_dt(val)
 
     @pytest.mark.parametrize("val", ["min", "max"])
-    @pytest.mark.parametrize(
-        "datestyle_out", ["ISO", "Postgres", "SQL", "German"]
-    )
+    @pytest.mark.parametrize("datestyle_out", ["ISO", "Postgres", "SQL", "German"])
     def test_load_datetime_overflow(self, conn, val, datestyle_out):
         cur = conn.cursor(binary=False)
         cur.execute(f"set datestyle = {datestyle_out}, YMD")
@@ -283,9 +271,7 @@ class TestDateTimeTz:
     @pytest.mark.parametrize("val, expr", [("2000,1,1~2", "2000-01-01")])
     @pytest.mark.parametrize("datestyle_out", ["SQL", "Postgres", "German"])
     @pytest.mark.parametrize("datestyle_in", ["DMY", "MDY", "YMD"])
-    def test_load_datetimetz_tzname(
-        self, conn, val, expr, datestyle_in, datestyle_out
-    ):
+    def test_load_datetimetz_tzname(self, conn, val, expr, datestyle_in, datestyle_out):
         cur = conn.cursor(binary=False)
         cur.execute(f"set datestyle = {datestyle_out}, {datestyle_in}")
         cur.execute("set timezone to '-02:00'")
@@ -627,9 +613,7 @@ class TestInterval:
             "SELECT %s::text, %s::text", [date(2020, 12, 31), date.max]
         ).fetchone()
         assert rec == ("2020-12-31", "infinity")
-        rec = cur.execute(
-            "select '2020-12-31'::date, 'infinity'::date"
-        ).fetchone()
+        rec = cur.execute("select '2020-12-31'::date, 'infinity'::date").fetchone()
         assert rec == (date(2020, 12, 31), date(9999, 12, 31))
 
     def test_load_copy(self, conn):
@@ -657,9 +641,7 @@ class TestInterval:
 
 
 def as_date(s):
-    return (
-        dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
-    )
+    return dt.date(*map(int, s.split(","))) if "," in s else getattr(dt.date, s)
 
 
 def as_time(s):
index ba7acdb256ae36a5d388b74f92e4bb6218e10ac2..bdbd837ea0f329231cd9b1df5091839e689a0f0b 100644 (file)
@@ -238,9 +238,7 @@ def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
     wrapper = getattr(multirange, wrapper)
     mr1 = Multirange()  # type: ignore[var-annotated]
     mr2 = Multirange([Range(bounds="()")])  # type: ignore[var-annotated]
-    cur = conn.execute(
-        f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],)
-    )
+    cur = conn.execute(f"""select '{{"{{}}","{{(,)}}"}}' = %{fmt_in}""", ([mr1, mr2],))
     assert cur.fetchone()[0] is True
 
 
@@ -300,9 +298,7 @@ def test_load_builtin_range(conn, pgtype, ranges, fmt_out):
 @pytest.mark.parametrize("format", pq.Format)
 def test_copy_in(conn, min, max, bounds, format):
     cur = conn.cursor()
-    cur.execute(
-        "create table copymr (id serial primary key, mr datemultirange)"
-    )
+    cur.execute("create table copymr (id serial primary key, mr datemultirange)")
 
     if bounds != "empty":
         min = dt.date(*map(int, min.split(","))) if min else None
@@ -313,15 +309,11 @@ def test_copy_in(conn, min, max, bounds, format):
 
     mr = Multirange([r])
     try:
-        with cur.copy(
-            f"copy copymr (mr) from stdin (format {format.name})"
-        ) as copy:
+        with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
             copy.write_row([mr])
     except e.InternalError_:
         if not min and not max and format == pq.Format.BINARY:
-            pytest.xfail(
-                "TODO: add annotation to dump multirange with no type info"
-            )
+            pytest.xfail("TODO: add annotation to dump multirange with no type info")
         else:
             raise
 
@@ -336,15 +328,11 @@ def test_copy_in(conn, min, max, bounds, format):
 @pytest.mark.parametrize("format", pq.Format)
 def test_copy_in_empty_wrappers(conn, wrapper, format):
     cur = conn.cursor()
-    cur.execute(
-        "create table copymr (id serial primary key, mr datemultirange)"
-    )
+    cur.execute("create table copymr (id serial primary key, mr datemultirange)")
 
     mr = getattr(multirange, wrapper)()
 
-    with cur.copy(
-        f"copy copymr (mr) from stdin (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
         copy.write_row([mr])
 
     rec = cur.execute("select mr from copymr order by id").fetchone()
@@ -359,9 +347,7 @@ def test_copy_in_empty_set_type(conn, pgtype, format):
 
     mr = Multirange()  # type: ignore[var-annotated]
 
-    with cur.copy(
-        f"copy copymr (mr) from stdin (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy copymr (mr) from stdin (format {format.name})") as copy:
         copy.set_types([pgtype])
         copy.write_row([mr])
 
index 1deaf2d73e82298e8fa3ec508635f13f8d592f10..6a6fbee1ff50859bbaf676a0f3e196e6614ba8a8 100644 (file)
@@ -13,9 +13,7 @@ from psycopg.adapt import PyFormat
 @pytest.mark.parametrize("val", ["192.168.0.1", "2001:db8::"])
 def test_address_dump(conn, fmt_in, val):
     cur = conn.cursor()
-    cur.execute(
-        f"select %{fmt_in} = %s::inet", (ipaddress.ip_address(val), val)
-    )
+    cur.execute(f"select %{fmt_in} = %s::inet", (ipaddress.ip_address(val), val))
     assert cur.fetchone()[0] is True
     cur.execute(
         f"select %{fmt_in} = array[null, %s]::inet[]",
@@ -44,9 +42,7 @@ def test_interface_dump(conn, fmt_in, val):
 @pytest.mark.parametrize("val", ["127.0.0.0/24", "::ffff:102:300/128"])
 def test_network_dump(conn, fmt_in, val):
     cur = conn.cursor()
-    cur.execute(
-        f"select %{fmt_in} = %s::cidr", (ipaddress.ip_network(val), val)
-    )
+    cur.execute(f"select %{fmt_in} = %s::cidr", (ipaddress.ip_network(val), val))
     assert cur.fetchone()[0] is True
     cur.execute(
         f"select %{fmt_in} = array[NULL, %s]::cidr[]",
index ee8d39c7a04eeb80418e096a2f3cff9dcdd293e4..d8fd8e069c62529878e6f71031092e6897308cb8 100644 (file)
@@ -216,14 +216,10 @@ def test_quote_float(conn, val, expr):
 def test_dump_float_approx(conn, val, expr):
     assert isinstance(val, float)
     cur = conn.cursor()
-    cur.execute(
-        f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)
-    )
+    cur.execute(f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,))
     assert cur.fetchone()[0] is True
 
-    cur.execute(
-        f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)
-    )
+    cur.execute(f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,))
     assert cur.fetchone()[0] is True
 
 
@@ -417,9 +413,7 @@ def test_dump_numeric_exhaustive(conn, fmt_in):
         for f in funcs:
             expr = f(i)
             val = Decimal(expr)
-            cur.execute(
-                f"select %{fmt_in}::text, %s::decimal::text", [val, expr]
-            )
+            cur.execute(f"select %{fmt_in}::text, %s::decimal::text", [val, expr])
             want, got = cur.fetchone()
             assert got == want
 
index d1934e799f88c4d2210a72ed8dd0509d7ed303c2..dbe8ad2027ad51c5e7d1975e972541005a13bcb4 100644 (file)
@@ -118,9 +118,7 @@ def test_dump_builtin_array_wrapper(conn, wrapper, fmt_in):
     wrapper = getattr(range_module, wrapper)
     r1 = wrapper(empty=True)
     r2 = wrapper(bounds="()")
-    cur = conn.execute(
-        f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],)
-    )
+    cur = conn.execute(f"""select '{{empty,"(,)"}}' = %{fmt_in}""", ([r1, r2],))
     assert cur.fetchone()[0] is True
 
 
@@ -168,9 +166,7 @@ def test_load_builtin_array(conn, pgtype, fmt_out):
     r1 = Range(empty=True)  # type: ignore[var-annotated]
     r2 = Range(bounds="()")  # type: ignore[var-annotated]
     cur = conn.cursor(binary=fmt_out)
-    (got,) = cur.execute(
-        f"select array['empty'::{pgtype}, '(,)'::{pgtype}]"
-    ).fetchone()
+    (got,) = cur.execute(f"select array['empty'::{pgtype}, '(,)'::{pgtype}]").fetchone()
     assert got == [r1, r2]
 
 
@@ -180,9 +176,7 @@ def test_load_builtin_range(conn, pgtype, min, max, bounds, fmt_out):
     r = Range(min, max, bounds)  # type: ignore[var-annotated]
     sub = type2sub[pgtype]
     cur = conn.cursor(binary=fmt_out)
-    cur.execute(
-        f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
-    )
+    cur.execute(f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds))
     # normalise discrete ranges
     if r.upper_inc and isinstance(r.upper, int):
         bounds = "[)" if r.lower_inc else "()"
@@ -213,15 +207,11 @@ def test_copy_in(conn, min, max, bounds, format):
         r = Range(empty=True)
 
     try:
-        with cur.copy(
-            f"copy copyrange (r) from stdin (format {format.name})"
-        ) as copy:
+        with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
             copy.write_row([r])
     except e.ProtocolViolation:
         if not min and not max and format == pq.Format.BINARY:
-            pytest.xfail(
-                "TODO: add annotation to dump ranges with no type info"
-            )
+            pytest.xfail("TODO: add annotation to dump ranges with no type info")
         else:
             raise
 
@@ -239,9 +229,7 @@ def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
     cls = getattr(range_module, wrapper)
     r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
 
-    with cur.copy(
-        f"copy copyrange (r) from stdin (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
         copy.write_row([r])
 
     rec = cur.execute("select r from copyrange order by id").fetchone()
@@ -260,9 +248,7 @@ def test_copy_in_empty_set_type(conn, bounds, pgtype, format):
     else:
         r = Range(None, None, bounds)
 
-    with cur.copy(
-        f"copy copyrange (r) from stdin (format {format.name})"
-    ) as copy:
+    with cur.copy(f"copy copyrange (r) from stdin (format {format.name})") as copy:
         copy.set_types([pgtype])
         copy.write_row([r])
 
index 02bc210d27dba00a7b82f2201e1b3f485cbf98e8..a460c483d473ba31c8e3d49d7988e5ee7466c855 100644 (file)
@@ -71,9 +71,7 @@ def shapely_conn(conn, svcconn):
 
 def test_no_adapter(conn):
     point = Point(1.2, 3.4)
-    with pytest.raises(
-        psycopg.ProgrammingError, match="cannot adapt type 'Point'"
-    ):
+    with pytest.raises(psycopg.ProgrammingError, match="cannot adapt type 'Point'"):
         conn.execute("SELECT pg_typeof(%s)", [point]).fetchone()[0]
 
 
@@ -89,16 +87,12 @@ def test_with_adapter(shapely_conn):
     SAMPLE_POLYGON = Polygon([(0, 0), (1, 1), (1, 0)])
 
     assert (
-        shapely_conn.execute(
-            "SELECT pg_typeof(%s)", [SAMPLE_POINT]
-        ).fetchone()[0]
+        shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POINT]).fetchone()[0]
         == "geometry"
     )
 
     assert (
-        shapely_conn.execute(
-            "SELECT pg_typeof(%s)", [SAMPLE_POLYGON]
-        ).fetchone()[0]
+        shapely_conn.execute("SELECT pg_typeof(%s)", [SAMPLE_POLYGON]).fetchone()[0]
         == "geometry"
     )
 
index 6070ed42c287461b5d81e2ea8a387019d4d5c568..ec04b0900fe3ca8745b8f801a3189dc3e519390d 100644 (file)
@@ -72,9 +72,7 @@ def test_quote_percent(conn):
     assert cur.fetchone()[0] is True
 
 
-@pytest.mark.parametrize(
-    "typename", ["text", "varchar", "name", "bpchar", '"char"']
-)
+@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar", '"char"'])
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_load_1char(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out)
index 72e6fdad7e63d92249080833020f633cf1939226..a02827f985c003126769f3e79860d352b86bb987 100644 (file)
@@ -41,9 +41,7 @@ def _check_version(got, want, whose_version):
         got = (got_maj, got_min, got_fix)
 
     # Parse a spec like "> 9.6"
-    m = re.match(
-        r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want
-    )
+    m = re.match(r"^\s*(>=|<=|>|<)\s*(?:(\d+)(?:\.(\d+)(?:\.(\d+))?)?)?\s*$", want)
     if m is None:
         pytest.fail(f"bad wanted version spec: {want}")
 
@@ -58,9 +56,7 @@ def _check_version(got, want, whose_version):
     else:
         want = (want_maj, want_min, want_fix)
 
-    op = getattr(
-        operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)]
-    )
+    op = getattr(operator, {">=": "ge", "<=": "le", ">": "gt", "<": "lt"}[m.group(1)])
 
     if not op(got, want):
         revops = {">=": "<", "<=": ">", ">": "<=", "<": ">="}
index de3c0014f360496382d93f8216bd2ec350027463..00885276d3250564fb5744d394c8daeca3ed17d2 100755 (executable)
@@ -9,9 +9,7 @@ from pathlib import Path
 from ruamel.yaml import YAML  # pip install ruamel.yaml
 
 logger = logging.getLogger()
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
-)
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
 
 
 def fetch_user(username):
index bc829bb8b0013e5032dfb057db96e61ab5c4f40d..56bda5855120e45a07063d095a93cf78b35a70f6 100755 (executable)
@@ -19,9 +19,7 @@ from collections import defaultdict, namedtuple
 from psycopg.errors import get_base_exception
 
 logger = logging.getLogger()
-logging.basicConfig(
-    level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s"
-)
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
 
 
 def main():
@@ -56,9 +54,7 @@ def parse_errors_txt(url):
             continue
 
         # Parse an error
-        m = re.match(
-            r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line
-        )
+        m = re.match(r"(.....)\s+(?:E|W|S)\s+ERRCODE_(\S+)(?:\s+(\S+))?$", line)
         if m:
             sqlstate, macro, spec = m.groups()
             # skip sqlstates without specs as they are not publically visible
index 4d5ac72c5a6d70f421180152882e7a4e2ddb5be2..743ad91adce5f59d2804614fcaf6f336d51bee03 100755 (executable)
@@ -105,9 +105,7 @@ def update_file(fn: Path, queries: List[str]) -> None:
 
     new = []
     for query in queries:
-        out = sp.run(
-            ["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True
-        )
+        out = sp.run(["psql", "-AXqt", "-c", query], stdout=sp.PIPE, check=True)
         new.extend(out.stdout.splitlines())
 
     new = [b" " * 4 + line if line else b"" for line in new]  # indent
diff --git a/tox.ini b/tox.ini
index 58c5f7f0c232f208747256b9565e61d0800dcda5..50bda0c8fb0af47ed40e26b1e074fc596c35bcbe 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -33,6 +33,6 @@ deps =
     codespell
 
 [flake8]
-max-line-length = 85
+max-line-length = 88
 ignore = W503, E203
 extend-exclude = .venv