]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add test for as_bytes method of sql objects
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 14:22:16 +0000 (16:22 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 24 Aug 2021 14:22:16 +0000 (16:22 +0200)
psycopg/psycopg/sql.py
tests/test_sql.py

index caf903025265104388185ccb43b1c58736fba671..fa2e3b1a24cd952cc05ed6fff0423ce966e4e1b1 100644 (file)
@@ -64,7 +64,6 @@ class Composable(ABC):
         `!Composable` is passed instead of the query string.
 
         """
-        # TODO: add tests for as_bytes
         raise NotImplementedError
 
     def as_string(self, context: Optional[AdaptContext]) -> str:
index ee8411520d712c0081aa6319a88ff6629131b110..4382a0de8d49c82cb3edd51c7e285e1b8910d9c2 100644 (file)
@@ -10,6 +10,8 @@ import pytest
 from psycopg import pq, sql, ProgrammingError
 from psycopg.adapt import PyFormat as Format
 
+eur = "\u20ac"
+
 
 @pytest.mark.parametrize(
     "obj, quoted",
@@ -265,12 +267,31 @@ class TestIdentifier:
         assert sql.Identifier("foo") != "foo"
         assert sql.Identifier("foo") != sql.SQL("foo")
 
-    def test_as_str(self, conn):
-        assert sql.Identifier("foo").as_string(conn) == '"foo"'
-        assert sql.Identifier("foo", "bar").as_string(conn) == '"foo"."bar"'
-        assert (
-            sql.Identifier("fo'o", 'ba"r').as_string(conn) == '"fo\'o"."ba""r"'
-        )
+    @pytest.mark.parametrize(
+        "args, want",
+        [
+            (("foo",), '"foo"'),
+            (("foo", "bar"), '"foo"."bar"'),
+            (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+        ],
+    )
+    def test_as_string(self, conn, args, want):
+        assert sql.Identifier(*args).as_string(conn) == want
+
+    @pytest.mark.parametrize(
+        "args, want, enc",
+        [
+            (("foo",), '"foo"', "ascii"),
+            (("foo", "bar"), '"foo"."bar"', "ascii"),
+            (("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+            (("foo", eur), f'"foo"."{eur}"', "utf8"),
+            (("foo", eur), f'"foo"."{eur}"', "latin9"),
+        ],
+    )
+    def test_as_bytes(self, conn, args, want, enc):
+        want = want.encode(enc)
+        conn.client_encoding = enc
+        assert sql.Identifier(*args).as_bytes(conn) == want
 
     def test_join(self):
         assert not hasattr(sql.Identifier("foo"), "join")
@@ -291,7 +312,7 @@ class TestLiteral:
         assert repr(sql.Literal("foo")) == "Literal('foo')"
         assert str(sql.Literal("foo")) == "Literal('foo')"
 
-    def test_as_str(self, conn):
+    def test_as_string(self, conn):
         assert sql.Literal(None).as_string(conn) == "NULL"
         assert noe(sql.Literal("foo").as_string(conn)) == "'foo'"
         assert sql.Literal(42).as_string(conn) == "42"
@@ -299,6 +320,19 @@ class TestLiteral:
             sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
         )
 
+    def test_as_bytes(self, conn):
+        assert sql.Literal(None).as_bytes(conn) == b"NULL"
+        assert noe(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
+        assert sql.Literal(42).as_bytes(conn) == b"42"
+        assert (
+            sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'"
+        )
+
+        conn.client_encoding = "utf8"
+        assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode("utf8")
+        conn.client_encoding = "latin9"
+        assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode("latin9")
+
     def test_eq(self):
         assert sql.Literal("foo") == sql.Literal("foo")
         assert sql.Literal("foo") != sql.Literal("bar")
@@ -370,6 +404,18 @@ class TestSQL:
         obj = sql.SQL(", ").join([])
         assert obj == sql.Composed([])
 
+    def test_as_string(self, conn):
+        assert sql.SQL("foo").as_string(conn) == "foo"
+
+    def test_as_bytes(self, conn):
+        assert sql.SQL("foo").as_bytes(conn) == b"foo"
+
+        conn.client_encoding = "utf8"
+        assert sql.SQL(eur).as_bytes(conn) == eur.encode("utf8")
+
+        conn.client_encoding = "latin9"
+        assert sql.SQL(eur).as_bytes(conn) == eur.encode("latin9")
+
 
 class TestComposed:
     def test_class(self):
@@ -428,30 +474,38 @@ class TestComposed:
         with pytest.raises(StopIteration):
             next(it)
 
+    def test_as_string(self, conn):
+        obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+        assert obj.as_string(conn) == "foobar"
+
+    def test_as_bytes(self, conn):
+        obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
+        assert obj.as_bytes(conn) == b"foobar"
+
+        obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)])
+
+        conn.client_encoding = "utf8"
+        assert obj.as_bytes(conn) == ("foo" + eur).encode("utf8")
+
+        conn.client_encoding = "latin9"
+        assert obj.as_bytes(conn) == ("foo" + eur).encode("latin9")
+
 
 class TestPlaceholder:
     def test_class(self):
         assert issubclass(sql.Placeholder, sql.Composable)
 
-    def test_repr(self, conn):
-        ph = sql.Placeholder()
-        assert str(ph) == repr(ph) == "Placeholder()"
-        assert ph.as_string(conn) == "%s"
+    @pytest.mark.parametrize("format", Format)
+    def test_repr_format(self, conn, format):
+        ph = sql.Placeholder(format=format)
+        add = f"format={format.name}" if format != Format.AUTO else ""
+        assert str(ph) == repr(ph) == f"Placeholder({add})"
 
-    def test_repr_binary(self, conn):
-        ph = sql.Placeholder(format=Format.BINARY)
-        assert str(ph) == repr(ph) == "Placeholder(format=BINARY)"
-        assert ph.as_string(conn) == "%b"
-
-    def test_repr_name(self, conn):
-        ph = sql.Placeholder("foo")
-        assert str(ph) == repr(ph) == "Placeholder('foo')"
-        assert ph.as_string(conn) == "%(foo)s"
-
-    def test_repr_name_binary(self, conn):
-        ph = sql.Placeholder("foo", format=Format.BINARY)
-        assert str(ph) == repr(ph) == "Placeholder('foo', format=BINARY)"
-        assert ph.as_string(conn) == "%(foo)b"
+    @pytest.mark.parametrize("format", Format)
+    def test_repr_name_format(self, conn, format):
+        ph = sql.Placeholder("foo", format=format)
+        add = f", format={format.name}" if format != Format.AUTO else ""
+        assert str(ph) == repr(ph) == f"Placeholder('foo'{add})"
 
     def test_bad_name(self):
         with pytest.raises(ValueError):
@@ -465,6 +519,22 @@ class TestPlaceholder:
         assert sql.Placeholder("foo") != sql.Placeholder()
         assert sql.Placeholder("foo") != sql.Literal("foo")
 
+    @pytest.mark.parametrize("format", Format)
+    def test_as_string(self, conn, format):
+        ph = sql.Placeholder(format=format)
+        assert ph.as_string(conn) == f"%{format}"
+
+        ph = sql.Placeholder(name="foo", format=format)
+        assert ph.as_string(conn) == f"%(foo){format}"
+
+    @pytest.mark.parametrize("format", Format)
+    def test_as_bytes(self, conn, format):
+        ph = sql.Placeholder(format=format)
+        assert ph.as_bytes(conn) == f"%{format}".encode("ascii")
+
+        ph = sql.Placeholder(name="foo", format=format)
+        assert ph.as_bytes(conn) == f"%(foo){format}".encode("ascii")
+
 
 class TestValues:
     def test_null(self, conn):
@@ -478,4 +548,12 @@ class TestValues:
 
 def noe(s):
     """Drop an eventual E from E'' quotes"""
-    return re.sub(r"\bE'", "'", s)
+    if isinstance(s, memoryview):
+        s = bytes(s)
+
+    if isinstance(s, str):
+        return re.sub(r"\bE'", "'", s)
+    elif isinstance(s, bytes):
+        return re.sub(br"\bE'", "'", s)
+    else:
+        raise TypeError(f"not dealing with {type(s).__name__}: {s}")