]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Dumper.quote()
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 15:42:25 +0000 (16:42 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 28 Oct 2020 15:49:38 +0000 (16:49 +0100)
Return a representation of the object in a format useful for embedding
in SQL. For most objects this is only the representation in quotes and
escaped, for some it's a different value (numbers, bools, NULL).

None can be quoted but not dumped, so it can be passed to an
sql.Literal() without the need of a special case.

psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/sql.py
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/numeric.py
psycopg3/psycopg3/types/singletons.py [new file with mode: 0644]
tests/test_adapt.py
tests/test_sql.py
tests/types/test_date.py
tests/types/test_numeric.py
tests/types/test_singletons.py [new file with mode: 0644]
tests/types/test_text.py

index 1cb92e92a6e596e7501e3b77cb16930e8eb82897..f85c2b8c226d7749e8d29f6f42522d7d1cf04e89 100644 (file)
@@ -29,6 +29,16 @@ class Dumper:
     def dump(self, obj: Any) -> bytes:
         raise NotImplementedError()
 
+    def quote(self, obj: Any) -> bytes:
+        value = self.dump(obj)
+
+        if self.connection:
+            esc = pq.Escaping(self.connection.pgconn)
+            return esc.escape_literal(value)
+        else:
+            esc = pq.Escaping()
+            return b"'%s'" % esc.escape_string(value)
+
     @property
     def oid(self) -> int:
         return TEXT_OID
index 1be83b4eef85a93fa96e7c6eface027e47aac530..16f16b87c523f779865c279ab2ee5371ed184de2 100644 (file)
@@ -366,24 +366,13 @@ class Literal(Composable):
     """
 
     def as_string(self, context: AdaptContext) -> str:
-        if self._obj is None:
-            return "NULL"
-
         from .adapt import _connection_from_context, Transformer
 
         conn = _connection_from_context(context)
-        tx = Transformer(conn)
+        tx = context if isinstance(context, Transformer) else Transformer(conn)
         dumper = tx.get_dumper(self._obj, Format.TEXT)
-        value = dumper.dump(self._obj)
-
-        if conn:
-            esc = Escaping(conn.pgconn)
-            quoted = esc.escape_literal(value)
-            return conn.codec.decode(quoted)[0]
-        else:
-            esc = Escaping()
-            quoted = b"'%s'" % esc.escape_string(value)
-            return quoted.decode("utf8")
+        quoted = dumper.quote(self._obj)
+        return conn.codec.decode(quoted)[0] if conn else quoted.decode("utf8")
 
 
 class Placeholder(Composable):
index bbdd8a8b7ed204440cca6803f4c43ecc121dfcdc..0ee0378c17ee998565a5730d79d430a7d697a57c 100644 (file)
@@ -8,7 +8,7 @@ psycopg3 types package
 from .oids import builtins
 
 # Register default adapters
-from . import array, composite, date, json, numeric, text  # noqa
+from . import array, composite, date, json, numeric, singletons, text  # noqa
 
 # Register associations with array oids
 array.register_all_arrays()
index cae84c72f07c1931a64c306f724b6854d7b90867..9d70dfe4093f358f51a78bd8cd6e58f0b04985de 100644 (file)
@@ -6,7 +6,7 @@ Adapers for numeric types.
 
 import codecs
 import struct
-from typing import Callable, Dict, Tuple, cast
+from typing import Any, Callable, Dict, Tuple, cast
 from decimal import Decimal
 
 from ..adapt import Dumper, Loader
@@ -18,17 +18,28 @@ UnpackFloat = Callable[[bytes], Tuple[float]]
 
 FLOAT8_OID = builtins["float8"].oid
 NUMERIC_OID = builtins["numeric"].oid
-BOOL_OID = builtins["bool"].oid
 
 _encode_ascii = codecs.lookup("ascii").encode
 _decode_ascii = codecs.lookup("ascii").decode
 
 
-@Dumper.text(int)
-class TextIntDumper(Dumper):
-    def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes:
+class NumberDumper(Dumper):
+    _special: Dict[bytes, bytes] = {}
+
+    def dump(self, obj: Any, __encode: EncodeFunc = _encode_ascii) -> bytes:
         return __encode(str(obj))[0]
 
+    def quote(self, obj: Any) -> bytes:
+        value = self.dump(obj)
+
+        if value in self._special:
+            return self._special[value]
+
+        return b" " + value if value.startswith(b"-") else value
+
+
+@Dumper.text(int)
+class TextIntDumper(NumberDumper):
     @property
     def oid(self) -> int:
         # We don't know the size of it, so we have to return a type big enough
@@ -36,9 +47,12 @@ class TextIntDumper(Dumper):
 
 
 @Dumper.text(float)
-class TextFloatDumper(Dumper):
-    def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes:
-        return __encode(str(obj))[0]
+class TextFloatDumper(NumberDumper):
+    _special = {
+        b"inf": b"'Infinity'::float8",
+        b"-inf": b"'-Infinity'::float8",
+        b"nan": b"'NaN'::float8",
+    }
 
     @property
     def oid(self) -> int:
@@ -47,37 +61,18 @@ class TextFloatDumper(Dumper):
 
 
 @Dumper.text(Decimal)
-class TextDecimalDumper(Dumper):
-    def dump(
-        self, obj: Decimal, __encode: EncodeFunc = _encode_ascii
-    ) -> bytes:
-        return __encode(str(obj))[0]
+class TextDecimalDumper(NumberDumper):
+    _special = {
+        b"Infinity": b"'Infinity'::numeric",
+        b"-Infinity": b"'-Infinity'::numeric",
+        b"NaN": b"'NaN'::numeric",
+    }
 
     @property
     def oid(self) -> int:
         return NUMERIC_OID
 
 
-@Dumper.text(bool)
-class TextBoolDumper(Dumper):
-    def dump(self, obj: bool) -> bytes:
-        return b"t" if obj else b"f"
-
-    @property
-    def oid(self) -> int:
-        return BOOL_OID
-
-
-@Dumper.binary(bool)
-class BinaryBoolDumper(Dumper):
-    def dump(self, obj: bool) -> bytes:
-        return b"\x01" if obj else b"\x00"
-
-    @property
-    def oid(self) -> int:
-        return BOOL_OID
-
-
 @Loader.text(builtins["int2"].oid)
 @Loader.text(builtins["int4"].oid)
 @Loader.text(builtins["int8"].oid)
@@ -161,23 +156,3 @@ class TextNumericLoader(Loader):
         self, data: bytes, __decode: DecodeFunc = _decode_ascii
     ) -> Decimal:
         return Decimal(__decode(data)[0])
-
-
-@Loader.text(builtins["bool"].oid)
-class TextBoolLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __values: Dict[bytes, bool] = {b"t": True, b"f": False},
-    ) -> bool:
-        return __values[data]
-
-
-@Loader.binary(builtins["bool"].oid)
-class BinaryBoolLoader(Loader):
-    def load(
-        self,
-        data: bytes,
-        __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False},
-    ) -> bool:
-        return __values[data]
diff --git a/psycopg3/psycopg3/types/singletons.py b/psycopg3/psycopg3/types/singletons.py
new file mode 100644 (file)
index 0000000..9aaa57f
--- /dev/null
@@ -0,0 +1,66 @@
+"""
+Adapters for None and boolean.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Dict
+
+from ..adapt import Dumper, Loader
+from .oids import builtins
+
+BOOL_OID = builtins["bool"].oid
+
+
+@Dumper.text(bool)
+class TextBoolDumper(Dumper):
+    def dump(self, obj: bool) -> bytes:
+        return b"t" if obj else b"f"
+
+    def quote(self, obj: bool) -> bytes:
+        return b"true" if obj else b"false"
+
+    @property
+    def oid(self) -> int:
+        return BOOL_OID
+
+
+@Dumper.binary(bool)
+class BinaryBoolDumper(Dumper):
+    def dump(self, obj: bool) -> bytes:
+        return b"\x01" if obj else b"\x00"
+
+    @property
+    def oid(self) -> int:
+        return BOOL_OID
+
+
+@Dumper.text(type(None))
+class NoneDumper(Dumper):
+    """
+    Not a complete dumper as it doesn't implement dump(), but it implements
+    quote(), so it can be used in sql composition.
+    """
+
+    def quote(self, obj: None) -> bytes:
+        return b"NULL"
+
+
+@Loader.text(builtins["bool"].oid)
+class TextBoolLoader(Loader):
+    def load(
+        self,
+        data: bytes,
+        __values: Dict[bytes, bool] = {b"t": True, b"f": False},
+    ) -> bool:
+        return __values[data]
+
+
+@Loader.binary(builtins["bool"].oid)
+class BinaryBoolLoader(Loader):
+    def load(
+        self,
+        data: bytes,
+        __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False},
+    ) -> bool:
+        return __values[data]
index b566c1859990b582d7e3f40a09868864a006de1f..893b256424e50b2ff3c5b39b60921c15447ab9fb 100644 (file)
@@ -20,6 +20,22 @@ def test_dump(data, format, result, type):
     assert dumper.oid == builtins[type].oid
 
 
