return f"{self.__class__.__name__}({self._obj!r})"
@abstractmethod
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
"""
Return the value of the object as bytes.
"""
raise NotImplementedError
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
"""
Return the value of the object as string.
seq = [obj if isinstance(obj, Composable) else Literal(obj) for obj in seq]
super().__init__(seq)
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
return b"".join(obj.as_bytes(context) for obj in self._obj)
def __iter__(self) -> Iterator[Composable]:
if not isinstance(obj, str):
raise TypeError(f"SQL values must be strings, got {obj!r} instead")
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
return self._obj
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self._obj.encode(enc)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
if conn:
esc = Escaping(conn.pgconn)
"""
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
tx = Transformer.from_context(context)
return tx.as_literal(self._obj)
return f"{self.__class__.__name__}({', '.join(parts)})"
- def as_string(self, context: Optional[AdaptContext]) -> str:
+ def as_string(self, context: Optional[AdaptContext] = None) -> str:
code = self._format.value
return f"%({self._obj}){code}" if self._obj else f"%{code}"
- def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
+ def as_bytes(self, context: Optional[AdaptContext] = None) -> bytes:
conn = context.connection if context else None
enc = conn_encoding(conn)
return self.as_string(context).encode(enc)
@pytest.mark.parametrize("args, want", _as_string_params)
def test_as_string_no_conn(self, args, want):
assert sql.Identifier(*args).as_string(None) == want
+ assert sql.Identifier(*args).as_string() == want
_as_bytes_params = [
crdb_encoding(("foo",), '"foo"', "ascii"),
assert sql.Identifier(*args).as_bytes(conn) == want
@pytest.mark.parametrize("args, want, enc", _as_bytes_params)
- def test_as_bytes_params(self, conn, args, want, enc):
+ def test_as_bytes_no_conn(self, conn, args, want, enc):
want = want.encode()
assert sql.Identifier(*args).as_bytes(None) == want
+ assert sql.Identifier(*args).as_bytes() == want
def test_join(self):
assert not hasattr(sql.Identifier("foo"), "join")
assert repr(sql.Literal("foo")) == "Literal('foo')"
assert str(sql.Literal("foo")) == "Literal('foo')"
- def test_as_string(self, conn):
- assert sql.Literal(None).as_string(conn) == "NULL"
- assert no_e(sql.Literal("foo").as_string(conn)) == "'foo'"
- assert sql.Literal(42).as_string(conn) == "42"
- assert sql.Literal(dt.date(2017, 1, 1)).as_string(conn) == "'2017-01-01'::date"
+ _params = [
+ (None, "NULL"),
+ ("foo", "'foo'"),
+ (42, "42"),
+ (dt.date(2017, 1, 1), "'2017-01-01'::date"),
+ ]
- def test_as_bytes(self, conn):
- assert sql.Literal(None).as_bytes(conn) == b"NULL"
- assert no_e(sql.Literal("foo").as_bytes(conn)) == b"'foo'"
- assert sql.Literal(42).as_bytes(conn) == b"42"
- assert sql.Literal(dt.date(2017, 1, 1)).as_bytes(conn) == b"'2017-01-01'::date"
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_string(self, conn, obj, want):
+ got = sql.Literal(obj).as_string(conn)
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_bytes(self, conn, obj, want):
+ got = sql.Literal(obj).as_bytes(conn)
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want.encode()
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_string_no_conn(self, obj, want):
+ got = sql.Literal(obj).as_string()
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want
+
+ @pytest.mark.parametrize("obj, want", _params)
+ def test_as_bytes_no_conn(self, obj, want):
+ got = sql.Literal(obj).as_bytes()
+ if isinstance(obj, str):
+ got = no_e(got)
+ assert got == want.encode()
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes_encoding(self, conn, encoding):
def test_as_string(self, conn):
assert sql.SQL("foo").as_string(conn) == "foo"
+ assert sql.SQL("foo").as_string() == "foo"
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes(self, conn, encoding):
assert sql.SQL(eur).as_bytes(conn) == eur.encode(encoding)
+ def test_no_conn(self):
+ assert sql.SQL(eur).as_string() == eur
+ assert sql.SQL(eur).as_bytes() == eur.encode()
+
class TestComposed:
def test_class(self):
def test_as_string(self, conn):
obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
assert obj.as_string(conn) == "foobar"
+ assert obj.as_string() == "foobar"
def test_as_bytes(self, conn):
obj = sql.Composed([sql.SQL("foo"), sql.SQL("bar")])
assert obj.as_bytes(conn) == b"foobar"
+ assert obj.as_bytes() == b"foobar"
@pytest.mark.parametrize("encoding", ["utf8", crdb_encoding("latin9")])
def test_as_bytes_encoding(self, conn, encoding):
def test_as_string(self, conn, format):
ph = sql.Placeholder(format=format)
assert ph.as_string(conn) == f"%{format.value}"
+ assert ph.as_string() == f"%{format.value}"
ph = sql.Placeholder(name="foo", format=format)
assert ph.as_string(conn) == f"%(foo){format.value}"
+ assert ph.as_string() == f"%(foo){format.value}"
@pytest.mark.parametrize("format", PyFormat)
def test_as_bytes(self, conn, format):
ph = sql.Placeholder(format=format)
- assert ph.as_bytes(conn) == f"%{format.value}".encode("ascii")
+ assert ph.as_bytes(conn) == f"%{format.value}".encode()
+ assert ph.as_bytes() == f"%{format.value}".encode()
ph = sql.Placeholder(name="foo", format=format)
- assert ph.as_bytes(conn) == f"%(foo){format.value}".encode("ascii")
+ assert ph.as_bytes(conn) == f"%(foo){format.value}".encode()
+ assert ph.as_bytes() == f"%(foo){format.value}".encode()
class TestValues:
def no_e(s):
"""Drop an eventual E from E'' quotes"""
- if isinstance(s, memoryview):
+ if isinstance(s, (memoryview, bytearray)):
s = bytes(s)
if isinstance(s, str):