From 98c3c65a4b4af1d9473b6e73bb52d6b017557372 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 20 Mar 2022 23:31:50 +0100 Subject: [PATCH] fix(sql): prefer names without space to cast builtins Certain types have a regtype with spaces, such as "timestamp with time zone". Prefer the type short name, in this case ("timestamptz"). --- psycopg/psycopg/_transform.py | 17 +++++++++++------ psycopg_c/psycopg_c/_psycopg/transform.pyx | 6 +++++- tests/test_sql.py | 13 ++++++++++++- 3 files changed, 28 insertions(+), 8 deletions(-) diff --git a/psycopg/psycopg/_transform.py b/psycopg/psycopg/_transform.py index 8509c93e3..803c3cfc7 100644 --- a/psycopg/psycopg/_transform.py +++ b/psycopg/psycopg/_transform.py @@ -190,18 +190,23 @@ class Transformer(AdaptContext): # If the result is quoted, and the oid not unknown, # add an explicit type cast. # Check the last char because the first one might be 'E'. - if dumper.oid and rv and rv[-1] == b"'"[0]: + oid = dumper.oid + if oid and rv and rv[-1] == b"'"[0]: try: - type_sql = self._oid_types[dumper.oid] + type_sql = self._oid_types[oid] except KeyError: - ti = self.adapters.types.get(dumper.oid) + ti = self.adapters.types.get(oid) if ti: - type_sql = ti.regtype.encode(self.encoding) - if dumper.oid == ti.array_oid: + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) + if oid == ti.array_oid: type_sql += b"[]" else: type_sql = b"" - self._oid_types[dumper.oid] = type_sql + self._oid_types[oid] = type_sql if type_sql: rv = b"%s::%s" % (rv, type_sql) diff --git a/psycopg_c/psycopg_c/_psycopg/transform.pyx b/psycopg_c/psycopg_c/_psycopg/transform.pyx index c6e98355c..e290c3adb 100644 --- a/psycopg_c/psycopg_c/_psycopg/transform.pyx +++ b/psycopg_c/psycopg_c/_psycopg/transform.pyx @@ -226,7 +226,11 @@ cdef class Transformer: type_sql = b"" ti = self.adapters.types.get(oid) if ti is not None: - type_sql = ti.regtype.encode(self.encoding) + if oid < 8192: + # builtin: prefer "timestamptz" to "timestamp with time zone" + type_sql = ti.name.encode(self.encoding) + else: + type_sql = ti.regtype.encode(self.encoding) if oid == ti.array_oid: type_sql += b"[]" diff --git a/tests/test_sql.py b/tests/test_sql.py index 21d0018c5..90c39da25 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -345,7 +345,18 @@ class TestLiteral: == "'{2000-01-01}'::date[]" ) - @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"]) + def test_short_name_builtin(self, conn): + assert sql.Literal(dt.time(0, 0)).as_string(conn) == "'00:00:00'::time" + assert ( + sql.Literal(dt.datetime(2000, 1, 1)).as_string(conn) + == "'2000-01-01 00:00:00'::timestamp" + ) + assert ( + sql.Literal([dt.datetime(2000, 1, 1)]).as_string(conn) + == "'{\"2000-01-01 00:00:00\"}'::timestamp[]" + ) + + @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"]) def test_invalid_name(self, conn, name): conn.execute( f""" -- 2.47.2