]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added helper object to convert Python query into a Postgres query
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 12 Apr 2020 08:44:43 +0000 (20:44 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 12 Apr 2020 08:54:38 +0000 (20:54 +1200)
The object helps keeping some state between preparing and executing,
useful in executemany so preparation can happen only once.

psycopg3/adapt.py
psycopg3/cursor.py
psycopg3/pq/pq_ctypes.py
psycopg3/utils/queries.py

index 6ae8697430050196d754dff763cb7e317bc6be7f..2702b6e37217a32297d57bce5a1e0eb504a73a85 100644 (file)
@@ -167,6 +167,7 @@ class Transformer:
 
     def __init__(self, context: AdaptContext = None):
         self.connection: Optional[BaseConnection]
+        self.codec: codecs.CodecInfo
         self.dumpers: DumpersMap
         self.loaders: LoadersMap
         self._dumpers_maps: List[DumpersMap] = []
@@ -186,6 +187,7 @@ class Transformer:
     def _setup_context(self, context: AdaptContext) -> None:
         if context is None:
             self.connection = None
+            self.codec = codecs.lookup("utf8")
             self.dumpers = {}
             self.loaders = {}
             self._dumpers_maps = [self.dumpers]
@@ -195,6 +197,7 @@ class Transformer:
             # A transformer created from a transformers: usually it happens
             # for nested types: share the entire state of the parent
             self.connection = context.connection
+            self.codec = context.codec
             self.dumpers = context.dumpers
             self.loaders = context.loaders
             self._dumpers_maps.extend(context._dumpers_maps)
@@ -204,6 +207,7 @@ class Transformer:
 
         elif isinstance(context, BaseCursor):
             self.connection = context.connection
+            self.codec = context.connection.codec
             self.dumpers = {}
             self._dumpers_maps.extend(
                 (self.dumpers, context.dumpers, self.connection.dumpers)
@@ -215,6 +219,7 @@ class Transformer:
 
         elif isinstance(context, BaseConnection):
             self.connection = context
+            self.codec = context.codec
             self.dumpers = {}
             self._dumpers_maps.extend((self.dumpers, context.dumpers))
             self.loaders = {}
index 6dfad11521ef8bf1ff42280f8e98ae36e3993fd8..dcac4cbf190b51db348d828e4ee89d8da707c3ad 100644 (file)
@@ -6,18 +6,17 @@ psycopg3 cursor objects
 
 import codecs
 from operator import attrgetter
-from typing import Any, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING
+from typing import Any, List, Optional, Sequence, Tuple, TYPE_CHECKING
 
 from . import errors as e
 from . import pq
 from . import generators
-from .utils.queries import query2pg, reorder_params
+from .utils.queries import PostgresQuery
 from .utils.typing import Query, Params
 
 if TYPE_CHECKING:
     from .adapt import DumpersMap, LoadersMap, Transformer
     from .connection import BaseConnection, Connection, AsyncConnection
-    from .generators import PQGen
 
 
 class Column(Sequence[Any]):
@@ -135,12 +134,7 @@ class BaseCursor:
         # no-op
         pass
 
-    def _execute_send(
-        self, query: Query, vars: Optional[Params]
-    ) -> "PQGen[List[pq.PGresult]]":
-        """
-        Implement part of execute() before waiting common to sync and async
-        """
+    def _start_query(self) -> None:
         from .adapt import Transformer
 
         if self.closed:
@@ -158,38 +152,31 @@ class BaseCursor:
         self._reset()
         self._transformer = Transformer(self)
 
-        codec = self.connection.codec
-
-        if isinstance(query, str):
-            query = codec.encode(query)[0]
+    def _execute_send(self, query: Query, vars: Optional[Params]) -> None:
+        """
+        Implement part of execute() before waiting common to sync and async
+        """
+        pgq = PostgresQuery(self._transformer)
+        pgq.convert(query, vars)
 
-        # process %% -> % only if there are paramters, even if empty list
-        if vars is not None:
-            query, formats, order = query2pg(query, vars, codec)
-        if vars:
-            if order is not None:
-                assert isinstance(vars, Mapping)
-                vars = reorder_params(vars, order)
-            assert isinstance(vars, Sequence)
-            params, types = self._transformer.dump_sequence(vars, formats)
+        if pgq.params:
             self.connection.pgconn.send_query_params(
-                query,
-                params,
-                param_formats=formats,
-                param_types=types,
+                pgq.query,
+                pgq.params,
+                param_formats=pgq.formats,
+                param_types=pgq.types,
                 result_format=pq.Format(self.binary),
             )
+
         else:
             # if we don't have to, let's use exec_ as it can run more than
             # one query in one go
             if self.binary:
                 self.connection.pgconn.send_query_params(
-                    query, (), result_format=pq.Format(self.binary)
+                    pgq.query, None, result_format=pq.Format(self.binary)
                 )
             else:
-                self.connection.pgconn.send_query(query)
-
-        return generators.execute(self.connection.pgconn)
+                self.connection.pgconn.send_query(pgq.query)
 
     def _execute_results(self, results: Sequence[pq.PGresult]) -> None:
         """
@@ -221,72 +208,26 @@ class BaseCursor:
 
     def _send_prepare(
         self, name: bytes, query: Query, vars: Optional[Params]
-    ) -> "PQGen[List[pq.PGresult]]":
+    ) -> PostgresQuery:
         """
         Implement part of execute() before waiting common to sync and async
         """
-        from .adapt import Transformer
-
-        if self.closed:
-            raise e.OperationalError("the cursor is closed")
-
-        if self.connection.closed:
-            raise e.OperationalError("the connection is closed")
-
-        if self.connection.status != self.connection.ConnStatus.OK:
-            raise e.InterfaceError(
-                f"cannot execute operations: the connection is"
-                f" in status {self.connection.status}"
-            )
-
-        self._reset()
-        self._transformer = Transformer(self)
-
-        codec = self.connection.codec
-
-        if isinstance(query, str):
-            query = codec.encode(query)[0]
+        pgq = PostgresQuery(self._transformer)
+        pgq.convert(query, vars)
 
-        # process %% -> % only if there are paramters, even if empty list
-        if vars is not None:
-            query, formats, order = query2pg(query, vars, codec)
-
-        if order is not None:
-            assert isinstance(vars, Mapping)
-            vars = reorder_params(vars, order)
-        assert isinstance(vars, Sequence)
-        params, types = self._transformer.dump_sequence(vars, formats)
         self.connection.pgconn.send_prepare(
-            name, query, param_types=types,
+            name, pgq.query, param_types=pgq.types,
         )
-        self._order = order
-        self._formats = formats
-        return generators.execute(self.connection.pgconn)
 
-    def _send_query_prepared(
-        self, name: bytes, vars: Optional[Params]
-    ) -> "PQGen[List[pq.PGresult]]":
-        if self.connection.closed:
-            raise e.OperationalError("the connection is closed")
-
-        if self.connection.status != self.connection.ConnStatus.OK:
-            raise e.InterfaceError(
-                f"cannot execute operations: the connection is"
-                f" in status {self.connection.status}"
-            )
+        return pgq
 
-        if self._order is not None:
-            assert isinstance(vars, Mapping)
-            vars = reorder_params(vars, self._order)
-        assert isinstance(vars, Sequence)
-        params, types = self._transformer.dump_sequence(vars, self._formats)
+    def _send_query_prepared(self, name: bytes, pgq: PostgresQuery) -> None:
         self.connection.pgconn.send_query_prepared(
             name,
-            params,
-            param_formats=self._formats,
+            pgq.params,
+            param_formats=pgq.formats,
             result_format=pq.Format(self.binary),
         )
-        return generators.execute(self.connection.pgconn)
 
     def nextset(self) -> Optional[bool]:
         self._iresult += 1
@@ -324,7 +265,9 @@ class Cursor(BaseCursor):
 
     def execute(self, query: Query, vars: Optional[Params] = None) -> "Cursor":
         with self.connection.lock:
-            gen = self._execute_send(query, vars)
+            self._start_query()
+            self._execute_send(query, vars)
+            gen = generators.execute(self.connection.pgconn)
             results = self.connection.wait(gen)
             self._execute_results(results)
         return self
@@ -333,14 +276,19 @@ class Cursor(BaseCursor):
         self, query: Query, vars_seq: Sequence[Params]
     ) -> "Cursor":
         with self.connection.lock:
+            self._start_query()
             for i, vars in enumerate(vars_seq):
                 if i == 0:
-                    gen = self._send_prepare(b"", query, vars)
+                    pgq = self._send_prepare(b"", query, vars)
+                    gen = generators.execute(self.connection.pgconn)
                     (result,) = self.connection.wait(gen)
                     if result.status == self.ExecStatus.FATAL_ERROR:
                         raise e.error_from_result(result)
+                else:
+                    pgq.dump(vars)
 
-                gen = self._send_query_prepared(b"", vars)
+                self._send_query_prepared(b"", pgq)
+                gen = generators.execute(self.connection.pgconn)
                 (result,) = self.connection.wait(gen)
                 self._execute_results((result,))
 
@@ -388,7 +336,9 @@ class AsyncCursor(BaseCursor):
         self, query: Query, vars: Optional[Params] = None
     ) -> "AsyncCursor":
         async with self.connection.lock:
-            gen = self._execute_send(query, vars)
+            self._start_query()
+            self._execute_send(query, vars)
+            gen = generators.execute(self.connection.pgconn)
             results = await self.connection.wait(gen)
             self._execute_results(results)
         return self
index f36cf9f34b7c4d880532793a5803be4e24a2011a..ae0ce825318954330fc3f2e98612bb27b5553352 100644 (file)
@@ -201,7 +201,7 @@ class PGconn:
     def exec_params(
         self,
         command: bytes,
-        param_values: Sequence[Optional[bytes]],
+        param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
@@ -218,7 +218,7 @@ class PGconn:
     def send_query_params(
         self,
         command: bytes,
-        param_values: Sequence[Optional[bytes]],
+        param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
@@ -257,7 +257,7 @@ class PGconn:
     def send_query_prepared(
         self,
         name: bytes,
-        param_values: Sequence[Optional[bytes]],
+        param_values: Optional[Sequence[Optional[bytes]]],
         param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
     ) -> None:
@@ -277,7 +277,7 @@ class PGconn:
     def _query_params_args(
         self,
         command: bytes,
-        param_values: Sequence[Optional[bytes]],
+        param_values: Optional[Sequence[Optional[bytes]]],
         param_types: Optional[Sequence[int]] = None,
         param_formats: Optional[Sequence[Format]] = None,
         result_format: Format = Format.TEXT,
@@ -285,10 +285,10 @@ class PGconn:
         if not isinstance(command, bytes):
             raise TypeError(f"bytes expected, got {type(command)} instead")
 
-        nparams = len(param_values)
+        nparams = len(param_values) if param_values is not None else 0
         aparams: Optional[Array[c_char_p]] = None
         alenghts: Optional[Array[c_int]] = None
-        if nparams:
+        if param_values:
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
                 *(len(p) if p is not None else 0 for p in param_values)
@@ -356,17 +356,17 @@ class PGconn:
     def exec_prepared(
         self,
         name: bytes,
-        param_values: Sequence[bytes],
+        param_values: Optional[Sequence[bytes]],
         param_formats: Optional[Sequence[int]] = None,
         result_format: int = 0,
     ) -> "PGresult":
         if not isinstance(name, bytes):
             raise TypeError(f"'name' must be bytes, got {type(name)} instead")
 
-        nparams = len(param_values)
+        nparams = len(param_values) if param_values is not None else 0
         aparams: Optional[Array[c_char_p]] = None
         alenghts: Optional[Array[c_int]] = None
-        if nparams:
+        if param_values:
             aparams = (c_char_p * nparams)(*param_values)
             alenghts = (c_int * nparams)(
                 *(len(p) if p is not None else 0 for p in param_values)
index 5fcbf8d46b9c55cd6e5fe35b76fe583a0f8ae779..41c7fa88e49af79b5eac9e455237a9cb6ba3912a 100644 (file)
@@ -7,11 +7,66 @@ Utility module to manipulate queries
 import re
 from codecs import CodecInfo
 from typing import Any, Dict, List, Mapping, Match, NamedTuple, Optional
-from typing import Sequence, Tuple, Union
+from typing import Sequence, Tuple, Union, TYPE_CHECKING
 
 from .. import errors as e
 from ..pq import Format
-from .typing import Params
+from .typing import Query, Params
+
+if TYPE_CHECKING:
+    from ..adapt import Transformer
+
+
+class PostgresQuery:
+    """
+    Helper to convert a Python query and parameters into Postgres format.
+    """
+
+    def __init__(self, transformer: "Transformer"):
+        self._tx = transformer
+        self.query: bytes = b""
+        self.params: Optional[List[Optional[bytes]]] = None
+        self.types: Optional[List[int]] = None
+        self.formats: Optional[List[Format]] = None
+
+        self._order: Optional[List[str]] = None
+
+    def convert(self, query: Query, vars: Optional[Params]) -> None:
+        """
+        Set up the query and parameters to convert.
+
+        The results of this function can be obtained accessing the object
+        attributes (`query`, `params`, `types`, `formats`).
+        """
+        codec = self._tx.codec
+        if isinstance(query, str):
+            query = codec.encode(query)[0]
+        if vars is not None:
+            self.query, self.formats, self._order = query2pg(
+                query, vars, codec
+            )
+        else:
+            self.query = query
+            self.formats = self._order = None
+
+        self.dump(vars)
+
+    def dump(self, vars: Optional[Params]) -> None:
+        """
+        Process a new set of variables on the same query as before.
+
+        This method updates `params` and `types`.
+        """
+        if vars:
+            if self._order is not None:
+                assert isinstance(vars, Mapping)
+                vars = reorder_params(vars, self._order)
+            assert isinstance(vars, Sequence)
+            self.params, self.types = self._tx.dump_sequence(
+                vars, self.formats or ()
+            )
+        else:
+            self.params = self.types = None
 
 
 def query2pg(