+@pytest.mark.parametrize(
+    "data, result",
+    [
+        (1, b"1"),
+        ("hello", b"'hello'"),
+        ("he'llo", b"'he''llo'"),
+        (True, b"true"),
+        (None, b"NULL"),
+    ],
+)
+def test_quote(data, result):
+    t = Transformer()
+    dumper = t.get_dumper(data, Format.TEXT)
+    assert dumper.quote(data) == result
+
+
 def test_dump_connection_ctx(conn):
     Dumper.register(str, make_dumper("t"), conn)
     Dumper.register_binary(str, make_dumper("b"), conn)
index 087e5d0f83dece644126716eeb82656ca943b4cd..36e3ea7c1bd86c086dbba7ce9baa7e9c918577d3 100755 (executable)
@@ -13,7 +13,7 @@ from psycopg3.pq import Format
 
 @pytest.mark.parametrize(
     "obj, quoted",
-    [("hello", "'hello'"), (42, "'42'"), (True, "'t'"), (None, "NULL")],
+    [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")],
 )
 def test_quote(obj, quoted):
     assert sql.quote(obj) == quoted
@@ -72,13 +72,13 @@ class TestSqlFormat:
     def test_percent_escape(self, conn):
         s = sql.SQL("42 % {0}").format(sql.Literal(7))
         s1 = s.as_string(conn)
-        assert s1 == "42 % '7'"
+        assert s1 == "42 % 7"
 
     def test_braces_escape(self, conn):
         s = sql.SQL("{{{0}}}").format(sql.Literal(7))
-        assert s.as_string(conn) == "{'7'}"
+        assert s.as_string(conn) == "{7}"
         s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
-        assert s.as_string(conn) == "{1,'7'}"
+        assert s.as_string(conn) == "{1,7}"
 
     def test_compose_badnargs(self):
         with pytest.raises(IndexError):
@@ -250,7 +250,7 @@ class TestLiteral:
     def test_as_str(self, conn):
         assert sql.Literal(None).as_string(conn) == "NULL"
         assert noe(sql.Literal("foo").as_string(conn)) == "'foo'"
-        assert sql.Literal(42).as_string(conn) == "'42'"
+        assert sql.Literal(42).as_string(conn) == "42"
         assert (
             sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
         )
@@ -313,7 +313,7 @@ class TestSQL:
             [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
         )
         assert isinstance(obj, sql.Composed)
-        assert obj.as_string(conn) == "\"foo\", bar, '42'"
+        assert obj.as_string(conn) == '"foo", bar, 42'
 
         obj = sql.SQL(", ").join(
             sql.Composed(
@@ -321,7 +321,7 @@ class TestSQL:
             )
         )
         assert isinstance(obj, sql.Composed)
-        assert obj.as_string(conn) == "\"foo\", bar, '42'"
+        assert obj.as_string(conn) == '"foo", bar, 42'
 
         obj = sql.SQL(", ").join([])
         assert obj == sql.Composed([])
index 036e102b8b451f9408acd8d2e06fe634c3da43b4..75f612c48dd99c62df6f5ed5ae13454863f1351b 100644 (file)
@@ -3,7 +3,7 @@ import datetime as dt
 
 import pytest
 
-from psycopg3 import DataError
+from psycopg3 import DataError, sql
 from psycopg3.adapt import Format
 
 
@@ -24,8 +24,14 @@ from psycopg3.adapt import Format
     ],
 )
 def test_dump_date(conn, val, expr):
