def dump(self, obj: Any) -> bytes:
raise NotImplementedError()
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+
+ if self.connection:
+ esc = pq.Escaping(self.connection.pgconn)
+ return esc.escape_literal(value)
+ else:
+ esc = pq.Escaping()
+ return b"'%s'" % esc.escape_string(value)
+
@property
def oid(self) -> int:
return TEXT_OID
"""
def as_string(self, context: AdaptContext) -> str:
- if self._obj is None:
- return "NULL"
-
from .adapt import _connection_from_context, Transformer
conn = _connection_from_context(context)
- tx = Transformer(conn)
+ tx = context if isinstance(context, Transformer) else Transformer(conn)
dumper = tx.get_dumper(self._obj, Format.TEXT)
- value = dumper.dump(self._obj)
-
- if conn:
- esc = Escaping(conn.pgconn)
- quoted = esc.escape_literal(value)
- return conn.codec.decode(quoted)[0]
- else:
- esc = Escaping()
- quoted = b"'%s'" % esc.escape_string(value)
- return quoted.decode("utf8")
+ quoted = dumper.quote(self._obj)
+ return conn.codec.decode(quoted)[0] if conn else quoted.decode("utf8")
class Placeholder(Composable):
from .oids import builtins
# Register default adapters
-from . import array, composite, date, json, numeric, text # noqa
+from . import array, composite, date, json, numeric, singletons, text # noqa
# Register associations with array oids
array.register_all_arrays()
import codecs
import struct
-from typing import Callable, Dict, Tuple, cast
+from typing import Any, Callable, Dict, Tuple, cast
from decimal import Decimal
from ..adapt import Dumper, Loader
FLOAT8_OID = builtins["float8"].oid
NUMERIC_OID = builtins["numeric"].oid
-BOOL_OID = builtins["bool"].oid
_encode_ascii = codecs.lookup("ascii").encode
_decode_ascii = codecs.lookup("ascii").decode
-@Dumper.text(int)
-class TextIntDumper(Dumper):
- def dump(self, obj: int, __encode: EncodeFunc = _encode_ascii) -> bytes:
+class NumberDumper(Dumper):
+ _special: Dict[bytes, bytes] = {}
+
+ def dump(self, obj: Any, __encode: EncodeFunc = _encode_ascii) -> bytes:
return __encode(str(obj))[0]
+ def quote(self, obj: Any) -> bytes:
+ value = self.dump(obj)
+
+ if value in self._special:
+ return self._special[value]
+
+ return b" " + value if value.startswith(b"-") else value
+
+
+@Dumper.text(int)
+class TextIntDumper(NumberDumper):
@property
def oid(self) -> int:
# We don't know the size of it, so we have to return a type big enough
@Dumper.text(float)
-class TextFloatDumper(Dumper):
- def dump(self, obj: float, __encode: EncodeFunc = _encode_ascii) -> bytes:
- return __encode(str(obj))[0]
+class TextFloatDumper(NumberDumper):
+ _special = {
+ b"inf": b"'Infinity'::float8",
+ b"-inf": b"'-Infinity'::float8",
+ b"nan": b"'NaN'::float8",
+ }
@property
def oid(self) -> int:
@Dumper.text(Decimal)
-class TextDecimalDumper(Dumper):
- def dump(
- self, obj: Decimal, __encode: EncodeFunc = _encode_ascii
- ) -> bytes:
- return __encode(str(obj))[0]
+class TextDecimalDumper(NumberDumper):
+ _special = {
+ b"Infinity": b"'Infinity'::numeric",
+ b"-Infinity": b"'-Infinity'::numeric",
+ b"NaN": b"'NaN'::numeric",
+ }
@property
def oid(self) -> int:
return NUMERIC_OID
-@Dumper.text(bool)
-class TextBoolDumper(Dumper):
- def dump(self, obj: bool) -> bytes:
- return b"t" if obj else b"f"
-
- @property
- def oid(self) -> int:
- return BOOL_OID
-
-
-@Dumper.binary(bool)
-class BinaryBoolDumper(Dumper):
- def dump(self, obj: bool) -> bytes:
- return b"\x01" if obj else b"\x00"
-
- @property
- def oid(self) -> int:
- return BOOL_OID
-
-
@Loader.text(builtins["int2"].oid)
@Loader.text(builtins["int4"].oid)
@Loader.text(builtins["int8"].oid)
self, data: bytes, __decode: DecodeFunc = _decode_ascii
) -> Decimal:
return Decimal(__decode(data)[0])
-
-
-@Loader.text(builtins["bool"].oid)
-class TextBoolLoader(Loader):
- def load(
- self,
- data: bytes,
- __values: Dict[bytes, bool] = {b"t": True, b"f": False},
- ) -> bool:
- return __values[data]
-
-
-@Loader.binary(builtins["bool"].oid)
-class BinaryBoolLoader(Loader):
- def load(
- self,
- data: bytes,
- __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False},
- ) -> bool:
- return __values[data]
--- /dev/null
+"""
+Adapters for None and boolean.
+"""
+
+# Copyright (C) 2020 The Psycopg Team
+
+from typing import Dict
+
+from ..adapt import Dumper, Loader
+from .oids import builtins
+
+BOOL_OID = builtins["bool"].oid
+
+
+@Dumper.text(bool)
+class TextBoolDumper(Dumper):
+ def dump(self, obj: bool) -> bytes:
+ return b"t" if obj else b"f"
+
+ def quote(self, obj: bool) -> bytes:
+ return b"true" if obj else b"false"
+
+ @property
+ def oid(self) -> int:
+ return BOOL_OID
+
+
+@Dumper.binary(bool)
+class BinaryBoolDumper(Dumper):
+ def dump(self, obj: bool) -> bytes:
+ return b"\x01" if obj else b"\x00"
+
+ @property
+ def oid(self) -> int:
+ return BOOL_OID
+
+
+@Dumper.text(type(None))
+class NoneDumper(Dumper):
+ """
+ Not a complete dumper as it doesn't implement dump(), but it implements
+ quote(), so it can be used in sql composition.
+ """
+
+ def quote(self, obj: None) -> bytes:
+ return b"NULL"
+
+
+@Loader.text(builtins["bool"].oid)
+class TextBoolLoader(Loader):
+ def load(
+ self,
+ data: bytes,
+ __values: Dict[bytes, bool] = {b"t": True, b"f": False},
+ ) -> bool:
+ return __values[data]
+
+
+@Loader.binary(builtins["bool"].oid)
+class BinaryBoolLoader(Loader):
+ def load(
+ self,
+ data: bytes,
+ __values: Dict[bytes, bool] = {b"\x01": True, b"\x00": False},
+ ) -> bool:
+ return __values[data]
assert dumper.oid == builtins[type].oid
+@pytest.mark.parametrize(
+ "data, result",
+ [
+ (1, b"1"),
+ ("hello", b"'hello'"),
+ ("he'llo", b"'he''llo'"),
+ (True, b"true"),
+ (None, b"NULL"),
+ ],
+)
+def test_quote(data, result):
+ t = Transformer()
+ dumper = t.get_dumper(data, Format.TEXT)
+ assert dumper.quote(data) == result
+
+
def test_dump_connection_ctx(conn):
Dumper.register(str, make_dumper("t"), conn)
Dumper.register_binary(str, make_dumper("b"), conn)
@pytest.mark.parametrize(
"obj, quoted",
- [("hello", "'hello'"), (42, "'42'"), (True, "'t'"), (None, "NULL")],
+ [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")],
)
def test_quote(obj, quoted):
assert sql.quote(obj) == quoted
def test_percent_escape(self, conn):
s = sql.SQL("42 % {0}").format(sql.Literal(7))
s1 = s.as_string(conn)
- assert s1 == "42 % '7'"
+ assert s1 == "42 % 7"
def test_braces_escape(self, conn):
s = sql.SQL("{{{0}}}").format(sql.Literal(7))
- assert s.as_string(conn) == "{'7'}"
+ assert s.as_string(conn) == "{7}"
s = sql.SQL("{{1,{0}}}").format(sql.Literal(7))
- assert s.as_string(conn) == "{1,'7'}"
+ assert s.as_string(conn) == "{1,7}"
def test_compose_badnargs(self):
with pytest.raises(IndexError):
def test_as_str(self, conn):
assert sql.Literal(None).as_string(conn) == "NULL"
assert noe(sql.Literal("foo").as_string(conn)) == "'foo'"
- assert sql.Literal(42).as_string(conn) == "'42'"
+ assert sql.Literal(42).as_string(conn) == "42"
assert (
sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'"
)
[sql.Identifier("foo"), sql.SQL("bar"), sql.Literal(42)]
)
assert isinstance(obj, sql.Composed)
- assert obj.as_string(conn) == "\"foo\", bar, '42'"
+ assert obj.as_string(conn) == '"foo", bar, 42'
obj = sql.SQL(", ").join(
sql.Composed(
)
)
assert isinstance(obj, sql.Composed)
- assert obj.as_string(conn) == "\"foo\", bar, '42'"
+ assert obj.as_string(conn) == '"foo", bar, 42'
obj = sql.SQL(", ").join([])
assert obj == sql.Composed([])
import pytest
-from psycopg3 import DataError
+from psycopg3 import DataError, sql
from psycopg3.adapt import Format
],
)
def test_dump_date(conn, val, expr):
+ val = as_date(val)
cur = conn.cursor()
- cur.execute(f"select '{expr}'::date = %s", (as_date(val),))
+ cur.execute(f"select '{expr}'::date = %s", (val,))
+ assert cur.fetchone()[0] is True
+
+ cur.execute(
+ sql.SQL("select {val}::date = %s").format(val=sql.Literal(val)), (val,)
+ )
assert cur.fetchone()[0] is True
import pytest
+from psycopg3 import sql
from psycopg3.adapt import Loader, Transformer, Format
from psycopg3.types import builtins
from psycopg3.types.numeric import TextFloatLoader
def test_dump_int(conn, val, expr):
assert isinstance(val, int)
cur = conn.cursor()
- cur.execute("select %s = %%s" % expr, (val,))
- assert cur.fetchone()[0]
+ cur.execute(f"select {expr} = %s", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0, b"0"),
+ (1, b"1"),
+ (-1, b" -1"),
+ (42, b"42"),
+ (-42, b" -42"),
+ (int(2 ** 63 - 1), b"9223372036854775807"),
+ (int(-(2 ** 63)), b" -9223372036854775808"),
+ ],
+)
+def test_quote_int(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, 0).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ assert cur.fetchone() == (val, -val)
@pytest.mark.xfail
(0.0, "'0'"),
(1.0, "'1'"),
(-1.0, "'-1'"),
- (float("nan"), "'nan'"),
- (float("inf"), "'inf'"),
- (float("-inf"), "'-inf'"),
+ (float("nan"), "'NaN'"),
+ (float("inf"), "'Infinity'"),
+ (float("-inf"), "'-Infinity'"),
],
)
def test_dump_float(conn, val, expr):
assert isinstance(val, float)
cur = conn.cursor()
- cur.execute("select %%s = %s::float8" % expr, (val,))
- assert cur.fetchone()[0]
+ cur.execute(f"select %s = {expr}::float8", (val,))
+ assert cur.fetchone()[0] is True
+
+
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ (0.0, b"0.0"),
+ (1.0, b"1.0"),
+ (10000000000000000.0, b"1e+16"),
+ (1000000.1, b"1000000.1"),
+ (-100000.000001, b" -100000.000001"),
+ (-1.0, b" -1.0"),
+ (float("nan"), b"'NaN'::float8"),
+ (float("inf"), b"'Infinity'::float8"),
+ (float("-inf"), b"'-Infinity'::float8"),
+ ],
+)
+def test_quote_float(conn, val, expr):
+ tx = Transformer()
+ assert tx.get_dumper(val, 0).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+ if isnan(val):
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ if isinstance(r[0], Decimal):
+ r = tuple(map(float, r))
+
+ assert r == (val, -val)
@pytest.mark.parametrize(
assert isinstance(val, float)
cur = conn.cursor()
cur.execute(
- "select abs((%s::float8 - %%s) / %s::float8) <= 1e-15" % (expr, expr),
- (val,),
+ f"select abs(({expr}::float8 - %s) / {expr}::float8) <= 1e-15", (val,)
)
- assert cur.fetchone()[0]
+ assert cur.fetchone()[0] is True
+
cur.execute(
- "select abs((%s::float4 - %%s) / %s::float4) <= 1e-6" % (expr, expr),
- (val,),
+ f"select abs(({expr}::float4 - %s) / {expr}::float4) <= 1e-6", (val,)
)
- assert cur.fetchone()[0]
+ assert cur.fetchone()[0] is True
@pytest.mark.xfail
assert result == val
+@pytest.mark.parametrize(
+ "val, expr",
+ [
+ ("0", b"0"),
+ ("0.0", b"0.0"),
+ ("0.00000000000000001", b"1E-17"),
+ ("-0.00000000000000001", b" -1E-17"),
+ ("nan", b"'NaN'::numeric"),
+ ],
+)
+def test_quote_numeric(conn, val, expr):
+ val = Decimal(val)
+ tx = Transformer()
+ assert tx.get_dumper(val, 0).quote(val) == expr
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}, -{v}").format(v=sql.Literal(val)))
+ r = cur.fetchone()
+
+ if val.is_nan():
+ assert isnan(r[0]) and isnan(r[1])
+ else:
+ assert r == (val, -val)
+
+
@pytest.mark.xfail
def test_dump_numeric_binary():
# TODO: numeric binary adaptation
assert result[0] is b
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+ tx = Transformer()
+ assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii")
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+ assert cur.fetchone()[0] is val
+
+
@pytest.mark.parametrize("pgtype", [None, "float8", "int8", "numeric"])
def test_minus_minus(conn, pgtype):
cur = conn.cursor()
--- /dev/null
+import pytest
+
+from psycopg3 import sql
+from psycopg3.adapt import Transformer, Format
+from psycopg3.types import builtins
+
+
+@pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+@pytest.mark.parametrize("b", [True, False, None])
+def test_roundtrip_bool(conn, b, fmt_in, fmt_out):
+ cur = conn.cursor(format=fmt_out)
+ ph = "%s" if fmt_in == Format.TEXT else "%b"
+ result = cur.execute(f"select {ph}", (b,)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].oid
+ assert result is b
+
+ result = cur.execute(f"select {ph}", ([b],)).fetchone()[0]
+ assert cur.pgresult.fformat(0) == fmt_out
+ if b is not None:
+ assert cur.pgresult.ftype(0) == builtins["bool"].array_oid
+ assert result[0] is b
+
+
+@pytest.mark.parametrize("val", [True, False])
+def test_quote_bool(conn, val):
+
+ tx = Transformer()
+ assert tx.get_dumper(val, 0).quote(val) == str(val).lower().encode("ascii")
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(val)))
+ assert cur.fetchone()[0] is val
+
+
+def test_quote_none(conn):
+
+ tx = Transformer()
+ assert tx.get_dumper(None, 0).quote(None) == b"NULL"
+
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {v}").format(v=sql.Literal(None)))
+ assert cur.fetchone()[0] is None
import pytest
-import psycopg3
+from psycopg3 import DatabaseError, sql
from psycopg3.adapt import Format
eur = "\u20ac"
ph = "%s" if fmt_in == Format.TEXT else "%b"
for i in range(1, 256):
cur.execute(f"select {ph} = chr(%s::int)", (chr(i), i))
- assert cur.fetchone()[0], chr(i)
+ assert cur.fetchone()[0] is True, chr(i)
+
+
+def test_quote_1char(conn):
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = chr(%s::int)")
+ for i in range(1, 256):
+ if chr(i) == "%":
+ continue
+ cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
+ assert cur.fetchone()[0] is True, chr(i)
+
+
+# the only way to make this pass is to reduce %% -> % every time
+# not only when there are query arguments
+# see https://github.com/psycopg/psycopg2/issues/825
+@pytest.mark.xfail
+def test_quote_percent(conn):
+ cur = conn.cursor()
+ cur.execute(sql.SQL("select {ch}").format(ch=sql.Literal("%")))
+ assert cur.fetchone()[0] == "%"
+
+ cur.execute(
+ sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")),
+ (ord("%"),),
+ )
+ assert cur.fetchone()[0] is True
@pytest.mark.parametrize("typename", ["text", "varchar", "name", "bpchar"])
cur = conn.cursor(format=fmt_out)
conn.client_encoding = "latin1"
- with pytest.raises(psycopg3.DatabaseError):
+ with pytest.raises(DatabaseError):
cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
ph = "%s" if fmt_in == Format.TEXT else "%b"
for i in range(0, 256):
cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}"))
- assert cur.fetchone()[0], i
+ assert cur.fetchone()[0] is True, i
+
+
+def test_quote_1byte(conn):
+ cur = conn.cursor()
+ query = sql.SQL("select {ch} = %s::bytea")
+ for i in range(0, 256):
+ cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",))
+ assert cur.fetchone()[0] is True, i
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])