From: Daniele Varrazzo Date: Tue, 24 Aug 2021 14:22:16 +0000 (+0200) Subject: Add test for as_bytes method of sql objects X-Git-Tag: 3.0.beta1~53 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e6677c27fb1c541421f3e35bcf1f1b18e035cfd3;p=thirdparty%2Fpsycopg.git Add test for as_bytes method of sql objects --- diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index caf903025..fa2e3b1a2 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -64,7 +64,6 @@ class Composable(ABC): `!Composable` is passed instead of the query string. """ - # TODO: add tests for as_bytes raise NotImplementedError def as_string(self, context: Optional[AdaptContext]) -> str: diff --git a/tests/test_sql.py b/tests/test_sql.py index ee8411520..4382a0de8 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -10,6 +10,8 @@ import pytest from psycopg import pq, sql, ProgrammingError from psycopg.adapt import PyFormat as Format +eur = "\u20ac" + @pytest.mark.parametrize( "obj, quoted", @@ -265,12 +267,31 @@ class TestIdentifier: assert sql.Identifier("foo") != "foo" assert sql.Identifier("foo") != sql.SQL("foo") - def test_as_str(self, conn): - assert sql.Identifier("foo").as_string(conn) == '"foo"' - assert sql.Identifier("foo", "bar").as_string(conn) == '"foo"."bar"' - assert ( - sql.Identifier("fo'o", 'ba"r').as_string(conn) == '"fo\'o"."ba""r"' - ) + @pytest.mark.parametrize( + "args, want", + [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ], + ) + def test_as_string(self, conn, args, want): + assert sql.Identifier(*args).as_string(conn) == want + + @pytest.mark.parametrize( + "args, want, enc", + [ + (("foo",), '"foo"', "ascii"), + (("foo", "bar"), '"foo"."bar"', "ascii"), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + (("foo", eur), f'"foo"."{eur}"', "latin9"), + ], + ) + def test_as_bytes(self, conn, args, want, enc): + want = want.encode(enc) + conn.client_encoding = enc + assert sql.Identifier(*args).as_bytes(conn) == want def test_join(self): assert not hasattr(sql.Identifier("foo"), "join") @@ -291,7 +312,7 @@ class TestLiteral: assert repr(sql.Literal("foo")) == "Literal('foo')" assert str(sql.Literal("foo")) == "Literal('foo')" - def test_as_str(self, conn): + def test_as_string(self, conn): assert sql.Literal(None).as_string(conn) == "NULL" assert noe(sql.Literal("foo").as_string(conn)) == "'foo'" assert sql.Literal(42).as_string(conn) == "42" @@ -299,6 +320,19 @@ class TestLiteral: sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'" ) + def test_as_bytes(self, conn): + assert sql.Literal(None).as_bytes(conn) == b"NULL" + assert noe(sql.Literal("foo").as_bytes(conn)) == b"'foo'" + assert sql.Literal(42).as_bytes(conn) == b"42" + assert ( + sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'" + ) + + conn.client_encoding = "utf8" + assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode("utf8") + conn.client_encoding = "latin9" + assert sql.Literal(eur).as_bytes(conn) == f"'{eur}'".encode("latin9") + def test_eq(self): assert sql.Literal("foo") == sql.Literal("foo") assert sql.Literal("foo") != sql.Literal("bar") @@ -370,6 +404,18 @@ class TestSQL: obj = sql.SQL(", ").join([]) assert obj == sql.Composed([]) + def test_as_string(self, conn): + assert sql.SQL("foo").as_string(conn) == "foo" + + def test_as_bytes(self, conn): + assert sql.SQL("foo").as_bytes(conn) == b"foo" + + conn.client_encoding = "utf8" + assert sql.SQL(eur).as_bytes(conn) == eur.encode("utf8") + + conn.client_encoding = "latin9" + assert sql.SQL(eur).as_bytes(conn) == eur.encode("latin9") + class TestComposed: def test_class(self): @@ -428,30 +474,38 @@ class TestComposed: with pytest.raises(StopIteration): next(it) + def test_as_string(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_string(conn) == "foobar" + + def test_as_bytes(self, conn): + obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) + assert obj.as_bytes(conn) == b"foobar" + + obj = sql.Composed([sql.SQL("foo"), sql.SQL(eur)]) + + conn.client_encoding = "utf8" + assert obj.as_bytes(conn) == ("foo" + eur).encode("utf8") + + conn.client_encoding = "latin9" + assert obj.as_bytes(conn) == ("foo" + eur).encode("latin9") + class TestPlaceholder: def test_class(self): assert issubclass(sql.Placeholder, sql.Composable) - def test_repr(self, conn): - ph = sql.Placeholder() - assert str(ph) == repr(ph) == "Placeholder()" - assert ph.as_string(conn) == "%s" + @pytest.mark.parametrize("format", Format) + def test_repr_format(self, conn, format): + ph = sql.Placeholder(format=format) + add = f"format={format.name}" if format != Format.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder({add})" - def test_repr_binary(self, conn): - ph = sql.Placeholder(format=Format.BINARY) - assert str(ph) == repr(ph) == "Placeholder(format=BINARY)" - assert ph.as_string(conn) == "%b" - - def test_repr_name(self, conn): - ph = sql.Placeholder("foo") - assert str(ph) == repr(ph) == "Placeholder('foo')" - assert ph.as_string(conn) == "%(foo)s" - - def test_repr_name_binary(self, conn): - ph = sql.Placeholder("foo", format=Format.BINARY) - assert str(ph) == repr(ph) == "Placeholder('foo', format=BINARY)" - assert ph.as_string(conn) == "%(foo)b" + @pytest.mark.parametrize("format", Format) + def test_repr_name_format(self, conn, format): + ph = sql.Placeholder("foo", format=format) + add = f", format={format.name}" if format != Format.AUTO else "" + assert str(ph) == repr(ph) == f"Placeholder('foo'{add})" def test_bad_name(self): with pytest.raises(ValueError): @@ -465,6 +519,22 @@ class TestPlaceholder: assert sql.Placeholder("foo") != sql.Placeholder() assert sql.Placeholder("foo") != sql.Literal("foo") + @pytest.mark.parametrize("format", Format) + def test_as_string(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_string(conn) == f"%{format}" + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_string(conn) == f"%(foo){format}" + + @pytest.mark.parametrize("format", Format) + def test_as_bytes(self, conn, format): + ph = sql.Placeholder(format=format) + assert ph.as_bytes(conn) == f"%{format}".encode("ascii") + + ph = sql.Placeholder(name="foo", format=format) + assert ph.as_bytes(conn) == f"%(foo){format}".encode("ascii") + class TestValues: def test_null(self, conn): @@ -478,4 +548,12 @@ class TestValues: def noe(s): """Drop an eventual E from E'' quotes""" - return re.sub(r"\bE'", "'", s) + if isinstance(s, memoryview): + s = bytes(s) + + if isinstance(s, str): + return re.sub(r"\bE'", "'", s) + elif isinstance(s, bytes): + return re.sub(br"\bE'", "'", s) + else: + raise TypeError(f"not dealing with {type(s).__name__}: {s}")