From 9a78b301fcbc9fd62880c642b8a7f00b70458177 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 29 Apr 2025 22:32:18 +0200 Subject: [PATCH] feat: add typing support for template strings --- .flake8 | 1 + psycopg/psycopg/_compat.py | 13 ++++++++++++ psycopg/psycopg/_connection_base.py | 4 ++-- psycopg/psycopg/_queries.py | 21 +++++++++++++------ psycopg/psycopg/_typeinfo.py | 4 ++-- psycopg/psycopg/abc.py | 5 +++-- psycopg/psycopg/connection.py | 27 ++++++++++++++++++++++-- psycopg/psycopg/connection_async.py | 32 +++++++++++++++++++++++++++-- psycopg/psycopg/cursor.py | 23 +++++++++++++++++++-- psycopg/psycopg/cursor_async.py | 23 +++++++++++++++++++-- psycopg/psycopg/types/composite.py | 2 +- psycopg/psycopg/types/enum.py | 4 ++-- psycopg/psycopg/types/multirange.py | 4 ++-- psycopg/psycopg/types/range.py | 5 +++-- tests/test_tstring.py | 14 +++++++++++-- 15 files changed, 153 insertions(+), 29 deletions(-) diff --git a/.flake8 b/.flake8 index 42c860068..bd0e12c29 100644 --- a/.flake8 +++ b/.flake8 @@ -14,6 +14,7 @@ per-file-ignores = # Allow concatenated string literals from async_to_sync psycopg/psycopg/connection.py: E501 psycopg_pool/psycopg_pool/pool.py: E501 + psycopg/psycopg/connection.py: E501 # Pytest's importorskip() getting in the way tests/types/test_numpy.py: E402 diff --git a/psycopg/psycopg/_compat.py b/psycopg/psycopg/_compat.py index 0a82def5a..4ab1f2c5c 100644 --- a/psycopg/psycopg/_compat.py +++ b/psycopg/psycopg/_compat.py @@ -27,8 +27,21 @@ if sys.version_info >= (3, 13): else: from typing_extensions import TypeVar +if sys.version_info >= (3, 14): + from string.templatelib import Interpolation, Template +else: + + class Template: + pass + + class Interpolation: + pass + + __all__ = [ + "Interpolation", "LiteralString", "Self", + "Template", "TypeVar", ] diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index b14f75bce..07b7e523a 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -17,7 +17,7 @@ from collections.abc import Callable from . import errors as e from . import generators, postgres, pq -from .abc import PQGen, PQGenConn, Query +from .abc import PQGen, PQGenConn, QueryNoTemplate from .sql import SQL, Composable from ._tpc import Xid from .rows import Row @@ -441,7 +441,7 @@ class BaseConnection(Generic[Row]): return conn def _exec_command( - self, command: Query, result_format: pq.Format = TEXT + self, command: QueryNoTemplate, result_format: pq.Format = TEXT ) -> PQGen[PGresult | None]: """ Generator to send a command and receive the result to the backend. diff --git a/psycopg/psycopg/_queries.py b/psycopg/psycopg/_queries.py index b25d1635c..e9d6815b5 100644 --- a/psycopg/psycopg/_queries.py +++ b/psycopg/psycopg/_queries.py @@ -16,6 +16,7 @@ from . import pq from .abc import Buffer, Params, Query from .sql import Composable from ._enums import PyFormat +from ._compat import Template from ._encodings import conn_encoding if TYPE_CHECKING: @@ -64,7 +65,7 @@ class PostgresQuery: The results of this function can be obtained accessing the object attributes (`query`, `params`, `types`, `formats`). """ - query = self._ensure_bytes(query) + query = self._ensure_bytes(query, vars) if vars is not None: # Avoid caching queries extremely long or with a huge number of @@ -160,9 +161,18 @@ class PostgresQuery: f" {', '.join(sorted(i for i in order or () if i not in vars))}" ) - def _ensure_bytes(self, query: Query) -> bytes: + def from_template(self, query: Template) -> bytes: + raise NotImplementedError + + def _ensure_bytes(self, query: Query, vars: Params | None) -> bytes: if isinstance(query, str): - return query.encode(self._tx.encoding) + return query.encode(self._encoding) + elif isinstance(query, Template): + if vars is not None: + raise TypeError( + "'execute()' with string template query doesn't support parameters" + ) + return self.from_template(query) elif isinstance(query, Composable): return query.as_bytes(self._tx) else: @@ -247,7 +257,7 @@ class PostgresClientQuery(PostgresQuery): The results of this function can be obtained accessing the object attributes (`query`, `params`, `types`, `formats`). """ - query = self._ensure_bytes(query) + query = self._ensure_bytes(query, vars) if vars is not None: if ( @@ -423,8 +433,7 @@ _ph_to_fmt = { class PostgresRawQuery(PostgresQuery): def convert(self, query: Query, vars: Params | None) -> None: - query = self._ensure_bytes(query) - self.query = query + self.query = self._ensure_bytes(query, vars) self._want_formats = self._order = None self.dump(vars) diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index 5586dc026..4c1879690 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -14,7 +14,7 @@ from collections.abc import Iterator, Sequence from . import errors as e from . import sql -from .abc import AdaptContext, Query +from .abc import AdaptContext, QueryNoTemplate from .rows import dict_row from ._compat import TypeVar from ._typemod import TypeModifier @@ -157,7 +157,7 @@ class TypeInfo: register_array(self, context) @classmethod - def _get_info_query(cls, conn: BaseConnection[Any]) -> Query: + def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate: return sql.SQL( """\ SELECT diff --git a/psycopg/psycopg/abc.py b/psycopg/psycopg/abc.py index 764b9a1d5..5af64413d 100644 --- a/psycopg/psycopg/abc.py +++ b/psycopg/psycopg/abc.py @@ -11,7 +11,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence from . import pq from ._enums import PyFormat as PyFormat -from ._compat import LiteralString, TypeVar +from ._compat import LiteralString, Template, TypeVar if TYPE_CHECKING: from . import sql @@ -26,7 +26,8 @@ NoneType: type = type(None) # An object implementing the buffer protocol Buffer: TypeAlias = Union[bytes, bytearray, memoryview] -Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"] +QueryNoTemplate: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"] +Query: TypeAlias = Union[QueryNoTemplate, Template] Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]] ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]") PipelineCommand: TypeAlias = Callable[[], None] diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index 8c2b32dbe..9e7872ae3 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -20,12 +20,13 @@ from collections.abc import Generator, Iterator from . import errors as e from . import pq, waiting from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query +from .abc import QueryNoTemplate from ._tpc import Xid from .rows import Row, RowFactory, args_row, tuple_row from .adapt import AdaptersMap from ._enums import IsolationLevel from .cursor import Cursor -from ._compat import Self +from ._compat import Self, Template from ._acompat import Lock from .conninfo import conninfo_attempts, conninfo_to_dict, make_conninfo from .conninfo import timeout_from_conninfo @@ -251,6 +252,21 @@ class Connection(BaseConnection[Row]): return cur + @overload + def execute( + self, + query: QueryNoTemplate, + params: Params | None = None, + *, + prepare: bool | None = None, + binary: bool = False, + ) -> Cursor[Row]: ... + + @overload + def execute( + self, query: Template, *, prepare: bool | None = None, binary: bool = False + ) -> Cursor[Row]: ... + def execute( self, query: Query, @@ -265,7 +281,14 @@ class Connection(BaseConnection[Row]): if binary: cur.format = BINARY - return cur.execute(query, params, prepare=prepare) + if isinstance(query, Template): + if params is not None: + raise TypeError( + "'execute()' with string template query doesn't support parameters" + ) + return cur.execute(query, prepare=prepare) + else: + return cur.execute(query, params, prepare=prepare) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index d9eead5e6..85eb134b7 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -17,11 +17,12 @@ from collections.abc import AsyncGenerator, AsyncIterator from . import errors as e from . import pq, waiting from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query +from .abc import QueryNoTemplate from ._tpc import Xid from .rows import AsyncRowFactory, Row, args_row, tuple_row from .adapt import AdaptersMap from ._enums import IsolationLevel -from ._compat import Self +from ._compat import Self, Template from ._acompat import ALock from .conninfo import conninfo_attempts_async, conninfo_to_dict, make_conninfo from .conninfo import timeout_from_conninfo @@ -272,6 +273,25 @@ class AsyncConnection(BaseConnection[Row]): return cur + @overload + async def execute( + self, + query: QueryNoTemplate, + params: Params | None = None, + *, + prepare: bool | None = None, + binary: bool = False, + ) -> AsyncCursor[Row]: ... + + @overload + async def execute( + self, + query: Template, + *, + prepare: bool | None = None, + binary: bool = False, + ) -> AsyncCursor[Row]: ... + async def execute( self, query: Query, @@ -286,7 +306,15 @@ class AsyncConnection(BaseConnection[Row]): if binary: cur.format = BINARY - return await cur.execute(query, params, prepare=prepare) + if isinstance(query, Template): + if params is not None: + raise TypeError( + "'execute()' with string template query" + " doesn't support parameters" + ) + return await cur.execute(query, prepare=prepare) + else: + return await cur.execute(query, params, prepare=prepare) except e._NO_TRACEBACK as ex: raise ex.with_traceback(None) diff --git a/psycopg/psycopg/cursor.py b/psycopg/psycopg/cursor.py index 02a4bbbe5..fd5600bad 100644 --- a/psycopg/psycopg/cursor.py +++ b/psycopg/psycopg/cursor.py @@ -16,10 +16,10 @@ from collections.abc import Iterable, Iterator from . import errors as e from . import pq -from .abc import Params, Query +from .abc import Params, Query, QueryNoTemplate from .copy import Copy, Writer from .rows import Row, RowFactory, RowMaker -from ._compat import Self +from ._compat import Self, Template from ._pipeline import Pipeline from ._cursor_base import BaseCursor @@ -78,6 +78,25 @@ class Cursor(BaseCursor["Connection[Any]", Row]): def _make_row_maker(self) -> RowMaker[Row]: return self._row_factory(self) + @overload + def execute( + self, + query: QueryNoTemplate, + params: Params | None = None, + *, + prepare: bool | None = None, + binary: bool | None = None, + ) -> Self: ... + + @overload + def execute( + self, + query: Template, + *, + prepare: bool | None = None, + binary: bool | None = None, + ) -> Self: ... + def execute( self, query: Query, diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 06c715858..ae7d6cec3 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -13,10 +13,10 @@ from collections.abc import AsyncIterator, Iterable from . import errors as e from . import pq -from .abc import Params, Query +from .abc import Params, Query, QueryNoTemplate from .copy import AsyncCopy, AsyncWriter from .rows import AsyncRowFactory, Row, RowMaker -from ._compat import Self +from ._compat import Self, Template from ._cursor_base import BaseCursor from ._pipeline_async import AsyncPipeline @@ -78,6 +78,25 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]): def _make_row_maker(self) -> RowMaker[Row]: return self._row_factory(self) + @overload + async def execute( + self, + query: QueryNoTemplate, + params: Params | None = None, + *, + prepare: bool | None = None, + binary: bool | None = None, + ) -> Self: ... + + @overload + async def execute( + self, + query: Template, + *, + prepare: bool | None = None, + binary: bool | None = None, + ) -> Self: ... + async def execute( self, query: Query, diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index 22b14449c..1890fa6dc 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -50,7 +50,7 @@ class CompositeInfo(TypeInfo): self.python_type: type | None = None @classmethod - def _get_info_query(cls, conn: BaseConnection[Any]) -> abc.Query: + def _get_info_query(cls, conn: BaseConnection[Any]) -> abc.QueryNoTemplate: return sql.SQL( """\ SELECT diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index 9c718675a..0a5b61091 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -12,7 +12,7 @@ from collections.abc import Mapping, Sequence from .. import errors as e from .. import postgres, sql from ..pq import Format -from ..abc import AdaptContext, Query +from ..abc import AdaptContext, QueryNoTemplate from ..adapt import Buffer, Dumper, Loader from .._compat import TypeVar from .._typeinfo import TypeInfo @@ -51,7 +51,7 @@ class EnumInfo(TypeInfo): self.enum: type[Enum] | None = None @classmethod - def _get_info_query(cls, conn: BaseConnection[Any]) -> Query: + def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate: return sql.SQL( """\ SELECT name, oid, array_oid, array_agg(label) AS labels diff --git a/psycopg/psycopg/types/multirange.py b/psycopg/psycopg/types/multirange.py index f842b196d..1933a90d4 100644 --- a/psycopg/psycopg/types/multirange.py +++ b/psycopg/psycopg/types/multirange.py @@ -16,7 +16,7 @@ from .. import _oids from .. import errors as e from .. import postgres, sql from ..pq import Format -from ..abc import AdaptContext, Buffer, Dumper, DumperKey, Query +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, QueryNoTemplate from .range import Range, T, dump_range_binary, dump_range_text, fail_dump from .range import load_range_binary, load_range_text from .._oids import INVALID_OID, TEXT_OID @@ -46,7 +46,7 @@ class MultirangeInfo(TypeInfo): self.subtype_oid = subtype_oid @classmethod - def _get_info_query(cls, conn: BaseConnection[Any]) -> Query: + def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate: if conn.info.server_version < 140000: raise e.NotSupportedError( "multirange types are only available from PostgreSQL 14" diff --git a/psycopg/psycopg/types/range.py b/psycopg/psycopg/types/range.py index 8b1d3a0ba..b410088e1 100644 --- a/psycopg/psycopg/types/range.py +++ b/psycopg/psycopg/types/range.py @@ -16,7 +16,8 @@ from .. import _oids from .. import errors as e from .. import postgres, sql from ..pq import Format -from ..abc import AdaptContext, Buffer, Dumper, DumperKey, DumpFunc, LoadFunc, Query +from ..abc import AdaptContext, Buffer, Dumper, DumperKey, DumpFunc, LoadFunc +from ..abc import QueryNoTemplate from .._oids import INVALID_OID, TEXT_OID from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader from .._compat import TypeVar @@ -53,7 +54,7 @@ class RangeInfo(TypeInfo): self.subtype_oid = subtype_oid @classmethod - def _get_info_query(cls, conn: BaseConnection[Any]) -> Query: + def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate: return sql.SQL( """\ SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid, diff --git a/tests/test_tstring.py b/tests/test_tstring.py index 167c095a7..2ab890303 100644 --- a/tests/test_tstring.py +++ b/tests/test_tstring.py @@ -1,2 +1,12 @@ -def test_tstring(): - t"" +import pytest + + +async def test_connection_no_params(aconn): + with pytest.raises(TypeError): + await aconn.execute(t"select 1", []) + + +async def test_cursor_no_params(aconn): + cur = aconn.cursor() + with pytest.raises(TypeError): + await cur.execute(t"select 1", []) -- 2.47.3