]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: allow no connection parameter in sql.Composible.as_string()/as_bytes()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 19 Jan 2024 16:40:37 +0000 (17:40 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Jan 2024 02:12:02 +0000 (02:12 +0000)
Close #716

docs/api/sql.rst
docs/news.rst
psycopg/psycopg/sql.py
tests/test_sql.py

index 6959fee4dba87df92ec7f2c74a0fe9f236f21a16..5e7000b269be7d3b954fbfdf6f228cc9425b0535 100644 (file)
@@ -108,9 +108,25 @@ The `!sql` objects are in the following inheritance hierarchy:
 
 .. autoclass:: Composable()
 
-    .. automethod:: as_bytes
     .. automethod:: as_string
 
+    .. versionchanged:: 3.2
+
+        The `!context` parameter is optional.
+
+        .. warning::
+
+            If a context is not specified, the results are "generic" and not
+            tailored for a specific target connection. Details such as the
+            connection encoding and escaping style will not be taken into
+            account.
+
+    .. automethod:: as_bytes
+
+    .. versionchanged:: 3.2
+
+        The `!context` parameter is optional. See `as_string` for details.
+
 
 .. autoclass:: SQL
 
index f42bb29b7de411bfbc8dd54bfa6f3e600880ad4b..839f1d8c74ee04e2aa1335a6e09951bb33acf57b 100644 (file)
@@ -21,6 +21,8 @@ Psycopg 3.2 (unreleased)
   transaction control methods available on the async connections.
 - Add support for libpq functions to close prepared statements and portals
   introduced in libpq v17 (:ticket:`#603`).
+- The `!context` parameter of `sql` objects `~sql.Composable.as_string()` and
+  `~sql.Composable.as_bytes()` methods is not optional (:ticket:`#716`).
 - Disable receiving more than one result on the same cursor in pipeline mode,
   to iterate through `~Cursor.nextset()`. The behaviour was different than
   in non-pipeline mode and not totally reliable (:ticket:`#604`).
index 6eaabee7cc2d642557f55e29836651a2713cba7c..a94f77f6efcf517703c56a25b6db96b26396c5bd 100644 (file)
@@ -55,7 +55,7 @@ class Composable(ABC):
         return f"{self.__class__.__name__}({self._obj!r})"
 
     @abstractmethod
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         """
         Return the value of the object as bytes.
 
@@ -69,7 +69,7 @@ class Composable(ABC):
         """
         raise NotImplementedError
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         """
         Return the value of the object as string.
 
@@ -130,7 +130,7 @@ class Composed(Composable):
         seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
         super().__init__(seq)
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         return b"".join(obj.as_bytes(context) for obj in self._obj)
 
     def __iter__(self) -> Iterator[Composable]:
@@ -200,10 +200,10 @@ class SQL(Composable):
         if not isinstance(obj, str):
             raise TypeError(f"SQL values must be strings, got {obj!r} instead")
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         return self._obj
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
         enc = conn_encoding(conn)
         return self._obj.encode(enc)
@@ -362,7 +362,7 @@ class Identifier(Composable):
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
         if conn:
             esc = Escaping(conn.pgconn)
@@ -400,7 +400,7 @@ class Literal(Composable):
 
     """
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         tx = Transformer.from_context(context)
         return tx.as_literal(self._obj)
 
@@ -459,11 +459,11 @@ class Placeholder(Composable):
 
         return f"{self.__class__.__name__}({', '.join(parts)})"
 
-    def as_string(self, context: Optional[AdaptContext]) -> str:
+    def as_string(self, context: Optional[AdaptContext] = None) -> str:
         code = self._format.value
         return f"%({self._obj}){code}" if self._obj else f"%{code}"
 
-    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+    def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
         conn = context.connection if context else None
         enc = conn_encoding(conn)
         return self.as_string(context).encode(enc)
index 4cd0b0c0e0cee9fd95918de21e4c743a5cc567a0..b5f1b37cad08a3008521caf7d4614519cf40682e 100644 (file)
@@ -280,6 +280,7 @@ class TestIdentifier:
     @pytest.mark.parametrize("args, want", _as_string_params)
     def test_as_string_no_conn(self, args, want):
         assert sql.Identifier(*args).as_string(None) == want
+        assert sql.Identifier(*args).as_string() == want
 
     _as_bytes_params = [
         crdb_encoding(("foo",), '"foo"', "ascii"),
@@ -296,9 +297,10 @@ class TestIdentifier:
         assert sql.Identifier(*args).as_bytes(conn) == want
 
     @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
-    def test_as_bytes_params(self, conn, args, want, enc):
+    def test_as_bytes_no_conn(self, conn, args, want, enc):
         want = want.encode()
         assert sql.Identifier(*args).as_bytes(None) == want
+        assert sql.Identifier(*args).as_bytes() == want
 
     def test_join(self):
         assert not hasattr(sql.Identifier("foo"), "join")
@@ -326,17 +328,40 @@ class TestLiteral:
         assert repr(sql.Literal("foo")) == "Literal('foo')"
         assert str(sql.Literal("foo")) == "Literal('foo')"
 
-    def test_as_string(self, conn):
-        assert sql.Literal(None).as_string(conn) == "NULL"
-        assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
-        assert sql.Literal(42).as_string(conn) == "42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
+    _params = [
+        (None, "NULL"),
+        ("foo", "'foo'"),
+        (42, "42"),
+        (dt.date(2017, 1, 1), "'2017-01-01'::date"),
+    ]
 
-    def test_as_bytes(self, conn):
-        assert sql.Literal(None).as_bytes(conn) == b"NULL"
-        assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
-        assert sql.Literal(42).as_bytes(conn) == b"42"
-        assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_string(self, conn, obj, want):
+        got = sql.Literal(obj).as_string(conn)
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_bytes(self, conn, obj, want):
+        got = sql.Literal(obj).as_bytes(conn)
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want.encode()
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_string_no_conn(self, obj, want):
+        got = sql.Literal(obj).as_string()
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want
+
+    @pytest.mark.parametrize("obj, want", _params)
+    def test_as_bytes_no_conn(self, obj, want):
+        got = sql.Literal(obj).as_bytes()
+        if isinstance(obj, str):
+            got = no_e(got)
+        assert got == want.encode()
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes_encoding(self, conn, encoding):
@@ -473,6 +498,7 @@ class TestSQL:
 
     def test_as_string(self, conn):
         assert sql.SQL("foo").as_string(conn) == "foo"
+        assert sql.SQL("foo").as_string() == "foo"
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes(self, conn, encoding):
@@ -481,6 +507,10 @@ class TestSQL:
 
         assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
 
+    def test_no_conn(self):
+        assert sql.SQL(eur).as_string() == eur
+        assert sql.SQL(eur).as_bytes() == eur.encode()
+
 
 class TestComposed:
     def test_class(self):
@@ -540,10 +570,12 @@ class TestComposed:
     def test_as_string(self, conn):
         obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
         assert obj.as_string(conn) == "foobar"
+        assert obj.as_string() == "foobar"
 
     def test_as_bytes(self, conn):
         obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
         assert obj.as_bytes(conn) == b"foobar"
+        assert obj.as_bytes() == b"foobar"
 
     @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
     def test_as_bytes_encoding(self, conn, encoding):
@@ -584,17 +616,21 @@ class TestPlaceholder:
     def test_as_string(self, conn, format):
         ph = sql.Placeholder(format=format)
         assert ph.as_string(conn) == f"%{format.value}"
+        assert ph.as_string() == f"%{format.value}"
 
         ph = sql.Placeholder(name="foo", format=format)
         assert ph.as_string(conn) == f"%(foo){format.value}"
+        assert ph.as_string() == f"%(foo){format.value}"
 
     @pytest.mark.parametrize("format", PyFormat)
     def test_as_bytes(self, conn, format):
         ph = sql.Placeholder(format=format)
-        assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+        assert ph.as_bytes(conn) == f"%{format.value}".encode()
+        assert ph.as_bytes() == f"%{format.value}".encode()
 
         ph = sql.Placeholder(name="foo", format=format)
-        assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+        assert ph.as_bytes(conn) == f"%(foo){format.value}".encode()
+        assert ph.as_bytes() == f"%(foo){format.value}".encode()
 
 
 class TestValues:
@@ -609,7 +645,7 @@ class TestValues:
 
 def no_e(s):
     """Drop an eventual E from E'' quotes"""
-    if isinstance(s, memoryview):
+    if isinstance(s, (memoryview, bytearray)):
         s = bytes(s)
 
     if isinstance(s, str):