]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add typing support for template strings
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 29 Apr 2025 20:32:18 +0000 (22:32 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 8 Sep 2025 09:46:55 +0000 (11:46 +0200)
15 files changed:
.flake8
psycopg/psycopg/_compat.py
psycopg/psycopg/_connection_base.py
psycopg/psycopg/_queries.py
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/abc.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
psycopg/psycopg/types/composite.py
psycopg/psycopg/types/enum.py
psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/range.py
tests/test_tstring.py

diff --git a/.flake8 b/.flake8
index 42c8600685838f51e21b8f52e8cca638c353cc93..bd0e12c2953b509fcb4e4d613a595a829b6427f7 100644 (file)
--- a/.flake8
+++ b/.flake8
@@ -14,6 +14,7 @@ per-file-ignores =
     # Allow concatenated string literals from async_to_sync
     psycopg/psycopg/connection.py: E501
     psycopg_pool/psycopg_pool/pool.py: E501
+    psycopg/psycopg/connection.py: E501
 
     # Pytest's importorskip() getting in the way
     tests/types/test_numpy.py: E402
index 0a82def5a0d65ab4bc66bc13e69af403fa98b9cb..4ab1f2c5c25ce3895632f1785ba9fd83d73426eb 100644 (file)
@@ -27,8 +27,21 @@ if sys.version_info >= (3, 13):
 else:
     from typing_extensions import TypeVar
 
+if sys.version_info >= (3, 14):
+    from string.templatelib import Interpolation, Template
+else:
+
+    class Template:
+        pass
+
+    class Interpolation:
+        pass
+
+
 __all__ = [
+    "Interpolation",
     "LiteralString",
     "Self",
+    "Template",
     "TypeVar",
 ]
index b14f75bceda5a5d98c221858184dbfad8d5af883..07b7e523adfe1e3f9c0c3f979725265b98d62413 100644 (file)
@@ -17,7 +17,7 @@ from collections.abc import Callable
 
 from . import errors as e
 from . import generators, postgres, pq
-from .abc import PQGen, PQGenConn, Query
+from .abc import PQGen, PQGenConn, QueryNoTemplate
 from .sql import SQL, Composable
 from ._tpc import Xid
 from .rows import Row
@@ -441,7 +441,7 @@ class BaseConnection(Generic[Row]):
         return conn
 
     def _exec_command(
-        self, command: Query, result_format: pq.Format = TEXT
+        self, command: QueryNoTemplate, result_format: pq.Format = TEXT
     ) -> PQGen[PGresult | None]:
         """
         Generator to send a command and receive the result to the backend.
index b25d1635c74c85644a5fcdbe2fc12e895dd0cbc7..e9d6815b5b08fabb7207d1b05f8ab0698e7deffa 100644 (file)
@@ -16,6 +16,7 @@ from . import pq
 from .abc import Buffer, Params, Query
 from .sql import Composable
 from ._enums import PyFormat
+from ._compat import Template
 from ._encodings import conn_encoding
 
 if TYPE_CHECKING:
@@ -64,7 +65,7 @@ class PostgresQuery:
         The results of this function can be obtained accessing the object
         attributes (`query`, `params`, `types`, `formats`).
         """
-        query = self._ensure_bytes(query)
+        query = self._ensure_bytes(query, vars)
 
         if vars is not None:
             # Avoid caching queries extremely long or with a huge number of
@@ -160,9 +161,18 @@ class PostgresQuery:
                     f" {', '.join(sorted(i for i in order or () if i not in vars))}"
                 )
 
-    def _ensure_bytes(self, query: Query) -> bytes:
+    def from_template(self, query: Template) -> bytes:
+        raise NotImplementedError
+
+    def _ensure_bytes(self, query: Query, vars: Params | None) -> bytes:
         if isinstance(query, str):
-            return query.encode(self._tx.encoding)
+            return query.encode(self._encoding)
+        elif isinstance(query, Template):
+            if vars is not None:
+                raise TypeError(
+                    "'execute()' with string template query doesn't support parameters"
+                )
+            return self.from_template(query)
         elif isinstance(query, Composable):
             return query.as_bytes(self._tx)
         else:
@@ -247,7 +257,7 @@ class PostgresClientQuery(PostgresQuery):
         The results of this function can be obtained accessing the object
         attributes (`query`, `params`, `types`, `formats`).
         """
-        query = self._ensure_bytes(query)
+        query = self._ensure_bytes(query, vars)
 
         if vars is not None:
             if (
@@ -423,8 +433,7 @@ _ph_to_fmt = {
 
 class PostgresRawQuery(PostgresQuery):
     def convert(self, query: Query, vars: Params | None) -> None:
-        query = self._ensure_bytes(query)
-        self.query = query
+        self.query = self._ensure_bytes(query, vars)
         self._want_formats = self._order = None
         self.dump(vars)
 
index 5586dc026cebcf30f72d78e1eb6f0ca058473c9e..4c187969084a520ae44b6d96d4f2444f7ecc7ffe 100644 (file)
@@ -14,7 +14,7 @@ from collections.abc import Iterator, Sequence
 
 from . import errors as e
 from . import sql
-from .abc import AdaptContext, Query
+from .abc import AdaptContext, QueryNoTemplate
 from .rows import dict_row
 from ._compat import TypeVar
 from ._typemod import TypeModifier
@@ -157,7 +157,7 @@ class TypeInfo:
             register_array(self, context)
 
     @classmethod
-    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
         return sql.SQL(
             """\
 SELECT
index 764b9a1d5711ffc750731373eb3eb9501653d726..5af64413da9987ef970316b593395b889fdc5475 100644 (file)
@@ -11,7 +11,7 @@ from collections.abc import Callable, Generator, Mapping, Sequence
 
 from . import pq
 from ._enums import PyFormat as PyFormat
-from ._compat import LiteralString, TypeVar
+from ._compat import LiteralString, Template, TypeVar
 
 if TYPE_CHECKING:
     from . import sql
@@ -26,7 +26,8 @@ NoneType: type = type(None)
 # An object implementing the buffer protocol
 Buffer: TypeAlias = Union[bytes, bytearray, memoryview]
 
-Query: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
+QueryNoTemplate: TypeAlias = Union[LiteralString, bytes, "sql.SQL", "sql.Composed"]
+Query: TypeAlias = Union[QueryNoTemplate, Template]
 Params: TypeAlias = Union[Sequence[Any], Mapping[str, Any]]
 ConnectionType = TypeVar("ConnectionType", bound="BaseConnection[Any]")
 PipelineCommand: TypeAlias = Callable[[], None]
index 8c2b32dbe5ed2ac57070e5c3fbbe005417414973..9e7872ae31dfb7614dd699ddd4afb2681b3c06c1 100644 (file)
@@ -20,12 +20,13 @@ from collections.abc import Generator, Iterator
 from . import errors as e
 from . import pq, waiting
 from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query
+from .abc import QueryNoTemplate
 from ._tpc import Xid
 from .rows import Row, RowFactory, args_row, tuple_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
 from .cursor import Cursor
-from ._compat import Self
+from ._compat import Self, Template
 from ._acompat import Lock
 from .conninfo import conninfo_attempts, conninfo_to_dict, make_conninfo
 from .conninfo import timeout_from_conninfo
@@ -251,6 +252,21 @@ class Connection(BaseConnection[Row]):
 
         return cur
 