+    val = as_date(val)
     cur = conn.cursor()
-    cur.execute(f"select '{expr}'::date = %s", (as_date(val),))
+    cur.execute(f"select '{expr}'::date = %s", (val,))
+    assert cur.fetchone()[0] is True
+
+    cur.execute(
+        sql.SQL("select {val}::date = %s").format(val=sql.Literal(val)), (val,)
+    )
     assert cur.fetchone()[0] is True
 
 
index 02c36558d1e1d8e348e341c28aac720af1176396..f491e6aec021af25d995bcd07c76af79232f6ef5 100644 (file)
@@ -3,6 +3,7 @@ from math import isnan, isinf, exp
 
 import pytest
 
+from psycopg3 import sql
 from psycopg3.adapt import Loader, Transformer, Format
 from psycopg3.types import builtins
 from psycopg3.types.numeric import TextFloatLoader
@@ -28,8 +29,29 @@ from psycopg3.types.numeric import TextFloatLoader
 def test_dump_int(conn, val, expr):
     assert isinstance(val, int)
     cur = conn.cursor()
-    cur.execute("select %s = %%s" % expr, (val,))
-    assert cur.fetchone()[0]
+    cur.execute(f"select {expr} = %s", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0, b"0"),
+        (1, b"1"),
+        (-1, b" -1"),
+        (42, b"42"),
+        (-42, b" -42"),
+        (int(2 ** 63 - 1), b"9223372036854775807"),
+        (int(-(2 ** 63)), b" -9223372036854775808"),
+    ],
+)
+def test_quote_int(conn, val, expr):
+    tx = Transformer()
+    assert tx.get_dumper(val, 0).quote(val) == expr
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+    assert cur.fetchone() == (val, -val)
 
 
 @pytest.mark.xfail
