]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
sql.Placeholder can represent binary placeholders
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 12:00:35 +0000 (13:00 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 12:01:49 +0000 (13:01 +0100)
psycopg3/psycopg3/sql.py
tests/test_sql.py

index 2dc0cb737df6ced39cfdbbb72485976d034727c0..1be83b4eef85a93fa96e7c6eface027e47aac530 100644 (file)
@@ -186,9 +186,7 @@ class SQL(Composable):
     def as_string(self, context: AdaptContext) -> str:
         return self._obj
 
-    def format(
-        self, *args: Composable, **kwargs: Composable
-    ) -> Composed:
+    def format(self, *args: Composable, **kwargs: Composable) -> Composed:
         """
         Merge `Composable` objects into a template.
 
@@ -391,8 +389,9 @@ class Literal(Composable):
 class Placeholder(Composable):
     """A `Composable` representing a placeholder for query parameters.
 
-    If the name is specified, generate a named placeholder (e.g. ``%(name)s``),
-    otherwise generate a positional placeholder (e.g. ``%s``).
+    If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
+    ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
+    ``%b``).
 
     The object is useful to generate SQL queries with a variable number of
     arguments.
@@ -415,7 +414,9 @@ class Placeholder(Composable):
 
     """
 
-    def __init__(self, name: Optional[str] = None):
+    def __init__(
+        self, name: Optional[str] = None, format: Format = Format.TEXT
+    ):
         if isinstance(name, str):
             if ")" in name:
                 raise ValueError("invalid name: %r" % name)
@@ -424,18 +425,23 @@ class Placeholder(Composable):
             raise TypeError("expected string or None as name, got %r" % name)
 
         super(Placeholder, self).__init__(name)
+        self._format = format
 
     def __repr__(self) -> str:
-        return (
-            f"{self.__class__.__name__}"
-            f"({self._obj if self._obj is not None else ''})"
-        )
+        parts = []
+        if self._obj:
+            parts.append(repr(self._obj))
+        if self._format != Format.TEXT:
+            parts.append(f"format={Format(self._format).name}")
+
+        return f"{self.__class__.__name__}({', '.join(parts)})"
 
     def as_string(self, context: AdaptContext) -> str:
+        code = "s" if self._format == Format.TEXT else "b"
         if self._obj is not None:
-            return "%%(%s)s" % self._obj
+            return f"%({self._obj}){code}"
         else:
-            return "%s"
+            return f"%{code}"
 
 
 # Literals
index 267bd4fd5d0d34f6eb5f94fe80a766c409083322..087e5d0f83dece644126716eeb82656ca943b4cd 100755 (executable)
@@ -8,6 +8,7 @@ import datetime as dt
 import pytest
 
 from psycopg3 import sql, ProgrammingError
+from psycopg3.pq import Format
 
 
 @pytest.mark.parametrize(
@@ -161,7 +162,7 @@ class TestSqlFormat:
         )
 
         cur.execute("select * from test_compose")
-        assert cur.fetchall(), [(10, "a", "b", "c"), (20, "d", "e", "f")]
+        assert cur.fetchall() == [(10, "a", "b", "c"), (20, "d", "e", "f")]
 
     def test_copy(self, conn):
         cur = conn.cursor()
@@ -222,7 +223,7 @@ class TestIdentifier:
 
     def test_as_str(self, conn):
         assert sql.Identifier("foo").as_string(conn) == '"foo"'
-        assert sql.Identifier("foo", "bar").as_string(conn), '"foo"."bar"'
+        assert sql.Identifier("foo", "bar").as_string(conn) == '"foo"."bar"'
         assert (
             sql.Identifier("fo'o", 'ba"r').as_string(conn) == '"fo\'o"."ba""r"'
         )
@@ -354,7 +355,7 @@ class TestComposed:
         obj = sql.Composed([sql.SQL("foo ")])
         obj = obj + sql.Literal("bar")
         assert isinstance(obj, sql.Composed)
-        assert noe(obj.as_string(conn)), "foo 'bar'"
+        assert noe(obj.as_string(conn)) == "foo 'bar'"
 
     def test_sum_inplace(self, conn):
         obj = sql.Composed([sql.SQL("foo ")])
@@ -383,14 +384,24 @@ class TestPlaceholder:
         assert issubclass(sql.Placeholder, sql.Composable)
 
     def test_repr(self, conn):
-        assert str(sql.Placeholder()), "Placeholder()"
-        assert repr(sql.Placeholder()), "Placeholder()"
-        assert sql.Placeholder().as_string(conn), "%s"
+        ph = sql.Placeholder()
+        assert str(ph) == repr(ph) == "Placeholder()"
+        assert ph.as_string(conn) == "%s"
+
+    def test_repr_binary(self, conn):
+        ph = sql.Placeholder(format=Format.BINARY)
+        assert str(ph) == repr(ph) == "Placeholder(format=BINARY)"
+        assert ph.as_string(conn) == "%b"
 
     def test_repr_name(self, conn):
-        assert str(sql.Placeholder("foo")), "Placeholder('foo')"
-        assert repr(sql.Placeholder("foo")), "Placeholder('foo')"
-        assert sql.Placeholder("foo").as_string(conn), "%(foo)s"
+        ph = sql.Placeholder("foo")
+        assert str(ph) == repr(ph) == "Placeholder('foo')"
+        assert ph.as_string(conn) == "%(foo)s"
+
+    def test_repr_name_binary(self, conn):
+        ph = sql.Placeholder("foo", format=Format.BINARY)
+        assert str(ph) == repr(ph) == "Placeholder('foo', format=BINARY)"
+        assert ph.as_string(conn) == "%(foo)b"
 
     def test_bad_name(self):
         with pytest.raises(ValueError):