+    @overload
+    def execute(
+        self,
+        query: QueryNoTemplate,
+        params: Params | None = None,
+        *,
+        prepare: bool | None = None,
+        binary: bool = False,
+    ) -> Cursor[Row]: ...
+
+    @overload
+    def execute(
+        self, query: Template, *, prepare: bool | None = None, binary: bool = False
+    ) -> Cursor[Row]: ...
+
     def execute(
         self,
         query: Query,
@@ -265,7 +281,14 @@ class Connection(BaseConnection[Row]):
             if binary:
                 cur.format = BINARY
 
-            return cur.execute(query, params, prepare=prepare)
+            if isinstance(query, Template):
+                if params is not None:
+                    raise TypeError(
+                        "'execute()' with string template query doesn't support parameters"
+                    )
+                return cur.execute(query, prepare=prepare)
+            else:
+                return cur.execute(query, params, prepare=prepare)
         except e._NO_TRACEBACK as ex:
             raise ex.with_traceback(None)
 
index d9eead5e67c989f5c2a371e36ba1b070efad1d94..85eb134b74850819b2b99ba4408aba14da2204e1 100644 (file)
@@ -17,11 +17,12 @@ from collections.abc import AsyncGenerator, AsyncIterator
 from . import errors as e
 from . import pq, waiting
 from .abc import RV, AdaptContext, ConnDict, ConnParam, Params, PQGen, Query
+from .abc import QueryNoTemplate
 from ._tpc import Xid
 from .rows import AsyncRowFactory, Row, args_row, tuple_row
 from .adapt import AdaptersMap
 from ._enums import IsolationLevel
-from ._compat import Self
+from ._compat import Self, Template
 from ._acompat import ALock
 from .conninfo import conninfo_attempts_async, conninfo_to_dict, make_conninfo
 from .conninfo import timeout_from_conninfo
@@ -272,6 +273,25 @@ class AsyncConnection(BaseConnection[Row]):
 
         return cur
 
+    @overload
+    async def execute(
+        self,
+        query: QueryNoTemplate,
+        params: Params | None = None,
+        *,
+        prepare: bool | None = None,
+        binary: bool = False,
+    ) -> AsyncCursor[Row]: ...
+
+    @overload
+    async def execute(
+        self,
+        query: Template,
+        *,
+        prepare: bool | None = None,
+        binary: bool = False,
+    ) -> AsyncCursor[Row]: ...
+
     async def execute(
         self,
         query: Query,
@@ -286,7 +306,15 @@ class AsyncConnection(BaseConnection[Row]):
             if binary:
                 cur.format = BINARY
 
-            return await cur.execute(query, params, prepare=prepare)
+            if isinstance(query, Template):
+                if params is not None:
+                    raise TypeError(
+                        "'execute()' with string template query"
+                        " doesn't support parameters"
+                    )
+                return await cur.execute(query, prepare=prepare)
+            else:
+                return await cur.execute(query, params, prepare=prepare)
 
         except e._NO_TRACEBACK as ex:
             raise ex.with_traceback(None)
index 02a4bbbe54ee9e6a58a876cffdb4ac16da53493a..fd5600badcf6270628469c9a61f816b3b7e11af5 100644 (file)
@@ -16,10 +16,10 @@ from collections.abc import Iterable, Iterator
 
 from . import errors as e
 from . import pq
-from .abc import Params, Query
+from .abc import Params, Query, QueryNoTemplate
 from .copy import Copy, Writer
 from .rows import Row, RowFactory, RowMaker
-from ._compat import Self
+from ._compat import Self, Template
 from ._pipeline import Pipeline
 from ._cursor_base import BaseCursor
 
@@ -78,6 +78,25 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
     def _make_row_maker(self) -> RowMaker[Row]:
         return self._row_factory(self)
 
+    @overload
+    def execute(
+        self,
+        query: QueryNoTemplate,
+        params: Params | None = None,
+        *,
+        prepare: bool | None = None,
+        binary: bool | None = None,
+    ) -> Self: ...
+
+    @overload
+    def execute(
+        self,
+        query: Template,
+        *,
+        prepare: bool | None = None,
+        binary: bool | None = None,
+    ) -> Self: ...
+
     def execute(
         self,
         query: Query,
index 06c715858d497ea50850c6e08ed5200955e21698..ae7d6cec36c9e70430f0ce819be845b8a7e22260 100644 (file)
@@ -13,10 +13,10 @@ from collections.abc import AsyncIterator, Iterable
 
 from . import errors as e
 from . import pq
-from .abc import Params, Query
+from .abc import Params, Query, QueryNoTemplate
 from .copy import AsyncCopy, AsyncWriter
 from .rows import AsyncRowFactory, Row, RowMaker
-from ._compat import Self
+from ._compat import Self, Template
 from ._cursor_base import BaseCursor
 from ._pipeline_async import AsyncPipeline
 
@@ -78,6 +78,25 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
     def _make_row_maker(self) -> RowMaker[Row]:
         return self._row_factory(self)
 
+    @overload
+    async def execute(
+        self,
+        query: QueryNoTemplate,
+        params: Params | None = None,
+        *,
+        prepare: bool | None = None,
+        binary: bool | None = None,
+    ) -> Self: ...
+
+    @overload
+    async def execute(
+        self,
+        query: Template,
+        *,
+        prepare: bool | None = None,
+        binary: bool | None = None,
+    ) -> Self: ...
+
     async def execute(
         self,
         query: Query,
index 22b14449cac7916ba0f0d4a567bbf1755337b8e3..1890fa6dc3853cfe7352fb81af780f27924f7463 100644 (file)
@@ -50,7 +50,7 @@ class CompositeInfo(TypeInfo):
         self.python_type: type | None = None
 
     @classmethod
-    def _get_info_query(cls, conn: BaseConnection[Any]) -> abc.Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> abc.QueryNoTemplate:
         return sql.SQL(
             """\
 SELECT
index 9c718675aa3b277ae1595b0cbbfa5760f4be0627..0a5b61091535fceda28f133b3ce40e0e536c6382 100644 (file)
@@ -12,7 +12,7 @@ from collections.abc import Mapping, Sequence
 from .. import errors as e
 from .. import postgres, sql
 from ..pq import Format
-from ..abc import AdaptContext, Query
+from ..abc import AdaptContext, QueryNoTemplate
 from ..adapt import Buffer, Dumper, Loader
 from .._compat import TypeVar
 from .._typeinfo import TypeInfo
@@ -51,7 +51,7 @@ class EnumInfo(TypeInfo):
         self.enum: type[Enum] | None = None
 
     @classmethod
-    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
         return sql.SQL(
             """\
 SELECT name, oid, array_oid, array_agg(label) AS labels
index f842b196d797d0547f08c75fd8ba9a9e75f5fd1b..1933a90d494b20cfcb4c6ce77fcdad6ea87825d7 100644 (file)
@@ -16,7 +16,7 @@ from .. import _oids
 from .. import errors as e
 from .. import postgres, sql
 from ..pq import Format
-from ..abc import AdaptContext, Buffer, Dumper, DumperKey, Query
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, QueryNoTemplate
 from .range import Range, T, dump_range_binary, dump_range_text, fail_dump
 from .range import load_range_binary, load_range_text
 from .._oids import INVALID_OID, TEXT_OID
@@ -46,7 +46,7 @@ class MultirangeInfo(TypeInfo):
         self.subtype_oid = subtype_oid
 
     @classmethod
-    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
         if conn.info.server_version < 140000:
             raise e.NotSupportedError(
                 "multirange types are only available from PostgreSQL 14"
index 8b1d3a0ba16bd48cc34815f1ec16fdbf50721eae..b410088e1dc8fd360c5b8db40cb74dfa68176ba4 100644 (file)
@@ -16,7 +16,8 @@ from .. import _oids
 from .. import errors as e
 from .. import postgres, sql
 from ..pq import Format
-from ..abc import AdaptContext, Buffer, Dumper, DumperKey, DumpFunc, LoadFunc, Query
+from ..abc import AdaptContext, Buffer, Dumper, DumperKey, DumpFunc, LoadFunc
+from ..abc import QueryNoTemplate
 from .._oids import INVALID_OID, TEXT_OID
 from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
 from .._compat import TypeVar
@@ -53,7 +54,7 @@ class RangeInfo(TypeInfo):
         self.subtype_oid = subtype_oid
 
     @classmethod
-    def _get_info_query(cls, conn: BaseConnection[Any]) -> Query:
+    def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
         return sql.SQL(
             """\
 SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
index 167c095a71a087cb4707b01e3a6bea23cfa28e16..2ab8903035679b4385530cb06d7c37b21f651845 100644 (file)
@@ -1,2 +1,12 @@
-def test_tstring():
-    t""
+import pytest
+
+
+async def test_connection_no_params(aconn):
+    with pytest.raises(TypeError):
+        await aconn.execute(t"select 1", [])
+
+
+async def test_cursor_no_params(aconn):
+    cur = aconn.cursor()
+    with pytest.raises(TypeError):
+        await cur.execute(t"select 1", [])