# they are empty strings, contain curly braces, delimiter characters,
# double quotes, backslashes, or white space, or match the word NULL.
# TODO: recognise only , as delimiter. Should be configured
- _re_needs_quote = re.compile(
+ _re_needs_quotes = re.compile(
br"""(?xi)
^$ # the empty string
| ["{},\\\s] # or a char to escape
ad = ad[0]
if ad is not None:
- if self._re_needs_quote.search(ad) is not None:
+ if self._re_needs_quotes.search(ad) is not None:
ad = (
b'"' + self._re_escape.sub(br"\\\1", ad) + b'"'
)
from typing import Optional, TYPE_CHECKING
from . import array
-from ..adapt import Format, TypeCaster, Transformer, AdaptContext
+from ..adapt import Format, Adapter, TypeCaster, Transformer, AdaptContext
from .oids import builtins, TypeInfo
if TYPE_CHECKING:
def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
cur = conn.cursor(binary=True)
- cur.execute(_type_info_query, (name,))
+ cur.execute(_type_info_query, {"name": name})
rec = cur.fetchone()
return CompositeTypeInfo(*rec) if rec is not None else None
conn: "AsyncConnection", name: str
) -> Optional[CompositeTypeInfo]:
cur = conn.cursor(binary=True)
- await cur.execute(_type_info_query, (name,))
+ await cur.execute(_type_info_query, {"name": name})
rec = await cur.fetchone()
return CompositeTypeInfo(*rec) if rec is not None else None
_type_info_query = """\
select
- name, oid, array_oid,
- array_agg(row(field_name, field_type)) as fields
-from (
- select
- typname as name,
- t.oid as oid,
- t.typarray as array_oid,
- a.attname as field_name,
- a.atttypid as field_type
- from pg_type t
- left join pg_attribute a on a.attrelid = t.typrelid
- where t.typname = %s
- and a.attnum > 0
- order by a.attnum
-) x
-group by name, oid, array_oid
+ t.typname as name,
+ t.oid as oid,
+ t.typarray as array_oid,
+ coalesce(a.fields, '{}') as fields
+from pg_type t
+left join (
+ select attrelid, array_agg(field) as fields
+ from (
+ select attrelid, row(attname, atttypid) field
+ from pg_attribute a
+ join pg_type t on t.typrelid = a.attrelid
+ where t.typname = %(name)s
+ and a.attnum > 0
+ and not a.attisdropped
+ order by a.attnum
+ ) x
+ group by attrelid
+) a on a.attrelid = t.typrelid
+where t.typname = %(name)s
"""
+@Adapter.text(tuple)
+class TextTupleAdapter(Adapter):
+ def __init__(self, src: type, context: AdaptContext = None):
+ super().__init__(src, context)
+ self._tx = Transformer(context)
+
+ def adapt(self, obj: Tuple[Any, ...]) -> Tuple[bytes, int]:
+ if not obj:
+ return b"()", TEXT_OID
+
+ parts = [b"("]
+
+ for item in obj:
+ if item is None:
+ parts.append(b",")
+ continue
+
+ ad = self._tx.adapt(item)
+ if isinstance(ad, tuple):
+ ad = ad[0]
+ if ad is None:
+ parts.append(b",")
+ continue
+
+ if self._re_needs_quotes.search(ad) is not None:
+ ad = b'"' + self._re_escape.sub(br"\1\1", ad) + b'"'
+
+ parts.append(ad)
+ parts.append(b",")
+
+ parts[-1] = b")"
+
+ return b"".join(parts), TEXT_OID
+
+ _re_needs_quotes = re.compile(
+ br"""(?xi)
+ ^$ # the empty string
+ | [",\\\s] # or a char to escape
+ """
+ )
+ _re_escape = re.compile(br"([\"])")
+
+
class BaseCompositeCaster(TypeCaster):
def __init__(self, oid: int, context: AdaptContext = None):
super().__init__(oid, context)
from psycopg3.types import builtins, composite
-@pytest.mark.parametrize(
- "rec, want",
- [
- ("", ()),
- # Funnily enough there's no way to represent (None,) in Postgres
- ("null", ()),
- ("null,null", (None, None)),
- ("null, ''", (None, "")),
- (
- "42,'foo','ba,r','ba''z','qu\"x'",
- ("42", "foo", "ba,r", "ba'z", 'qu"x'),
- ),
- (
- "'foo''', '''foo', '\"bar', 'bar\"' ",
- ("foo'", "'foo", '"bar', 'bar"'),
- ),
- ],
-)
+tests_str = [
+ ("", ()),
+ # Funnily enough there's no way to represent (None,) in Postgres
+ ("null", ()),
+ ("null,null", (None, None)),
+ ("null, ''", (None, "")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ ("42", "foo", "ba,r", "ba'z", 'qu"x'),
+ ),
+ ("'foo''', '''foo', '\"bar', 'bar\"' ", ("foo'", "'foo", '"bar', 'bar"'),),
+]
+
+
+@pytest.mark.parametrize("rec, want", tests_str)
def test_cast_record(conn, want, rec):
cur = conn.cursor()
res = cur.execute(f"select row({rec})").fetchone()[0]
assert res == want
+@pytest.mark.parametrize("rec, obj", tests_str)
+def test_adapt_tuple(conn, rec, obj):
+ cur = conn.cursor()
+ fields = [f"f{i} text" for i in range(len(obj))]
+ cur.execute(
+ f"""
+ drop type if exists tmptype;
+ create type tmptype as ({', '.join(fields)});
+ """
+ )
+ info = composite.fetch_info(conn, "tmptype")
+ composite.register(info, context=conn)
+
+ res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
+ assert res == obj
+
+
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
def test_cast_all_chars(conn, fmt_out):
cur = conn.cursor(binary=fmt_out == Format.BINARY)