From: Daniele Varrazzo Date: Fri, 19 Jan 2024 16:40:37 +0000 (+0100) Subject: feat: allow no connection parameter in sql.Composible.as_string()/as_bytes() X-Git-Tag: 3.2.0~91^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=2884e17ba9da296c04655f1618e56a8f7f576cdf;p=thirdparty%2Fpsycopg.git feat: allow no connection parameter in sql.Composible.as_string()/as_bytes() Close #716 --- diff --git a/docs/api/sql.rst b/docs/api/sql.rst index 6959fee4d..5e7000b26 100644 --- a/docs/api/sql.rst +++ b/docs/api/sql.rst @@ -108,9 +108,25 @@ The `!sql` objects are in the following inheritance hierarchy: .. autoclass:: Composable() - .. automethod:: as_bytes .. automethod:: as_string + .. versionchanged:: 3.2 + + The `!context` parameter is optional. + + .. warning:: + + If a context is not specified, the results are "generic" and not + tailored for a specific target connection. Details such as the + connection encoding and escaping style will not be taken into + account. + + .. automethod:: as_bytes + + .. versionchanged:: 3.2 + + The `!context` parameter is optional. See `as_string` for details. + .. autoclass:: SQL diff --git a/docs/news.rst b/docs/news.rst index f42bb29b7..839f1d8c7 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -21,6 +21,8 @@ Psycopg 3.2 (unreleased) transaction control methods available on the async connections. - Add support for libpq functions to close prepared statements and portals introduced in libpq v17 (:ticket:`#603`). +- The `!context` parameter of `sql` objects `~sql.Composable.as_string()` and + `~sql.Composable.as_bytes()` methods is not optional (:ticket:`#716`). - Disable receiving more than one result on the same cursor in pipeline mode, to iterate through `~Cursor.nextset()`. The behaviour was different than in non-pipeline mode and not totally reliable (:ticket:`#604`). diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index 6eaabee7c..a94f77f6e 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -55,7 +55,7 @@ class Composable(ABC): return f"{self.__class__.__name__}({self._obj!r})" @abstractmethod - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: """ Return the value of the object as bytes. @@ -69,7 +69,7 @@ class Composable(ABC): """ raise NotImplementedError - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: """ Return the value of the object as string. @@ -130,7 +130,7 @@ class Composed(Composable): seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq] super().__init__(seq) - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: return b"".join(obj.as_bytes(context) for obj in self._obj) def __iter__(self) -> Iterator[Composable]: @@ -200,10 +200,10 @@ class SQL(Composable): if not isinstance(obj, str): raise TypeError(f"SQL values must be strings, got {obj!r} instead") - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: return self._obj - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None enc = conn_encoding(conn) return self._obj.encode(enc) @@ -362,7 +362,7 @@ class Identifier(Composable): def __repr__(self) -> str: return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None if conn: esc = Escaping(conn.pgconn) @@ -400,7 +400,7 @@ class Literal(Composable): """ - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: tx = Transformer.from_context(context) return tx.as_literal(self._obj) @@ -459,11 +459,11 @@ class Placeholder(Composable): return f"{self.__class__.__name__}({', '.join(parts)})" - def as_string(self, context: Optional[AdaptContext]) -> str: + def as_string(self, context: Optional[AdaptContext] = None) -> str: code = self._format.value return f"%({self._obj}){code}" if self._obj else f"%{code}" - def as_bytes(self, context: Optional[AdaptContext]) -> bytes: + def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes: conn = context.connection if context else None enc = conn_encoding(conn) return self.as_string(context).encode(enc) diff --git a/tests/test_sql.py b/tests/test_sql.py index 4cd0b0c0e..b5f1b37ca 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -280,6 +280,7 @@ class TestIdentifier: @pytest.mark.parametrize("args, want", _as_string_params) def test_as_string_no_conn(self, args, want): assert sql.Identifier(*args).as_string(None) == want + assert sql.Identifier(*args).as_string() == want _as_bytes_params = [ crdb_encoding(("foo",), '"foo"', "ascii"), @@ -296,9 +297,10 @@ class TestIdentifier: assert sql.Identifier(*args).as_bytes(conn) == want @pytest.mark.parametrize("args, want, enc", _as_bytes_params) - def test_as_bytes_params(self, conn, args, want, enc): + def test_as_bytes_no_conn(self, conn, args, want, enc): want = want.encode() assert sql.Identifier(*args).as_bytes(None) == want + assert sql.Identifier(*args).as_bytes() == want def test_join(self): assert not hasattr(sql.Identifier("foo"), "join") @@ -326,17 +328,40 @@ class TestLiteral: assert repr(sql.Literal("foo")) == "Literal('foo')" assert str(sql.Literal("foo")) == "Literal('foo')" - def test_as_string(self, conn): - assert sql.Literal(None).as_string(conn) == "NULL" - assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'" - assert sql.Literal(42).as_string(conn) == "42" - assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date" + _params = [ + (None, "NULL"), + ("foo", "'foo'"), + (42, "42"), + (dt.date(2017, 1, 1), "'2017-01-01'::date"), + ] - def test_as_bytes(self, conn): - assert sql.Literal(None).as_bytes(conn) == b"NULL" - assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'" - assert sql.Literal(42).as_bytes(conn) == b"42" - assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date" + @pytest.mark.parametrize("obj, want", _params) + def test_as_string(self, conn, obj, want): + got = sql.Literal(obj).as_string(conn) + if isinstance(obj, str): + got = no_e(got) + assert got == want + + @pytest.mark.parametrize("obj, want", _params) + def test_as_bytes(self, conn, obj, want): + got = sql.Literal(obj).as_bytes(conn) + if isinstance(obj, str): + got = no_e(got) + assert got == want.encode() + + @pytest.mark.parametrize("obj, want", _params) + def test_as_string_no_conn(self, obj, want): + got = sql.Literal(obj).as_string() + if isinstance(obj, str): + got = no_e(got) + assert got == want + + @pytest.mark.parametrize("obj, want", _params) + def test_as_bytes_no_conn(self, obj, want): + got = sql.Literal(obj).as_bytes() + if isinstance(obj, str): + got = no_e(got) + assert got == want.encode() @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes_encoding(self, conn, encoding): @@ -473,6 +498,7 @@ class TestSQL: def test_as_string(self, conn): assert sql.SQL("foo").as_string(conn) == "foo" + assert sql.SQL("foo").as_string() == "foo" @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes(self, conn, encoding): @@ -481,6 +507,10 @@ class TestSQL: assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + def test_no_conn(self): + assert sql.SQL(eur).as_string() == eur + assert sql.SQL(eur).as_bytes() == eur.encode() + class TestComposed: def test_class(self): @@ -540,10 +570,12 @@ class TestComposed: def test_as_string(self, conn): obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) assert obj.as_string(conn) == "foobar" + assert obj.as_string() == "foobar" def test_as_bytes(self, conn): obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")]) assert obj.as_bytes(conn) == b"foobar" + assert obj.as_bytes() == b"foobar" @pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")]) def test_as_bytes_encoding(self, conn, encoding): @@ -584,17 +616,21 @@ class TestPlaceholder: def test_as_string(self, conn, format): ph = sql.Placeholder(format=format) assert ph.as_string(conn) == f"%{format.value}" + assert ph.as_string() == f"%{format.value}" ph = sql.Placeholder(name="foo", format=format) assert ph.as_string(conn) == f"%(foo){format.value}" + assert ph.as_string() == f"%(foo){format.value}" @pytest.mark.parametrize("format", PyFormat) def test_as_bytes(self, conn, format): ph = sql.Placeholder(format=format) - assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii") + assert ph.as_bytes(conn) == f"%{format.value}".encode() + assert ph.as_bytes() == f"%{format.value}".encode() ph = sql.Placeholder(name="foo", format=format) - assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii") + assert ph.as_bytes(conn) == f"%(foo){format.value}".encode() + assert ph.as_bytes() == f"%(foo){format.value}".encode() class TestValues: @@ -609,7 +645,7 @@ class TestValues: def no_e(s): """Drop an eventual E from E'' quotes""" - if isinstance(s, memoryview): + if isinstance(s, (memoryview, bytearray)): s = bytes(s) if isinstance(s, str):