]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: clean internal use of LiteralString
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 27 Jun 2022 21:32:20 +0000 (22:32 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Jul 2022 02:52:59 +0000 (03:52 +0100)
Tested one-off using pyre 0.9.13; however it gives too many differences
compared to mypy to use it at the moment.

The pyre run doesn't currently find LiteralString-related problems,
except the todo at psycopg/sql.py:251, because string.Formatter.parse()
doesn't return a LiteralString upon LiteralString input.

psycopg/psycopg/connection.py
psycopg/psycopg/sql.py

index 4da145b3c02f4d95f4c50da70a0fc0cf9b8bef79..0cc24ea20bc07a97e47fd164ce4e3e0b440291dc 100644 (file)
@@ -27,7 +27,7 @@ from .rows import Row, RowFactory, tuple_row, TupleRow, args_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
-from ._compat import TypeAlias
+from ._compat import TypeAlias, LiteralString
 from ._cmodule import _psycopg
 from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
 from ._pipeline import BasePipeline, Pipeline
@@ -614,8 +614,10 @@ class BaseConnection(Generic[Row]):
         if self._pipeline:
             yield from self._pipeline._sync_gen()
 
-    def _tpc_finish_gen(self, action: str, xid: Union[Xid, str, None]) -> PQGen[None]:
-        fname = f"tpc_{action}()"
+    def _tpc_finish_gen(
+        self, action: LiteralString, xid: Union[Xid, str, None]
+    ) -> PQGen[None]:
+        fname = f"tpc_{action.lower()}()"
         if xid is None:
             if not self._tpc:
                 raise e.ProgrammingError(
@@ -634,12 +636,12 @@ class BaseConnection(Generic[Row]):
 
         if self._tpc and not self._tpc[1]:
             meth: Callable[[], PQGen[None]]
-            meth = getattr(self, f"_{action}_gen")
+            meth = getattr(self, f"_{action.lower()}_gen")
             self._tpc = None
             yield from meth()
         else:
             yield from self._exec_command(
-                SQL("{} PREPARED {}").format(SQL(action.upper()), str(xid))
+                SQL("{} PREPARED {}").format(SQL(action), str(xid))
             )
             self._tpc = None
 
@@ -1006,14 +1008,14 @@ class Connection(BaseConnection[Row]):
         Commit a prepared two-phase transaction.
         """
         with self.lock:
-            self.wait(self._tpc_finish_gen("commit", xid))
+            self.wait(self._tpc_finish_gen("COMMIT", xid))
 
     def tpc_rollback(self, xid: Union[Xid, str, None] = None) -> None:
         """
         Roll back a prepared two-phase transaction.
         """
         with self.lock:
-            self.wait(self._tpc_finish_gen("rollback", xid))
+            self.wait(self._tpc_finish_gen("ROLLBACK", xid))
 
     def tpc_recover(self) -> List[Xid]:
         status = self.info.transaction_status
index df44fdf51ccb6fedc061f68403dad7dbf71dac38..45c4f6bbf972d6c90a3aded78a561f79ec063186 100644 (file)
@@ -143,7 +143,7 @@ class Composed(Composable):
         else:
             return NotImplemented
 
-    def join(self, joiner: Union["SQL", str]) -> "Composed":
+    def join(self, joiner: Union["SQL", LiteralString]) -> "Composed":
         """
         Return a new `!Composed` interposing the *joiner* with the `!Composed` items.
 
@@ -191,7 +191,7 @@ class SQL(Composable):
         SELECT "foo", "bar" FROM "table"
     """
 
-    _obj: str
+    _obj: LiteralString
     _formatter = string.Formatter()
 
     def __init__(self, obj: LiteralString):
@@ -245,6 +245,9 @@ class SQL(Composable):
         """
         rv: List[Composable] = []
         autonum: Optional[int] = 0
+        # TODO: this is probably not the right way to whitelist pre
+        # pyre complains. Will wait for mypy to complain too to fix.
+        pre: LiteralString
         for pre, name, spec, conv in self._formatter.parse(self._obj):
             if spec:
                 raise ValueError("no format specification supported by SQL")