]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Make wider use of Composite.as_bytes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 04:13:25 +0000 (05:13 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 04:57:09 +0000 (05:57 +0100)
psycopg3/psycopg3/_queries.py
psycopg3/psycopg3/transaction.py

index b26ee2f50f999b80d53c51162faaad1b9d2e6d5c..15085332db48d7336417e7f245d741f8bc294467 100644 (file)
@@ -48,7 +48,7 @@ class PostgresQuery:
         attributes (`query`, `params`, `types`, `formats`).
         """
         if isinstance(query, Composable):
-            query = query.as_string(self._tx)
+            query = query.as_bytes(self._tx)
 
         if vars is not None:
             self.query, self.formats, self._order, self._parts = _query2pg(
index 9e279688d786f35123c2e734e146075b39f13a9e..21c28fda01bb325fe6f6adbe1f2cfbb49524f1c4 100644 (file)
@@ -74,7 +74,7 @@ class BaseTransaction(Generic[ConnectionType]):
             args.append("force_rollback=True")
         return f"{self.__class__.__qualname__}({', '.join(args)})"
 
-    def _enter_commands(self) -> List[str]:
+    def _enter_commands(self) -> List[bytes]:
         if not self._yolo:
             raise TypeError("transaction blocks can be used only once")
         else:
@@ -97,19 +97,19 @@ class BaseTransaction(Generic[ConnectionType]):
         commands = []
         if self._outer_transaction:
             assert not self._conn._savepoints, self._conn._savepoints
-            commands.append("begin")
+            commands.append(b"begin")
 
         if self._savepoint_name:
             commands.append(
                 sql.SQL("savepoint {}")
                 .format(sql.Identifier(self._savepoint_name))
-                .as_string(self._conn)
+                .as_bytes(self._conn)
             )
 
         self._conn._savepoints.append(self._savepoint_name)
         return commands
 
-    def _commit_commands(self) -> List[str]:
+    def _commit_commands(self) -> List[bytes]:
         assert self._conn._savepoints[-1] == self._savepoint_name
         self._conn._savepoints.pop()
 
@@ -118,16 +118,16 @@ class BaseTransaction(Generic[ConnectionType]):
             commands.append(
                 sql.SQL("release {}")
                 .format(sql.Identifier(self._savepoint_name))
-                .as_string(self._conn)
+                .as_bytes(self._conn)
             )
 
         if self._outer_transaction:
             assert not self._conn._savepoints
-            commands.append("commit")
+            commands.append(b"commit")
 
         return commands
 
-    def _rollback_commands(self) -> List[str]:
+    def _rollback_commands(self) -> List[bytes]:
         assert self._conn._savepoints[-1] == self._savepoint_name
         self._conn._savepoints.pop()
 
@@ -136,12 +136,12 @@ class BaseTransaction(Generic[ConnectionType]):
             commands.append(
                 sql.SQL("rollback to {n}; release {n}")
                 .format(n=sql.Identifier(self._savepoint_name))
-                .as_string(self._conn)
+                .as_bytes(self._conn)
             )
 
         if self._outer_transaction:
             assert not self._conn._savepoints
-            commands.append("rollback")
+            commands.append(b"rollback")
 
         return commands
 
@@ -190,8 +190,8 @@ class Transaction(BaseTransaction["Connection"]):
 
         return False
 
-    def _execute(self, commands: List[str]) -> None:
-        self._conn._exec_command("; ".join(commands))
+    def _execute(self, commands: List[bytes]) -> None:
+        self._conn._exec_command(b"; ".join(commands))
 
 
 class AsyncTransaction(BaseTransaction["AsyncConnection"]):
@@ -239,5 +239,5 @@ class AsyncTransaction(BaseTransaction["AsyncConnection"]):
 
         return False
 
-    async def _execute(self, commands: List[str]) -> None:
-        await self._conn._exec_command("; ".join(commands))
+    async def _execute(self, commands: List[bytes]) -> None:
+        await self._conn._exec_command(b"; ".join(commands))