]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(sql): fetch correct type names in TypeInfo subclasses
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 20 Mar 2022 15:07:56 +0000 (16:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 10 May 2022 17:13:26 +0000 (19:13 +0200)
Needed to fix `Literal.as_string()` of types using a `TypeInfo` subclass
to fetch info (composite, range, multirange).

psycopg/psycopg/_typeinfo.py
tests/types/test_composite.py
tests/types/test_multirange.py
tests/types/test_range.py

index 8ca0896fafd5324349cc311cc8577254d0d6fb0c..c69642a6b09a760c019222534d456f4a984dbf0d 100644 (file)
@@ -176,8 +176,16 @@ class RangeInfo(TypeInfo):
 
     __module__ = "psycopg.types.range"
 
-    def __init__(self, name: str, oid: int, array_oid: int, subtype_oid: int):
-        super().__init__(name, oid, array_oid)
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        *,
+        alt_name: str = "",
+        subtype_oid: int,
+    ):
+        super().__init__(name, oid, array_oid, alt_name=alt_name)
         self.subtype_oid = subtype_oid
 
     @classmethod
@@ -186,6 +194,7 @@ class RangeInfo(TypeInfo):
     ) -> str:
         return """\
 SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+    t.oid::regtype::text AS alt_name,
     r.rngsubtype AS subtype_oid
 FROM pg_type t
 JOIN pg_range r ON t.oid = r.rngtypid
@@ -208,10 +217,12 @@ class MultirangeInfo(TypeInfo):
         name: str,
         oid: int,
         array_oid: int,
+        *,
+        alt_name: str = "",
         range_oid: int,
         subtype_oid: int,
     ):
-        super().__init__(name, oid, array_oid)
+        super().__init__(name, oid, array_oid, alt_name=alt_name)
         self.range_oid = range_oid
         self.subtype_oid = subtype_oid
 
@@ -225,6 +236,7 @@ class MultirangeInfo(TypeInfo):
             )
         return """\
 SELECT t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+    t.oid::regtype::text AS alt_name,
     r.rngtypid AS range_oid, r.rngsubtype AS subtype_oid
 FROM pg_type t
 JOIN pg_range r ON t.oid = r.rngmultitypid
@@ -247,10 +259,12 @@ class CompositeInfo(TypeInfo):
         name: str,
         oid: int,
         array_oid: int,
+        *,
+        alt_name: str = "",
         field_names: Sequence[str],
         field_types: Sequence[int],
     ):
-        super().__init__(name, oid, array_oid)
+        super().__init__(name, oid, array_oid, alt_name=alt_name)
         self.field_names = field_names
         self.field_types = field_types
         # Will be set by register() if the `factory` is a type
