From: Daniele Varrazzo Date: Wed, 28 Oct 2020 03:01:48 +0000 (+0100) Subject: Added psycopg3.sql module X-Git-Tag: 3.0.dev0~424^2~5 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=08962963c70f2667abac73017b9509809786828e;p=thirdparty%2Fpsycopg.git Added psycopg3.sql module A straight porting from psycopg2, with the addition of an utility `quote()` function to provide a simple entry point to solve the problem "adapt this object here". --- diff --git a/psycopg3/psycopg3/cursor.py b/psycopg3/psycopg3/cursor.py index 81bc86289..238e4f3a8 100644 --- a/psycopg3/psycopg3/cursor.py +++ b/psycopg3/psycopg3/cursor.py @@ -13,6 +13,7 @@ from . import proto from .proto import Query, Params, DumpersMap, LoadersMap, PQGen from .utils.queries import PostgresQuery from .copy import Copy, AsyncCopy +from .sql import Composable if TYPE_CHECKING: from .connection import BaseConnection, Connection, AsyncConnection @@ -161,6 +162,9 @@ class BaseCursor: """ Implement part of execute() before waiting common to sync and async """ + if isinstance(query, Composable): + query = query.as_string(self) + pgq = PostgresQuery(self._transformer) pgq.convert(query, vars) @@ -213,6 +217,9 @@ class BaseCursor: """ Implement part of execute() before waiting common to sync and async """ + if isinstance(query, Composable): + query = query.as_string(self) + pgq = PostgresQuery(self._transformer) pgq.convert(query, vars) diff --git a/psycopg3/psycopg3/proto.py b/psycopg3/psycopg3/proto.py index 93c0f2bce..dfdd9f2a6 100644 --- a/psycopg3/psycopg3/proto.py +++ b/psycopg3/psycopg3/proto.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: EncodeFunc = Callable[[str], Tuple[bytes, int]] DecodeFunc = Callable[[bytes], Tuple[str, int]] -Query = Union[str, bytes] +Query = Union[str, bytes, "Composable"] Params = Union[Sequence[Any], Mapping[str, Any]] @@ -45,6 +45,11 @@ LoaderType = Type["Loader"] LoadersMap = Dict[Tuple[int, Format], LoaderType] +class Composable(Protocol): + def as_string(self, context: AdaptContext) -> str: + ... + + class Transformer(Protocol): def __init__(self, context: AdaptContext = None): ... diff --git a/psycopg3/psycopg3/sql.py b/psycopg3/psycopg3/sql.py new file mode 100644 index 000000000..434f2ed2e --- /dev/null +++ b/psycopg3/psycopg3/sql.py @@ -0,0 +1,440 @@ +""" +SQL composition utility module +""" + +# Copyright (C) 2020 The Psycopg Team + +import string +from typing import Any, Iterator, List, Optional, Sequence, Union + +from .pq import Escaping, Format +from .proto import AdaptContext + + +def quote(obj: Any, context: AdaptContext = None) -> str: + """ + Adapt a Python object to a quoted SQL string. + + Use this function only if you absolutely want to convert a Python string to + an SQL quoted literal to use e.g. to generate batch SQL and you won't have + a connection avaliable when you will need to use it. + + This function is relatively inefficient, because it doesn't cache the + adaptation rules. If you pass a *context* you can adapt the adaptation + rules used, otherwise only global rules are used. + + """ + return Literal(obj).as_string(context) + + +class Composable(object): + """ + Abstract base class for objects that can be used to compose an SQL string. + + `!Composable` objects can be passed directly to `~cursor.execute()`, + `~cursor.executemany()`, `~cursor.copy_expert()` in place of the query + string. + + `!Composable` objects can be joined using the ``+`` operator: the result + will be a `Composed` instance containing the objects joined. The operator + ``*`` is also supported with an integer argument: the result is a + `!Composed` instance containing the left argument repeated as many times as + requested. + """ + + def __init__(self, obj: Any): + self._obj = obj + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._obj!r})" + + def as_string(self, context: AdaptContext) -> str: + """ + Return the string value of the object. + + :param context: the context to evaluate the string into. + :type context: `connection` or `cursor` + + The method is automatically invoked by `~cursor.execute()`, + `~cursor.executemany()`, `~cursor.copy_expert()` if a `!Composable` is + passed instead of the query string. + """ + raise NotImplementedError + + def __add__(self, other: "Composable") -> "Composed": + if isinstance(other, Composed): + return Composed([self]) + other + if isinstance(other, Composable): + return Composed([self]) + Composed([other]) + else: + return NotImplemented + + def __mul__(self, n: int) -> "Composed": + return Composed([self] * n) + + def __eq__(self, other: Any) -> bool: + return type(self) is type(other) and self._obj == other._obj + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + + +class Composed(Composable): + """ + A `Composable` object made of a sequence of `!Composable`. + + The object is usually created using `!Composable` operators and methods. + However it is possible to create a `!Composed` directly specifying a + sequence of `!Composable` as arguments. + + Example:: + + >>> comp = sql.Composed( + ... [sql.SQL("insert into "), sql.Identifier("table")]) + >>> print(comp.as_string(conn)) + insert into "table" + + `!Composed` objects are iterable (so they can be used in `SQL.join` for + instance). + """ + + _obj: List[Composable] + + def __init__(self, seq: Sequence[Any]): + wrapped = [] + for obj in seq: + if not isinstance(obj, Composable): + raise TypeError( + f"Composed elements must be Composable, got {obj!r} instead" + ) + wrapped.append(obj) + + super(Composed, self).__init__(wrapped) + + def as_string(self, context: AdaptContext) -> str: + rv = [] + for obj in self._obj: + rv.append(obj.as_string(context)) + return "".join(rv) + + def __iter__(self) -> Iterator[Composable]: + return iter(self._obj) + + def __add__(self, other: Composable) -> "Composed": + if isinstance(other, Composed): + return Composed(self._obj + other._obj) + if isinstance(other, Composable): + return Composed(self._obj + [other]) + else: + return NotImplemented + + def join(self, joiner: Union["SQL", str]) -> "Composed": + """ + Return a new `!Composed` interposing the *joiner* with the `!Composed` items. + + The *joiner* must be a `SQL` or a string which will be interpreted as + an `SQL`. + + Example:: + + >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed + >>> print(fields.join(', ').as_string(conn)) + "foo", "bar" + + """ + if isinstance(joiner, str): + joiner = SQL(joiner) + elif not isinstance(joiner, SQL): + raise TypeError( + "Composed.join() argument must be a string or an SQL" + ) + + return joiner.join(self._obj) + + +class SQL(Composable): + """ + A `Composable` representing a snippet of SQL statement. + + `!SQL` exposes `join()` and `format()` methods useful to create a template + where to merge variable parts of a query (for instance field or table + names). + + The *string* doesn't undergo any form of escaping, so it is not suitable to + represent variable identifiers or values: you should only use it to pass + constant strings representing templates or snippets of SQL statements; use + other objects such as `Identifier` or `Literal` to represent variable + parts. + + Example:: + + >>> query = sql.SQL("select {0} from {1}").format( + ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), + ... sql.Identifier('table')) + >>> print(query.as_string(conn)) + select "foo", "bar" from "table" + """ + + _obj: str + _formatter = string.Formatter() + + def __init__(self, obj: str): + if not isinstance(obj, str): + raise TypeError("SQL values must be strings") + super(SQL, self).__init__(obj) + + def as_string(self, context: AdaptContext) -> str: + return self._obj + + def format( + self, *args: Composable, **kwargs: Composable + ) -> Composed: + """ + Merge `Composable` objects into a template. + + :param `Composable` args: parameters to replace to numbered + (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders + :param `Composable` kwargs: parameters to replace to named (``{name}``) + placeholders + :return: the union of the `!SQL` string with placeholders replaced + :rtype: `Composed` + + The method is similar to the Python `str.format()` method: the string + template supports auto-numbered (``{}``), numbered (``{0}``, + ``{1}``...), and named placeholders (``{name}``), with positional + arguments replacing the numbered placeholders and keywords replacing + the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) + are not supported. Only `!Composable` objects can be passed to the + template. + + Example:: + + >>> print(sql.SQL("select * from {} where {} = %s") + ... .format(sql.Identifier('people'), sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + >>> print(sql.SQL("select * from {tbl} where {pkey} = %s") + ... .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id')) + ... .as_string(conn)) + select * from "people" where "id" = %s + + """ + rv: List[Composable] = [] + autonum: Optional[int] = 0 + for pre, name, spec, conv in self._formatter.parse(self._obj): + if spec: + raise ValueError("no format specification supported by SQL") + if conv: + raise ValueError("no format conversion supported by SQL") + if pre: + rv.append(SQL(pre)) + + if name is None: + continue + + if name.isdigit(): + if autonum: + raise ValueError( + "cannot switch from automatic field numbering to manual" + ) + rv.append(args[int(name)]) + autonum = None + + elif not name: + if autonum is None: + raise ValueError( + "cannot switch from manual field numbering to automatic" + ) + rv.append(args[autonum]) + autonum += 1 + + else: + rv.append(kwargs[name]) + + return Composed(rv) + + def join(self, seq: Sequence[Composable]) -> Composed: + """ + Join a sequence of `Composable`. + + :param seq: the elements to join. + :type seq: iterable of `!Composable` + + Use the `!SQL` object's *string* to separate the elements in *seq*. + Note that `Composed` objects are iterable too, so they can be used as + argument for this method. + + Example:: + + >>> snip = sql.SQL(', ').join( + ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) + >>> print(snip.as_string(conn)) + "foo", "bar", "baz" + """ + rv = [] + it = iter(seq) + try: + rv.append(next(it)) + except StopIteration: + pass + else: + for i in it: + rv.append(self) + rv.append(i) + + return Composed(rv) + + +class Identifier(Composable): + """ + A `Composable` representing an SQL identifier or a dot-separated sequence. + + Identifiers usually represent names of database objects, such as tables or + fields. PostgreSQL identifiers follow `different rules`__ than SQL string + literals for escaping (e.g. they use double quotes instead of single). + + .. __: https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html# \ + SQL-SYNTAX-IDENTIFIERS + + Example:: + + >>> t1 = sql.Identifier("foo") + >>> t2 = sql.Identifier("ba'r") + >>> t3 = sql.Identifier('ba"z') + >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) + "foo", "ba'r", "ba""z" + + Multiple strings can be passed to the object to represent a qualified name, + i.e. a dot-separated sequence of identifiers. + + Example:: + + >>> query = sql.SQL("select {} from {}").format( + ... sql.Identifier("table", "field"), + ... sql.Identifier("schema", "table")) + >>> print(query.as_string(conn)) + select "table"."field" from "schema"."table" + + """ + + _obj: Sequence[str] + + def __init__(self, *strings: str): + if not strings: + raise TypeError("Identifier cannot be empty") + + for s in strings: + if not isinstance(s, str): + raise TypeError("SQL identifier parts must be strings") + + super(Identifier, self).__init__(strings) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" + + def as_string(self, context: AdaptContext) -> str: + from .adapt import _connection_from_context + + conn = _connection_from_context(context) + if not conn: + raise ValueError(f"no connection in the context: {context}") + + esc = Escaping(conn.pgconn) + codec = conn.codec + escs = [esc.escape_identifier(codec.encode(s)[0]) for s in self._obj] + return codec.decode(b".".join(escs))[0] + + +class Literal(Composable): + """ + A `Composable` representing an SQL value to include in a query. + + Usually you will want to include placeholders in the query and pass values + as `~cursor.execute()` arguments. If however you really really need to + include a literal value in the query you can use this object. + + The string returned by `!as_string()` follows the normal :ref:`adaptation + rules ` for Python objects. + + Example:: + + >>> s1 = sql.Literal("foo") + >>> s2 = sql.Literal("ba'r") + >>> s3 = sql.Literal(42) + >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) + 'foo', 'ba''r', 42 + + """ + + def as_string(self, context: AdaptContext) -> str: + from .adapt import _connection_from_context, Transformer + + conn = _connection_from_context(context) + tx = Transformer(conn) + dumper = tx.get_dumper(self._obj, Format.TEXT) + value = dumper.dump(self._obj) + + if conn: + esc = Escaping(conn.pgconn) + quoted = esc.escape_literal(value) + return conn.codec.decode(quoted)[0] + else: + esc = Escaping() + quoted = b"'%s'" % esc.escape_string(value) + return quoted.decode("utf8") + + +class Placeholder(Composable): + """A `Composable` representing a placeholder for query parameters. + + If the name is specified, generate a named placeholder (e.g. ``%(name)s``), + otherwise generate a positional placeholder (e.g. ``%s``). + + The object is useful to generate SQL queries with a variable number of + arguments. + + Examples:: + + >>> names = ['foo', 'bar', 'baz'] + + >>> q1 = sql.SQL("insert into table ({}) values ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(sql.Placeholder() * len(names))) + >>> print(q1.as_string(conn)) + insert into table ("foo", "bar", "baz") values (%s, %s, %s) + + >>> q2 = sql.SQL("insert into table ({}) values ({})").format( + ... sql.SQL(', ').join(map(sql.Identifier, names)), + ... sql.SQL(', ').join(map(sql.Placeholder, names))) + >>> print(q2.as_string(conn)) + insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s) + + """ + + def __init__(self, name: Optional[str] = None): + if isinstance(name, str): + if ")" in name: + raise ValueError("invalid name: %r" % name) + + elif name is not None: + raise TypeError("expected string or None as name, got %r" % name) + + super(Placeholder, self).__init__(name) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}" + f"({self._obj if self._obj is not None else ''})" + ) + + def as_string(self, context: AdaptContext) -> str: + if self._obj is not None: + return "%%(%s)s" % self._obj + else: + return "%s" + + +# Literals +NULL = SQL("NULL") +DEFAULT = SQL("DEFAULT") diff --git a/tests/test_sql.py b/tests/test_sql.py new file mode 100755 index 000000000..6c95516b3 --- /dev/null +++ b/tests/test_sql.py @@ -0,0 +1,409 @@ +# test_sql.py - tests for the psycopg2.sql module + +# Copyright (C) 2020 The Psycopg Team + +import re +import datetime as dt + +import pytest + +from psycopg3 import sql, ProgrammingError + + +class TestSqlFormat: + def test_pos(self, conn): + s = sql.SQL("select {} from {}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_pos_spec(self, conn): + s = sql.SQL("select {0} from {1}").format( + sql.Identifier("field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + s = sql.SQL("select {1} from {0}").format( + sql.Identifier("table"), sql.Identifier("field") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_dict(self, conn): + s = sql.SQL("select {f} from {t}").format( + f=sql.Identifier("field"), t=sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == 'select "field" from "table"' + + def test_unicode(self, conn): + s = sql.SQL(u"select {0} from {1}").format( + sql.Identifier(u"field"), sql.Identifier("table") + ) + s1 = s.as_string(conn) + assert isinstance(s1, str) + assert s1 == u'select "field" from "table"' + + def test_compose_literal(self, conn): + s = sql.SQL("select {0};").format(sql.Literal(dt.date(2016, 12, 31))) + s1 = s.as_string(conn) + assert s1 == "select '2016-12-31';" + + def test_compose_empty(self, conn): + s = sql.SQL("select foo;").format() + s1 = s.as_string(conn) + assert s1 == "select foo;" + + def test_percent_escape(self, conn): + s = sql.SQL("42 % {0}").format(sql.Literal(7)) + s1 = s.as_string(conn) + assert s1 == "42 % '7'" + + def test_braces_escape(self, conn): + s = sql.SQL("{{{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{'7'}" + s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) + assert s.as_string(conn) == "{1,'7'}" + + def test_compose_badnargs(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format() + + def test_compose_badnargs_auto(self): + with pytest.raises(IndexError): + sql.SQL("select {};").format() + with pytest.raises(ValueError): + sql.SQL("select {} {1};").format(10, 20) + with pytest.raises(ValueError): + sql.SQL("select {0} {};").format(10, 20) + + def test_compose_bad_args_type(self): + with pytest.raises(IndexError): + sql.SQL("select {0};").format(a=10) + with pytest.raises(KeyError): + sql.SQL("select {x};").format(10) + + def test_must_be_composable(self): + with pytest.raises(TypeError): + sql.SQL("select {0};").format("foo") + with pytest.raises(TypeError): + sql.SQL("select {0};").format(10) + + def test_no_modifiers(self): + with pytest.raises(ValueError): + sql.SQL("select {a!r};").format(a=10) + with pytest.raises(ValueError): + sql.SQL("select {a:<};").format(a=10) + + def test_must_be_adaptable(self, conn): + class Foo(object): + pass + + s = sql.SQL("select {0};").format(sql.Literal(Foo())) + with pytest.raises(ProgrammingError): + s.as_string(conn) + + def test_execute(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.execute( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join( + map(sql.Identifier, ["foo", "bar", "ba'z"]) + ), + (sql.Placeholder() * 3).join(", "), + ), + (10, "a", "b", "c"), + ) + + cur.execute("select * from test_compose") + assert cur.fetchall() == [(10, "a", "b", "c")] + + def test_executemany(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + cur.executemany( + sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( + sql.Identifier("test_compose"), + sql.SQL(", ").join( + map(sql.Identifier, ["foo", "bar", "ba'z"]) + ), + (sql.Placeholder() * 3).join(", "), + ), + [(10, "a", "b", "c"), (20, "d", "e", "f")], + ) + + cur.execute("select * from test_compose") + assert cur.fetchall(), [(10, "a", "b", "c"), (20, "d", "e", "f")] + + def test_copy(self, conn): + cur = conn.cursor() + cur.execute( + """ + create table test_compose ( + id serial primary key, + foo text, bar text, "ba'z" text) + """ + ) + + with cur.copy( + sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ), + ) as copy: + copy.write_row((10, "a", "b", "c")) + copy.write_row((20, "d", "e", "f")) + + copy = cur.copy( + sql.SQL("copy (select {f} from {t} order by id) to stdout").format( + t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z") + ) + ) + assert list(copy) == [b"c\n", b"f\n"] + + +class TestIdentifier: + def test_class(self): + assert issubclass(sql.Identifier, sql.Composable) + + def test_init(self): + assert isinstance(sql.Identifier("foo"), sql.Identifier) + assert isinstance(sql.Identifier(u"foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier) + with pytest.raises(TypeError): + sql.Identifier() + with pytest.raises(TypeError): + sql.Identifier(10) + with pytest.raises(TypeError): + sql.Identifier(dt.date(2016, 12, 31)) + + def test_repr(self): + obj = sql.Identifier("fo'o") + assert repr(obj) == 'Identifier("fo\'o")' + assert repr(obj) == str(obj) + + obj = sql.Identifier("fo'o", 'ba"r') + assert repr(obj) == "Identifier(\"fo'o\", 'ba\"r')" + assert repr(obj) == str(obj) + + def test_eq(self): + assert sql.Identifier("foo") == sql.Identifier("foo") + assert sql.Identifier("foo", "bar") == sql.Identifier("foo", "bar") + assert sql.Identifier("foo") != sql.Identifier("bar") + assert sql.Identifier("foo") != "foo" + assert sql.Identifier("foo") != sql.SQL("foo") + + def test_as_str(self, conn): + assert sql.Identifier("foo").as_string(conn) == '"foo"' + assert sql.Identifier("foo", "bar").as_string(conn), '"foo"."bar"' + assert ( + sql.Identifier("fo'o", 'ba"r').as_string(conn) == '"fo\'o"."ba""r"' + ) + + def test_join(self): + assert not hasattr(sql.Identifier("foo"), "join") + + +class TestLiteral: + def test_class(self): + assert issubclass(sql.Literal, sql.Composable) + + def test_init(self): + assert isinstance(sql.Literal("foo"), sql.Literal) + assert isinstance(sql.Literal(u"foo"), sql.Literal) + assert isinstance(sql.Literal(b"foo"), sql.Literal) + assert isinstance(sql.Literal(42), sql.Literal) + assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal) + + def test_repr(self, conn): + assert repr(sql.Literal("foo")) == "Literal('foo')" + assert str(sql.Literal("foo")) == "Literal('foo')" + assert noe(sql.Literal("foo").as_string(conn)) == "'foo'" + assert sql.Literal(42).as_string(conn) == "'42'" + assert ( + sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'" + ) + + def test_eq(self): + assert sql.Literal("foo") == sql.Literal("foo") + assert sql.Literal("foo") != sql.Literal("bar") + assert sql.Literal("foo") != "foo" + assert sql.Literal("foo") != sql.SQL("foo") + + def test_must_be_adaptable(self, conn): + class Foo(object): + pass + + with pytest.raises(ProgrammingError): + sql.Literal(Foo()).as_string(conn) + + +class TestSQL: + def test_class(self): + assert issubclass(sql.SQL, sql.Composable) + + def test_init(self): + assert isinstance(sql.SQL("foo"), sql.SQL) + assert isinstance(sql.SQL(u"foo"), sql.SQL) + with pytest.raises(TypeError): + sql.SQL(10) + with pytest.raises(TypeError): + sql.SQL(dt.date(2016, 12, 31)) + + def test_repr(self, conn): + assert repr(sql.SQL("foo")) == "SQL('foo')" + assert str(sql.SQL("foo")) == "SQL('foo')" + assert sql.SQL("foo").as_string(conn) == "foo" + + def test_eq(self): + assert sql.SQL("foo") == sql.SQL("foo") + assert sql.SQL("foo") != sql.SQL("bar") + assert sql.SQL("foo") != "foo" + assert sql.SQL("foo") != sql.Literal("foo") + + def test_sum(self, conn): + obj = sql.SQL("foo") + sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_sum_inplace(self, conn): + obj = sql.SQL("foo") + obj += sql.SQL("bar") + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foobar" + + def test_multiply(self, conn): + obj = sql.SQL("foo") * 3 + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "foofoofoo" + + def test_join(self, conn): + obj = sql.SQL(", ").join( + [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "\"foo\", bar, '42'" + + obj = sql.SQL(", ").join( + sql.Composed( + [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] + ) + ) + assert isinstance(obj, sql.Composed) + assert obj.as_string(conn) == "\"foo\", bar, '42'" + + obj = sql.SQL(", ").join([]) + assert obj == sql.Composed([]) + + +class TestComposed: + def test_class(self): + assert issubclass(sql.Composed, sql.Composable) + + def test_repr(self): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + assert ( + repr(obj) == """Composed([Literal('foo'), Identifier("b'ar")])""" + ) + assert str(obj) == repr(obj) + + def test_eq(self): + L = [sql.Literal("foo"), sql.Identifier("b'ar")] + l2 = [sql.Literal("foo"), sql.Literal("b'ar")] + assert sql.Composed(L) == sql.Composed(list(L)) + assert sql.Composed(L) != L + assert sql.Composed(L) != sql.Composed(l2) + + def test_join(self, conn): + obj = sql.Composed([sql.Literal("foo"), sql.Identifier("b'ar")]) + obj = obj.join(", ") + assert isinstance(obj, sql.Composed) + assert noe(obj.as_string(conn)) == "'foo', \"b'ar\"" + + def test_sum(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj = obj + sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert noe(obj.as_string(conn)), "foo 'bar'" + + def test_sum_inplace(self, conn): + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Literal("bar") + assert isinstance(obj, sql.Composed) + assert noe(obj.as_string(conn)) == "foo 'bar'" + + obj = sql.Composed([sql.SQL("foo ")]) + obj += sql.Composed([sql.Literal("bar")]) + assert isinstance(obj, sql.Composed) + assert noe(obj.as_string(conn)) == "foo 'bar'" + + def test_iter(self): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + it = iter(obj) + i = next(it) + assert i == sql.SQL("foo") + i = next(it) + assert i == sql.SQL("bar") + with pytest.raises(StopIteration): + next(it) + + +class TestPlaceholder: + def test_class(self): + assert issubclass(sql.Placeholder, sql.Composable) + + def test_repr(self, conn): + assert str(sql.Placeholder()), "Placeholder()" + assert repr(sql.Placeholder()), "Placeholder()" + assert sql.Placeholder().as_string(conn), "%s" + + def test_repr_name(self, conn): + assert str(sql.Placeholder("foo")), "Placeholder('foo')" + assert repr(sql.Placeholder("foo")), "Placeholder('foo')" + assert sql.Placeholder("foo").as_string(conn), "%(foo)s" + + def test_bad_name(self): + with pytest.raises(ValueError): + sql.Placeholder(")") + + def test_eq(self): + assert sql.Placeholder("foo") == sql.Placeholder("foo") + assert sql.Placeholder("foo") != sql.Placeholder("bar") + assert sql.Placeholder("foo") != "foo" + assert sql.Placeholder() == sql.Placeholder() + assert sql.Placeholder("foo") != sql.Placeholder() + assert sql.Placeholder("foo") != sql.Literal("foo") + + +class TestValues: + def test_null(self, conn): + assert isinstance(sql.NULL, sql.SQL) + assert sql.NULL.as_string(conn) == "NULL" + + def test_default(self, conn): + assert isinstance(sql.DEFAULT, sql.SQL) + assert sql.DEFAULT.as_string(conn) == "DEFAULT" + + +def noe(s): + """Drop an eventual E from E'' quotes""" + return re.sub(r"\bE'", "'", s)