]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(connection): harmless manipulation to minimize async conversion diff
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 2 Sep 2023 21:10:05 +0000 (22:10 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 11 Oct 2023 21:45:38 +0000 (23:45 +0200)
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py

index 4075c1fe493fdaa2799aed6392f6fddc9a965761..ad8ff9e49c14e17d8d466b05659fb216f6ea2d75 100644 (file)
@@ -7,16 +7,14 @@ psycopg connection objects
 import logging
 import threading
 from types import TracebackType
-from typing import Any, cast, Dict, Generator, Iterator
-from typing import List, Optional, Type, TypeVar, Union
-from typing import overload, TYPE_CHECKING
+from typing import Any, Generator, Iterator, Dict, List, Optional
+from typing import Type, TypeVar, Union, cast, overload, TYPE_CHECKING
 from contextlib import contextmanager
 
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, Params, Query, RV
-from .abc import PQGen, PQGenConn
+from .abc import AdaptContext, Params, PQGen, PQGenConn, Query, RV
 from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
@@ -24,8 +22,8 @@ from ._enums import IsolationLevel
 from .cursor import Cursor
 from .conninfo import make_conninfo, conninfo_to_dict
 from ._pipeline import Pipeline
-from .generators import notifies
 from ._encodings import pgconn_encoding
+from .generators import notifies
 from .transaction import Transaction
 from .server_cursor import ServerCursor
 from ._connection_base import BaseConnection, CursorRow, Notify
@@ -111,6 +109,7 @@ class Connection(BaseConnection[Row]):
         """
         Connect to a database server and return a new `Connection` instance.
         """
+
         params = cls._get_connection_params(conninfo, **kwargs)
         conninfo = make_conninfo(**params)
 
@@ -149,11 +148,7 @@ class Connection(BaseConnection[Row]):
             try:
                 self.rollback()
             except Exception as exc2:
-                logger.warning(
-                    "error ignored in rollback on %s: %s",
-                    self,
-                    exc2,
-                )
+                logger.warning("error ignored in rollback on %s: %s", self, exc2)
         else:
             self.commit()
 
@@ -233,7 +228,7 @@ class Connection(BaseConnection[Row]):
         withhold: bool = False,
     ) -> Union[Cursor[Any], ServerCursor[Any]]:
         """
-        Return a new cursor to send commands and queries to the connection.
+        Return a new `Cursor` to send commands and queries to the connection.
         """
         self._check_connection_ok()
 
@@ -272,7 +267,6 @@ class Connection(BaseConnection[Row]):
                 cur.format = BINARY
 
             return cur.execute(query, params, prepare=prepare)
-
         except e._NO_TRACEBACK as ex:
             raise ex.with_traceback(None)
 
@@ -288,9 +282,7 @@ class Connection(BaseConnection[Row]):
 
     @contextmanager
     def transaction(
-        self,
-        savepoint_name: Optional[str] = None,
-        force_rollback: bool = False,
+        self, savepoint_name: Optional[str] = None, force_rollback: bool = False
     ) -> Iterator[Transaction]:
         """
         Start a context block with a new transaction or nested transaction.
@@ -326,7 +318,7 @@ class Connection(BaseConnection[Row]):
 
     @contextmanager
     def pipeline(self) -> Iterator[Pipeline]:
-        """Switch the connection into pipeline mode."""
+        """Context manager to switch the connection into pipeline mode."""
         with self.lock:
             self._check_connection_ok()
 
index 412b6bd4431a22a28e8683dd27a081a38a6cb2fa..92709003f802338262134ec691304f94bd26dd26 100644 (file)
@@ -155,11 +155,7 @@ class AsyncConnection(BaseConnection[Row]):
             try:
                 await self.rollback()
             except Exception as exc2:
-                logger.warning(
-                    "error ignored in rollback on %s: %s",
-                    self,
-                    exc2,
-                )
+                logger.warning("error ignored in rollback on %s: %s", self, exc2)
         else:
             await self.commit()
 
@@ -299,9 +295,7 @@ class AsyncConnection(BaseConnection[Row]):
 
     @asynccontextmanager
     async def transaction(
-        self,
-        savepoint_name: Optional[str] = None,
-        force_rollback: bool = False,
+        self, savepoint_name: Optional[str] = None, force_rollback: bool = False
     ) -> AsyncIterator[AsyncTransaction]:
         """
         Start a context block with a new transaction or nested transaction.
@@ -439,14 +433,14 @@ class AsyncConnection(BaseConnection[Row]):
         Commit a prepared two-phase transaction.
         """
         async with self.lock:
-            await self.wait(self._tpc_finish_gen("commit", xid))
+            await self.wait(self._tpc_finish_gen("COMMIT", xid))
 
     async def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
         """
         Roll back a prepared two-phase transaction.
         """
         async with self.lock:
-            await self.wait(self._tpc_finish_gen("rollback", xid))
+            await self.wait(self._tpc_finish_gen("ROLLBACK", xid))
 
     async def tpc_recover(self) -> List[Xid]:
         self._check_tpc()