From: Joel Jakobsson Date: Tue, 9 May 2023 00:49:03 +0000 (+0200) Subject: Add raw query support with PostgreSQL native placeholders X-Git-Tag: pool-3.2.0~66^2~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=5babce4c498566730a04eff12d7e39e9e072ad02;p=thirdparty%2Fpsycopg.git Add raw query support with PostgreSQL native placeholders This commit introduces support for raw queries with PostgreSQL's native placeholders ($1, $2, etc.) in psycopg3. By setting the use_raw_query attribute to True in a custom cursor class, users can enable the use of raw queries with native placeholders. The code demonstrates how to create a custom RawQueryCursor class that sets the use_raw_query attribute to True. This custom cursor class can be set as the cursor_factory when connecting to the database, allowing users to choose between PostgreSQL's native placeholders or the standard %s placeholder in their queries. The code also demonstrates how both styles of placeholders can coexist. Test cases are included to verify the correct behavior of the new feature. --- diff --git a/docs/advanced/cursors.rst b/docs/advanced/cursors.rst index 954d665df..f5e369ed0 100644 --- a/docs/advanced/cursors.rst +++ b/docs/advanced/cursors.rst @@ -190,3 +190,44 @@ directly call the fetch methods, skipping the `~ServerCursor.execute()` call: # no cur.execute() for record in cur: # or cur.fetchone(), cur.fetchmany()... # do something with record + +.. _raw-query-cursors: + +Raw Query Cursors +------------------ + +.. versionadded:: 3.2 + +Raw query cursors allow users to use PostgreSQL native placeholders ($1, $2, +etc.) in their queries instead of the standard %s placeholder. This can be +useful when it's desirable to pass the query unmodified to PostgreSQL and rely +on PostgreSQL's placeholder functionality, such as when dealing with a very +complex query containing %s inside strings, dollar-quoted strings or elsewhere. + +One important note is that raw query cursors only accept positional arguments +in the form of a list or tuple. This means you cannot use named arguments +(i.e., dictionaries). + +There are two ways to use raw query cursors: + +1. Using the cursor factory: + +.. code:: python + + from psycopg import connect, RawCursor + + with connect(dsn, cursor_factory=RawCursor) as conn: + with conn.cursor() as cur: + cur.execute("SELECT $1, $2", [1, "Hello"]) + assert cur.fetchone() == (1, "Hello") + +2. Instantiating a cursor: + +.. code:: python + + from psycopg import connect, RawCursor + + with connect(dsn) as conn: + with RawCursor(conn) as cur: + cur.execute("SELECT $1, $2", [1, "Hello"]) + assert cur.fetchone() == (1, "Hello") diff --git a/psycopg/psycopg/__init__.py b/psycopg/psycopg/__init__.py index 072edef52..ff7d398a1 100644 --- a/psycopg/psycopg/__init__.py +++ b/psycopg/psycopg/__init__.py @@ -24,6 +24,7 @@ from .transaction import Rollback, Transaction, AsyncTransaction from .cursor_async import AsyncCursor from .server_cursor import AsyncServerCursor, ServerCursor from .client_cursor import AsyncClientCursor, ClientCursor +from .raw_cursor import AsyncRawCursor, RawCursor from .connection_async import AsyncConnection from . import dbapi20 @@ -64,6 +65,7 @@ __all__ = [ "AsyncCopy", "AsyncCursor", "AsyncPipeline", + "AsyncRawCursor", "AsyncServerCursor", "AsyncTransaction", "BaseConnection", @@ -76,6 +78,7 @@ __all__ = [ "IsolationLevel", "Notify", "Pipeline", + "RawCursor", "Rollback", "ServerCursor", "Transaction", diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index 2a7554c30..d9bbaa841 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -72,7 +72,7 @@ class PostgresQuery: self._want_formats, self._order, self._parts, - ) = _query2pg(bquery, self._encoding) + ) = self.query2pg(bquery, self._encoding) else: self.query = bquery self._want_formats = self._order = None @@ -86,8 +86,11 @@ class PostgresQuery: This method updates `params` and `types`. """ if vars is not None: - params = _validate_and_reorder_params(self._parts, vars, self._order) - assert self._want_formats is not None + params = self.validate_and_reorder_params(self._parts, vars, self._order) + num_params = len(params) + if self._want_formats is None: + self._want_formats = [PyFormat.AUTO] * num_params + assert len(self._want_formats) == num_params self.params = self._tx.dump_sequence(params, self._want_formats) self.types = self._tx.types or () self.formats = self._tx.formats @@ -96,6 +99,110 @@ class PostgresQuery: self.types = () self.formats = None + @staticmethod + def is_params_sequence(vars: Params) -> bool: + # Try concrete types, then abstract types + t = type(vars) + if t is list or t is tuple: + sequence = True + elif t is dict: + sequence = False + elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): + sequence = True + elif isinstance(vars, Mapping): + sequence = False + else: + raise TypeError( + "query parameters should be a sequence or a mapping," + f" got {type(vars).__name__}" + ) + return sequence + + @staticmethod + def validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] + ) -> Sequence[Any]: + """ + Verify the compatibility between a query and a set of params. + """ + sequence = PostgresQuery.is_params_sequence(vars) + + if sequence: + if len(vars) != len(parts) - 1: + raise e.ProgrammingError( + f"the query has {len(parts) - 1} placeholders but" + f" {len(vars)} parameters were passed" + ) + if vars and not isinstance(parts[0].item, int): + raise TypeError("named placeholders require a mapping of parameters") + return vars # type: ignore[return-value] + + else: + if vars and len(parts) > 1 and not isinstance(parts[0][1], str): + raise TypeError( + "positional placeholders (%s) require a sequence of parameters" + ) + try: + return [ + vars[item] for item in order or () # type: ignore[call-overload] + ] + except KeyError: + raise e.ProgrammingError( + "query parameter missing:" + f" {', '.join(sorted(i for i in order or () if i not in vars))}" + ) + + @staticmethod + @lru_cache() + def query2pg( + query: bytes, encoding: str + ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into something Postgres understands. + + - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres + format (``$1``, ``$2``) + - placeholders can be %s, %t, or %b (auto, text or binary) + - return ``query`` (bytes), ``formats`` (list of formats) ``order`` + (sequence of names used in the query, in the position they appear) + ``parts`` (splits of queries and placeholders). + """ + parts = _split_query(query, encoding) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + formats = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"$%d" % (part.item + 1)) + formats.append(part.format) + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"$%d" % (len(seen) + 1) + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + formats.append(part.format) + else: + if seen[part.item][1] != part.format: + raise e.ProgrammingError( + f"placeholder '{part.item}' cannot have different formats" + ) + chunks.append(seen[part.item][0]) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), formats, order, parts + class PostgresClientQuery(PostgresQuery): """ @@ -119,7 +226,7 @@ class PostgresClientQuery(PostgresQuery): bquery = query if vars is not None: - (self.template, self._order, self._parts) = _query2pg_client( + (self.template, _, self._order, self._parts) = PostgresClientQuery.query2pg( bquery, self._encoding ) else: @@ -135,7 +242,7 @@ class PostgresClientQuery(PostgresQuery): This method updates `params` and `types`. """ if vars is not None: - params = _validate_and_reorder_params(self._parts, vars, self._order) + params = self.validate_and_reorder_params(self._parts, vars, self._order) self.params = tuple( self._tx.as_literal(p) if p is not None else b"NULL" for p in params ) @@ -143,140 +250,43 @@ class PostgresClientQuery(PostgresQuery): else: self.params = None - -@lru_cache() -def _query2pg( - query: bytes, encoding: str -) -> Tuple[bytes, List[PyFormat], Optional[List[str]], List[QueryPart]]: - """ - Convert Python query and params into something Postgres understands. - - - Convert Python placeholders (``%s``, ``%(name)s``) into Postgres - format (``$1``, ``$2``) - - placeholders can be %s, %t, or %b (auto, text or binary) - - return ``query`` (bytes), ``formats`` (list of formats) ``order`` - (sequence of names used in the query, in the position they appear) - ``parts`` (splits of queries and placeholders). - """ - parts = _split_query(query, encoding) - order: Optional[List[str]] = None - chunks: List[bytes] = [] - formats = [] - - if isinstance(parts[0].item, int): - for part in parts[:-1]: - assert isinstance(part.item, int) - chunks.append(part.pre) - chunks.append(b"$%d" % (part.item + 1)) - formats.append(part.format) - - elif isinstance(parts[0].item, str): - seen: Dict[str, Tuple[bytes, PyFormat]] = {} - order = [] - for part in parts[:-1]: - assert isinstance(part.item, str) - chunks.append(part.pre) - if part.item not in seen: - ph = b"$%d" % (len(seen) + 1) - seen[part.item] = (ph, part.format) - order.append(part.item) - chunks.append(ph) - formats.append(part.format) - else: - if seen[part.item][1] != part.format: - raise e.ProgrammingError( - f"placeholder '{part.item}' cannot have different formats" - ) - chunks.append(seen[part.item][0]) - - # last part - chunks.append(parts[-1].pre) - - return b"".join(chunks), formats, order, parts - - -@lru_cache() -def _query2pg_client( - query: bytes, encoding: str -) -> Tuple[bytes, Optional[List[str]], List[QueryPart]]: - """ - Convert Python query and params into a template to perform client-side binding - """ - parts = _split_query(query, encoding, collapse_double_percent=False) - order: Optional[List[str]] = None - chunks: List[bytes] = [] - - if isinstance(parts[0].item, int): - for part in parts[:-1]: - assert isinstance(part.item, int) - chunks.append(part.pre) - chunks.append(b"%s") - - elif isinstance(parts[0].item, str): - seen: Dict[str, Tuple[bytes, PyFormat]] = {} - order = [] - for part in parts[:-1]: - assert isinstance(part.item, str) - chunks.append(part.pre) - if part.item not in seen: - ph = b"%s" - seen[part.item] = (ph, part.format) - order.append(part.item) - chunks.append(ph) - else: - chunks.append(seen[part.item][0]) - order.append(part.item) - - # last part - chunks.append(parts[-1].pre) - - return b"".join(chunks), order, parts - - -def _validate_and_reorder_params( - parts: List[QueryPart], vars: Params, order: Optional[List[str]] -) -> Sequence[Any]: - """ - Verify the compatibility between a query and a set of params. - """ - # Try concrete types, then abstract types - t = type(vars) - if t is list or t is tuple: - sequence = True - elif t is dict: - sequence = False - elif isinstance(vars, Sequence) and not isinstance(vars, (bytes, str)): - sequence = True - elif isinstance(vars, Mapping): - sequence = False - else: - raise TypeError( - "query parameters should be a sequence or a mapping," - f" got {type(vars).__name__}" - ) - - if sequence: - if len(vars) != len(parts) - 1: - raise e.ProgrammingError( - f"the query has {len(parts) - 1} placeholders but" - f" {len(vars)} parameters were passed" - ) - if vars and not isinstance(parts[0].item, int): - raise TypeError("named placeholders require a mapping of parameters") - return vars # type: ignore[return-value] - - else: - if vars and len(parts) > 1 and not isinstance(parts[0][1], str): - raise TypeError( - "positional placeholders (%s) require a sequence of parameters" - ) - try: - return [vars[item] for item in order or ()] # type: ignore[call-overload] - except KeyError: - raise e.ProgrammingError( - "query parameter missing:" - f" {', '.join(sorted(i for i in order or () if i not in vars))}" - ) + @staticmethod + @lru_cache() + def query2pg( + query: bytes, encoding: str + ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: + """ + Convert Python query and params into a template to perform client-side binding + """ + parts = _split_query(query, encoding, collapse_double_percent=False) + order: Optional[List[str]] = None + chunks: List[bytes] = [] + + if isinstance(parts[0].item, int): + for part in parts[:-1]: + assert isinstance(part.item, int) + chunks.append(part.pre) + chunks.append(b"%s") + + elif isinstance(parts[0].item, str): + seen: Dict[str, Tuple[bytes, PyFormat]] = {} + order = [] + for part in parts[:-1]: + assert isinstance(part.item, str) + chunks.append(part.pre) + if part.item not in seen: + ph = b"%s" + seen[part.item] = (ph, part.format) + order.append(part.item) + chunks.append(ph) + else: + chunks.append(seen[part.item][0]) + order.append(part.item) + + # last part + chunks.append(parts[-1].pre) + + return b"".join(chunks), None, order, parts _re_placeholder = re.compile( diff --git a/psycopg/psycopg/raw_cursor.py b/psycopg/psycopg/raw_cursor.py new file mode 100644 index 000000000..1a584c47e --- /dev/null +++ b/psycopg/psycopg/raw_cursor.py @@ -0,0 +1,61 @@ +""" +psycopg raw queries cursors +""" + +# Copyright (C) 2023 The Psycopg Team + +from typing import Any, Optional, Sequence, Tuple, List, TYPE_CHECKING +from functools import lru_cache + +from ._queries import PostgresQuery, QueryPart + +from .abc import ConnectionType, Query, Params +from .rows import Row +from .cursor import BaseCursor, Cursor +from ._enums import PyFormat +from .cursor_async import AsyncCursor + +if TYPE_CHECKING: + from .connection import Connection # noqa: F401 + from .connection_async import AsyncConnection # noqa: F401 + + +class RawPostgresQuery(PostgresQuery): + @staticmethod + @lru_cache() + def query2pg( + query: bytes, encoding: str + ) -> Tuple[bytes, Optional[List[PyFormat]], Optional[List[str]], List[QueryPart]]: + """ + Noop; Python raw query is already in the format Postgres understands. + """ + return query, None, None, [] + + @staticmethod + def validate_and_reorder_params( + parts: List[QueryPart], vars: Params, order: Optional[List[str]] + ) -> Sequence[Any]: + """ + Verify the compatibility; params must be a sequence for raw query. + """ + sequence = PostgresQuery.is_params_sequence(vars) + if not sequence: + raise TypeError("raw query require a sequence of parameters") + return vars # type: ignore[return-value] + + +class RawCursorMixin(BaseCursor[ConnectionType, Row]): + def _convert_query( + self, query: Query, params: Optional[Params] = None + ) -> PostgresQuery: + pgq = RawPostgresQuery(self._tx) + pgq.convert(query, params) + return pgq + + +class RawCursor(RawCursorMixin["Connection[Any]", Row], Cursor[Row]): + __module__ = "psycopg" + + +class AsyncRawCursor(RawCursorMixin["AsyncConnection[Any]", Row], AsyncCursor[Row]): + __module__ = "psycopg" diff --git a/tests/test_cursor.py b/tests/test_cursor.py index b7d3c051a..f2dcb6ea3 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -21,7 +21,7 @@ from .utils import gc_collect, raiseif from .fix_crdb import is_crdb, crdb_encoding, crdb_time_precision -@pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor]) +@pytest.fixture(params=[psycopg.Cursor, psycopg.ClientCursor, psycopg.RawCursor]) def conn(conn, request): conn.cursor_factory = request.param return conn diff --git a/tests/test_cursor_async.py b/tests/test_cursor_async.py index d85a929d9..1a05d966a 100644 --- a/tests/test_cursor_async.py +++ b/tests/test_cursor_async.py @@ -16,7 +16,9 @@ from .fix_crdb import crdb_encoding execmany = execmany # avoid F811 underneath -@pytest.fixture(params=[psycopg.AsyncCursor, psycopg.AsyncClientCursor]) +@pytest.fixture( + params=[psycopg.AsyncCursor, psycopg.AsyncClientCursor, psycopg.AsyncRawCursor] +) def aconn(aconn, request, anyio_backend): aconn.cursor_factory = request.param return aconn