]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added tuple text adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 15:52:15 +0000 (03:52 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 8 Apr 2020 15:52:15 +0000 (03:52 +1200)
It cannot return a record oid, Postgres is not happy to parse it, so
return it as text.

psycopg3/types/array.py
psycopg3/types/composite.py
tests/types/test_composite.py

index ad4697c3eca614d8b6dcd7fee49957a7cc3d4ec2..bc314d9766aed5d38a30d1d4f9fc69fc9c1f2925 100644 (file)
@@ -47,7 +47,7 @@ class TextListAdapter(BaseListAdapter):
     # they are empty strings, contain curly braces, delimiter characters,
     # double quotes, backslashes, or white space, or match the word NULL.
     # TODO: recognise only , as delimiter. Should be configured
-    _re_needs_quote = re.compile(
+    _re_needs_quotes = re.compile(
         br"""(?xi)
           ^$              # the empty string
         | ["{},\\\s]      # or a char to escape
@@ -91,7 +91,7 @@ class TextListAdapter(BaseListAdapter):
                         ad = ad[0]
 
                     if ad is not None:
-                        if self._re_needs_quote.search(ad) is not None:
+                        if self._re_needs_quotes.search(ad) is not None:
                             ad = (
                                 b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
                             )
index 8ff8d33f76afda1bf606b46fa64a3d1484a91761..001feb378aeb2780e0d02e562b5c47199d536222 100644 (file)
@@ -9,7 +9,7 @@ from typing import Any, Callable, Generator, List, Sequence, Tuple, Union
 from typing import Optional, TYPE_CHECKING
 
 from . import array
-from ..adapt import Format, TypeCaster, Transformer, AdaptContext
+from ..adapt import Format, Adapter, TypeCaster, Transformer, AdaptContext
 from .oids import builtins, TypeInfo
 
 if TYPE_CHECKING:
@@ -46,7 +46,7 @@ class CompositeTypeInfo(TypeInfo):
 
 def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
     cur = conn.cursor(binary=True)
-    cur.execute(_type_info_query, (name,))
+    cur.execute(_type_info_query, {"name": name})
     rec = cur.fetchone()
     return CompositeTypeInfo(*rec) if rec is not None else None
 
@@ -55,7 +55,7 @@ async def fetch_info_async(
     conn: "AsyncConnection", name: str
 ) -> Optional[CompositeTypeInfo]:
     cur = conn.cursor(binary=True)
-    await cur.execute(_type_info_query, (name,))
+    await cur.execute(_type_info_query, {"name": name})
     rec = await cur.fetchone()
     return CompositeTypeInfo(*rec) if rec is not None else None
 
@@ -99,25 +99,71 @@ def register(
 
 _type_info_query = """\
 select
-    name, oid, array_oid,
-    array_agg(row(field_name, field_type)) as fields
-from (
-    select
-        typname as name,
-        t.oid as oid,
-        t.typarray as array_oid,
-        a.attname as field_name,
-        a.atttypid as field_type
-    from pg_type t
-    left join pg_attribute a on a.attrelid = t.typrelid
-    where t.typname = %s
-    and a.attnum > 0
-    order by a.attnum
-) x
-group by name, oid, array_oid
+    t.typname as name,
+    t.oid as oid,
+    t.typarray as array_oid,
+    coalesce(a.fields, '{}') as fields
+from pg_type t
+left join (
+    select attrelid, array_agg(field) as fields
+    from (
+        select attrelid, row(attname, atttypid) field
+        from pg_attribute a
+        join pg_type t on t.typrelid = a.attrelid
+        where t.typname = %(name)s
+        and a.attnum > 0
+        and not a.attisdropped
+        order by a.attnum
+    ) x
+    group by attrelid
+) a on a.attrelid = t.typrelid
+where t.typname = %(name)s
 """
 
 
+@Adapter.text(tuple)
+class TextTupleAdapter(Adapter):
+    def __init__(self, src: type, context: AdaptContext = None):
+        super().__init__(src, context)
+        self._tx = Transformer(context)
+
+    def adapt(self, obj: Tuple[Any, ...]) -> Tuple[bytes, int]:
+        if not obj:
+            return b"()", TEXT_OID
+
+        parts = [b"("]
+
+        for item in obj:
+            if item is None:
+                parts.append(b",")
+                continue
+
+            ad = self._tx.adapt(item)
+            if isinstance(ad, tuple):
+                ad = ad[0]
+            if ad is None:
+                parts.append(b",")
+                continue
+
+            if self._re_needs_quotes.search(ad) is not None:
+                ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
+
+            parts.append(ad)
+            parts.append(b",")
+
+        parts[-1] = b")"
+
+        return b"".join(parts), TEXT_OID
+
+    _re_needs_quotes = re.compile(
+        br"""(?xi)
+          ^$            # the empty string
+        | [",\\\s]      # or a char to escape
+        """
+    )
+    _re_escape = re.compile(br"([\"])")
+
+
 class BaseCompositeCaster(TypeCaster):
     def __init__(self, oid: int, context: AdaptContext = None):
         super().__init__(oid, context)
index acaef811cfb7f8742e783b74e875b39a491e9508..95c39a32e54a33ff94de15253ec285e3445ce340 100644 (file)
@@ -4,30 +4,44 @@ from psycopg3.adapt import Format, TypeCaster
 from psycopg3.types import builtins, composite
 
 
-@pytest.mark.parametrize(
-    "rec, want",
-    [
-        ("", ()),
-        # Funnily enough there's no way to represent (None,) in Postgres
-        ("null", ()),
-        ("null,null", (None, None)),
-        ("null, ''", (None, "")),
-        (
-            "42,'foo','ba,r','ba''z','qu\"x'",
-            ("42", "foo", "ba,r", "ba'z", 'qu"x'),
-        ),
-        (
-            "'foo''', '''foo', '\"bar', 'bar\"' ",
-            ("foo'", "'foo", '"bar', 'bar"'),
-        ),
-    ],
-)
+tests_str = [
+    ("", ()),
+    # Funnily enough there's no way to represent (None,) in Postgres
+    ("null", ()),
+    ("null,null", (None, None)),
+    ("null, ''", (None, "")),
+    (
+        "42,'foo','ba,r','ba''z','qu\"x'",
+        ("42", "foo", "ba,r", "ba'z", 'qu"x'),
+    ),
+    ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"'),),
+]
+
+
+@pytest.mark.parametrize("rec, want", tests_str)
 def test_cast_record(conn, want, rec):
     cur = conn.cursor()
     res = cur.execute(f"select row({rec})").fetchone()[0]
     assert res == want
 
 
+@pytest.mark.parametrize("rec, obj", tests_str)
+def test_adapt_tuple(conn, rec, obj):
+    cur = conn.cursor()
+    fields = [f"f{i} text" for i in range(len(obj))]
+    cur.execute(
+        f"""
+        drop type if exists tmptype;
+        create type tmptype as ({', '.join(fields)});
+        """
+    )
+    info = composite.fetch_info(conn, "tmptype")
+    composite.register(info, context=conn)
+
+    res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
+    assert res == obj
+
+
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_cast_all_chars(conn, fmt_out):
     cur = conn.cursor(binary=fmt_out == Format.BINARY)