From: Daniele Varrazzo Date: Wed, 28 Oct 2020 15:42:25 +0000 (+0100) Subject: Added Dumper.quote() X-Git-Tag: 3.0.dev0~424^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=b8fad3b9d83109b2ef8d050d542a856fdc084e53;p=thirdparty%2Fpsycopg.git Added Dumper.quote() Return a representation of the object in a format useful for embedding in SQL. For most objects this is only the representation in quotes and escaped, for some it's a different value (numbers, bools, NULL). None can be quoted but not dumped, so it can be passed to an sql.Literal() without the need of a special case. --- diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 1cb92e92a..f85c2b8c2 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -29,6 +29,16 @@ class Dumper: def dump(self, obj: Any) -> bytes: raise NotImplementedError() + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + + if self.connection: + esc = pq.Escaping(self.connection.pgconn) + return esc.escape_literal(value) + else: + esc = pq.Escaping() + return b"'%s'" % esc.escape_string(value) + @property def oid(self) -> int: return TEXT_OID diff --git a/psycopg3/psycopg3/sql.py b/psycopg3/psycopg3/sql.py index 1be83b4ee..16f16b87c 100644 --- a/psycopg3/psycopg3/sql.py +++ b/psycopg3/psycopg3/sql.py @@ -366,24 +366,13 @@ class Literal(Composable): """ def as_string(self, context: AdaptContext) -> str: - if self._obj is None: - return "NULL" - from .adapt import _connection_from_context, Transformer conn = _connection_from_context(context) - tx = Transformer(conn) + tx = context if isinstance(context, Transformer) else Transformer(conn) dumper = tx.get_dumper(self._obj, Format.TEXT) - value = dumper.dump(self._obj) - - if conn: - esc = Escaping(conn.pgconn) - quoted = esc.escape_literal(value) - return conn.codec.decode(quoted)[0] - else: - esc = Escaping() - quoted = b"'%s'" % esc.escape_string(value) - return quoted.decode("utf8") + quoted = dumper.quote(self._obj) + return conn.codec.decode(quoted)[0] if conn else quoted.decode("utf8") class Placeholder(Composable): diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index bbdd8a8b7..0ee0378c1 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -8,7 +8,7 @@ psycopg3 types package from .oids import builtins # Register default adapters -from . import array, composite, date, json, numeric, text # noqa +from . import array, composite, date, json, numeric, singletons, text # noqa # Register associations with array oids array.register_all_arrays() diff --git a/psycopg3/psycopg3/types/numeric.py b/psycopg3/psycopg3/types/numeric.py index cae84c72f..9d70dfe40 100644 --- a/psycopg3/psycopg3/types/numeric.py +++ b/psycopg3/psycopg3/types/numeric.py @@ -6,7 +6,7 @@ Adapers for numeric types. import codecs import struct -from typing import Callable, Dict, Tuple, cast +from typing import Any, Callable, Dict, Tuple, cast from decimal import Decimal from ..adapt import Dumper, Loader @@ -18,17 +18,28 @@ UnpackFloat = Callable[[bytes], Tuple[float]] FLOAT8_OID = builtins["float8"].oid NUMERIC_OID = builtins["numeric"].oid -BOOL_OID = builtins["bool"].oid _encode_ascii = codecs.lookup("ascii").encode _decode_ascii = codecs.lookup("ascii").decode -@Dumper.text(int) -class TextIntDumper(Dumper): - def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes: +class NumberDumper(Dumper): + _special: Dict[bytes, bytes] = {} + + def dump(self, obj: Any, __encode: EncodeFunc = _encode_ascii) -> bytes: return __encode(str(obj))[0] + def quote(self, obj: Any) -> bytes: + value = self.dump(obj) + + if value in self._special: + return self._special[value] + + return b" " + value if value.startswith(b"-") else value + + +@Dumper.text(int) +class TextIntDumper(NumberDumper): @property def oid(self) -> int: # We don't know the size of it, so we have to return a type big enough @@ -36,9 +47,12 @@ class TextIntDumper(Dumper): @Dumper.text(float) -class TextFloatDumper(Dumper): - def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes: - return __encode(str(obj))[0] +class TextFloatDumper(NumberDumper): + _special = { + b"inf": b"'Infinity'::float8", + b"-inf": b"'-Infinity'::float8", + b"nan": b"'NaN'::float8", + } @property def oid(self) -> int: @@ -47,37 +61,18 @@ class TextFloatDumper(Dumper): @Dumper.text(Decimal) -class TextDecimalDumper(Dumper): - def dump( - self, obj: Decimal, __encode: EncodeFunc = _encode_ascii - ) -> bytes: - return __encode(str(obj))[0] +class TextDecimalDumper(NumberDumper): + _special = { + b"Infinity": b"'Infinity'::numeric", + b"-Infinity": b"'-Infinity'::numeric", + b"NaN": b"'NaN'::numeric", + } @property def oid(self) -> int: return NUMERIC_OID -@Dumper.text(bool) -class TextBoolDumper(Dumper): - def dump(self, obj: bool) -> bytes: - return b"t" if obj else b"f" - - @property - def oid(self) -> int: - return BOOL_OID - - -@Dumper.binary(bool) -class BinaryBoolDumper(Dumper): - def dump(self, obj: bool) -> bytes: - return b"\x01" if obj else b"\x00" - - @property - def oid(self) -> int: - return BOOL_OID - - @Loader.text(builtins["int2"].oid) @Loader.text(builtins["int4"].oid) @Loader.text(builtins["int8"].oid) @@ -161,23 +156,3 @@ class TextNumericLoader(Loader): self, data: bytes, __decode: DecodeFunc = _decode_ascii ) -> Decimal: return Decimal(__decode(data)[0]) - - -@Loader.text(builtins["bool"].oid) -class TextBoolLoader(Loader): - def load( - self, - data: bytes, - __values: Dict[bytes, bool] = {b"t": True, b"f": False}, - ) -> bool: - return __values[data] - - -@Loader.binary(builtins["bool"].oid) -class BinaryBoolLoader(Loader): - def load( - self, - data: bytes, - __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False}, - ) -> bool: - return __values[data] diff --git a/psycopg3/psycopg3/types/singletons.py b/psycopg3/psycopg3/types/singletons.py new file mode 100644 index 000000000..9aaa57f53 --- /dev/null +++ b/psycopg3/psycopg3/types/singletons.py @@ -0,0 +1,66 @@ +""" +Adapters for None and boolean. +""" + +# Copyright (C) 2020 The Psycopg Team + +from typing import Dict + +from ..adapt import Dumper, Loader +from .oids import builtins + +BOOL_OID = builtins["bool"].oid + + +@Dumper.text(bool) +class TextBoolDumper(Dumper): + def dump(self, obj: bool) -> bytes: + return b"t" if obj else b"f" + + def quote(self, obj: bool) -> bytes: + return b"true" if obj else b"false" + + @property + def oid(self) -> int: + return BOOL_OID + + +@Dumper.binary(bool) +class BinaryBoolDumper(Dumper): + def dump(self, obj: bool) -> bytes: + return b"\x01" if obj else b"\x00" + + @property + def oid(self) -> int: + return BOOL_OID + + +@Dumper.text(type(None)) +class NoneDumper(Dumper): + """ + Not a complete dumper as it doesn't implement dump(), but it implements + quote(), so it can be used in sql composition. + """ + + def quote(self, obj: None) -> bytes: + return b"NULL" + + +@Loader.text(builtins["bool"].oid) +class TextBoolLoader(Loader): + def load( + self, + data: bytes, + __values: Dict[bytes, bool] = {b"t": True, b"f": False}, + ) -> bool: + return __values[data] + + +@Loader.binary(builtins["bool"].oid) +class BinaryBoolLoader(Loader): + def load( + self, + data: bytes, + __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False}, + ) -> bool: + return __values[data] diff --git a/tests/test_adapt.py b/tests/test_adapt.py index b566c1859..893b25642 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -20,6 +20,22 @@ def test_dump(data, format, result, type): assert dumper.oid == builtins[type].oid +@pytest.mark.parametrize( + "data, result", + [ + (1, b"1"), + ("hello", b"'hello'"), + ("he'llo", b"'he''llo'"), + (True, b"true"), + (None, b"NULL"), + ], +) +def test_quote(data, result): + t = Transformer() + dumper = t.get_dumper(data, Format.TEXT) + assert dumper.quote(data) == result + + def test_dump_connection_ctx(conn): Dumper.register(str, make_dumper("t"), conn) Dumper.register_binary(str, make_dumper("b"), conn) diff --git a/tests/test_sql.py b/tests/test_sql.py index 087e5d0f8..36e3ea7c1 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -13,7 +13,7 @@ from psycopg3.pq import Format @pytest.mark.parametrize( "obj, quoted", - [("hello", "'hello'"), (42, "'42'"), (True, "'t'"), (None, "NULL")], + [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")], ) def test_quote(obj, quoted): assert sql.quote(obj) == quoted @@ -72,13 +72,13 @@ class TestSqlFormat: def test_percent_escape(self, conn): s = sql.SQL("42 % {0}").format(sql.Literal(7)) s1 = s.as_string(conn) - assert s1 == "42 % '7'" + assert s1 == "42 % 7" def test_braces_escape(self, conn): s = sql.SQL("{{{0}}}").format(sql.Literal(7)) - assert s.as_string(conn) == "{'7'}" + assert s.as_string(conn) == "{7}" s = sql.SQL("{{1,{0}}}").format(sql.Literal(7)) - assert s.as_string(conn) == "{1,'7'}" + assert s.as_string(conn) == "{1,7}" def test_compose_badnargs(self): with pytest.raises(IndexError): @@ -250,7 +250,7 @@ class TestLiteral: def test_as_str(self, conn): assert sql.Literal(None).as_string(conn) == "NULL" assert noe(sql.Literal("foo").as_string(conn)) == "'foo'" - assert sql.Literal(42).as_string(conn) == "'42'" + assert sql.Literal(42).as_string(conn) == "42" assert ( sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'" ) @@ -313,7 +313,7 @@ class TestSQL: [sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)] ) assert isinstance(obj, sql.Composed) - assert obj.as_string(conn) == "\"foo\", bar, '42'" + assert obj.as_string(conn) == '"foo", bar, 42' obj = sql.SQL(", ").join( sql.Composed( @@ -321,7 +321,7 @@ class TestSQL: ) ) assert isinstance(obj, sql.Composed) - assert obj.as_string(conn) == "\"foo\", bar, '42'" + assert obj.as_string(conn) == '"foo", bar, 42' obj = sql.SQL(", ").join([]) assert obj == sql.Composed([]) diff --git a/tests/types/test_date.py b/tests/types/test_date.py index 036e102b8..75f612c48 100644 --- a/tests/types/test_date.py +++ b/tests/types/test_date.py @@ -3,7 +3,7 @@ import datetime as dt import pytest -from psycopg3 import DataError +from psycopg3 import DataError, sql from psycopg3.adapt import Format @@ -24,8 +24,14 @@ from psycopg3.adapt import Format ], ) def test_dump_date(conn, val, expr): + val = as_date(val) cur = conn.cursor() - cur.execute(f"select '{expr}'::date = %s", (as_date(val),)) + cur.execute(f"select '{expr}'::date = %s", (val,)) + assert cur.fetchone()[0] is True + + cur.execute( + sql.SQL("select {val}::date = %s").format(val=sql.Literal(val)), (val,) + ) assert cur.fetchone()[0] is True diff --git a/tests/types/test_numeric.py b/tests/types/test_numeric.py index 02c36558d..f491e6aec 100644 --- a/tests/types/test_numeric.py +++ b/tests/types/test_numeric.py @@ -3,6 +3,7 @@ from math import isnan, isinf, exp import pytest +from psycopg3 import sql from psycopg3.adapt import Loader, Transformer, Format from psycopg3.types import builtins from psycopg3.types.numeric import TextFloatLoader @@ -28,8 +29,29 @@ from psycopg3.types.numeric import TextFloatLoader def test_dump_int(conn, val, expr): assert isinstance(val, int) cur = conn.cursor() - cur.execute("select %s = %%s" % expr, (val,)) - assert cur.fetchone()[0] + cur.execute(f"select {expr} = %s", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0, b"0"), + (1, b"1"), + (-1, b" -1"), + (42, b"42"), + (-42, b" -42"), + (int(2 ** 63 - 1), b"9223372036854775807"), + (int(-(2 ** 63)), b" -9223372036854775808"), + ], +) +def test_quote_int(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, 0).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + assert cur.fetchone() == (val, -val) @pytest.mark.xfail @@ -91,16 +113,46 @@ def test_load_int(conn, val, pgtype, want, fmt_out): (0.0, "'0'"), (1.0, "'1'"), (-1.0, "'-1'"), - (float("nan"), "'nan'"), - (float("inf"), "'inf'"), - (float("-inf"), "'-inf'"), + (float("nan"), "'NaN'"), + (float("inf"), "'Infinity'"), + (float("-inf"), "'-Infinity'"), ], ) def test_dump_float(conn, val, expr): assert isinstance(val, float) cur = conn.cursor() - cur.execute("select %%s = %s::float8" % expr, (val,)) - assert cur.fetchone()[0] + cur.execute(f"select %s = {expr}::float8", (val,)) + assert cur.fetchone()[0] is True + + +@pytest.mark.parametrize( + "val, expr", + [ + (0.0, b"0.0"), + (1.0, b"1.0"), + (10000000000000000.0, b"1e+16"), + (1000000.1, b"1000000.1"), + (-100000.000001, b" -100000.000001"), + (-1.0, b" -1.0"), + (float("nan"), b"'NaN'::float8"), + (float("inf"), b"'Infinity'::float8"), + (float("-inf"), b"'-Infinity'::float8"), + ], +) +def test_quote_float(conn, val, expr): + tx = Transformer() + assert tx.get_dumper(val, 0).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + if isnan(val): + assert isnan(r[0]) and isnan(r[1]) + else: + if isinstance(r[0], Decimal): + r = tuple(map(float, r)) + + assert r == (val, -val) @pytest.mark.parametrize( @@ -118,15 +170,14 @@ def test_dump_float_approx(conn, val, expr): assert isinstance(val, float) cur = conn.cursor() cur.execute( - "select abs((%s::float8 - %%s) / %s::float8) <= 1e-15" % (expr, expr), - (val,), + f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,) ) - assert cur.fetchone()[0] + assert cur.fetchone()[0] is True + cur.execute( - "select abs((%s::float4 - %%s) / %s::float4) <= 1e-6" % (expr, expr), - (val,), + f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,) ) - assert cur.fetchone()[0] + assert cur.fetchone()[0] is True @pytest.mark.xfail @@ -233,6 +284,31 @@ def test_roundtrip_numeric(conn, val): assert result == val +@pytest.mark.parametrize( + "val, expr", + [ + ("0", b"0"), + ("0.0", b"0.0"), + ("0.00000000000000001", b"1E-17"), + ("-0.00000000000000001", b" -1E-17"), + ("nan", b"'NaN'::numeric"), + ], +) +def test_quote_numeric(conn, val, expr): + val = Decimal(val) + tx = Transformer() + assert tx.get_dumper(val, 0).quote(val) == expr + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val))) + r = cur.fetchone() + + if val.is_nan(): + assert isnan(r[0]) and isnan(r[1]) + else: + assert r == (val, -val) + + @pytest.mark.xfail def test_dump_numeric_binary(): # TODO: numeric binary adaptation @@ -307,6 +383,17 @@ def test_roundtrip_bool(conn, b, fmt_in, fmt_out): assert result[0] is b +@pytest.mark.parametrize("val", [True, False]) +def test_quote_bool(conn, val): + + tx = Transformer() + assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii") + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val))) + assert cur.fetchone()[0] is val + + @pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"]) def test_minus_minus(conn, pgtype): cur = conn.cursor() diff --git a/tests/types/test_singletons.py b/tests/types/test_singletons.py new file mode 100644 index 000000000..14ff018d8 --- /dev/null +++ b/tests/types/test_singletons.py @@ -0,0 +1,45 @@ +import pytest + +from psycopg3 import sql +from psycopg3.adapt import Transformer, Format +from psycopg3.types import builtins + + +@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +@pytest.mark.parametrize("b", [True, False, None]) +def test_roundtrip_bool(conn, b, fmt_in, fmt_out): + cur = conn.cursor(format=fmt_out) + ph = "%s" if fmt_in == Format.TEXT else "%b" + result = cur.execute(f"select {ph}", (b,)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].oid + assert result is b + + result = cur.execute(f"select {ph}", ([b],)).fetchone()[0] + assert cur.pgresult.fformat(0) == fmt_out + if b is not None: + assert cur.pgresult.ftype(0) == builtins["bool"].array_oid + assert result[0] is b + + +@pytest.mark.parametrize("val", [True, False]) +def test_quote_bool(conn, val): + + tx = Transformer() + assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii") + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val))) + assert cur.fetchone()[0] is val + + +def test_quote_none(conn): + + tx = Transformer() + assert tx.get_dumper(None, 0).quote(None) == b"NULL" + + cur = conn.cursor() + cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None))) + assert cur.fetchone()[0] is None diff --git a/tests/types/test_text.py b/tests/types/test_text.py index 90da91262..d15db13a9 100644 --- a/tests/types/test_text.py +++ b/tests/types/test_text.py @@ -1,6 +1,6 @@ import pytest -import psycopg3 +from psycopg3 import DatabaseError, sql from psycopg3.adapt import Format eur = "\u20ac" @@ -17,7 +17,33 @@ def test_dump_1char(conn, fmt_in): ph = "%s" if fmt_in == Format.TEXT else "%b" for i in range(1, 256): cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i)) - assert cur.fetchone()[0], chr(i) + assert cur.fetchone()[0] is True, chr(i) + + +def test_quote_1char(conn): + cur = conn.cursor() + query = sql.SQL("select {ch} = chr(%s::int)") + for i in range(1, 256): + if chr(i) == "%": + continue + cur.execute(query.format(ch=sql.Literal(chr(i))), (i,)) + assert cur.fetchone()[0] is True, chr(i) + + +# the only way to make this pass is to reduce %% -> % every time +# not only when there are query arguments +# see https://github.com/psycopg/psycopg2/issues/825 +@pytest.mark.xfail +def test_quote_percent(conn): + cur = conn.cursor() + cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%"))) + assert cur.fetchone()[0] == "%" + + cur.execute( + sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")), + (ord("%"),), + ) + assert cur.fetchone()[0] is True @pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"]) @@ -82,7 +108,7 @@ def test_load_badenc(conn, typename, fmt_out): cur = conn.cursor(format=fmt_out) conn.client_encoding = "latin1" - with pytest.raises(psycopg3.DatabaseError): + with pytest.raises(DatabaseError): cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) @@ -143,7 +169,15 @@ def test_dump_1byte(conn, fmt_in): ph = "%s" if fmt_in == Format.TEXT else "%b" for i in range(0, 256): cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}")) - assert cur.fetchone()[0], i + assert cur.fetchone()[0] is True, i + + +def test_quote_1byte(conn): + cur = conn.cursor() + query = sql.SQL("select {ch} = %s::bytea") + for i in range(0, 256): + cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",)) + assert cur.fetchone()[0] is True, i @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])