]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add raw query support with PostgreSQL native placeholders
authorJoel Jakobsson <joel@compiler.org>
Tue, 9 May 2023 00:49:03 +0000 (02:49 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 15 Aug 2023 15:29:03 +0000 (16:29 +0100)
This commit introduces support for raw queries with PostgreSQL's native
placeholders ($1, $2, etc.) in psycopg3. By setting the use_raw_query attribute
to True in a custom cursor class, users can enable the use of raw queries with
native placeholders.

The code demonstrates how to create a custom RawQueryCursor class that sets the
use_raw_query attribute to True. This custom cursor class can be set as the
cursor_factory when connecting to the database, allowing users to choose between
PostgreSQL's native placeholders or the standard %s placeholder in their queries.
The code also demonstrates how both styles of placeholders can coexist. Test
cases are included to verify the correct behavior of the new feature.

docs/advanced/cursors.rst
psycopg/psycopg/__init__.py
psycopg/psycopg/_queries.py
psycopg/psycopg/raw_cursor.py [new file with mode: 0644]
tests/test_cursor.py
tests/test_cursor_async.py

index 954d665dfbf68afbff131250da2b8b5f1736494a..f5e369ed0964c409094b6425fc51ea8195929867 100644 (file)
@@ -190,3 +190,44 @@ directly call the fetch methods, skipping the `~ServerCursor.execute()` call:
     # no cur.execute()
     for record in cur:  # or cur.fetchone(), cur.fetchmany()...
         # do something with record
+
+.. _raw-query-cursors:
+
+Raw Query Cursors
+------------------
+
+.. versionadded:: 3.2
+
+Raw query cursors allow users to use PostgreSQL native placeholders ($1, $2,
+etc.) in their queries instead of the standard %s placeholder. This can be
+useful when it's desirable to pass the query unmodified to PostgreSQL and rely
+on PostgreSQL's placeholder functionality, such as when dealing with a very
+complex query containing %s inside strings, dollar-quoted strings or elsewhere.
+
+One important note is that raw query cursors only accept positional arguments
+in the form of a list or tuple. This means you cannot use named arguments
+(i.e., dictionaries).
+
+There are two ways to use raw query cursors:
+
+1. Using the cursor factory:
+
+.. code:: python
+
+    from psycopg import connect, RawCursor
+
+    with connect(dsn, cursor_factory=RawCursor) as conn:
+        with conn.cursor() as cur:
+            cur.execute("SELECT $1, $2", [1, "Hello"])
+            assert cur.fetchone() == (1, "Hello")
+
+2. Instantiating a cursor:
+
+.. code:: python
+
+    from psycopg import connect, RawCursor
+
+    with connect(dsn) as conn:
+        with RawCursor(conn) as cur:
+            cur.execute("SELECT $1, $2", [1, "Hello"])
+            assert cur.fetchone() == (1, "Hello")
index 072edef5233a41ab6a49209cabae20542571744b..ff7d398a1d0494c424e6845bb7179cec4f76ef18 100644 (file)
@@ -24,6 +24,7 @@ from .transaction import Rollback, Transaction, AsyncTransaction
 from .cursor_async import AsyncCursor
 from .server_cursor import AsyncServerCursor, ServerCursor
 from .client_cursor import AsyncClientCursor, ClientCursor
+from .raw_cursor import AsyncRawCursor, RawCursor
 from .connection_async import AsyncConnection
 
 from . import dbapi20
@@ -64,6 +65,7 @@ __all__ = [
     "AsyncCopy",
     "AsyncCursor",
     "AsyncPipeline",
+    "AsyncRawCursor",
     "AsyncServerCursor",
     "AsyncTransaction",
     "BaseConnection",
@@ -76,6 +78,7 @@ __all__ = [
     "IsolationLevel",
     "Notify",
     "Pipeline",
+    "RawCursor",
     "Rollback",
     "ServerCursor",
     "Transaction",
index 2a7554c30cc2da6c36c4df7b4af264a631156796..d9bbaa8418ac0d35f3b658f5815ecb29fc7e94b9 100644 (file)
@@ -72,7 +72,7 @@ class PostgresQuery:
                 self._want_formats,
                 self._order,
                 self._parts,
-            ) = _query2pg(bquery, self._encoding)
+            ) = self.query2pg(bquery, self._encoding)
         else:
             self.query = bquery
             self._want_formats = self._order = None
@@ -86,8 +86,11 @@ class PostgresQuery:
         This method updates `params` and `types`.
         """
         if vars is not None:
-            params = _validate_and_reorder_params(self._parts, vars, self._order)
-            assert self._want_formats is not None
+            params = self.validate_and_reorder_params(self._parts, vars, self._order)
+            num_params = len(params)
+            if self._want_formats is None:
+                self._want_formats = [PyFormat.AUTO] * num_params
+            assert len(self._want_formats) == num_params
             self.params = self._tx.dump_sequence(params, self._want_formats)
             self.types = self._tx.types or ()
             self.formats = self._tx.formats
@@ -96,6 +99,110 @@ class PostgresQuery:
             self.types = ()
             self.formats = None
 
+    @staticmethod
+    def is_params_sequence(vars: Params) -> bool:
+        # Try concrete types, then abstract types
+        t = type(vars)
+        if t is list or t is tuple:
+            sequence = True
+        elif t is dict:
+            sequence = False
+        elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
+            sequence = True
+        elif isinstance(vars, Mapping):
+            sequence = False
+        else:
+            raise TypeError(
+                "query parameters should be a sequence or a mapping,"
+                f" got {type(vars).__name__}"
+            )
+        return sequence
+
+    @staticmethod
+    def validate_and_reorder_params(
+        parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+    ) -> Sequence[Any]:
+        """
+        Verify the compatibility between a query and a set of params.
+        """
+        sequence = PostgresQuery.is_params_sequence(vars)
+
+        if sequence:
+            if len(vars) != len(parts) - 1:
+                raise e.ProgrammingError(
+                    f"the query has {len(parts) - 1} placeholders but"
+                    f" {len(vars)} parameters were passed"
+                )
+            if vars and not isinstance(parts[0].item, int):
+                raise TypeError("named placeholders require a mapping of parameters")
+            return vars  # type: ignore[return-value]
+
+        else:
+            if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
+                raise TypeError(
+                    "positional placeholders (%s) require a sequence of parameters"
+                )
+            try:
+                return [
+                    vars[item] for item in order or ()  # type: ignore[call-overload]
+                ]
+            except KeyError:
+                raise e.ProgrammingError(
+                    "query parameter missing:"
+                    f" {', '.join(sorted(i for i in order or () if i not in vars))}"
+                )
+
+    @staticmethod
+    @lru_cache()
+    def query2pg(
+        query: bytes, encoding: str
+    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
+        """
+        Convert Python query and params into something Postgres understands.
+
+        - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
+        format (``$1``, ``$2``)
+        - placeholders can be %s, %t, or %b (auto, text or binary)
+        - return ``query`` (bytes), ``formats`` (list of formats) ``order``
+        (sequence of names used in the query, in the position they appear)
+        ``parts`` (splits of queries and placeholders).
+        """
+        parts = _split_query(query, encoding)
+        order: Optional[List[str]] = None
+        chunks: List[bytes] = []
+        formats = []
+
+        if isinstance(parts[0].item, int):
+            for part in parts[:-1]:
+                assert isinstance(part.item, int)
+                chunks.append(part.pre)
+                chunks.append(b"$%d" % (part.item + 1))
+                formats.append(part.format)
+
+        elif isinstance(parts[0].item, str):
+            seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+            order = []
+            for part in parts[:-1]:
+                assert isinstance(part.item, str)
+                chunks.append(part.pre)
+                if part.item not in seen:
+                    ph = b"$%d" % (len(seen) + 1)
+                    seen[part.item] = (ph, part.format)
+                    order.append(part.item)
+                    chunks.append(ph)
+                    formats.append(part.format)
+                else:
+                    if seen[part.item][1] != part.format:
+                        raise e.ProgrammingError(
+                            f"placeholder '{part.item}' cannot have different formats"
+                        )
+                    chunks.append(seen[part.item][0])
+
+        # last part
+        chunks.append(parts[-1].pre)
+
+        return b"".join(chunks), formats, order, parts
+
 
 class PostgresClientQuery(PostgresQuery):
     """
@@ -119,7 +226,7 @@ class PostgresClientQuery(PostgresQuery):
             bquery = query
 
         if vars is not None:
-            (self.template, self._order, self._parts) = _query2pg_client(
+            (self.template, _, self._order, self._parts) = PostgresClientQuery.query2pg(
                 bquery, self._encoding
             )
         else:
@@ -135,7 +242,7 @@ class PostgresClientQuery(PostgresQuery):
         This method updates `params` and `types`.
         """
         if vars is not None:
-            params = _validate_and_reorder_params(self._parts, vars, self._order)
+            params = self.validate_and_reorder_params(self._parts, vars, self._order)
             self.params = tuple(
                 self._tx.as_literal(p) if p is not None else b"NULL" for p in params
             )
@@ -143,140 +250,43 @@ class PostgresClientQuery(PostgresQuery):
         else:
             self.params = None
 
-
-@lru_cache()
-def _query2pg(
-    query: bytes, encoding: str
-) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]:
-    """
-    Convert Python query and params into something Postgres understands.
-
-    - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres
-      format (``$1``, ``$2``)
-    - placeholders can be %s, %t, or %b (auto, text or binary)
-    - return ``query`` (bytes), ``formats`` (list of formats) ``order``
-      (sequence of names used in the query, in the position they appear)
-      ``parts`` (splits of queries and placeholders).
-    """
-    parts = _split_query(query, encoding)
-    order: Optional[List[str]] = None
-    chunks: List[bytes] = []
-    formats = []
-
-    if isinstance(parts[0].item, int):
-        for part in parts[:-1]:
-            assert isinstance(part.item, int)
-            chunks.append(part.pre)
-            chunks.append(b"$%d" % (part.item + 1))
-            formats.append(part.format)
-
-    elif isinstance(parts[0].item, str):
-        seen: Dict[str, Tuple[bytes, PyFormat]] = {}
-        order = []
-        for part in parts[:-1]:
-            assert isinstance(part.item, str)
-            chunks.append(part.pre)
-            if part.item not in seen:
-                ph = b"$%d" % (len(seen) + 1)
-                seen[part.item] = (ph, part.format)
-                order.append(part.item)
-                chunks.append(ph)
-                formats.append(part.format)
-            else:
-                if seen[part.item][1] != part.format:
-                    raise e.ProgrammingError(
-                        f"placeholder '{part.item}' cannot have different formats"
-                    )
-                chunks.append(seen[part.item][0])
-
-    # last part
-    chunks.append(parts[-1].pre)
-
-    return b"".join(chunks), formats, order, parts
-
-
-@lru_cache()
-def _query2pg_client(
-    query: bytes, encoding: str
-) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]:
-    """
-    Convert Python query and params into a template to perform client-side binding
-    """
-    parts = _split_query(query, encoding, collapse_double_percent=False)
-    order: Optional[List[str]] = None
-    chunks: List[bytes] = []
-
-    if isinstance(parts[0].item, int):
-        for part in parts[:-1]:
-            assert isinstance(part.item, int)
-            chunks.append(part.pre)
-            chunks.append(b"%s")
-
-    elif isinstance(parts[0].item, str):
-        seen: Dict[str, Tuple[bytes, PyFormat]] = {}
-        order = []
-        for part in parts[:-1]:
-            assert isinstance(part.item, str)
-            chunks.append(part.pre)
-            if part.item not in seen:
-                ph = b"%s"
-                seen[part.item] = (ph, part.format)
-                order.append(part.item)
-                chunks.append(ph)
-            else:
-                chunks.append(seen[part.item][0])
-                order.append(part.item)
-
-    # last part
-    chunks.append(parts[-1].pre)
-
-    return b"".join(chunks), order, parts
-
-
-def _validate_and_reorder_params(
-    parts: List[QueryPart], vars: Params, order: Optional[List[str]]
-) -> Sequence[Any]:
-    """
-    Verify the compatibility between a query and a set of params.
-    """
-    # Try concrete types, then abstract types
-    t = type(vars)
-    if t is list or t is tuple:
-        sequence = True
-    elif t is dict:
-        sequence = False
-    elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)):
-        sequence = True
-    elif isinstance(vars, Mapping):
-        sequence = False
-    else:
-        raise TypeError(
-            "query parameters should be a sequence or a mapping,"
-            f" got {type(vars).__name__}"
-        )
-
-    if sequence:
-        if len(vars) != len(parts) - 1:
-            raise e.ProgrammingError(
-                f"the query has {len(parts) - 1} placeholders but"
-                f" {len(vars)} parameters were passed"
-            )
-        if vars and not isinstance(parts[0].item, int):
-            raise TypeError("named placeholders require a mapping of parameters")
-        return vars  # type: ignore[return-value]
-
-    else:
-        if vars and len(parts) > 1 and not isinstance(parts[0][1], str):
-            raise TypeError(
-                "positional placeholders (%s) require a sequence of parameters"
-            )
-        try:
-            return [vars[item] for item in order or ()]  # type: ignore[call-overload]
-        except KeyError:
-            raise e.ProgrammingError(
-                "query parameter missing:"
-                f" {', '.join(sorted(i for i in order or () if i not in vars))}"
-            )
+    @staticmethod
+    @lru_cache()
+    def query2pg(
+        query: bytes, encoding: str
+    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
+        """
+        Convert Python query and params into a template to perform client-side binding
+        """
+        parts = _split_query(query, encoding, collapse_double_percent=False)
+        order: Optional[List[str]] = None
+        chunks: List[bytes] = []
+
+        if isinstance(parts[0].item, int):
+            for part in parts[:-1]:
+                assert isinstance(part.item, int)
+                chunks.append(part.pre)
+                chunks.append(b"%s")
+
+        elif isinstance(parts[0].item, str):
+            seen: Dict[str, Tuple[bytes, PyFormat]] = {}
+            order = []
+            for part in parts[:-1]:
+                assert isinstance(part.item, str)
+                chunks.append(part.pre)
+                if part.item not in seen:
+                    ph = b"%s"
+                    seen[part.item] = (ph, part.format)
+                    order.append(part.item)
+                    chunks.append(ph)
+                else:
+                    chunks.append(seen[part.item][0])
+                    order.append(part.item)
+
+        # last part
+        chunks.append(parts[-1].pre)
+
+        return b"".join(chunks), None, order, parts
 
 
 _re_placeholder = re.compile(
diff --git a/psycopg/psycopg/raw_cursor.py b/psycopg/psycopg/raw_cursor.py
new file mode 100644 (file)
index 0000000..1a584c4
--- /dev/null
@@ -0,0 +1,61 @@
+"""
+psycopg raw queries cursors
+"""
+
+# Copyright (C) 2023 The Psycopg Team
+
+from typing import Any, Optional, Sequence, Tuple, List, TYPE_CHECKING
+from functools import lru_cache
+
+from ._queries import PostgresQuery, QueryPart
+
+from .abc import ConnectionType, Query, Params
+from .rows import Row
+from .cursor import BaseCursor, Cursor
+from ._enums import PyFormat
+from .cursor_async import AsyncCursor
+
+if TYPE_CHECKING:
+    from .connection import Connection  # noqa: F401
+    from .connection_async import AsyncConnection  # noqa: F401
+
+
+class RawPostgresQuery(PostgresQuery):
+    @staticmethod
+    @lru_cache()
+    def query2pg(
+        query: bytes, encoding: str
+    ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]:
+        """
+        Noop; Python raw query is already in the format Postgres understands.
+        """
+        return query, None, None, []
+
+    @staticmethod
+    def validate_and_reorder_params(
+        parts: List[QueryPart], vars: Params, order: Optional[List[str]]
+    ) -> Sequence[Any]:
+        """
+        Verify the compatibility; params must be a sequence for raw query.
+        """
+        sequence = PostgresQuery.is_params_sequence(vars)
+        if not sequence:
+            raise TypeError("raw query require a sequence of parameters")
+        return vars  # type: ignore[return-value]
+
+
+class RawCursorMixin(BaseCursor[ConnectionType, Row]):
+    def _convert_query(
+        self, query: Query, params: Optional[Params] = None
+    ) -> PostgresQuery:
+        pgq = RawPostgresQuery(self._tx)
+        pgq.convert(query, params)
+        return pgq
+
+
+class RawCursor(RawCursorMixin["Connection[Any]", Row], Cursor[Row]):
+    __module__ = "psycopg"
+
+
+class AsyncRawCursor(RawCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]):
+    __module__ = "psycopg"
index b7d3c051a678fa870919d2b26cd4ccc40245c5ec..f2dcb6ea3a76071dca85ab6a208abba21a8354d7 100644 (file)
@@ -21,7 +21,7 @@ from .utils import gc_collect, raiseif
 from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision
 
 
-@pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor])
+@pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor])
 def conn(conn, request):
     conn.cursor_factory = request.param
     return conn
index d85a929d949ce5b1fe271a42d52e1d7ea6474528..1a05d966a6e6b42621cbb7dd02335ce7e4f90962 100644 (file)
@@ -16,7 +16,9 @@ from .fix_crdb import crdb_encoding
 execmany = execmany  # avoid F811 underneath
 
 
-@pytest.fixture(params=[psycopg.AsyncCursor, psycopg.AsyncClientCursor])
+@pytest.fixture(
+    params=[psycopg.AsyncCursor, psycopg.AsyncClientCursor, psycopg.AsyncRawCursor]
+)
 def aconn(aconn, request, anyio_backend):
     aconn.cursor_factory = request.param
     return aconn