From: Daniele Varrazzo Date: Fri, 12 Jan 2024 16:57:42 +0000 (+0100) Subject: feat: allow Identifier.as_string() and as_bytes() to take no connection X-Git-Tag: 3.2.0~91^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=26a30f7c05a9cde0e255609308fd96feec3186a6;p=thirdparty%2Fpsycopg.git feat: allow Identifier.as_string() and as_bytes() to take no connection --- diff --git a/psycopg/psycopg/sql.py b/psycopg/psycopg/sql.py index d793bf389..6eaabee7c 100644 --- a/psycopg/psycopg/sql.py +++ b/psycopg/psycopg/sql.py @@ -364,13 +364,20 @@ class Identifier(Composable): def as_bytes(self, context: Optional[AdaptContext]) -> bytes: conn = context.connection if context else None - if not conn: - raise ValueError("a connection is necessary for Identifier") - esc = Escaping(conn.pgconn) - enc = conn_encoding(conn) - escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + if conn: + esc = Escaping(conn.pgconn) + enc = conn_encoding(conn) + escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] + else: + escs = [self._escape_identifier(s.encode()) for s in self._obj] return b".".join(escs) + def _escape_identifier(self, s: bytes) -> bytes: + """ + Approximation of PQescapeIdentifier taking no connection. + """ + return b'"' + s.replace(b'"', b'""') + b'"' + class Literal(Composable): """ diff --git a/tests/test_sql.py b/tests/test_sql.py index b1ec8d85d..4cd0b0c0e 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -267,35 +267,49 @@ class TestIdentifier: assert sql.Identifier("foo") != "foo" assert sql.Identifier("foo") != sql.SQL("foo") - @pytest.mark.parametrize( - "args, want", - [ - (("foo",), '"foo"'), - (("foo", "bar"), '"foo"."bar"'), - (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), - ], - ) + _as_string_params = [ + (("foo",), '"foo"'), + (("foo", "bar"), '"foo"."bar"'), + (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'), + ] + + @pytest.mark.parametrize("args, want", _as_string_params) def test_as_string(self, conn, args, want): assert sql.Identifier(*args).as_string(conn) == want - @pytest.mark.parametrize( - "args, want, enc", - [ - crdb_encoding(("foo",), '"foo"', "ascii"), - crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), - crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), - (("foo", eur), f'"foo"."{eur}"', "utf8"), - crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), - ], - ) + @pytest.mark.parametrize("args, want", _as_string_params) + def test_as_string_no_conn(self, args, want): + assert sql.Identifier(*args).as_string(None) == want + + _as_bytes_params = [ + crdb_encoding(("foo",), '"foo"', "ascii"), + crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"), + crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"), + (("foo", eur), f'"foo"."{eur}"', "utf8"), + crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"), + ] + + @pytest.mark.parametrize("args, want, enc", _as_bytes_params) def test_as_bytes(self, conn, args, want, enc): want = want.encode(enc) conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}") assert sql.Identifier(*args).as_bytes(conn) == want + @pytest.mark.parametrize("args, want, enc", _as_bytes_params) + def test_as_bytes_params(self, conn, args, want, enc): + want = want.encode() + assert sql.Identifier(*args).as_bytes(None) == want + def test_join(self): assert not hasattr(sql.Identifier("foo"), "join") + def test_escape_no_conn(self, conn): + conn.execute("set client_encoding to 'utf8'") + for c in range(1, 128): + s = chr(c) + want = sql.Identifier(s).as_bytes(conn) + assert want == sql.Identifier(s).as_bytes(None) + class TestLiteral: def test_class(self):