--- /dev/null
+"""
+SQL composition utility module
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+import string
+from typing import Any, Iterator, List, Optional, Sequence, Union
+
+from .pq import Escaping, Format
+from .proto import AdaptContext
+
+
+def quote(obj: Any, context: AdaptContext = None) -> str:
+ """
+ Adapt a Python object to a quoted SQL string.
+
+ Use this function only if you absolutely want to convert a Python string to
+ an SQL quoted literal to use e.g. to generate batch SQL and you won't have
+ a connection avaliable when you will need to use it.
+
+ This function is relatively inefficient, because it doesn't cache the
+ adaptation rules. If you pass a *context* you can adapt the adaptation
+ rules used, otherwise only global rules are used.
+
+ """
+ return Literal(obj).as_string(context)
+
+
+class Composable(object):
+ """
+ Abstract base class for objects that can be used to compose an SQL string.
+
+ `!Composable` objects can be passed directly to `~cursor.execute()`,
+ `~cursor.executemany()`, `~cursor.copy_expert()` in place of the query
+ string.
+
+ `!Composable` objects can be joined using the ``+`` operator: the result
+ will be a `Composed` instance containing the objects joined. The operator
+ ``*`` is also supported with an integer argument: the result is a
+ `!Composed` instance containing the left argument repeated as many times as
+ requested.
+ """
+
+ def __init__(self, obj: Any):
+ self._obj = obj
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({self._obj!r})"
+
+ def as_string(self, context: AdaptContext) -> str:
+ """
+ Return the string value of the object.
+
+ :param context: the context to evaluate the string into.
+ :type context: `connection` or `cursor`
+
+ The method is automatically invoked by `~cursor.execute()`,
+ `~cursor.executemany()`, `~cursor.copy_expert()` if a `!Composable` is
+ passed instead of the query string.
+ """
+ raise NotImplementedError
+
+ def __add__(self, other: "Composable") -> "Composed":
+ if isinstance(other, Composed):
+ return Composed([self]) + other
+ if isinstance(other, Composable):
+ return Composed([self]) + Composed([other])
+ else:
+ return NotImplemented
+
+ def __mul__(self, n: int) -> "Composed":
+ return Composed([self] * n)
+
+ def __eq__(self, other: Any) -> bool:
+ return type(self) is type(other) and self._obj == other._obj
+
+ def __ne__(self, other: Any) -> bool:
+ return not self.__eq__(other)
+
+
+class Composed(Composable):
+ """
+ A `Composable` object made of a sequence of `!Composable`.
+
+ The object is usually created using `!Composable` operators and methods.
+ However it is possible to create a `!Composed` directly specifying a
+ sequence of `!Composable` as arguments.
+
+ Example::
+
+ >>> comp = sql.Composed(
+ ... [sql.SQL("insert into "), sql.Identifier("table")])
+ >>> print(comp.as_string(conn))
+ insert into "table"
+
+ `!Composed` objects are iterable (so they can be used in `SQL.join` for
+ instance).
+ """
+
+ _obj: List[Composable]
+
+ def __init__(self, seq: Sequence[Any]):
+ wrapped = []
+ for obj in seq:
+ if not isinstance(obj, Composable):
+ raise TypeError(
+ f"Composed elements must be Composable, got {obj!r} instead"
+ )
+ wrapped.append(obj)
+
+ super(Composed, self).__init__(wrapped)
+
+ def as_string(self, context: AdaptContext) -> str:
+ rv = []
+ for obj in self._obj:
+ rv.append(obj.as_string(context))
+ return "".join(rv)
+
+ def __iter__(self) -> Iterator[Composable]:
+ return iter(self._obj)
+
+ def __add__(self, other: Composable) -> "Composed":
+ if isinstance(other, Composed):
+ return Composed(self._obj + other._obj)
+ if isinstance(other, Composable):
+ return Composed(self._obj + [other])
+ else:
+ return NotImplemented
+
+ def join(self, joiner: Union["SQL", str]) -> "Composed":
+ """
+ Return a new `!Composed` interposing the *joiner* with the `!Composed` items.
+
+ The *joiner* must be a `SQL` or a string which will be interpreted as
+ an `SQL`.
+
+ Example::
+
+ >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed
+ >>> print(fields.join(', ').as_string(conn))
+ "foo", "bar"
+
+ """
+ if isinstance(joiner, str):
+ joiner = SQL(joiner)
+ elif not isinstance(joiner, SQL):
+ raise TypeError(
+ "Composed.join() argument must be a string or an SQL"
+ )
+
+ return joiner.join(self._obj)
+
+
+class SQL(Composable):
+ """
+ A `Composable` representing a snippet of SQL statement.
+
+ `!SQL` exposes `join()` and `format()` methods useful to create a template
+ where to merge variable parts of a query (for instance field or table
+ names).
+
+ The *string* doesn't undergo any form of escaping, so it is not suitable to
+ represent variable identifiers or values: you should only use it to pass
+ constant strings representing templates or snippets of SQL statements; use
+ other objects such as `Identifier` or `Literal` to represent variable
+ parts.
+
+ Example::
+
+ >>> query = sql.SQL("select {0} from {1}").format(
+ ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
+ ... sql.Identifier('table'))
+ >>> print(query.as_string(conn))
+ select "foo", "bar" from "table"
+ """
+
+ _obj: str
+ _formatter = string.Formatter()
+
+ def __init__(self, obj: str):
+ if not isinstance(obj, str):
+ raise TypeError("SQL values must be strings")
+ super(SQL, self).__init__(obj)
+
+ def as_string(self, context: AdaptContext) -> str:
+ return self._obj
+
+ def format(
+ self, *args: Composable, **kwargs: Composable
+ ) -> Composed:
+ """
+ Merge `Composable` objects into a template.
+
+ :param `Composable` args: parameters to replace to numbered
+ (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
+ :param `Composable` kwargs: parameters to replace to named (``{name}``)
+ placeholders
+ :return: the union of the `!SQL` string with placeholders replaced
+ :rtype: `Composed`
+
+ The method is similar to the Python `str.format()` method: the string
+ template supports auto-numbered (``{}``), numbered (``{0}``,
+ ``{1}``...), and named placeholders (``{name}``), with positional
+ arguments replacing the numbered placeholders and keywords replacing
+ the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
+ are not supported. Only `!Composable` objects can be passed to the
+ template.
+
+ Example::
+
+ >>> print(sql.SQL("select * from {} where {} = %s")
+ ... .format(sql.Identifier('people'), sql.Identifier('id'))
+ ... .as_string(conn))
+ select * from "people" where "id" = %s
+
+ >>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
+ ... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
+ ... .as_string(conn))
+ select * from "people" where "id" = %s
+
+ """
+ rv: List[Composable] = []
+ autonum: Optional[int] = 0
+ for pre, name, spec, conv in self._formatter.parse(self._obj):
+ if spec:
+ raise ValueError("no format specification supported by SQL")
+ if conv:
+ raise ValueError("no format conversion supported by SQL")
+ if pre:
+ rv.append(SQL(pre))
+
+ if name is None:
+ continue
+
+ if name.isdigit():
+ if autonum:
+ raise ValueError(
+ "cannot switch from automatic field numbering to manual"
+ )
+ rv.append(args[int(name)])
+ autonum = None
+
+ elif not name:
+ if autonum is None:
+ raise ValueError(
+ "cannot switch from manual field numbering to automatic"
+ )
+ rv.append(args[autonum])
+ autonum += 1
+
+ else:
+ rv.append(kwargs[name])
+
+ return Composed(rv)
+
+ def join(self, seq: Sequence[Composable]) -> Composed:
+ """
+ Join a sequence of `Composable`.
+
+ :param seq: the elements to join.
+ :type seq: iterable of `!Composable`
+
+ Use the `!SQL` object's *string* to separate the elements in *seq*.
+ Note that `Composed` objects are iterable too, so they can be used as
+ argument for this method.
+
+ Example::
+
+ >>> snip = sql.SQL(', ').join(
+ ... sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
+ >>> print(snip.as_string(conn))
+ "foo", "bar", "baz"
+ """
+ rv = []
+ it = iter(seq)
+ try:
+ rv.append(next(it))
+ except StopIteration:
+ pass
+ else:
+ for i in it:
+ rv.append(self)
+ rv.append(i)
+
+ return Composed(rv)
+
+
+class Identifier(Composable):
+ """
+ A `Composable` representing an SQL identifier or a dot-separated sequence.
+
+ Identifiers usually represent names of database objects, such as tables or
+ fields. PostgreSQL identifiers follow `different rules`__ than SQL string
+ literals for escaping (e.g. they use double quotes instead of single).
+
+ .. __: https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html# \
+ SQL-SYNTAX-IDENTIFIERS
+
+ Example::
+
+ >>> t1 = sql.Identifier("foo")
+ >>> t2 = sql.Identifier("ba'r")
+ >>> t3 = sql.Identifier('ba"z')
+ >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
+ "foo", "ba'r", "ba""z"
+
+ Multiple strings can be passed to the object to represent a qualified name,
+ i.e. a dot-separated sequence of identifiers.
+
+ Example::
+
+ >>> query = sql.SQL("select {} from {}").format(
+ ... sql.Identifier("table", "field"),
+ ... sql.Identifier("schema", "table"))
+ >>> print(query.as_string(conn))
+ select "table"."field" from "schema"."table"
+
+ """
+
+ _obj: Sequence[str]
+
+ def __init__(self, *strings: str):
+ if not strings:
+ raise TypeError("Identifier cannot be empty")
+
+ for s in strings:
+ if not isinstance(s, str):
+ raise TypeError("SQL identifier parts must be strings")
+
+ super(Identifier, self).__init__(strings)
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
+
+ def as_string(self, context: AdaptContext) -> str:
+ from .adapt import _connection_from_context
+
+ conn = _connection_from_context(context)
+ if not conn:
+ raise ValueError(f"no connection in the context: {context}")
+
+ esc = Escaping(conn.pgconn)
+ codec = conn.codec
+ escs = [esc.escape_identifier(codec.encode(s)[0]) for s in self._obj]
+ return codec.decode(b".".join(escs))[0]
+
+
+class Literal(Composable):
+ """
+ A `Composable` representing an SQL value to include in a query.
+
+ Usually you will want to include placeholders in the query and pass values
+ as `~cursor.execute()` arguments. If however you really really need to
+ include a literal value in the query you can use this object.
+
+ The string returned by `!as_string()` follows the normal :ref:`adaptation
+ rules <python-types-adaptation>` for Python objects.
+
+ Example::
+
+ >>> s1 = sql.Literal("foo")
+ >>> s2 = sql.Literal("ba'r")
+ >>> s3 = sql.Literal(42)
+ >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
+ 'foo', 'ba''r', 42
+
+ """
+
+ def as_string(self, context: AdaptContext) -> str:
+ from .adapt import _connection_from_context, Transformer
+
+ conn = _connection_from_context(context)
+ tx = Transformer(conn)
+ dumper = tx.get_dumper(self._obj, Format.TEXT)
+ value = dumper.dump(self._obj)
+
+ if conn:
+ esc = Escaping(conn.pgconn)
+ quoted = esc.escape_literal(value)
+ return conn.codec.decode(quoted)[0]
+ else:
+ esc = Escaping()
+ quoted = b"'%s'" % esc.escape_string(value)
+ return quoted.decode("utf8")
+
+
+class Placeholder(Composable):
+ """A `Composable` representing a placeholder for query parameters.
+
+ If the name is specified, generate a named placeholder (e.g. ``%(name)s``),
+ otherwise generate a positional placeholder (e.g. ``%s``).
+
+ The object is useful to generate SQL queries with a variable number of
+ arguments.
+
+ Examples::
+
+ >>> names = ['foo', 'bar', 'baz']
+
+ >>> q1 = sql.SQL("insert into table ({}) values ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(sql.Placeholder() * len(names)))
+ >>> print(q1.as_string(conn))
+ insert into table ("foo", "bar", "baz") values (%s, %s, %s)
+
+ >>> q2 = sql.SQL("insert into table ({}) values ({})").format(
+ ... sql.SQL(', ').join(map(sql.Identifier, names)),
+ ... sql.SQL(', ').join(map(sql.Placeholder, names)))
+ >>> print(q2.as_string(conn))
+ insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s)
+
+ """
+
+ def __init__(self, name: Optional[str] = None):
+ if isinstance(name, str):
+ if ")" in name:
+ raise ValueError("invalid name: %r" % name)
+
+ elif name is not None:
+ raise TypeError("expected string or None as name, got %r" % name)
+
+ super(Placeholder, self).__init__(name)
+
+ def __repr__(self) -> str:
+ return (
+ f"{self.__class__.__name__}"
+ f"({self._obj if self._obj is not None else ''})"
+ )
+
+ def as_string(self, context: AdaptContext) -> str:
+ if self._obj is not None:
+ return "%%(%s)s" % self._obj
+ else:
+ return "%s"
+
+
+# Literals
+NULL = SQL("NULL")
+DEFAULT = SQL("DEFAULT")
--- /dev/null
+# test_sql.py - tests for the psycopg2.sql module
+
+# Copyright (C) 2020 The Psycopg Team
+
+import re
+import datetime as dt
+
+import pytest
+
+from psycopg3 import sql, ProgrammingError
+
+
+class TestSqlFormat:
+ def test_pos(self, conn):
+ s = sql.SQL("select {} from {}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_pos_spec(self, conn):
+ s = sql.SQL("select {0} from {1}").format(
+ sql.Identifier("field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ s = sql.SQL("select {1} from {0}").format(
+ sql.Identifier("table"), sql.Identifier("field")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_dict(self, conn):
+ s = sql.SQL("select {f} from {t}").format(
+ f=sql.Identifier("field"), t=sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == 'select "field" from "table"'
+
+ def test_unicode(self, conn):
+ s = sql.SQL(u"select {0} from {1}").format(
+ sql.Identifier(u"field"), sql.Identifier("table")
+ )
+ s1 = s.as_string(conn)
+ assert isinstance(s1, str)
+ assert s1 == u'select "field" from "table"'
+
+ def test_compose_literal(self, conn):
+ s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31)))
+ s1 = s.as_string(conn)
+ assert s1 == "select '2016-12-31';"
+
+ def test_compose_empty(self, conn):
+ s = sql.SQL("select foo;").format()
+ s1 = s.as_string(conn)
+ assert s1 == "select foo;"
+
+ def test_percent_escape(self, conn):
+ s = sql.SQL("42 % {0}").format(sql.Literal(7))
+ s1 = s.as_string(conn)
+ assert s1 == "42 % '7'"
+
+ def test_braces_escape(self, conn):
+ s = sql.SQL("{{{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{'7'}"
+ s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
+ assert s.as_string(conn) == "{1,'7'}"
+
+ def test_compose_badnargs(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format()
+
+ def test_compose_badnargs_auto(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {};").format()
+ with pytest.raises(ValueError):
+ sql.SQL("select {} {1};").format(10, 20)
+ with pytest.raises(ValueError):
+ sql.SQL("select {0} {};").format(10, 20)
+
+ def test_compose_bad_args_type(self):
+ with pytest.raises(IndexError):
+ sql.SQL("select {0};").format(a=10)
+ with pytest.raises(KeyError):
+ sql.SQL("select {x};").format(10)
+
+ def test_must_be_composable(self):
+ with pytest.raises(TypeError):
+ sql.SQL("select {0};").format("foo")
+ with pytest.raises(TypeError):
+ sql.SQL("select {0};").format(10)
+
+ def test_no_modifiers(self):
+ with pytest.raises(ValueError):
+ sql.SQL("select {a!r};").format(a=10)
+ with pytest.raises(ValueError):
+ sql.SQL("select {a:<};").format(a=10)
+
+ def test_must_be_adaptable(self, conn):
+ class Foo(object):
+ pass
+
+ s = sql.SQL("select {0};").format(sql.Literal(Foo()))
+ with pytest.raises(ProgrammingError):
+ s.as_string(conn)
+
+ def test_execute(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.execute(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(
+ map(sql.Identifier, ["foo", "bar", "ba'z"])
+ ),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ (10, "a", "b", "c"),
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall() == [(10, "a", "b", "c")]
+
+ def test_executemany(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+ cur.executemany(
+ sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
+ sql.Identifier("test_compose"),
+ sql.SQL(", ").join(
+ map(sql.Identifier, ["foo", "bar", "ba'z"])
+ ),
+ (sql.Placeholder() * 3).join(", "),
+ ),
+ [(10, "a", "b", "c"), (20, "d", "e", "f")],
+ )
+
+ cur.execute("select * from test_compose")
+ assert cur.fetchall(), [(10, "a", "b", "c"), (20, "d", "e", "f")]
+
+ def test_copy(self, conn):
+ cur = conn.cursor()
+ cur.execute(
+ """
+ create table test_compose (
+ id serial primary key,
+ foo text, bar text, "ba'z" text)
+ """
+ )
+
+ with cur.copy(
+ sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ ),
+ ) as copy:
+ copy.write_row((10, "a", "b", "c"))
+ copy.write_row((20, "d", "e", "f"))
+
+ copy = cur.copy(
+ sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
+ t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")
+ )
+ )
+ assert list(copy) == [b"c\n", b"f\n"]
+
+
+class TestIdentifier:
+ def test_class(self):
+ assert issubclass(sql.Identifier, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
+ assert isinstance(sql.Identifier(u"foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
+ with pytest.raises(TypeError):
+ sql.Identifier()
+ with pytest.raises(TypeError):
+ sql.Identifier(10)
+ with pytest.raises(TypeError):
+ sql.Identifier(dt.date(2016, 12, 31))
+
+ def test_repr(self):
+ obj = sql.Identifier("fo'o")
+ assert repr(obj) == 'Identifier("fo\'o")'
+ assert repr(obj) == str(obj)
+
+ obj = sql.Identifier("fo'o", 'ba"r')
+ assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')"
+ assert repr(obj) == str(obj)
+
+ def test_eq(self):
+ assert sql.Identifier("foo") == sql.Identifier("foo")
+ assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar")
+ assert sql.Identifier("foo") != sql.Identifier("bar")
+ assert sql.Identifier("foo") != "foo"
+ assert sql.Identifier("foo") != sql.SQL("foo")
+
+ def test_as_str(self, conn):
+ assert sql.Identifier("foo").as_string(conn) == '"foo"'
+ assert sql.Identifier("foo", "bar").as_string(conn), '"foo"."bar"'
+ assert (
+ sql.Identifier("fo'o", 'ba"r').as_string(conn) == '"fo\'o"."ba""r"'
+ )
+
+ def test_join(self):
+ assert not hasattr(sql.Identifier("foo"), "join")
+
+
+class TestLiteral:
+ def test_class(self):
+ assert issubclass(sql.Literal, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.Literal("foo"), sql.Literal)
+ assert isinstance(sql.Literal(u"foo"), sql.Literal)
+ assert isinstance(sql.Literal(b"foo"), sql.Literal)
+ assert isinstance(sql.Literal(42), sql.Literal)
+ assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal)
+
+ def test_repr(self, conn):
+ assert repr(sql.Literal("foo")) == "Literal('foo')"
+ assert str(sql.Literal("foo")) == "Literal('foo')"
+ assert noe(sql.Literal("foo").as_string(conn)) == "'foo'"
+ assert sql.Literal(42).as_string(conn) == "'42'"
+ assert (
+ sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
+ )
+
+ def test_eq(self):
+ assert sql.Literal("foo") == sql.Literal("foo")
+ assert sql.Literal("foo") != sql.Literal("bar")
+ assert sql.Literal("foo") != "foo"
+ assert sql.Literal("foo") != sql.SQL("foo")
+
+ def test_must_be_adaptable(self, conn):
+ class Foo(object):
+ pass
+
+ with pytest.raises(ProgrammingError):
+ sql.Literal(Foo()).as_string(conn)
+
+
+class TestSQL:
+ def test_class(self):
+ assert issubclass(sql.SQL, sql.Composable)
+
+ def test_init(self):
+ assert isinstance(sql.SQL("foo"), sql.SQL)
+ assert isinstance(sql.SQL(u"foo"), sql.SQL)
+ with pytest.raises(TypeError):
+ sql.SQL(10)
+ with pytest.raises(TypeError):
+ sql.SQL(dt.date(2016, 12, 31))
+
+ def test_repr(self, conn):
+ assert repr(sql.SQL("foo")) == "SQL('foo')"
+ assert str(sql.SQL("foo")) == "SQL('foo')"
+ assert sql.SQL("foo").as_string(conn) == "foo"
+
+ def test_eq(self):
+ assert sql.SQL("foo") == sql.SQL("foo")
+ assert sql.SQL("foo") != sql.SQL("bar")
+ assert sql.SQL("foo") != "foo"
+ assert sql.SQL("foo") != sql.Literal("foo")
+
+ def test_sum(self, conn):
+ obj = sql.SQL("foo") + sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.SQL("foo")
+ obj += sql.SQL("bar")
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foobar"
+
+ def test_multiply(self, conn):
+ obj = sql.SQL("foo") * 3
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "foofoofoo"
+
+ def test_join(self, conn):
+ obj = sql.SQL(", ").join(
+ [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "\"foo\", bar, '42'"
+
+ obj = sql.SQL(", ").join(
+ sql.Composed(
+ [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
+ )
+ )
+ assert isinstance(obj, sql.Composed)
+ assert obj.as_string(conn) == "\"foo\", bar, '42'"
+
+ obj = sql.SQL(", ").join([])
+ assert obj == sql.Composed([])
+
+
+class TestComposed:
+ def test_class(self):
+ assert issubclass(sql.Composed, sql.Composable)
+
+ def test_repr(self):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ assert (
+ repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])"""
+ )
+ assert str(obj) == repr(obj)
+
+ def test_eq(self):
+ L = [sql.Literal("foo"), sql.Identifier("b'ar")]
+ l2 = [sql.Literal("foo"), sql.Literal("b'ar")]
+ assert sql.Composed(L) == sql.Composed(list(L))
+ assert sql.Composed(L) != L
+ assert sql.Composed(L) != sql.Composed(l2)
+
+ def test_join(self, conn):
+ obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")])
+ obj = obj.join(", ")
+ assert isinstance(obj, sql.Composed)
+ assert noe(obj.as_string(conn)) == "'foo', \"b'ar\""
+
+ def test_sum(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj = obj + sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert noe(obj.as_string(conn)), "foo 'bar'"
+
+ def test_sum_inplace(self, conn):
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Literal("bar")
+ assert isinstance(obj, sql.Composed)
+ assert noe(obj.as_string(conn)) == "foo 'bar'"
+
+ obj = sql.Composed([sql.SQL("foo ")])
+ obj += sql.Composed([sql.Literal("bar")])
+ assert isinstance(obj, sql.Composed)
+ assert noe(obj.as_string(conn)) == "foo 'bar'"
+
+ def test_iter(self):
+ obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+ it = iter(obj)
+ i = next(it)
+ assert i == sql.SQL("foo")
+ i = next(it)
+ assert i == sql.SQL("bar")
+ with pytest.raises(StopIteration):
+ next(it)
+
+
+class TestPlaceholder:
+ def test_class(self):
+ assert issubclass(sql.Placeholder, sql.Composable)
+
+ def test_repr(self, conn):
+ assert str(sql.Placeholder()), "Placeholder()"
+ assert repr(sql.Placeholder()), "Placeholder()"
+ assert sql.Placeholder().as_string(conn), "%s"
+
+ def test_repr_name(self, conn):
+ assert str(sql.Placeholder("foo")), "Placeholder('foo')"
+ assert repr(sql.Placeholder("foo")), "Placeholder('foo')"
+ assert sql.Placeholder("foo").as_string(conn), "%(foo)s"
+
+ def test_bad_name(self):
+ with pytest.raises(ValueError):
+ sql.Placeholder(")")
+
+ def test_eq(self):
+ assert sql.Placeholder("foo") == sql.Placeholder("foo")
+ assert sql.Placeholder("foo") != sql.Placeholder("bar")
+ assert sql.Placeholder("foo") != "foo"
+ assert sql.Placeholder() == sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Placeholder()
+ assert sql.Placeholder("foo") != sql.Literal("foo")
+
+
+class TestValues:
+ def test_null(self, conn):
+ assert isinstance(sql.NULL, sql.SQL)
+ assert sql.NULL.as_string(conn) == "NULL"
+
+ def test_default(self, conn):
+ assert isinstance(sql.DEFAULT, sql.SQL)
+ assert sql.DEFAULT.as_string(conn) == "DEFAULT"
+
+
+def noe(s):
+ """Drop an eventual E from E'' quotes"""
+ return re.sub(r"\bE'", "'", s)