]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(sql): prefer names without space to cast builtins 264/head
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 20 Mar 2022 22:31:50 +0000 (23:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 10 May 2022 17:13:26 +0000 (19:13 +0200)
Certain types have a regtype with spaces, such as "timestamp with time
zone". Prefer the type short name, in this case ("timestamptz").

psycopg/psycopg/_transform.py
psycopg_c/psycopg_c/_psycopg/transform.pyx
tests/test_sql.py

index 8509c93e3ae99e13a0bbd2526b71553619b13f1b..803c3cfc700fd2050dfc3a7f2519038579d9f103 100644 (file)
@@ -190,18 +190,23 @@ class Transformer(AdaptContext):
         # If the result is quoted, and the oid not unknown,
         # add an explicit type cast.
         # Check the last char because the first one might be 'E'.
-        if dumper.oid and rv and rv[-1] == b"'"[0]:
+        oid = dumper.oid
+        if oid and rv and rv[-1] == b"'"[0]:
             try:
-                type_sql = self._oid_types[dumper.oid]
+                type_sql = self._oid_types[oid]
             except KeyError:
-                ti = self.adapters.types.get(dumper.oid)
+                ti = self.adapters.types.get(oid)
                 if ti:
-                    type_sql = ti.regtype.encode(self.encoding)
-                    if dumper.oid == ti.array_oid:
+                    if oid < 8192:
+                        # builtin: prefer "timestamptz" to "timestamp with time zone"
+                        type_sql = ti.name.encode(self.encoding)
+                    else:
+                        type_sql = ti.regtype.encode(self.encoding)
+                    if oid == ti.array_oid:
                         type_sql += b"[]"
                 else:
                     type_sql = b""
-                self._oid_types[dumper.oid] = type_sql
+                self._oid_types[oid] = type_sql
 
             if type_sql:
                 rv = b"%s::%s" % (rv, type_sql)
index c6e98355cccd64d9744a868cbb95c63c67b918f0..e290c3adbf67fd291bcbd853fddcbf175e9eb3d5 100644 (file)
@@ -226,7 +226,11 @@ cdef class Transformer:
                 type_sql = b""
                 ti = self.adapters.types.get(oid)
                 if ti is not None:
-                    type_sql = ti.regtype.encode(self.encoding)
+                    if oid < 8192:
+                        # builtin: prefer "timestamptz" to "timestamp with time zone"
+                        type_sql = ti.name.encode(self.encoding)
+                    else:
+                        type_sql = ti.regtype.encode(self.encoding)
                     if oid == ti.array_oid:
                         type_sql += b"[]"
 
index 21d0018c5707ad2c1d297a683a1a8f21edf03bd6..90c39da25967ff3d39d0bb7f975aba11f8fccb6c 100644 (file)
@@ -345,7 +345,18 @@ class TestLiteral:
             == "'{2000-01-01}'::date[]"
         )
 
-    @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+    def test_short_name_builtin(self, conn):
+        assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time"
+        assert (
+            sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn)
+            == "'2000-01-01 00:00:00'::timestamp"
+        )
+        assert (
+            sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn)
+            == "'{\"2000-01-01 00:00:00\"}'::timestamp[]"
+        )
+
+    @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"])
     def test_invalid_name(self, conn, name):
         conn.execute(
             f"""