]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Text dumping raises DataError if it contains a NUL byte
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Nov 2020 01:17:44 +0000 (02:17 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 4 Nov 2020 17:02:13 +0000 (18:02 +0100)
Note that psycopg2 was raising ValueError, but DataError is more
appropriate. If the param is binary, Postgres will raise
.CharacterNotInRepertoire, which is a DataError subclass.

psycopg3/psycopg3/types/text.py
tests/types/test_text.py

index 1c3073cf80577baf6adf805ca35521471b0295ec..8d82c963d69a1105a56d71eea4f586f5fa5afd98 100644 (file)
@@ -10,15 +10,14 @@ from typing import Optional, Union, TYPE_CHECKING
 from ..oids import builtins, INVALID_OID
 from ..adapt import Dumper, Loader
 from ..proto import AdaptContext, EncodeFunc, DecodeFunc
+from ..errors import DataError
 from ..pq import Escaping
 
 if TYPE_CHECKING:
     from ..pq.proto import Escaping as EscapingProto
 
 
-@Dumper.text(str)
-@Dumper.binary(str)
-class StringDumper(Dumper):
+class _StringDumper(Dumper):
     def __init__(self, src: type, context: AdaptContext):
         super().__init__(src, context)
 
@@ -31,10 +30,24 @@ class StringDumper(Dumper):
         else:
             self._encode = codecs.lookup("utf8").encode
 
+
+@Dumper.binary(str)
+class StringBinaryDumper(_StringDumper):
     def dump(self, obj: str) -> bytes:
         return self._encode(obj)[0]
 
 
+@Dumper.text(str)
+class StringDumper(_StringDumper):
+    def dump(self, obj: str) -> bytes:
+        if "\x00" in obj:
+            raise DataError(
+                "PostgreSQL text fields cannot contain NUL (0x00) bytes"
+            )
+        else:
+            return self._encode(obj)[0]
+
+
 @Loader.text(builtins["text"].oid)
 @Loader.binary(builtins["text"].oid)
 @Loader.text(builtins["varchar"].oid)
index d15db13a9bc1ca1de68bd875a6c9f69db504078b..3c5f47ac3843fcfbdcfa0d126fc3fa4d9fcf5d77 100644 (file)
@@ -1,6 +1,7 @@
 import pytest
 
-from psycopg3 import DatabaseError, sql
+import psycopg3
+from psycopg3 import sql
 from psycopg3.adapt import Format
 
 eur = "\u20ac"
@@ -30,6 +31,22 @@ def test_quote_1char(conn):
         assert cur.fetchone()[0] is True, chr(i)
 
 
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+def test_dump_zero(conn, fmt_in):
+    cur = conn.cursor()
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    s = "foo\x00bar"
+    with pytest.raises(psycopg3.DataError):
+        cur.execute(f"select {ph}", (s,))
+
+
+def test_quote_zero(conn):
+    cur = conn.cursor()
+    s = "foo\x00bar"
+    with pytest.raises(psycopg3.DataError):
+        cur.execute(sql.SQL("select {}").format(sql.Literal(s)))
+
+
 # the only way to make this pass is to reduce %% -> % every time
 # not only when there are query arguments
 # see https://github.com/psycopg/psycopg2/issues/825
@@ -108,7 +125,7 @@ def test_load_badenc(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
     conn.client_encoding = "latin1"
-    with pytest.raises(DatabaseError):
+    with pytest.raises(psycopg3.DatabaseError):
         cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))