]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add pg DOMAIN type reflection
authorThomas Stephenson <t.stephenson@cqu.edu.au>
Wed, 21 Feb 2024 20:17:01 +0000 (15:17 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 22 Mar 2024 13:27:01 +0000 (09:27 -0400)
The PostgreSQL dialect now returns :class:`_postgresql.DOMAIN` instances
when reflecting a column that has a domain as type.
Previously the domain data type was returned instead.
As part of this change, the domain reflection was improved to also
return the collation of the text types.

Fixes: #10693
Closes: #10729
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/10729
Pull-request-sha: adac164d191138265ecd64a28be91254a53a9c25

Change-Id: I8730840de2e7e9649067191430eefa086bcf5e7b
(cherry picked from commit 0b6a54811d9cf4943ba2ae4b5a0eaa718b1e848e)

doc/build/changelog/unreleased_20/10693.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/named_types.py
lib/sqlalchemy/dialects/postgresql/pg_catalog.py
test/dialect/postgresql/test_reflection.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/10693.rst b/doc/build/changelog/unreleased_20/10693.rst
new file mode 100644 (file)
index 0000000..c5044b9
--- /dev/null
@@ -0,0 +1,9 @@
+.. change::
+    :tags: postgresql, reflection
+    :tickets: 10693
+
+    The PostgreSQL dialect now returns :class:`_postgresql.DOMAIN` instances
+    when reflecting a column that has a domain as type. Previously, the domain
+    data type was returned instead. As part of this change, the domain
+    reflection was improved to also return the collation of the text types.
+    Pull request courtesy of Thomas Stephenson.
index ebb7d546db0b705d3aab7e858cb5472b0cdcb7d5..4ab3ca24d16a76268116408ce1bec29228d3e6f8 100644 (file)
@@ -2776,6 +2776,8 @@ class ReflectedDomain(ReflectedNamedType):
     """The constraints defined in the domain, if any.
     The constraint are in order of evaluation by postgresql.
     """
+    collation: Optional[str]
+    """The collation for the domain."""
 
 
 class ReflectedEnum(ReflectedNamedType):
@@ -3707,20 +3709,156 @@ class PGDialect(default.DefaultDialect):
 
         return columns.items()
 
-    def _get_columns_info(self, rows, domains, enums, schema):
-        array_type_pattern = re.compile(r"\[\]$")
-        attype_pattern = re.compile(r"\(.*\)")
-        charlen_pattern = re.compile(r"\(([\d,]+)\)")
-        args_pattern = re.compile(r"\((.*)\)")
-        args_split_pattern = re.compile(r"\s*,\s*")
-
-        def _handle_array_type(attype):
-            return (
-                # strip '[]' from integer[], etc.
-                array_type_pattern.sub("", attype),
-                attype.endswith("[]"),
+    _format_type_args_pattern = re.compile(r"\((.*)\)")
+    _format_type_args_delim = re.compile(r"\s*,\s*")
+    _format_array_spec_pattern = re.compile(r"((?:\[\])*)$")
+
+    def _reflect_type(
+        self,
+        format_type: Optional[str],
+        domains: dict[str, ReflectedDomain],
+        enums: dict[str, ReflectedEnum],
+        type_description: str,
+    ) -> sqltypes.TypeEngine[Any]:
+        """
+        Attempts to reconstruct a column type defined in ischema_names based
+        on the information available in the format_type.
+
+        If the `format_type` cannot be associated with a known `ischema_names`,
+        it is treated as a reference to a known PostgreSQL named `ENUM` or
+        `DOMAIN` type.
+        """
+        type_description = type_description or "unknown type"
+        if format_type is None:
+            util.warn(
+                "PostgreSQL format_type() returned NULL for %s"
+                % type_description
+            )
+            return sqltypes.NULLTYPE
+
+        attype_args_match = self._format_type_args_pattern.search(format_type)
+        if attype_args_match and attype_args_match.group(1):
+            attype_args = self._format_type_args_delim.split(
+                attype_args_match.group(1)
             )
+        else:
+            attype_args = ()
+
+        match_array_dim = self._format_array_spec_pattern.search(format_type)
+        # Each "[]" in array specs corresponds to an array dimension
+        array_dim = len(match_array_dim.group(1) or "") // 2
+
+        # Remove all parameters and array specs from format_type to obtain an
+        # ischema_name candidate
+        attype = self._format_type_args_pattern.sub("", format_type)
+        attype = self._format_array_spec_pattern.sub("", attype)
+
+        schema_type = self.ischema_names.get(attype.lower(), None)
+        args, kwargs = (), {}
+
+        if attype == "numeric":
+            if len(attype_args) == 2:
+                precision, scale = map(int, attype_args)
+                args = (precision, scale)
+
+        elif attype == "double precision":
+            args = (53,)
+
+        elif attype == "integer":
+            args = ()
+
+        elif attype in ("timestamp with time zone", "time with time zone"):
+            kwargs["timezone"] = True
+            if len(attype_args) == 1:
+                kwargs["precision"] = int(attype_args[0])
 
+        elif attype in (
+            "timestamp without time zone",
+            "time without time zone",
+            "time",
+        ):
+            kwargs["timezone"] = False
+            if len(attype_args) == 1:
+                kwargs["precision"] = int(attype_args[0])
+
+        elif attype == "bit varying":
+            kwargs["varying"] = True
+            if len(attype_args) == 1:
+                charlen = int(attype_args[0])
+                args = (charlen,)
+
+        elif attype.startswith("interval"):
+            schema_type = INTERVAL
+
+            field_match = re.match(r"interval (.+)", attype)
+            if field_match:
+                kwargs["fields"] = field_match.group(1)
+
+            if len(attype_args) == 1:
+                kwargs["precision"] = int(attype_args[0])
+
+        else:
+            enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+            if enum_or_domain_key in enums:
+                schema_type = ENUM
+                enum = enums[enum_or_domain_key]
+
+                args = tuple(enum["labels"])
+                kwargs["name"] = enum["name"]
+
+                if not enum["visible"]:
+                    kwargs["schema"] = enum["schema"]
+                args = tuple(enum["labels"])
+            elif enum_or_domain_key in domains:
+                schema_type = DOMAIN
+                domain = domains[enum_or_domain_key]
+
+                data_type = self._reflect_type(
+                    domain["type"],
+                    domains,
+                    enums,
+                    type_description="DOMAIN '%s'" % domain["name"],
+                )
+                args = (domain["name"], data_type)
+
+                kwargs["collation"] = domain["collation"]
+                kwargs["default"] = domain["default"]
+                kwargs["not_null"] = not domain["nullable"]
+                kwargs["create_type"] = False
+
+                if domain["constraints"]:
+                    # We only support a single constraint
+                    check_constraint = domain["constraints"][0]
+
+                    kwargs["constraint_name"] = check_constraint["name"]
+                    kwargs["check"] = check_constraint["check"]
+
+                if not domain["visible"]:
+                    kwargs["schema"] = domain["schema"]
+
+            else:
+                try:
+                    charlen = int(attype_args[0])
+                    args = (charlen, *attype_args[1:])
+                except (ValueError, IndexError):
+                    args = attype_args
+
+        if not schema_type:
+            util.warn(
+                "Did not recognize type '%s' of %s"
+                % (attype, type_description)
+            )
+            return sqltypes.NULLTYPE
+
+        data_type = schema_type(*args, **kwargs)
+        if array_dim >= 1:
+            # postgres does not preserve dimensionality or size of array types.
+            data_type = _array.ARRAY(data_type)
+
+        return data_type
+
+    def _get_columns_info(self, rows, domains, enums, schema):
         columns = defaultdict(list)
         for row_dict in rows:
             # ensure that each table has an entry, even if it has no columns
@@ -3731,131 +3869,28 @@ class PGDialect(default.DefaultDialect):
                 continue
             table_cols = columns[(schema, row_dict["table_name"])]
 
-            format_type = row_dict["format_type"]
+            coltype = self._reflect_type(
+                row_dict["format_type"],
+                domains,
+                enums,
+                type_description="column '%s'" % row_dict["name"],
+            )
+
             default = row_dict["default"]
             name = row_dict["name"]
             generated = row_dict["generated"]
-            identity = row_dict["identity_options"]
-
-            if format_type is None:
-                no_format_type = True
-                attype = format_type = "no format_type()"
-                is_array = False
-            else:
-                no_format_type = False
-
-                # strip (*) from character varying(5), timestamp(5)
-                # with time zone, geometry(POLYGON), etc.
-                attype = attype_pattern.sub("", format_type)
-
-                # strip '[]' from integer[], etc. and check if an array
-                attype, is_array = _handle_array_type(attype)
-
-            # strip quotes from case sensitive enum or domain names
-            enum_or_domain_key = tuple(util.quoted_token_parser(attype))
-
             nullable = not row_dict["not_null"]
 
-            charlen = charlen_pattern.search(format_type)
-            if charlen:
-                charlen = charlen.group(1)
-            args = args_pattern.search(format_type)
-            if args and args.group(1):
-                args = tuple(args_split_pattern.split(args.group(1)))
-            else:
-                args = ()
-            kwargs = {}
+            if isinstance(coltype, DOMAIN):
+                if not default:
+                    # domain can override the default value but
+                    # cant set it to None
+                    if coltype.default is not None:
+                        default = coltype.default
 
-            if attype == "numeric":
-                if charlen:
-                    prec, scale = charlen.split(",")
-                    args = (int(prec), int(scale))
-                else:
-                    args = ()
-            elif attype == "double precision":
-                args = (53,)
-            elif attype == "integer":
-                args = ()
-            elif attype in ("timestamp with time zone", "time with time zone"):
-                kwargs["timezone"] = True
-                if charlen:
-                    kwargs["precision"] = int(charlen)
-                args = ()
-            elif attype in (
-                "timestamp without time zone",
-                "time without time zone",
-                "time",
-            ):
-                kwargs["timezone"] = False
-                if charlen:
-                    kwargs["precision"] = int(charlen)
-                args = ()
-            elif attype == "bit varying":
-                kwargs["varying"] = True
-                if charlen:
-                    args = (int(charlen),)
-                else:
-                    args = ()
-            elif attype.startswith("interval"):
-                field_match = re.match(r"interval (.+)", attype, re.I)
-                if charlen:
-                    kwargs["precision"] = int(charlen)
-                if field_match:
-                    kwargs["fields"] = field_match.group(1)
-                attype = "interval"
-                args = ()
-            elif charlen:
-                args = (int(charlen),)
-
-            while True:
-                # looping here to suit nested domains
-                if attype in self.ischema_names:
-                    coltype = self.ischema_names[attype]
-                    break
-                elif enum_or_domain_key in enums:
-                    enum = enums[enum_or_domain_key]
-                    coltype = ENUM
-                    kwargs["name"] = enum["name"]
-                    if not enum["visible"]:
-                        kwargs["schema"] = enum["schema"]
-                    args = tuple(enum["labels"])
-                    break
-                elif enum_or_domain_key in domains:
-                    domain = domains[enum_or_domain_key]
-                    attype = domain["type"]
-                    attype, is_array = _handle_array_type(attype)
-                    # strip quotes from case sensitive enum or domain names
-                    enum_or_domain_key = tuple(
-                        util.quoted_token_parser(attype)
-                    )
-                    # A table can't override a not null on the domain,
-                    # but can override nullable
-                    nullable = nullable and domain["nullable"]
-                    if domain["default"] and not default:
-                        # It can, however, override the default
-                        # value, but can't set it to null.
-                        default = domain["default"]
-                    continue
-                else:
-                    coltype = None
-                    break
-
-            if coltype:
-                coltype = coltype(*args, **kwargs)
-                if is_array:
-                    coltype = self.ischema_names["_array"](coltype)
-            elif no_format_type:
-                util.warn(
-                    "PostgreSQL format_type() returned NULL for column '%s'"
-                    % (name,)
-                )
-                coltype = sqltypes.NULLTYPE
-            else:
-                util.warn(
-                    "Did not recognize type '%s' of column '%s'"
-                    % (attype, name)
-                )
-                coltype = sqltypes.NULLTYPE
+                nullable = nullable and not coltype.not_null
+
+            identity = row_dict["identity_options"]
 
             # If a zero byte or blank string depending on driver (is also
             # absent for older PG versions), then not a generated column.
@@ -4904,12 +4939,18 @@ class PGDialect(default.DefaultDialect):
                 pg_catalog.pg_namespace.c.nspname.label("schema"),
                 con_sq.c.condefs,
                 con_sq.c.connames,
+                pg_catalog.pg_collation.c.collname,
             )
             .join(
                 pg_catalog.pg_namespace,
                 pg_catalog.pg_namespace.c.oid
                 == pg_catalog.pg_type.c.typnamespace,
             )
+            .outerjoin(
+                pg_catalog.pg_collation,
+                pg_catalog.pg_type.c.typcollation
+                == pg_catalog.pg_collation.c.oid,
+            )
             .outerjoin(
                 con_sq,
                 pg_catalog.pg_type.c.oid == con_sq.c.contypid,
@@ -4923,14 +4964,13 @@ class PGDialect(default.DefaultDialect):
 
     @reflection.cache
     def _load_domains(self, connection, schema=None, **kw):
-        # Load data types for domains:
         result = connection.execute(self._domain_query(schema))
 
-        domains = []
+        domains: List[ReflectedDomain] = []
         for domain in result.mappings():
             # strip (30) from character varying(30)
             attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
-            constraints = []
+            constraints: List[ReflectedDomainConstraint] = []
             if domain["connames"]:
                 # When a domain has multiple CHECK constraints, they will
                 # be tested in alphabetical order by name.
@@ -4944,7 +4984,7 @@ class PGDialect(default.DefaultDialect):
                     check = def_[7:-1]
                     constraints.append({"name": name, "check": check})
 
-            domain_rec = {
+            domain_rec: ReflectedDomain = {
                 "name": domain["name"],
                 "schema": domain["schema"],
                 "visible": domain["visible"],
@@ -4952,6 +4992,7 @@ class PGDialect(default.DefaultDialect):
                 "nullable": domain["nullable"],
                 "default": domain["default"],
                 "constraints": constraints,
+                "collation": domain["collname"],
             }
             domains.append(domain_rec)
 
index 56bec1dc732c3d997cbffb914167919735acc187..16e5c867efc26bb4998db2c92c40aa4d5011a261 100644 (file)
@@ -416,10 +416,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
         data_type: _TypeEngineArgument[Any],
         *,
         collation: Optional[str] = None,
-        default: Optional[Union[str, elements.TextClause]] = None,
+        default: Union[elements.TextClause, str, None] = None,
         constraint_name: Optional[str] = None,
         not_null: Optional[bool] = None,
-        check: Optional[str] = None,
+        check: Union[elements.TextClause, str, None] = None,
         create_type: bool = True,
         **kw: Any,
     ):
@@ -463,7 +463,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
         self.default = default
         self.collation = collation
         self.constraint_name = constraint_name
-        self.not_null = not_null
+        self.not_null = bool(not_null)
         if check is not None:
             check = coercions.expect(roles.DDLExpressionRole, check)
         self.check = check
@@ -474,6 +474,20 @@ class DOMAIN(NamedType, sqltypes.SchemaType):
     def __test_init__(cls):
         return cls("name", sqltypes.Integer)
 
+    def adapt(self, impl, **kw):
+        if self.default:
+            kw["default"] = self.default
+        if self.constraint_name is not None:
+            kw["constraint_name"] = self.constraint_name
+        if self.not_null:
+            kw["not_null"] = self.not_null
+        if self.check is not None:
+            kw["check"] = str(self.check)
+        if self.create_type:
+            kw["create_type"] = self.create_type
+
+        return super().adapt(impl, **kw)
+
 
 class CreateEnumType(schema._CreateDropBase):
     __visit_name__ = "create_enum_type"
index 7b44bc93f7bd02f7e79ffea21e87a72ddf3f43ae..9b5562c13fcaf777633d5b79bfab410da261b8e5 100644 (file)
@@ -77,7 +77,7 @@ RELKINDS_MAT_VIEW = ("m",)
 RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
 
 # tables
-pg_catalog_meta = MetaData()
+pg_catalog_meta = MetaData(schema="pg_catalog")
 
 pg_namespace = Table(
     "pg_namespace",
@@ -85,7 +85,6 @@ pg_namespace = Table(
     Column("oid", OID),
     Column("nspname", NAME),
     Column("nspowner", OID),
-    schema="pg_catalog",
 )
 
 pg_class = Table(
@@ -120,7 +119,6 @@ pg_class = Table(
     Column("relispartition", Boolean, info={"server_version": (10,)}),
     Column("relrewrite", OID, info={"server_version": (11,)}),
     Column("reloptions", ARRAY(Text)),
-    schema="pg_catalog",
 )
 
 pg_type = Table(
@@ -155,7 +153,6 @@ pg_type = Table(
     Column("typndims", Integer),
     Column("typcollation", OID, info={"server_version": (9, 1)}),
     Column("typdefault", Text),
-    schema="pg_catalog",
 )
 
 pg_index = Table(
@@ -182,7 +179,6 @@ pg_index = Table(
     Column("indoption", INT2VECTOR),
     Column("indexprs", PG_NODE_TREE),
     Column("indpred", PG_NODE_TREE),
-    schema="pg_catalog",
 )
 
 pg_attribute = Table(
@@ -209,7 +205,6 @@ pg_attribute = Table(
     Column("attislocal", Boolean),
     Column("attinhcount", Integer),
     Column("attcollation", OID, info={"server_version": (9, 1)}),
-    schema="pg_catalog",
 )
 
 pg_constraint = Table(
@@ -235,7 +230,6 @@ pg_constraint = Table(
     Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
     Column("conkey", ARRAY(SmallInteger)),
     Column("confkey", ARRAY(SmallInteger)),
-    schema="pg_catalog",
 )
 
 pg_sequence = Table(
@@ -249,7 +243,6 @@ pg_sequence = Table(
     Column("seqmin", BigInteger),
     Column("seqcache", BigInteger),
     Column("seqcycle", Boolean),
-    schema="pg_catalog",
     info={"server_version": (10,)},
 )
 
@@ -260,7 +253,6 @@ pg_attrdef = Table(
     Column("adrelid", OID),
     Column("adnum", SmallInteger),
     Column("adbin", PG_NODE_TREE),
-    schema="pg_catalog",
 )
 
 pg_description = Table(
@@ -270,7 +262,6 @@ pg_description = Table(
     Column("classoid", OID),
     Column("objsubid", Integer),
     Column("description", Text(collation="C")),
-    schema="pg_catalog",
 )
 
 pg_enum = Table(
@@ -280,7 +271,6 @@ pg_enum = Table(
     Column("enumtypid", OID),
     Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
     Column("enumlabel", NAME),
-    schema="pg_catalog",
 )
 
 pg_am = Table(
@@ -290,5 +280,21 @@ pg_am = Table(
     Column("amname", NAME),
     Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
     Column("amtype", CHAR, info={"server_version": (9, 6)}),
-    schema="pg_catalog",
+)
+
+pg_collation = Table(
+    "pg_collation",
+    pg_catalog_meta,
+    Column("oid", OID, info={"server_version": (9, 3)}),
+    Column("collname", NAME),
+    Column("collnamespace", OID),
+    Column("collowner", OID),
+    Column("collprovider", CHAR, info={"server_version": (10,)}),
+    Column("collisdeterministic", Boolean, info={"server_version": (12,)}),
+    Column("collencoding", Integer),
+    Column("collcollate", Text),
+    Column("collctype", Text),
+    Column("colliculocale", Text),
+    Column("collicurules", Text, info={"server_version": (16,)}),
+    Column("collversion", Text, info={"server_version": (10,)}),
 )
index dd6c8aa88ee35567b0a5fa9e1560234322456311..3d29a89de7b9b9b6a309a0550d6131da92281276 100644 (file)
@@ -23,6 +23,7 @@ from sqlalchemy import Text
 from sqlalchemy import UniqueConstraint
 from sqlalchemy.dialects.postgresql import ARRAY
 from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import INTEGER
 from sqlalchemy.dialects.postgresql import INTERVAL
@@ -408,25 +409,24 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
     def setup_test_class(cls):
         with testing.db.begin() as con:
             for ddl in [
-                'CREATE SCHEMA "SomeSchema"',
+                'CREATE SCHEMA IF NOT EXISTS "SomeSchema"',
                 "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
                 "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
                 "CREATE TYPE testtype AS ENUM ('test')",
                 "CREATE DOMAIN enumdomain AS testtype",
                 "CREATE DOMAIN arraydomain AS INTEGER[]",
+                "CREATE DOMAIN arraydomain_2d AS INTEGER[][]",
+                "CREATE DOMAIN arraydomain_3d AS  INTEGER[][][]",
                 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
-                "CREATE DOMAIN nullable_domain AS TEXT CHECK "
+                'CREATE DOMAIN nullable_domain AS TEXT COLLATE "C" CHECK '
                 "(VALUE IN('FOO', 'BAR'))",
                 "CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL",
                 "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK "
                 "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) "
                 "CHECK(VALUE != 22)",
             ]:
-                try:
-                    con.exec_driver_sql(ddl)
-                except exc.DBAPIError as e:
-                    if "already exists" not in str(e):
-                        raise e
+                con.exec_driver_sql(ddl)
+
             con.exec_driver_sql(
                 "CREATE TABLE testtable (question integer, answer "
                 "testdomain)"
@@ -446,7 +446,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             )
 
             con.exec_driver_sql(
-                "CREATE TABLE array_test (id integer, data arraydomain)"
+                "CREATE TABLE array_test ("
+                "id integer, "
+                "datas arraydomain, "
+                "datass arraydomain_2d, "
+                "datasss arraydomain_3d"
+                ")"
             )
 
             con.exec_driver_sql(
@@ -473,6 +478,8 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             con.exec_driver_sql("DROP TYPE testtype")
             con.exec_driver_sql("DROP TABLE array_test")
             con.exec_driver_sql("DROP DOMAIN arraydomain")
+            con.exec_driver_sql("DROP DOMAIN arraydomain_2d")
+            con.exec_driver_sql("DROP DOMAIN arraydomain_3d")
             con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
             con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
 
@@ -489,7 +496,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             {"question", "answer"},
             "Columns of reflected table didn't equal expected columns",
         )
-        assert isinstance(table.c.answer.type, Integer)
+        assert isinstance(table.c.answer.type, DOMAIN)
+        assert table.c.answer.type.name, "testdomain"
+        assert isinstance(table.c.answer.type.data_type, Integer)
 
     def test_nullable_from_domain(self, connection):
         metadata = MetaData()
@@ -514,18 +523,36 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
     def test_enum_domain_is_reflected(self, connection):
         metadata = MetaData()
         table = Table("enum_test", metadata, autoload_with=connection)
-        eq_(table.c.data.type.enums, ["test"])
+        assert isinstance(table.c.data.type, DOMAIN)
+        eq_(table.c.data.type.data_type.enums, ["test"])
 
     def test_array_domain_is_reflected(self, connection):
         metadata = MetaData()
         table = Table("array_test", metadata, autoload_with=connection)
-        eq_(table.c.data.type.__class__, ARRAY)
-        eq_(table.c.data.type.item_type.__class__, INTEGER)
+
+        def assert_is_integer_array_domain(domain, name):
+            # Postgres does not persist the dimensionality of the array.
+            # It's always treated as integer[]
+            assert isinstance(domain, DOMAIN)
+            assert domain.name == name
+            assert isinstance(domain.data_type, ARRAY)
+            assert isinstance(domain.data_type.item_type, INTEGER)
+
+        array_domain = table.c.datas.type
+        assert_is_integer_array_domain(array_domain, "arraydomain")
+
+        array_domain_2d = table.c.datass.type
+        assert_is_integer_array_domain(array_domain_2d, "arraydomain_2d")
+
+        array_domain_3d = table.c.datasss.type
+        assert_is_integer_array_domain(array_domain_3d, "arraydomain_3d")
 
     def test_quoted_remote_schema_domain_is_reflected(self, connection):
         metadata = MetaData()
         table = Table("quote_test", metadata, autoload_with=connection)
-        eq_(table.c.data.type.__class__, INTEGER)
+        assert isinstance(table.c.data.type, DOMAIN)
+        assert table.c.data.type.name, "Quoted.Domain"
+        assert isinstance(table.c.data.type.data_type, Integer)
 
     def test_table_is_reflected_test_schema(self, connection):
         metadata = MetaData()
@@ -603,6 +630,27 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "integer[]",
                     "default": None,
                     "constraints": [],
+                    "collation": None,
+                },
+                {
+                    "visible": True,
+                    "name": "arraydomain_2d",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "integer[]",
+                    "default": None,
+                    "constraints": [],
+                    "collation": None,
+                },
+                {
+                    "visible": True,
+                    "name": "arraydomain_3d",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "integer[]",
+                    "default": None,
+                    "constraints": [],
+                    "collation": None,
                 },
                 {
                     "visible": True,
@@ -612,6 +660,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "testtype",
                     "default": None,
                     "constraints": [],
+                    "collation": None,
                 },
                 {
                     "visible": True,
@@ -626,6 +675,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                         # autogenerated name by pg
                         {"check": "VALUE <> 22", "name": "my_int_check"},
                     ],
+                    "collation": None,
                 },
                 {
                     "visible": True,
@@ -635,6 +685,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "text",
                     "default": None,
                     "constraints": [],
+                    "collation": "default",
                 },
                 {
                     "visible": True,
@@ -651,6 +702,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                             "name": "nullable_domain_check",
                         }
                     ],
+                    "collation": "C",
                 },
                 {
                     "visible": True,
@@ -660,6 +712,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "integer",
                     "default": "42",
                     "constraints": [],
+                    "collation": None,
                 },
             ],
             "test_schema": [
@@ -671,6 +724,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "integer",
                     "default": "0",
                     "constraints": [],
+                    "collation": None,
                 }
             ],
             "SomeSchema": [
@@ -682,13 +736,20 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                     "type": "integer",
                     "default": "0",
                     "constraints": [],
+                    "collation": None,
                 }
             ],
         }
 
     def test_inspect_domains(self, connection):
         inspector = inspect(connection)
-        eq_(inspector.get_domains(), self.all_domains["public"])
+        domains = inspector.get_domains()
+
+        domain_names = {d["name"] for d in domains}
+        expect_domain_names = {d["name"] for d in self.all_domains["public"]}
+        eq_(domain_names, expect_domain_names)
+
+        eq_(domains, self.all_domains["public"])
 
     def test_inspect_domains_schema(self, connection):
         inspector = inspect(connection)
@@ -705,7 +766,38 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         all_ = [d for dl in self.all_domains.values() for d in dl]
         all_ += inspector.get_domains("information_schema")
         exp = sorted(all_, key=lambda d: (d["schema"], d["name"]))
-        eq_(inspector.get_domains("*"), exp)
+        domains = inspector.get_domains("*")
+
+        eq_(domains, exp)
+
+
+class ArrayReflectionTest(fixtures.TablesTest):
+    __only_on__ = "postgresql >= 10"
+    __backend__ = True
+
+    @classmethod
+    def define_tables(cls, metadata):
+        Table(
+            "array_table",
+            metadata,
+            Column("id", INTEGER, primary_key=True),
+            Column("datas", ARRAY(INTEGER)),
+            Column("datass", ARRAY(INTEGER, dimensions=2)),
+            Column("datasss", ARRAY(INTEGER, dimensions=3)),
+        )
+
+    def test_array_table_is_reflected(self, connection):
+        metadata = MetaData()
+        table = Table("array_table", metadata, autoload_with=connection)
+
+        def assert_is_integer_array(data_type):
+            assert isinstance(data_type, ARRAY)
+            # posgres treats all arrays as one-dimensional arrays
+            assert isinstance(data_type.item_type, INTEGER)
+
+        assert_is_integer_array(table.c.datas.type)
+        assert_is_integer_array(table.c.datass.type)
+        assert_is_integer_array(table.c.datasss.type)
 
 
 class ReflectionTest(
index 08479b445f5133e1236d86e5fec413cf974b76b5..65c5fdbf7f64c5e2fde37b6a109a8db59de5b636 100644 (file)
@@ -73,6 +73,7 @@ from sqlalchemy.dialects.postgresql import TSMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSTZRANGE
+from sqlalchemy.dialects.postgresql import UUID
 from sqlalchemy.dialects.postgresql.ranges import MultiRange
 from sqlalchemy.exc import CompileError
 from sqlalchemy.exc import DBAPIError
@@ -531,6 +532,7 @@ class NamedTypeTest(
                                 "check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text",
                             }
                         ],
+                        "collation": "default",
                     }
                 ],
             )
@@ -1075,7 +1077,7 @@ class NamedTypeTest(
                 connection, "fourfivesixtype"
             )
 
-    def test_reflection(self, metadata, connection):
+    def test_enum_type_reflection(self, metadata, connection):
         etype = Enum(
             "four", "five", "six", name="fourfivesixtype", metadata=metadata
         )
@@ -1229,6 +1231,212 @@ class NamedTypeTest(
         ]
 
 
+class DomainTest(
+    AssertsCompiledSQL, fixtures.TestBase, AssertsExecutionResults
+):
+    __backend__ = True
+    __only_on__ = "postgresql > 8.3"
+
+    def test_domain_type_reflection(self, metadata, connection):
+        positive_int = DOMAIN(
+            "positive_int", Integer(), check="value > 0", not_null=True
+        )
+        my_str = DOMAIN("my_string", Text(), collation="C", default="~~")
+        Table(
+            "table",
+            metadata,
+            Column("value", positive_int),
+            Column("str", my_str),
+        )
+
+        metadata.create_all(connection)
+        m2 = MetaData()
+        t2 = Table("table", m2, autoload_with=connection)
+
+        vt = t2.c.value.type
+        is_true(isinstance(vt, DOMAIN))
+        is_true(isinstance(vt.data_type, Integer))
+        eq_(vt.name, "positive_int")
+        eq_(str(vt.check), "VALUE > 0")
+        is_(vt.default, None)
+        is_(vt.collation, None)
+        is_true(vt.constraint_name is not None)
+        is_true(vt.not_null)
+        is_false(vt.create_type)
+
+        st = t2.c.str.type
+        is_true(isinstance(st, DOMAIN))
+        is_true(isinstance(st.data_type, Text))
+        eq_(st.name, "my_string")
+        is_(st.check, None)
+        is_true("~~" in st.default)
+        eq_(st.collation, "C")
+        is_(st.constraint_name, None)
+        is_false(st.not_null)
+        is_false(st.create_type)
+
+    def test_domain_create_table(self, metadata, connection):
+        metadata = self.metadata
+        Email = DOMAIN(
+            name="email",
+            data_type=Text,
+            check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+        )
+        PosInt = DOMAIN(
+            name="pos_int",
+            data_type=Integer,
+            not_null=True,
+            check=r"VALUE > 0",
+        )
+        t1 = Table(
+            "table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("email", Email),
+            Column("number", PosInt),
+        )
+        t1.create(connection)
+        t1.create(connection, checkfirst=True)  # check the create
+        connection.execute(
+            t1.insert(), {"email": "test@example.com", "number": 42}
+        )
+        connection.execute(t1.insert(), {"email": "a@b.c", "number": 1})
+        connection.execute(
+            t1.insert(), {"email": "example@gmail.co.uk", "number": 99}
+        )
+        eq_(
+            connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
+            [
+                (1, "test@example.com", 42),
+                (2, "a@b.c", 1),
+                (3, "example@gmail.co.uk", 99),
+            ],
+        )
+
+    @testing.combinations(
+        tuple(
+            [
+                DOMAIN(
+                    name="mytype",
+                    data_type=Text,
+                    check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+                    create_type=True,
+                ),
+            ]
+        ),
+        tuple(
+            [
+                DOMAIN(
+                    name="mytype",
+                    data_type=Text,
+                    check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+                    create_type=False,
+                ),
+            ]
+        ),
+        argnames="domain",
+    )
+    def test_create_drop_domain_with_table(self, connection, metadata, domain):
+        table = Table("e1", metadata, Column("e1", domain))
+
+        def _domain_names():
+            return {d["name"] for d in inspect(connection).get_domains()}
+
+        assert "mytype" not in _domain_names()
+
+        if domain.create_type:
+            table.create(connection)
+            assert "mytype" in _domain_names()
+        else:
+            with expect_raises(exc.ProgrammingError):
+                table.create(connection)
+            connection.rollback()
+
+            domain.create(connection)
+            assert "mytype" in _domain_names()
+            table.create(connection)
+
+        table.drop(connection)
+        if domain.create_type:
+            assert "mytype" not in _domain_names()
+
+    @testing.combinations(
+        (Integer, "value > 0", 4),
+        (String, "value != ''", "hello world"),
+        (
+            UUID,
+            "value != '{00000000-0000-0000-0000-000000000000}'",
+            uuid.uuid4(),
+        ),
+        (
+            DateTime,
+            "value >= '2020-01-01T00:00:00'",
+            datetime.datetime.fromisoformat("2021-01-01T00:00:00.000"),
+        ),
+        argnames="domain_datatype, domain_check, value",
+    )
+    def test_domain_roundtrip(
+        self, metadata, connection, domain_datatype, domain_check, value
+    ):
+        table = Table(
+            "domain_roundtrip_test",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column(
+                "value",
+                DOMAIN("valuedomain", domain_datatype, check=domain_check),
+            ),
+        )
+        table.create(connection)
+
+        connection.execute(table.insert(), {"value": value})
+
+        results = connection.execute(
+            table.select().order_by(table.c.id)
+        ).fetchall()
+        eq_(results, [(1, value)])
+
+    @testing.combinations(
+        (DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True), 4, -4),
+        (
+            DOMAIN("email", String, check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'"),
+            "e@xample.com",
+            "fred",
+        ),
+        argnames="domain,pass_value,fail_value",
+    )
+    def test_check_constraint(
+        self, metadata, connection, domain, pass_value, fail_value
+    ):
+        table = Table("table", metadata, Column("value", domain))
+        table.create(connection)
+
+        connection.execute(table.insert(), {"value": pass_value})
+
+        # psycopg/psycopg2 raise IntegrityError, while pg8000 raises
+        # ProgrammingError
+        with expect_raises(exc.DatabaseError):
+            connection.execute(table.insert(), {"value": fail_value})
+
+    @testing.combinations(
+        (DOMAIN("nullable_domain", Integer, not_null=True), 1),
+        (DOMAIN("non_nullable_domain", Integer, not_null=False), 1),
+        argnames="domain,pass_value",
+    )
+    def test_domain_nullable(self, metadata, connection, domain, pass_value):
+        table = Table("table", metadata, Column("value", domain))
+        table.create(connection)
+        connection.execute(table.insert(), {"value": pass_value})
+
+        if domain.not_null:
+            # psycopg/psycopg2 raise IntegrityError, while pg8000 raises
+            # ProgrammingError
+            with expect_raises(exc.DatabaseError):
+                connection.execute(table.insert(), {"value": None})
+        else:
+            connection.execute(table.insert(), {"value": None})
+
+
 class DomainDDLEventTest(DDLEventWCreateHarness, fixtures.TestBase):
     __backend__ = True
 
@@ -1557,6 +1765,10 @@ class TimePrecisionTest(fixtures.TestBase):
         t1.create(connection)
         m2 = MetaData()
         t2 = Table("t1", m2, autoload_with=connection)
+
+        eq_(t1.c.c1.type.__class__, postgresql.TIME)
+        eq_(t1.c.c4.type.__class__, postgresql.TIMESTAMP)
+
         eq_(t2.c.c1.type.precision, None)
         eq_(t2.c.c2.type.precision, 5)
         eq_(t2.c.c3.type.precision, 5)