@@ -91,16 +113,46 @@ def test_load_int(conn, val, pgtype, want, fmt_out):
         (0.0, "'0'"),
         (1.0, "'1'"),
         (-1.0, "'-1'"),
-        (float("nan"), "'nan'"),
-        (float("inf"), "'inf'"),
-        (float("-inf"), "'-inf'"),
+        (float("nan"), "'NaN'"),
+        (float("inf"), "'Infinity'"),
+        (float("-inf"), "'-Infinity'"),
     ],
 )
 def test_dump_float(conn, val, expr):
     assert isinstance(val, float)
     cur = conn.cursor()
-    cur.execute("select %%s = %s::float8" % expr, (val,))
-    assert cur.fetchone()[0]
+    cur.execute(f"select %s = {expr}::float8", (val,))
+    assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        (0.0, b"0.0"),
+        (1.0, b"1.0"),
+        (10000000000000000.0, b"1e+16"),
+        (1000000.1, b"1000000.1"),
+        (-100000.000001, b" -100000.000001"),
+        (-1.0, b" -1.0"),
+        (float("nan"), b"'NaN'::float8"),
+        (float("inf"), b"'Infinity'::float8"),
+        (float("-inf"), b"'-Infinity'::float8"),
+    ],
+)
+def test_quote_float(conn, val, expr):
+    tx = Transformer()
+    assert tx.get_dumper(val, 0).quote(val) == expr
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+    r = cur.fetchone()
+    if isnan(val):
+        assert isnan(r[0]) and isnan(r[1])
+    else:
+        if isinstance(r[0], Decimal):
+            r = tuple(map(float, r))
+
+        assert r == (val, -val)
 
 
 @pytest.mark.parametrize(
@@ -118,15 +170,14 @@ def test_dump_float_approx(conn, val, expr):
     assert isinstance(val, float)
     cur = conn.cursor()
     cur.execute(
-        "select abs((%s::float8 - %%s) / %s::float8) <= 1e-15" % (expr, expr),
-        (val,),
+        f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)
     )
-    assert cur.fetchone()[0]
+    assert cur.fetchone()[0] is True
+
     cur.execute(
-        "select abs((%s::float4 - %%s) / %s::float4) <= 1e-6" % (expr, expr),
-        (val,),
+        f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)
     )
-    assert cur.fetchone()[0]
+    assert cur.fetchone()[0] is True
 
 
 @pytest.mark.xfail
@@ -233,6 +284,31 @@ def test_roundtrip_numeric(conn, val):
         assert result == val
 
 
