def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
conn = context.connection if context else None
- if not conn:
- raise ValueError("a connection is necessary for Identifier")
- esc = Escaping(conn.pgconn)
- enc = conn_encoding(conn)
- escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ if conn:
+ esc = Escaping(conn.pgconn)
+ enc = conn_encoding(conn)
+ escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+ else:
+ escs = [self._escape_identifier(s.encode()) for s in self._obj]
return b".".join(escs)
+ def _escape_identifier(self, s: bytes) -> bytes:
+ """
+ Approximation of PQescapeIdentifier taking no connection.
+ """
+ return b'"' + s.replace(b'"', b'""') + b'"'
+
class Literal(Composable):
"""
assert sql.Identifier("foo") != "foo"
assert sql.Identifier("foo") != sql.SQL("foo")
- @pytest.mark.parametrize(
- "args, want",
- [
- (("foo",), '"foo"'),
- (("foo", "bar"), '"foo"."bar"'),
- (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
- ],
- )
+ _as_string_params = [
+ (("foo",), '"foo"'),
+ (("foo", "bar"), '"foo"."bar"'),
+ (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+ ]
+
+ @pytest.mark.parametrize("args, want", _as_string_params)
def test_as_string(self, conn, args, want):
assert sql.Identifier(*args).as_string(conn) == want
- @pytest.mark.parametrize(
- "args, want, enc",
- [
- crdb_encoding(("foo",), '"foo"', "ascii"),
- crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
- crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
- (("foo", eur), f'"foo"."{eur}"', "utf8"),
- crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
- ],
- )
+ @pytest.mark.parametrize("args, want", _as_string_params)
+ def test_as_string_no_conn(self, args, want):
+ assert sql.Identifier(*args).as_string(None) == want
+
+ _as_bytes_params = [
+ crdb_encoding(("foo",), '"foo"', "ascii"),
+ crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+ crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+ (("foo", eur), f'"foo"."{eur}"', "utf8"),
+ crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+ ]
+
+ @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
def test_as_bytes(self, conn, args, want, enc):
want = want.encode(enc)
conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
assert sql.Identifier(*args).as_bytes(conn) == want
+ @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
+ def test_as_bytes_params(self, conn, args, want, enc):
+ want = want.encode()
+ assert sql.Identifier(*args).as_bytes(None) == want
+
def test_join(self):
assert not hasattr(sql.Identifier("foo"), "join")
+ def test_escape_no_conn(self, conn):
+ conn.execute("set client_encoding to 'utf8'")
+ for c in range(1, 128):
+ s = chr(c)
+ want = sql.Identifier(s).as_bytes(conn)
+ assert want == sql.Identifier(s).as_bytes(None)
+
class TestLiteral:
def test_class(self):