]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: allow Identifier.as_string() and as_bytes() to take no connection
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 12 Jan 2024 16:57:42 +0000 (17:57 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 27 Jan 2024 02:12:02 +0000 (02:12 +0000)
psycopg/psycopg/sql.py
tests/test_sql.py

index d793bf389245a6140664318ec1c056d817678e1f..6eaabee7cc2d642557f55e29836651a2713cba7c 100644 (file)
@@ -364,13 +364,20 @@ class Identifier(Composable):
 
     def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
         conn = context.connection if context else None
-        if not conn:
-            raise ValueError("a connection is necessary for Identifier")
-        esc = Escaping(conn.pgconn)
-        enc = conn_encoding(conn)
-        escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+        if conn:
+            esc = Escaping(conn.pgconn)
+            enc = conn_encoding(conn)
+            escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
+        else:
+            escs = [self._escape_identifier(s.encode()) for s in self._obj]
         return b".".join(escs)
 
+    def _escape_identifier(self, s: bytes) -> bytes:
+        """
+        Approximation of PQescapeIdentifier taking no connection.
+        """
+        return b'"' + s.replace(b'"', b'""') + b'"'
+
 
 class Literal(Composable):
     """
index b1ec8d85df3d0841966c0b17875184e3c8f1404d..4cd0b0c0e0cee9fd95918de21e4c743a5cc567a0 100644 (file)
@@ -267,35 +267,49 @@ class TestIdentifier:
         assert sql.Identifier("foo") != "foo"
         assert sql.Identifier("foo") != sql.SQL("foo")
 
-    @pytest.mark.parametrize(
-        "args, want",
-        [
-            (("foo",), '"foo"'),
-            (("foo", "bar"), '"foo"."bar"'),
-            (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
-        ],
-    )
+    _as_string_params = [
+        (("foo",), '"foo"'),
+        (("foo", "bar"), '"foo"."bar"'),
+        (("fo'o", 'ba"r'), '"fo\'o"."ba""r"'),
+    ]
+
+    @pytest.mark.parametrize("args, want", _as_string_params)
     def test_as_string(self, conn, args, want):
         assert sql.Identifier(*args).as_string(conn) == want
 
-    @pytest.mark.parametrize(
-        "args, want, enc",
-        [
-            crdb_encoding(("foo",), '"foo"', "ascii"),
-            crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
-            crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
-            (("foo", eur), f'"foo"."{eur}"', "utf8"),
-            crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
-        ],
-    )
+    @pytest.mark.parametrize("args, want", _as_string_params)
+    def test_as_string_no_conn(self, args, want):
+        assert sql.Identifier(*args).as_string(None) == want
+
+    _as_bytes_params = [
+        crdb_encoding(("foo",), '"foo"', "ascii"),
+        crdb_encoding(("foo", "bar"), '"foo"."bar"', "ascii"),
+        crdb_encoding(("fo'o", 'ba"r'), '"fo\'o"."ba""r"', "ascii"),
+        (("foo", eur), f'"foo"."{eur}"', "utf8"),
+        crdb_encoding(("foo", eur), f'"foo"."{eur}"', "latin9"),
+    ]
+
+    @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
     def test_as_bytes(self, conn, args, want, enc):
         want = want.encode(enc)
         conn.execute(f"set client_encoding to {py2pgenc(enc).decode()}")
         assert sql.Identifier(*args).as_bytes(conn) == want
 
+    @pytest.mark.parametrize("args, want, enc", _as_bytes_params)
+    def test_as_bytes_params(self, conn, args, want, enc):
+        want = want.encode()
+        assert sql.Identifier(*args).as_bytes(None) == want
+
     def test_join(self):
         assert not hasattr(sql.Identifier("foo"), "join")
 
+    def test_escape_no_conn(self, conn):
+        conn.execute("set client_encoding to 'utf8'")
+        for c in range(1, 128):
+            s = chr(c)
+            want = sql.Identifier(s).as_bytes(conn)
+            assert want == sql.Identifier(s).as_bytes(None)
+
 
 class TestLiteral:
     def test_class(self):