+@pytest.mark.parametrize(
+    "val, expr",
+    [
+        ("0", b"0"),
+        ("0.0", b"0.0"),
+        ("0.00000000000000001", b"1E-17"),
+        ("-0.00000000000000001", b" -1E-17"),
+        ("nan", b"'NaN'::numeric"),
+    ],
+)
+def test_quote_numeric(conn, val, expr):
+    val = Decimal(val)
+    tx = Transformer()
+    assert tx.get_dumper(val, 0).quote(val) == expr
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+    r = cur.fetchone()
+
+    if val.is_nan():
+        assert isnan(r[0]) and isnan(r[1])
+    else:
+        assert r == (val, -val)
+
+
 @pytest.mark.xfail
 def test_dump_numeric_binary():
     # TODO: numeric binary adaptation
@@ -307,6 +383,17 @@ def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
     assert result[0] is b
 
 
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+    tx = Transformer()
+    assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii")
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+    assert cur.fetchone()[0] is val
+
+
 @pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
 def test_minus_minus(conn, pgtype):
     cur = conn.cursor()
diff --git a/tests/types/test_singletons.py b/tests/types/test_singletons.py
new file mode 100644 (file)
index 0000000..14ff018
--- /dev/null
@@ -0,0 +1,45 @@
+import pytest
+
+from psycopg3 import sql
+from psycopg3.adapt import Transformer, Format
+from psycopg3.types import builtins
+
+
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("b", [True, False, None])
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+    cur = conn.cursor(format=fmt_out)
+    ph = "%s" if fmt_in == Format.TEXT else "%b"
+    result = cur.execute(f"select {ph}", (b,)).fetchone()[0]
+    assert cur.pgresult.fformat(0) == fmt_out
+    if b is not None:
+        assert cur.pgresult.ftype(0) == builtins["bool"].oid
+    assert result is b
+
+    result = cur.execute(f"select {ph}", ([b],)).fetchone()[0]
+    assert cur.pgresult.fformat(0) == fmt_out
+    if b is not None:
+        assert cur.pgresult.ftype(0) == builtins["bool"].array_oid
+    assert result[0] is b
+
+
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+    tx = Transformer()
+    assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii")
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+    assert cur.fetchone()[0] is val
+
+
+def test_quote_none(conn):
+
+    tx = Transformer()
+    assert tx.get_dumper(None, 0).quote(None) == b"NULL"
+
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+    assert cur.fetchone()[0] is None
index 90da912625d291b130ad18d6331e431831dddaee..d15db13a9bc1ca1de68bd875a6c9f69db504078b 100644 (file)
@@ -1,6 +1,6 @@
 import pytest
 
-import psycopg3
+from psycopg3 import DatabaseError, sql
 from psycopg3.adapt import Format
 
 eur = "\u20ac"
@@ -17,7 +17,33 @@ def test_dump_1char(conn, fmt_in):
     ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(1, 256):
         cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i))
-        assert cur.fetchone()[0], chr(i)
+        assert cur.fetchone()[0] is True, chr(i)
+
+
+def test_quote_1char(conn):
+    cur = conn.cursor()
+    query = sql.SQL("select {ch} = chr(%s::int)")
+    for i in range(1, 256):
+        if chr(i) == "%":
+            continue
+        cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
+        assert cur.fetchone()[0] is True, chr(i)
+
+
+# the only way to make this pass is to reduce %% -> % every time
+# not only when there are query arguments
+# see https://github.com/psycopg/psycopg2/issues/825
+@pytest.mark.xfail
+def test_quote_percent(conn):
+    cur = conn.cursor()
+    cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%")))
+    assert cur.fetchone()[0] == "%"
+
+    cur.execute(
+        sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")),
+        (ord("%"),),
+    )
+    assert cur.fetchone()[0] is True
 
 
 @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
@@ -82,7 +108,7 @@ def test_load_badenc(conn, typename, fmt_out):
     cur = conn.cursor(format=fmt_out)
 
     conn.client_encoding = "latin1"
-    with pytest.raises(psycopg3.DatabaseError):
+    with pytest.raises(DatabaseError):
         cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
 
 
@@ -143,7 +169,15 @@ def test_dump_1byte(conn, fmt_in):
     ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(0, 256):
         cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}"))
-        assert cur.fetchone()[0], i
+        assert cur.fetchone()[0] is True, i
+
+
+def test_quote_1byte(conn):
+    cur = conn.cursor()
+    query = sql.SQL("select {ch} = %s::bytea")
+    for i in range(0, 256):
+        cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",))
+        assert cur.fetchone()[0] is True, i
 
 
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])