@@ -263,6 +277,7 @@ class CompositeInfo(TypeInfo):
         return """\
 SELECT
     t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+    t.oid::regtype::text AS alt_name,
     coalesce(a.fnames, '{}') AS field_names,
     coalesce(a.ftypes, '{}') AS field_types
 FROM pg_type t
index aee61c8dbb585ae450f2fd2926d8aab58db27df5..c31822f0b684ec53afa04f167394357113ba32da 100644 (file)
@@ -1,7 +1,6 @@
 import pytest
 
-from psycopg import pq, postgres
-from psycopg.sql import Identifier
+from psycopg import pq, postgres, sql
 from psycopg.adapt import PyFormat
 from psycopg.postgres import types as builtins
 from psycopg.types.range import Range
@@ -138,11 +137,11 @@ fetch_cases = [
         [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
     ),
     (
-        Identifier("testcomp"),
+        sql.Identifier("testcomp"),
         [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
     ),
     (
-        Identifier("testschema", "testcomp"),
+        sql.Identifier("testschema", "testcomp"),
         [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
     ),
 ]
@@ -345,3 +344,17 @@ def test_invalid_fields_names(conn):
     conn.execute("insert into meh values (%s)", [obj])
     got = conn.execute("select wat from meh").fetchone()[0]
     assert obj == got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+def test_literal_invalid_name(conn, name):
+    conn.execute("set client_encoding to utf8")
+    conn.execute(f'create type "{name}" as (foo text)')
+    info = CompositeInfo.fetch(conn, f'"{name}"')
+    register_composite(info, conn)
+    obj = info.python_type("hello")
+    assert sql.Literal(obj).as_string(conn) == f"'(hello)'::\"{name}\""
+    cur = conn.execute(sql.SQL("select {}").format(obj))
+    got = cur.fetchone()[0]
+    assert got == obj
+    assert type(got) is type(obj)
index bdbd837ea0f329231cd9b1df5091839e689a0f0b..7d6bd3a9a8d415c8b387b00a0a198b0a728a87b3 100644 (file)
@@ -4,16 +4,15 @@ from decimal import Decimal
 
 import pytest
 
-from psycopg import pq
+from psycopg import pq, sql
 from psycopg import errors as e
-from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat
 from psycopg.types.range import Range
 from psycopg.types import multirange
 from psycopg.types.multirange import Multirange, MultirangeInfo
 from psycopg.types.multirange import register_multirange
 
-from .test_range import create_test_range
+from .test_range import create_test_range, eur
 
 pytestmark = pytest.mark.pg(">= 14")
 
@@ -363,8 +362,8 @@ def testmr(svcconn):
 fetch_cases = [
     ("testmultirange", "text"),
     ("testschema.testmultirange", "float8"),
-    (Identifier("testmultirange"), "text"),
-    (Identifier("testschema", "testmultirange"), "float8"),
+    (sql.Identifier("testmultirange"), "text"),
+    (sql.Identifier("testschema", "testmultirange"), "float8"),
 ]
 
 
@@ -414,3 +413,15 @@ def test_load_custom_empty(conn, testmr, fmt_out):
     (got,) = cur.execute("select '{}'::testmultirange").fetchone()
     assert isinstance(got, Multirange)
     assert not got
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}"])
+def test_literal_invalid_name(conn, name):
+    conn.execute("set client_encoding to utf8")
+    conn.execute(f'create type "{name}" as range (subtype = text)')
+    info = MultirangeInfo.fetch(conn, f'"{name}_multirange"')
+    register_multirange(info, conn)
+    obj = Multirange([Range("a", "z", "[]")])
+    assert sql.Literal(obj).as_string(conn) == f"'{{[a,z]}}'::\"{name}_multirange\""
+    cur = conn.execute(sql.SQL("select {}").format(obj))
+    assert cur.fetchone()[0] == obj
index dbe8ad2027ad51c5e7d1975e972541005a13bcb4..c00124f2b76374f660bc27d967211ac9199ddeac 100644 (file)
@@ -4,13 +4,13 @@ from decimal import Decimal
 
 import pytest
 
-from psycopg import pq
+from psycopg import pq, sql
 from psycopg import errors as e
-from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat
 from psycopg.types import range as range_module
 from psycopg.types.range import Range, RangeInfo, register_range
 
+eur = "\u20ac"
 
 type2sub = {
     "int4range": "int4",
@@ -278,8 +278,8 @@ def create_test_range(conn):
 fetch_cases = [
     ("testrange", "text"),
     ("testschema.testrange", "float8"),
-    (Identifier("testrange"), "text"),
-    (Identifier("testschema", "testrange"), "float8"),
+    (sql.Identifier("testrange"), "text"),
+    (sql.Identifier("testschema", "testrange"), "float8"),
 ]
 
 
@@ -656,3 +656,15 @@ class TestRangeObject:
 def test_no_info_error(conn):
     with pytest.raises(TypeError, match="range"):
         register_range(None, conn)  # type: ignore[arg-type]
+
+
+@pytest.mark.parametrize("name", ["a-b", f"{eur}", "order"])
+def test_literal_invalid_name(conn, name):
+    conn.execute("set client_encoding to utf8")
+    conn.execute(f'create type "{name}" as range (subtype = text)')
+    info = RangeInfo.fetch(conn, f'"{name}"')
+    register_range(info, conn)
+    obj = Range("a", "z", "[]")
+    assert sql.Literal(obj).as_string(conn) == f"'[a,z]'::\"{name}\""
+    cur = conn.execute(sql.SQL("select {}").format(obj))
+    assert cur.fetchone()[0] == obj