# Copyright (C) 2021 The Psycopg Team
+from __future__ import annotations
+
import sys
+from typing import Any, Iterator
if sys.version_info >= (3, 11):
from typing import LiteralString, Self
if sys.version_info >= (3, 14):
from string.templatelib import Interpolation, Template
else:
+ from dataclasses import dataclass
class Template:
- pass
+ strings: tuple[str]
+ interpolations: tuple[Interpolation]
+
+ def __new__(cls, *args: str | Interpolation) -> Self:
+ return cls()
+
+ def __iter__(self) -> Iterator[str | Interpolation]:
+ return
+ yield
+ @dataclass
class Interpolation:
- pass
+ value: Any
+ expression: str
+ conversion: str | None
+ format_spec: str
__all__ = [
from .abc import ConnectionType, Params, PQGen, Query
from .rows import Row, RowMaker
from ._column import Column
+from ._compat import Template
from .pq.misc import connection_summary
from ._queries import PostgresClientQuery, PostgresQuery
from ._preparing import Prepare
yield from self._start_query()
# Merge the params client-side
- if params:
+ if params or isinstance(statement, Template):
pgq = PostgresClientQuery(self._tx)
pgq.convert(statement, params)
statement = pgq.query
from collections.abc import Callable, Mapping, Sequence
from . import errors as e
-from . import pq
-from .abc import Buffer, Params, Query
-from .sql import Composable
+from . import pq, sql
+from .abc import Buffer, Params, Query, QueryNoTemplate
from ._enums import PyFormat
from ._compat import Template
-from ._encodings import conn_encoding
+from ._tstrings import TemplateProcessor
if TYPE_CHECKING:
from .abc import Transformer
self._want_formats: list[PyFormat] | None = None
self.formats: Sequence[pq.Format] | None = None
- self._encoding = conn_encoding(transformer.connection)
self._parts: list[QueryPart]
self.query = b""
self._order: list[str] | None = None
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
- query = self._ensure_bytes(query, vars)
+ if isinstance(query, Template):
+ return self._convert_template(query, vars)
+
+ query = self._ensure_bytes(query)
if vars is not None:
# Avoid caching queries extremely long or with a huge number of
f = _query2pg_nocache
(self.query, self._want_formats, self._order, self._parts) = f(
- query, self._encoding
+ query, self._tx.encoding
)
else:
self.query = query
f" {', '.join(sorted(i for i in order or () if i not in vars))}"
)
- def from_template(self, query: Template) -> bytes:
- raise NotImplementedError
-
- def _ensure_bytes(self, query: Query, vars: Params | None) -> bytes:
+ def _ensure_bytes(self, query: QueryNoTemplate) -> bytes:
if isinstance(query, str):
- return query.encode(self._encoding)
- elif isinstance(query, Template):
- if vars is not None:
- raise TypeError(
- "'execute()' with string template query doesn't support parameters"
- )
- return self.from_template(query)
- elif isinstance(query, Composable):
+ return query.encode(self._tx.encoding)
+ elif isinstance(query, sql.Composable):
return query.as_bytes(self._tx)
else:
return query
+ def _convert_template(self, query: Template, vars: Params | None) -> None:
+ if vars is not None:
+ raise TypeError(
+ "'execute()' with string template query doesn't support parameters"
+ )
+
+ tp = TemplateProcessor(query, server_params=True, tx=self._tx)
+ tp.process()
+ self.query = tp.query
+ if tp.params:
+ self.params = self._tx.dump_sequence(tp.params, tp.formats)
+ self.types = self._tx.types or ()
+ self.formats = self._tx.formats
+ else:
+ self.params = None
+ self.types = ()
+ self.formats = None
+
# The type of the _query2pg() and _query2pg_nocache() methods
_Query2Pg: TypeAlias = Callable[
The results of this function can be obtained accessing the object
attributes (`query`, `params`, `types`, `formats`).
"""
- query = self._ensure_bytes(query, vars)
+ if isinstance(query, Template):
+ return self._convert_template(query, vars)
+
+ query = self._ensure_bytes(query)
if vars is not None:
if (
else:
f = _query2pg_client_nocache
- (self.template, self._order, self._parts) = f(query, self._encoding)
+ (self.template, self._order, self._parts) = f(query, self._tx.encoding)
else:
self.query = query
self._order = None
else:
self.params = None
+ def _convert_template(self, query: Template, vars: Params | None) -> None:
+ if vars is not None:
+ raise TypeError(
+ "'execute()' with string template query doesn't support parameters"
+ )
+
+ tp = TemplateProcessor(query, server_params=False, tx=self._tx)
+ tp.process()
+ self.query = tp.query
+ self.params = tp.params
+
_Query2PgClient: TypeAlias = Callable[
[bytes, str], "tuple[bytes, list[str] | None, list[QueryPart]]"
class PostgresRawQuery(PostgresQuery):
def convert(self, query: Query, vars: Params | None) -> None:
- self.query = self._ensure_bytes(query, vars)
+ if isinstance(query, Template):
+ return self._convert_template(query, vars)
+
+ self.query = self._ensure_bytes(query)
self._want_formats = self._order = None
self.dump(vars)
self.params = None
self.types = ()
self.formats = None
+
+ def _convert_template(self, query: Template, vars: Params | None) -> None:
+ raise e.NotSupportedError(
+ f"{type(self).__name__} doesn't support template strings"
+ )
from . import pq, sql
from .abc import ConnectionType, Params, PQGen, Query
from .rows import Row
+from ._compat import Interpolation, Template
from .generators import execute
from ._cursor_base import BaseCursor
)
yield from self._conn._exec_command(query)
- def _make_declare_statement(self, query: Query) -> sql.Composed:
- if isinstance(query, bytes):
- query = query.decode(self._encoding)
- if not isinstance(query, sql.Composable):
- query = sql.SQL(query)
-
+ def _make_declare_statement(self, query: Query) -> Query:
parts = [sql.SQL("DECLARE"), sql.Identifier(self._name)]
if self._scrollable is not None:
parts.append(sql.SQL("SCROLL" if self._scrollable else "NO SCROLL"))
parts.append(sql.SQL("CURSOR"))
if self._withhold:
parts.append(sql.SQL("WITH HOLD"))
- parts.append(sql.SQL("FOR"))
- parts.append(query)
+ parts.append(sql.SQL("FOR "))
+ declare = sql.SQL(" ").join(parts)
+
+ if isinstance(query, Template):
+ # t"{declare:q}{query:q}" but compatible with Python < 3.14
+ return Template(
+ Interpolation(declare, "declare", None, "q"),
+ Interpolation(query, "query", None, "q"),
+ )
+
+ if isinstance(query, bytes):
+ query = query.decode(self._encoding)
+ if not isinstance(query, sql.Composable):
+ query = sql.SQL(query)
- return sql.SQL(" ").join(parts)
+ return declare + query
--- /dev/null
+"""
+Template strings support in queries.
+"""
+
+# Copyright (C) 2025 The Psycopg Team
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from . import errors as e
+from . import sql
+from ._enums import PyFormat
+from ._compat import Interpolation, Template
+
+if TYPE_CHECKING:
+ from .abc import Transformer
+
+# Formats supported by template strings
+FMT_AUTO = PyFormat.AUTO.value
+FMT_TEXT = PyFormat.TEXT.value
+FMT_BINARY = PyFormat.BINARY.value
+FMT_IDENT = "i"
+FMT_LITERAL = "l"
+FMT_SQL = "q"
+
+
+class TemplateProcessor:
+ def __init__(self, template: Template, *, tx: Transformer, server_params: bool):
+ self.template = template
+ self._tx = tx
+ self._server_params = server_params
+
+ self.query = b""
+ self.formats: list[PyFormat] = []
+ self.params: list[Any] = []
+
+ self._chunks: list[bytes] = []
+
+ def process(self) -> None:
+ self._process_template(self.template)
+ self.query = b"".join(self._chunks)
+
+ def _check_template_format(self, item: Interpolation, want_fmt: str) -> None:
+ if item.format_spec == want_fmt:
+ return
+ fmt = f":{item.format_spec}" if item.format_spec else ""
+ cls = type(item.value)
+ msg = f"{cls.__module__}.{cls.__qualname__} require format ':{want_fmt}'"
+ raise e.ProgrammingError(f"{msg}; got '{{{item.expression}{fmt}}}'")
+
+ def _process_template(self, t: Template) -> None:
+ for item in t:
+ if isinstance(item, str):
+ self._chunks.append(item.encode(self._tx.encoding))
+ continue
+
+ assert isinstance(item, Interpolation)
+ if item.conversion:
+ raise TypeError(
+ "conversions not supported in query; got"
+ f" '{{{item.expression}!{item.conversion}}}'"
+ )
+
+ if isinstance(item.value, Template):
+ self._check_template_format(item, FMT_SQL)
+ self._process_template(item.value)
+
+ elif isinstance(item.value, sql.Composable):
+ self._process_composable(item)
+
+ elif (fmt := item.format_spec or FMT_AUTO) == FMT_IDENT:
+ if not isinstance(item.value, str):
+ raise e.ProgrammingError(
+ "identifier values must be strings; got"
+ f" {type(item.value).__qualname__}"
+ f" in {{{item.expression}:{fmt}}}"
+ )
+ self._chunks.append(sql.Identifier(item.value).as_bytes(self._tx))
+
+ elif fmt == FMT_LITERAL:
+ self._chunks.append(sql.Literal(item.value).as_bytes(self._tx))
+
+ elif fmt == FMT_SQL:
+ # It must have been processed already
+ raise e.ProgrammingError(
+ "sql values must be sql.Composite, sql.SQL, or Template;"
+ f" got {type(item.value).__qualname__}"
+ f" in {{{item.expression}:{fmt}}}"
+ )
+
+ else:
+ if self._server_params:
+ self._process_server_variable(item, fmt)
+ else:
+ self._process_client_variable(item, fmt)
+
+ def _process_server_variable(self, item: Interpolation, fmt: str) -> None:
+ try:
+ pyfmt = PyFormat(fmt)
+ except ValueError:
+ raise e.ProgrammingError(
+ f"format '{fmt}' not supported in query;"
+ f" got '{{{item.expression}:{fmt}}}'"
+ )
+
+ self.formats.append(pyfmt)
+ self.params.append(item.value)
+ self._chunks.append(b"$%d" % len(self.params))
+
+ def _process_client_variable(self, item: Interpolation, fmt: str) -> None:
+ try:
+ PyFormat(fmt)
+ except ValueError:
+ raise e.ProgrammingError(
+ f"format '{fmt}' not supported in query;"
+ f" got '{{{item.expression}:{fmt}}}'"
+ )
+
+ param = sql.Literal(item.value).as_bytes(self._tx)
+ self._chunks.append(param)
+ self.params.append(param)
+
+ def _process_composable(self, item: Interpolation) -> None:
+ if isinstance(item.value, sql.Identifier):
+ self._check_template_format(item, FMT_IDENT)
+ self._chunks.append(item.value.as_bytes(self._tx))
+ return
+
+ elif isinstance(item.value, sql.Literal):
+ self._check_template_format(item, FMT_LITERAL)
+ self._chunks.append(item.value.as_bytes(self._tx))
+ return
+
+ elif isinstance(item.value, (sql.SQL, sql.Composed)):
+ self._check_template_format(item, FMT_SQL)
+ self._chunks.append(item.value.as_bytes(self._tx))
+ return
+
+ else:
+ raise e.ProgrammingError(
+ f"{type(item.value).__qualname__} not supported in string templates"
+ )
import codecs
import string
from abc import ABC, abstractmethod
-from typing import Any
+from typing import Any, overload
from collections.abc import Iterable, Iterator, Sequence
from .pq import Escaping
from .abc import AdaptContext
from ._enums import PyFormat
-from ._compat import LiteralString
+from ._compat import LiteralString, Template
from ._encodings import conn_encoding
from ._transformer import Transformer
return Composed(rv)
- def join(self, seq: Iterable[Any]) -> Composed:
+ @overload
+ def join(self, seq: Iterable[Template]) -> Template: ...
+
+ @overload
+ def join(self, seq: Iterable[Any]) -> Composed: ...
+
+ def join(self, seq: Iterable[Any]) -> Composed | Template:
"""
Join a sequence of `Composable`.
- :param seq: the elements to join. Elements that are not `Composable`
- will be considered `Literal`.
+ :param seq: the elements to join.
Use the `!SQL` object's string to separate the elements in `!seq`.
+ Elements that are not `Composable` will be considered `Literal`.
+
+ If the arguments are `Template` instance, return a `Template` joining
+ all the items. Note that arguments must either be all templates or
+ none should be.
+
Note that `Composed` objects are iterable too, so they can be used as
argument for this method.
>>> print(snip.as_string(conn))
"foo", "bar", "baz"
"""
- rv = []
+
it = iter(seq)
try:
- rv.append(next(it))
+ first = next(it)
except StopIteration:
- pass
- else:
- for i in it:
- rv.append(self)
- rv.append(i)
-
- return Composed(rv)
+ return Composed([])
+
+ if isinstance(first, Template):
+ items = list(first)
+ for t in it:
+ if not isinstance(t, Template):
+ raise TypeError(f"can't mix Template and {type(t).__name__}")
+ items.append(self._obj)
+ items.extend(t)
+ return Template(*items)
+
+ cs = [first]
+ for i in it:
+ if isinstance(i, Template):
+ raise TypeError(f"can't mix Template and {type(i).__name__}")
+ cs.append(self)
+ cs.append(i)
+
+ return Composed(cs)
class Identifier(Composable):
assert isinstance(obj, sql.Composed)
assert obj.as_string(conn) == """'foo', 'bar', 42"""
- obj = sql.SQL(", ").join([])
- assert obj == sql.Composed([])
+ obj2 = sql.SQL(", ").join([])
+ assert obj2 == sql.Composed([])
def test_as_string(self, conn):
assert sql.SQL("foo").as_string(conn) == "foo"
+from random import random
+
import pytest
+import psycopg
+from psycopg import sql
+from psycopg.pq import Format
+
+from .acompat import alist
+
+vstr = "hello"
+vint = 16
+
async def test_connection_no_params(aconn):
with pytest.raises(TypeError):
- await aconn.execute(t"select 1", [])
+ await aconn.execute(t"select 1", []) # noqa: F542
async def test_cursor_no_params(aconn):
cur = aconn.cursor()
with pytest.raises(TypeError):
- await cur.execute(t"select 1", [])
+ await cur.execute(t"select 1", []) # noqa: F542
+
+
+async def test_connection_execute(aconn):
+ cur = await aconn.execute(t"select {vstr}")
+ assert await cur.fetchone() == ("hello",)
+ assert cur._query.query == b"select $1"
+ assert cur._query.params == [b"hello"]
+ assert cur._query.types == (0,)
+
+
+@pytest.mark.parametrize(
+ "t", [t"select {vstr!a}", t"select {vstr!r}", t"select {vstr!s}"]
+)
+async def test_no_conversion(aconn, t):
+ with pytest.raises(TypeError):
+ await aconn.execute(t)
+
+
+@pytest.mark.parametrize(
+ "t, fmt",
+ [
+ (t"select {vint}", Format.BINARY),
+ (t"select {vint:s}", Format.BINARY),
+ (t"select {vint:t}", Format.TEXT),
+ (t"select {vint:b}", Format.BINARY),
+ ],
+)
+async def test_format(aconn, t, fmt):
+ cur = await aconn.execute(t)
+ assert await cur.fetchone() == (16,)
+ assert cur._query.query == b"select $1"
+ assert cur._query.types == (psycopg.adapters.types["smallint"].oid,)
+ assert cur._query.params == [b"\x00\x10" if fmt == Format.BINARY else b"16"]
+ assert cur._query.formats == [fmt]
+
+
+async def test_format_bad(aconn):
+ with pytest.raises(psycopg.ProgrammingError, match="format 'x' not supported"):
+ await aconn.execute(t"select {vint:x}")
+
+
+async def test_expression(aconn):
+ cur = await aconn.execute(t"select {vint * 2}")
+ assert await cur.fetchone() == (32,)
+ assert cur._query.query == b"select $1"
+ assert cur._query.types == (psycopg.adapters.types["smallint"].oid,)
+ assert cur._query.params == [b"\x00\x20"]
+ assert cur._query.formats == [Format.BINARY]
+
+
+async def test_format_identifier(aconn):
+ f1 = "foo-bar"
+ f2 = "baz"
+ cur = await aconn.execute(t"select {vint} as {f1:i}, {vint * 2:t} as {f2:i}")
+ assert await cur.fetchone() == (16, 32)
+ assert cur._query.query == b'select $1 as "foo-bar", $2 as "baz"'
+ assert cur._query.types == (psycopg.adapters.types["smallint"].oid,) * 2
+ assert cur._query.params == [b"\x00\x10", b"32"]
+ assert cur._query.formats == [Format.BINARY, Format.TEXT]
+
+
+async def test_format_literal(aconn):
+ f1 = "foo-bar"
+ f2 = "baz"
+ cur = await aconn.execute(t"select {vint * 2:l} as {f1:i}, {vint:t} as {f2:i}")
+ assert await cur.fetchone() == (32, 16)
+ assert cur._query.query == b'select 32 as "foo-bar", $1 as "baz"'
+ assert cur._query.types == (psycopg.adapters.types["smallint"].oid,)
+ assert cur._query.params == [b"16"]
+ assert cur._query.formats == [Format.TEXT]
+
+
+async def test_nested(aconn):
+ part = t"{vint} as foo"
+ cur = await aconn.execute(t"select {part:q}")
+ assert await cur.fetchone() == (16,)
+ assert cur._query.query == b"select $1 as foo"
+ assert cur._query.types == (psycopg.adapters.types["smallint"].oid,)
+ assert cur._query.params == [b"\x00\x10"]
+ assert cur._query.formats == [Format.BINARY]
+
+ with pytest.raises(psycopg.ProgrammingError, match="Template.*':q'"):
+ cur = await aconn.execute(t"select {part}")
+
+
+async def test_scope(aconn):
+ t = t"select " # noqa: F542
+ for i, name in enumerate(("foo", "bar", "baz")):
+ if i:
+ t += t", " # noqa: F542
+ t += t"{i} as {name:i}"
+
+ cur = await aconn.execute(t)
+ assert await cur.fetchone() == (0, 1, 2)
+ assert cur.description[0].name == "foo"
+ assert cur.description[2].name == "baz"
+
+
+async def test_no_reuse(aconn):
+ t = t"select {vint}, {vint}"
+ cur = await aconn.execute(t)
+ assert await cur.fetchone() == (vint, vint)
+ assert b"$2" in cur._query.query
+
+
+async def test_volatile(aconn):
+ t = t"select {random()}, {random()}"
+ cur = await aconn.execute(t)
+ rec = await cur.fetchone()
+ assert rec[0] != rec[1]
+ assert b"$2" in cur._query.query
+
+
+async def test_sql(aconn):
+ part = sql.SQL("foo")
+ cur = await aconn.execute(t"select {vint} as {part:q}")
+ assert await cur.fetchone() == (16,)
+ assert cur._query.query == b"select $1 as foo"
+
+ with pytest.raises(psycopg.ProgrammingError, match=r"sql\.SQL.*':q'"):
+ await aconn.execute(t"select {vint} as {part:i}")
+
+
+async def test_sql_composed(aconn):
+ part = sql.SQL("{} as {}").format(vint, sql.Identifier("foo"))
+ cur = await aconn.execute(t"select {part:q}")
+ assert await cur.fetchone() == (16,)
+ assert cur._query.query == b'select 16 as "foo"'
+
+ with pytest.raises(psycopg.ProgrammingError, match=r"sql\.Composed.*':q'"):
+ await aconn.execute(t"select {part}")
+
+
+async def test_sql_identifier(aconn):
+ part = sql.Identifier("foo")
+ cur = await aconn.execute(t"select {vint} as {part:i}")
+ assert await cur.fetchone() == (16,)
+ assert cur._query.query == b'select $1 as "foo"'
+
+ with pytest.raises(psycopg.ProgrammingError, match=r"sql\.Identifier.*':i'"):
+ await aconn.execute(t"select {vint} as {part}")
+
+
+async def test_sql_literal(aconn):
+ lit = sql.Literal(42)
+ cur = await aconn.execute(t"select {lit:l} as foo")
+ assert await cur.fetchone() == (42,)
+ assert cur._query.query == b'select 42 as foo'
+
+ with pytest.raises(psycopg.ProgrammingError, match=r"sql\.Literal.*':l'"):
+ await aconn.execute(t"select {lit} as foo")
+
+
+async def test_sql_placeholder(aconn):
+ part = sql.Placeholder("foo")
+ with pytest.raises(psycopg.ProgrammingError, match="Placeholder not supported"):
+ await aconn.execute(t"select {part}")
+
+
+@pytest.mark.xfail(reason="Template.join() needed")
+async def test_template_join(aconn):
+ ts = [t"{i} as {name:i}" for i, name in enumerate(("foo", "bar", "baz"))]
+ fields = t','.join(ts) # noqa: F542
+ cur = await aconn.execute(t"select {fields}")
+ assert await cur.fetchone() == (0, 1, 2)
+ assert cur.description[0].name == "foo"
+ assert cur.description[2].name == "baz"
+
+
+async def test_sql_join(aconn):
+ ts = [t"{i} as {name:i}" for i, name in enumerate(("foo", "bar", "baz"))]
+ fields = sql.SQL(',').join(ts)
+ cur = await aconn.execute(t"select {fields:q}")
+ assert await cur.fetchone() == (0, 1, 2)
+ assert cur.description[0].name == "foo"
+ assert cur.description[2].name == "baz"
+
+
+async def test_copy(aconn):
+ cur = aconn.cursor()
+ async with cur.copy(
+ t"copy (select * from generate_series(1, {3})) to stdout"
+ ) as copy:
+ data = await alist(copy.rows())
+ assert data == [("1",), ("2",), ("3",)]
+
+
+async def test_client_cursor(aconn):
+ cur = psycopg.AsyncClientCursor(aconn)
+ await cur.execute(t"select {vint}, {vstr} as {vstr:i}")
+ assert await cur.fetchone() == (vint, vstr)
+ assert cur.description[1].name == vstr
+ assert str(vint) in cur._query.query.decode()
+ assert str(vint) == cur._query.params[0].decode()
+ assert f"'{vstr}'" in cur._query.query.decode()
+ assert f"'{vstr}'" in cur._query.params[1].decode()
+
+
+async def test_mogrify(aconn):
+ cur = psycopg.AsyncClientCursor(aconn)
+ res = cur.mogrify(t"select {vint}, {vstr} as {vstr:i}")
+ assert res == "select 16, 'hello' as \"hello\""
+
+
+async def test_raw_cursor(aconn):
+ cur = psycopg.AsyncRawCursor(aconn)
+ with pytest.raises(psycopg.NotSupportedError):
+ await cur.execute(t"select {vint}, {vstr} as {vstr:i}")
+
+
+async def test_server_cursor(aconn):
+ async with psycopg.AsyncServerCursor(aconn, "test") as cur:
+ await cur.execute(t"select {vint}, {vstr} as {vstr:i}")
+ assert await cur.fetchone() == (vint, vstr)
+ assert cur.description[1].name == vstr
+ assert b"$2" in cur._query.query
+ assert b"$3" not in cur._query.query