From: Daniele Varrazzo Date: Thu, 27 Nov 2025 01:02:22 +0000 (+0100) Subject: fix(c): make sure to return bytes from `Literal.as_bytes()` X-Git-Tag: 3.3.0~8 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=cca68f06438399b7555baad4834d19a7f273cb3f;p=thirdparty%2Fpsycopg.git fix(c): make sure to return bytes from `Literal.as_bytes()` In some case it might have returned a bytearray or a memoryview, according to the dumper involved. --- diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index 6509a73d7..2b61c6443 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -238,6 +238,9 @@ cdef class Transformer: if type_ptr: rv = b"%s::%s" % (rv, type_ptr) + if not isinstance(rv, bytes): + rv = bytes(rv) + return rv def get_dumper(self, obj, format) -> "Dumper": diff --git a/tests/test_sql.py b/tests/test_sql.py index edb3fd7f9..b6c10bd84 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -294,7 +294,9 @@ class TestIdentifier: def test_as_bytes(self, conn, args, want, enc): want = want.encode(enc) conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") - assert sql.Identifier(*args).as_bytes(conn) == want + got = sql.Identifier(*args).as_bytes(conn) + assert isinstance(got, bytes) + assert got == want @pytest.mark.parametrize("args, want, enc", _as_bytes_params) def test_as_bytes_no_conn(self, conn, args, want, enc): @@ -345,6 +347,7 @@ class TestLiteral: @pytest.mark.parametrize("obj, want", _params) def test_as_bytes(self, conn, obj, want): got = sql.Literal(obj).as_bytes(conn) + assert isinstance(got, bytes) if isinstance(obj, str): got = no_e(got) assert got == want.encode() @@ -509,7 +512,9 @@ class TestSQL: if encoding: conn.execute(f"set client_encoding to {encoding}") - assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding) + got = sql.SQL(eur).as_bytes(conn) + assert isinstance(got, bytes) + assert got == eur.encode(encoding) def test_no_conn(self): assert sql.SQL(eur).as_string() == eur @@ -629,6 +634,7 @@ class TestPlaceholder: @pytest.mark.parametrize("format", PyFormat) def test_as_bytes(self, conn, format): ph = sql.Placeholder(format=format) + assert isinstance(ph.as_bytes(), bytes) assert ph.as_bytes(conn) == f"%{format.value}".encode() assert ph.as_bytes() == f"%{format.value}".encode()