]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(c): make sure to return bytes from `Literal.as_bytes()`
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 27 Nov 2025 01:02:22 +0000 (02:02 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 27 Nov 2025 02:00:14 +0000 (03:00 +0100)
In some case it might have returned a bytearray or a memoryview,
according to the dumper involved.

psycopg_c/psycopg_c/_psycopg/transform.pyx
tests/test_sql.py

index 6509a73d7fc53aed97f423d72106e3cbcc48d767..2b61c6443d42e7a209f532959edcad214ca60e25 100644 (file)
@@ -238,6 +238,9 @@ cdef class Transformer:
             if <object>type_ptr:
                 rv = b"%s::%s" % (rv, <object>type_ptr)
 
+        if not isinstance(rv, bytes):
+            rv = bytes(rv)
+
         return rv
 
     def get_dumper(self, obj, format) -> "Dumper":
index edb3fd7f929e369671de610d33c0b0c13372ed9d..b6c10bd845d6707e7e66a695c10c01c86d21b863 100644 (file)
@@ -294,7 +294,9 @@ class TestIdentifier:
     def test_as_bytes(self, conn, args, want, enc):
         want = want.encode(enc)
         conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
-        assert sql.Identifier(*args).as_bytes(conn) == want
+        got = sql.Identifier(*args).as_bytes(conn)
+        assert isinstance(got, bytes)
+        assert got == want
 
     @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
     def test_as_bytes_no_conn(self, conn, args, want, enc):
@@ -345,6 +347,7 @@ class TestLiteral:
     @pytest.mark.parametrize("obj, want", _params)
     def test_as_bytes(self, conn, obj, want):
         got = sql.Literal(obj).as_bytes(conn)
+        assert isinstance(got, bytes)
         if isinstance(obj, str):
             got = no_e(got)
         assert got == want.encode()
@@ -509,7 +512,9 @@ class TestSQL:
         if encoding:
             conn.execute(f"set client_encoding to {encoding}")
 
-        assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
+        got = sql.SQL(eur).as_bytes(conn)
+        assert isinstance(got, bytes)
+        assert got == eur.encode(encoding)
 
     def test_no_conn(self):
         assert sql.SQL(eur).as_string() == eur
@@ -629,6 +634,7 @@ class TestPlaceholder:
     @pytest.mark.parametrize("format", PyFormat)
     def test_as_bytes(self, conn, format):
         ph = sql.Placeholder(format=format)
+        assert isinstance(ph.as_bytes(), bytes)
         assert ph.as_bytes(conn) == f"%{format.value}".encode()
         assert ph.as_bytes() == f"%{format.value}".encode()