]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
sql.SQL.format() accepts any Python object, making it a Literal
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 Nov 2020 02:21:47 +0000 (02:21 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 9 Nov 2020 02:21:47 +0000 (02:21 +0000)
psycopg3/psycopg3/sql.py
tests/test_sql.py

index 3b57d68cf10bec05b9e19006b521b211fb062823..9d16936f4d229160e071aa0382d5055bb9b09e41 100644 (file)
@@ -85,7 +85,8 @@ class Composed(Composable):
 
     The object is usually created using `!Composable` operators and methods.
     However it is possible to create a `!Composed` directly specifying a
-    sequence of `!Composable` as arguments.
+    sequence of objects as arguments: if they are not `!Composable` they will
+    be wrapped in a `Literal`.
 
     Example::
 
@@ -101,12 +102,10 @@ class Composed(Composable):
     _obj: List[Composable]
 
     def __init__(self, seq: Sequence[Any]):
+        seq = [
+            obj if isinstance(obj, Composable) else Literal(obj) for obj in seq
+        ]
         super().__init__(seq)
-        for obj in seq:
-            if not isinstance(obj, Composable):
-                raise TypeError(
-                    f"Composed elements must be Composable, got {obj!r} instead"
-                )
 
     def as_string(self, context: AdaptContext) -> str:
         rv = []
@@ -184,14 +183,13 @@ class SQL(Composable):
     def as_string(self, context: AdaptContext) -> str:
         return self._obj
 
-    def format(self, *args: Composable, **kwargs: Composable) -> Composed:
+    def format(self, *args: Any, **kwargs: Any) -> Composed:
         """
         Merge `Composable` objects into a template.
 
-        :param `Composable` args: parameters to replace to numbered
-            (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
-        :param `Composable` kwargs: parameters to replace to named (``{name}``)
-            placeholders
+        :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
+            auto-numbered (``{}``) placeholders
+        :param kwargs: parameters to replace to named (``{name}``) placeholders
         :return: the union of the `!SQL` string with placeholders replaced
         :rtype: `Composed`
 
@@ -200,8 +198,12 @@ class SQL(Composable):
         ``{1}``...), and named placeholders (``{name}``), with positional
         arguments replacing the numbered placeholders and keywords replacing
         the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
-        are not supported. Only `!Composable` objects can be passed to the
-        template.
+        are not supported.
+
+        If a `!Composable` objects is passed to the template it will be merged
+        according to its `as_string()` method. If any other Python object is
+        passed, it will be wrapped in a `Literal` object and so escacaped
+        according to SQL rules.
 
         Example::
 
@@ -210,10 +212,10 @@ class SQL(Composable):
             ...     .as_string(conn))
             select * from "people" where "id" = %s
 
-            >>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
-            ...     .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
+            >>> print(sql.SQL("select * from {tbl} where name = {name}")
+            ...     .format(tbl=sql.Identifier('people'), name="O'Rourke"))
             ...     .as_string(conn))
-            select * from "people" where "id" = %s
+            select * from "people" where name = 'O''Rourke'
 
         """
         rv: List[Composable] = []
index 36e3ea7c1bd86c086dbba7ce9baa7e9c918577d3..62ff15d0bcaff594a16385e13cbc81879961bad5 100755 (executable)
@@ -98,12 +98,6 @@ class TestSqlFormat:
         with pytest.raises(KeyError):
             sql.SQL("select {x};").format(10)
 
-    def test_must_be_composable(self):
-        with pytest.raises(TypeError):
-            sql.SQL("select {0};").format("foo")
-        with pytest.raises(TypeError):
-            sql.SQL("select {0};").format(10)
-
     def test_no_modifiers(self):
         with pytest.raises(ValueError):
             sql.SQL("select {a!r};").format(a=10)
@@ -118,6 +112,12 @@ class TestSqlFormat:
         with pytest.raises(ProgrammingError):
             s.as_string(conn)
 
+    def test_auto_literal(self, conn):
+        s = sql.SQL("select {}, {}, {}").format(
+            "he'lo", 10, dt.date(2020, 1, 1)
+        )
+        assert s.as_string(conn) == "select 'he''lo', 10, '2020-01-01'"
+
     def test_execute(self, conn):
         cur = conn.cursor()
         cur.execute(
@@ -351,6 +351,12 @@ class TestComposed:
         assert isinstance(obj, sql.Composed)
         assert noe(obj.as_string(conn)) == "'foo', \"b'ar\""
 
+    def test_auto_literal(self, conn):
+        obj = sql.Composed(["fo'o", dt.date(2020, 1, 1)])
+        obj = obj.join(", ")
+        assert isinstance(obj, sql.Composed)
+        assert noe(obj.as_string(conn)) == "'fo''o', '2020-01-01'"
+
     def test_sum(self, conn):
         obj = sql.Composed([sql.SQL("foo ")])
         obj = obj + sql.Literal("bar")