]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Domain type
authorDavid Baumgold <david@davidbaumgold.com>
Fri, 11 Feb 2022 17:30:24 +0000 (12:30 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 21 Jun 2022 14:17:40 +0000 (10:17 -0400)
Added a new Postgresql :class:`_postgresql.DOMAIN` datatype, which follows
the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL
:class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on
this.

Fixes: #7316
Closes: #7317
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/7317
Pull-request-sha: bc9a82f010e6ca2f70a6e8a7620b748e483c26c3

Change-Id: Id8d7e48843a896de17d20cc466b115b3cc065132

12 files changed:
doc/build/changelog/unreleased_20/7316.rst [new file with mode: 0644]
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/named_types.py [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/types.py
lib/sqlalchemy/sql/compiler.py
test/aaa_profiling/test_memusage.py
test/dialect/postgresql/test_compiler.py
test/dialect/postgresql/test_reflection.py
test/dialect/postgresql/test_types.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/7316.rst b/doc/build/changelog/unreleased_20/7316.rst
new file mode 100644 (file)
index 0000000..817d994
--- /dev/null
@@ -0,0 +1,21 @@
+.. change::
+    :tags: feature, postgresql
+    :tickets: 7316
+
+    Added a new PostgreSQL :class:`_postgresql.DOMAIN` datatype, which follows
+    the same CREATE TYPE / DROP TYPE behaviors as that of PostgreSQL
+    :class:`_postgresql.ENUM`. Much thanks to David Baumgold for the efforts on
+    this.
+
+    .. seealso::
+
+        :class:`_postgresql.DOMAIN`
+
+.. change::
+    :tags: change, postgresql
+
+    The :paramref:`_postgresql.ENUM.name` parameter for the PostgreSQL-specific
+    :class:`_postgresql.ENUM` datatype is now a required keyword argument. The
+    "name" is necessary in any case in order for the :class:`_postgresql.ENUM`
+    to be usable as an error would be raised at SQL/DDL render time if "name"
+    were not present.
\ No newline at end of file
index 81f0a0c4e13c38b35d543a2f1bdc71676594bb7e..ea0c2aa42a034c533019f7da3f4c5a20caf7469b 100644 (file)
@@ -49,6 +49,9 @@ construction arguments, are as follows:
 .. autoclass:: CIDR
 
 
+.. autoclass:: DOMAIN
+    :members: __init__, create, drop
+
 .. autoclass:: DOUBLE_PRECISION
     :members: __init__
     :noindex:
index 85bbf8c5bb4fff5584b6e988c6f6c31a763460b0..62195f59e6b7d432fdcaefca74746f6f48cf9e40 100644 (file)
@@ -22,6 +22,7 @@ from .base import BIGINT
 from .base import BOOLEAN
 from .base import CHAR
 from .base import DATE
+from .base import DOMAIN
 from .base import DOUBLE_PRECISION
 from .base import FLOAT
 from .base import INTEGER
@@ -40,6 +41,12 @@ from .hstore import HSTORE
 from .hstore import hstore
 from .json import JSON
 from .json import JSONB
+from .named_types import CreateDomainType
+from .named_types import CreateEnumType
+from .named_types import DropDomainType
+from .named_types import DropEnumType
+from .named_types import ENUM
+from .named_types import NamedType
 from .ranges import DATERANGE
 from .ranges import INT4RANGE
 from .ranges import INT8RANGE
@@ -49,9 +56,6 @@ from .ranges import TSTZRANGE
 from .types import BIT
 from .types import BYTEA
 from .types import CIDR
-from .types import CreateEnumType
-from .types import DropEnumType
-from .types import ENUM
 from .types import INET
 from .types import INTERVAL
 from .types import MACADDR
@@ -97,6 +101,7 @@ __all__ = (
     "INTERVAL",
     "ARRAY",
     "ENUM",
+    "DOMAIN",
     "dialect",
     "array",
     "HSTORE",
@@ -113,6 +118,9 @@ __all__ = (
     "Any",
     "All",
     "DropEnumType",
+    "DropDomainType",
+    "CreateDomainType",
+    "NamedType",
     "CreateEnumType",
     "ExcludeConstraint",
     "aggregate_order_by",
index 8402341f6409c69ee36f15ade53462b7e00cb09f..8fc24c933e1fbd27afed39def788e1f4fbb0f1ce 100644 (file)
@@ -1450,6 +1450,9 @@ from __future__ import annotations
 from collections import defaultdict
 from functools import lru_cache
 import re
+from typing import Any
+from typing import List
+from typing import Optional
 
 from . import array as _array
 from . import dml
@@ -1457,30 +1460,34 @@ from . import hstore as _hstore
 from . import json as _json
 from . import pg_catalog
 from . import ranges as _ranges
-from .types import _DECIMAL_TYPES  # noqa
-from .types import _FLOAT_TYPES  # noqa
-from .types import _INT_TYPES  # noqa
-from .types import BIT
-from .types import BYTEA
-from .types import CIDR
-from .types import CreateEnumType  # noqa
-from .types import DropEnumType  # noqa
-from .types import ENUM
-from .types import INET
-from .types import INTERVAL
-from .types import MACADDR
-from .types import MONEY
-from .types import OID
-from .types import PGBit  # noqa
-from .types import PGCidr  # noqa
-from .types import PGInet  # noqa
-from .types import PGInterval  # noqa
-from .types import PGMacAddr  # noqa
-from .types import PGUuid
-from .types import REGCLASS
-from .types import TIME
-from .types import TIMESTAMP
-from .types import TSVECTOR
+from .named_types import CreateDomainType as CreateDomainType  # noqa: F401
+from .named_types import CreateEnumType as CreateEnumType  # noqa: F401
+from .named_types import DOMAIN as DOMAIN  # noqa: F401
+from .named_types import DropDomainType as DropDomainType  # noqa: F401
+from .named_types import DropEnumType as DropEnumType  # noqa: F401
+from .named_types import ENUM as ENUM  # noqa: F401
+from .named_types import NamedType as NamedType  # noqa: F401
+from .types import _DECIMAL_TYPES  # noqa: F401
+from .types import _FLOAT_TYPES  # noqa: F401
+from .types import _INT_TYPES  # noqa: F401
+from .types import BIT as BIT
+from .types import BYTEA as BYTEA
+from .types import CIDR as CIDR
+from .types import INET as INET
+from .types import INTERVAL as INTERVAL
+from .types import MACADDR as MACADDR
+from .types import MONEY as MONEY
+from .types import OID as OID
+from .types import PGBit as PGBit  # noqa: F401
+from .types import PGCidr as PGCidr  # noqa: F401
+from .types import PGInet as PGInet  # noqa: F401
+from .types import PGInterval as PGInterval  # noqa: F401
+from .types import PGMacAddr as PGMacAddr  # noqa: F401
+from .types import PGUuid as PGUuid
+from .types import REGCLASS as REGCLASS
+from .types import TIME as TIME
+from .types import TIMESTAMP as TIMESTAMP
+from .types import TSVECTOR as TSVECTOR
 from ... import exc
 from ... import schema
 from ... import select
@@ -1515,6 +1522,7 @@ from ...types import SMALLINT
 from ...types import TEXT
 from ...types import UUID as UUID
 from ...types import VARCHAR
+from ...util.typing import TypedDict
 
 IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
 
@@ -2198,6 +2206,38 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         return "DROP TYPE %s" % (self.preparer.format_type(type_))
 
+    def visit_create_domain_type(self, create):
+        domain: DOMAIN = create.element
+
+        options = []
+        if domain.collation is not None:
+            options.append(f"COLLATE {self.preparer.quote(domain.collation)}")
+        if domain.default is not None:
+            default = self.render_default_string(domain.default)
+            options.append(f"DEFAULT {default}")
+        if domain.constraint_name is not None:
+            name = self.preparer.truncate_and_render_constraint_name(
+                domain.constraint_name
+            )
+            options.append(f"CONSTRAINT {name}")
+        if domain.not_null:
+            options.append("NOT NULL")
+        if domain.check is not None:
+            check = self.sql_compiler.process(
+                domain.check, include_table=False, literal_binds=True
+            )
+            options.append(f"CHECK ({check})")
+
+        return (
+            f"CREATE DOMAIN {self.preparer.format_type(domain)} AS "
+            f"{self.type_compiler.process(domain.data_type)} "
+            f"{' '.join(options)}"
+        )
+
+    def visit_drop_domain_type(self, drop):
+        domain = drop.element
+        return f"DROP DOMAIN {self.preparer.format_type(domain)}"
+
     def visit_create_index(self, create):
         preparer = self.preparer
         index = create.element
@@ -2470,6 +2510,11 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
             identifier_preparer = self.dialect.identifier_preparer
         return identifier_preparer.format_type(type_)
 
+    def visit_DOMAIN(self, type_, identifier_preparer=None, **kw):
+        if identifier_preparer is None:
+            identifier_preparer = self.dialect.identifier_preparer
+        return identifier_preparer.format_type(type_)
+
     def visit_TIMESTAMP(self, type_, **kw):
         return "TIMESTAMP%s %s" % (
             "(%d)" % type_.precision
@@ -2548,7 +2593,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
 
     def format_type(self, type_, use_schema=True):
         if not type_.name:
-            raise exc.CompileError("PostgreSQL ENUM type requires a name.")
+            raise exc.CompileError(
+                f"PostgreSQL {type_.__class__.__name__} type requires a name."
+            )
 
         name = self.quote(type_.name)
         effective_schema = self.schema_for_object(type_)
@@ -2558,14 +2605,60 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
             and use_schema
             and effective_schema is not None
         ):
-            name = self.quote_schema(effective_schema) + "." + name
+            name = f"{self.quote_schema(effective_schema)}.{name}"
         return name
 
 
+class ReflectedNamedType(TypedDict):
+    """Represents a reflected named type."""
+
+    name: str
+    """Name of the type."""
+    schema: str
+    """The schema of the type."""
+    visible: bool
+    """Indicates if this type is in the current search path."""
+
+
+class ReflectedDomainConstraint(TypedDict):
+    """Represents a reflect check constraint of a domain."""
+
+    name: str
+    """Name of the constraint."""
+    check: str
+    """The check constraint text."""
+
+
+class ReflectedDomain(ReflectedNamedType):
+    """Represents a reflected enum."""
+
+    type: str
+    """The string name of the underlying data type of the domain."""
+    nullable: bool
+    """Indicates if the domain allows null or not."""
+    default: Optional[str]
+    """The string representation of the default value of this domain
+    or ``None`` if none present.
+    """
+    constraints: List[ReflectedDomainConstraint]
+    """The constraints defined in the domain, if any.
+    The constraint are in order of evaluation by postgresql.
+    """
+
+
+class ReflectedEnum(ReflectedNamedType):
+    """Represents a reflected enum."""
+
+    labels: List[str]
+    """The labels that compose the enum."""
+
+
 class PGInspector(reflection.Inspector):
     dialect: PGDialect
 
-    def get_table_oid(self, table_name, schema=None):
+    def get_table_oid(
+        self, table_name: str, schema: Optional[str] = None
+    ) -> int:
         """Return the OID for the given table name.
 
         :param table_name: string name of the table.  For special quoting,
@@ -2582,7 +2675,38 @@ class PGInspector(reflection.Inspector):
                 conn, table_name, schema, info_cache=self.info_cache
             )
 
-    def get_enums(self, schema=None):
+    def get_domains(
+        self, schema: Optional[str] = None
+    ) -> List[ReflectedDomain]:
+        """Return a list of DOMAIN objects.
+
+        Each member is a dictionary containing these fields:
+
+            * name - name of the domain
+            * schema - the schema name for the domain.
+            * visible - boolean, whether or not this domain is visible
+              in the default search path.
+            * type - the type defined by this domain.
+            * nullable - Indicates if this domain can be ``NULL``.
+            * default - The default value of the domain or ``None`` if the
+              domain has no default.
+            * constraints - A list of dict wit the constraint defined by this
+              domain. Each element constaints two keys: ``name`` of the
+              constraint and ``check`` with the constraint text.
+
+        :param schema: schema name.  If None, the default schema
+         (typically 'public') is used.  May also be set to ``'*'`` to
+         indicate load domains for all schemas.
+
+        .. versionadded:: 2.0
+
+        """
+        with self._operation_context() as conn:
+            return self.dialect._load_domains(
+                conn, schema, info_cache=self.info_cache
+            )
+
+    def get_enums(self, schema: Optional[str] = None) -> List[ReflectedEnum]:
         """Return a list of ENUM objects.
 
         Each member is a dictionary containing these fields:
@@ -2594,7 +2718,7 @@ class PGInspector(reflection.Inspector):
             * labels - a list of string labels that apply to the enum.
 
         :param schema: schema name.  If None, the default schema
-         (typically 'public') is used.  May also be set to '*' to
+         (typically 'public') is used.  May also be set to ``'*'`` to
          indicate load enums for all schemas.
 
         .. versionadded:: 1.0.0
@@ -2605,7 +2729,9 @@ class PGInspector(reflection.Inspector):
                 conn, schema, info_cache=self.info_cache
             )
 
-    def get_foreign_table_names(self, schema=None):
+    def get_foreign_table_names(
+        self, schema: Optional[str] = None
+    ) -> List[str]:
         """Return a list of FOREIGN TABLE names.
 
         Behavior is similar to that of
@@ -2621,13 +2747,15 @@ class PGInspector(reflection.Inspector):
                 conn, schema, info_cache=self.info_cache
             )
 
-    def has_type(self, type_name, schema=None, **kw):
+    def has_type(
+        self, type_name: str, schema: Optional[str] = None, **kw: Any
+    ) -> bool:
         """Return if the database has the specified type in the provided
         schema.
 
         :param type_name: the type to check.
         :param schema: schema name.  If None, the default schema
-         (typically 'public') is used.  May also be set to '*' to
+         (typically 'public') is used.  May also be set to ``'*'`` to
          check in all schemas.
 
         .. versionadded:: 2.0
@@ -2941,10 +3069,12 @@ class PGDialect(default.DefaultDialect):
             pg_catalog.pg_namespace,
             pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
         )
+
         if scope is ObjectScope.DEFAULT:
             query = query.where(pg_class_table.c.relpersistence != "t")
         elif scope is ObjectScope.TEMPORARY:
             query = query.where(pg_class_table.c.relpersistence == "t")
+
         if schema is None:
             query = query.where(
                 pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
@@ -3319,9 +3449,12 @@ class PGDialect(default.DefaultDialect):
 
         # dictionary with (name, ) if default search path or (schema, name)
         # as keys
-        domains = self._load_domains(
-            connection, info_cache=kw.get("info_cache")
-        )
+        domains = {
+            ((d["schema"], d["name"]) if not d["visible"] else (d["name"],)): d
+            for d in self._load_domains(
+                connection, schema="*", info_cache=kw.get("info_cache")
+            )
+        }
 
         # dictionary with (name, ) if default search path or (schema, name)
         # as keys
@@ -3446,7 +3579,7 @@ class PGDialect(default.DefaultDialect):
                     break
                 elif enum_or_domain_key in domains:
                     domain = domains[enum_or_domain_key]
-                    attype = domain["attype"]
+                    attype = domain["type"]
                     attype, is_array = _handle_array_type(attype)
                     # strip quotes from case sensitive enum or domain names
                     enum_or_domain_key = tuple(
@@ -3736,7 +3869,7 @@ class PGDialect(default.DefaultDialect):
 
     @util.memoized_property
     def _fk_regex_pattern(self):
-        # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+        # https://www.postgresql.org/docs/current/static/sql-createtable.html
         return re.compile(
             r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
             r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
@@ -4201,7 +4334,7 @@ class PGDialect(default.DefaultDialect):
                     (
                         pg_catalog.pg_constraint.c.oid.is_not(None),
                         pg_catalog.pg_get_constraintdef(
-                            pg_catalog.pg_constraint.c.oid
+                            pg_catalog.pg_constraint.c.oid, True
                         ),
                     ),
                     else_=None,
@@ -4265,6 +4398,17 @@ class PGDialect(default.DefaultDialect):
             check_constraints[(schema, table_name)].append(entry)
         return check_constraints.items()
 
+    def _pg_type_filter_schema(self, query, schema):
+        if schema is None:
+            query = query.where(
+                pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+                # ignore pg_catalog schema
+                pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+            )
+        elif schema != "*":
+            query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+        return query
+
     @lru_cache()
     def _enum_query(self, schema):
         lbl_sq = (
@@ -4310,15 +4454,7 @@ class PGDialect(default.DefaultDialect):
             )
         )
 
-        if schema is None:
-            query = query.where(
-                pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
-                # ignore pg_catalog schema
-                pg_catalog.pg_namespace.c.nspname != "pg_catalog",
-            )
-        elif schema != "*":
-            query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
-        return query
+        return self._pg_type_filter_schema(query, schema)
 
     @reflection.cache
     def _load_enums(self, connection, schema=None, **kw):
@@ -4339,9 +4475,27 @@ class PGDialect(default.DefaultDialect):
             )
         return enums
 
-    @util.memoized_property
-    def _domain_query(self):
-        return (
+    @lru_cache()
+    def _domain_query(self, schema):
+        con_sq = (
+            select(
+                pg_catalog.pg_constraint.c.contypid,
+                sql.func.array_agg(
+                    pg_catalog.pg_get_constraintdef(
+                        pg_catalog.pg_constraint.c.oid, True
+                    )
+                ).label("condefs"),
+                sql.func.array_agg(pg_catalog.pg_constraint.c.conname).label(
+                    "connames"
+                ),
+            )
+            # The domain this constraint is on; zero if not a domain constraint
+            .where(pg_catalog.pg_constraint.c.contypid != 0)
+            .group_by(pg_catalog.pg_constraint.c.contypid)
+            .subquery("domain_constraints")
+        )
+
+        query = (
             select(
                 pg_catalog.pg_type.c.typname.label("name"),
                 pg_catalog.format_type(
@@ -4354,38 +4508,57 @@ class PGDialect(default.DefaultDialect):
                     "visible"
                 ),
                 pg_catalog.pg_namespace.c.nspname.label("schema"),
+                con_sq.c.condefs,
+                con_sq.c.connames,
             )
             .join(
                 pg_catalog.pg_namespace,
                 pg_catalog.pg_namespace.c.oid
                 == pg_catalog.pg_type.c.typnamespace,
             )
+            .outerjoin(
+                con_sq,
+                pg_catalog.pg_type.c.oid == con_sq.c.contypid,
+            )
             .where(pg_catalog.pg_type.c.typtype == "d")
+            .order_by(
+                pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+            )
         )
+        return self._pg_type_filter_schema(query, schema)
 
     @reflection.cache
-    def _load_domains(self, connection, **kw):
+    def _load_domains(self, connection, schema=None, **kw):
         # Load data types for domains:
-        result = connection.execute(self._domain_query)
+        result = connection.execute(self._domain_query(schema))
 
-        domains = {}
+        domains = []
         for domain in result.mappings():
-            domain = domain
             # strip (30) from character varying(30)
             attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
-            # 'visible' just means whether or not the domain is in a
-            # schema that's on the search path -- or not overridden by
-            # a schema with higher precedence. If it's not visible,
-            # it will be prefixed with the schema-name when it's used.
-            if domain["visible"]:
-                key = (domain["name"],)
-            else:
-                key = (domain["schema"], domain["name"])
-
-            domains[key] = {
-                "attype": attype,
+            constraints = []
+            if domain["connames"]:
+                # When a domain has multiple CHECK constraints, they will
+                # be tested in alphabetical order by name.
+                sorted_constraints = sorted(
+                    zip(domain["connames"], domain["condefs"]),
+                    key=lambda t: t[0],
+                )
+                for name, def_ in sorted_constraints:
+                    # constraint is in the form "CHECK (expression)".
+                    # remove "CHECK (" and the tailing ")".
+                    check = def_[7:-1]
+                    constraints.append({"name": name, "check": check})
+
+            domain_rec = {
+                "name": domain["name"],
+                "schema": domain["schema"],
+                "visible": domain["visible"],
+                "type": attype,
                 "nullable": domain["nullable"],
                 "default": domain["default"],
+                "constraints": constraints,
             }
+            domains.append(domain_rec)
 
         return domains
diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py
new file mode 100644 (file)
index 0000000..b2f274b
--- /dev/null
@@ -0,0 +1,476 @@
+# postgresql/named_types.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+from __future__ import annotations
+
+from typing import Any
+from typing import Optional
+from typing import Type
+from typing import TYPE_CHECKING
+from typing import Union
+
+from ... import schema
+from ... import util
+from ...sql import coercions
+from ...sql import elements
+from ...sql import roles
+from ...sql import sqltypes
+from ...sql import type_api
+from ...sql.ddl import InvokeDDLBase
+
+if TYPE_CHECKING:
+    from ...sql._typing import _TypeEngineArgument
+
+
+class NamedType(sqltypes.TypeEngine):
+    """Base for named types."""
+
+    __abstract__ = True
+    DDLGenerator: Type["NamedTypeGenerator"]
+    DDLDropper: Type["NamedTypeDropper"]
+    create_type: bool
+
+    def create(self, bind, checkfirst=True, **kw):
+        """Emit ``CREATE`` DDL for this type.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type does not exist already before
+         creating.
+
+        """
+        bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
+
+    def drop(self, bind, checkfirst=True, **kw):
+        """Emit ``DROP`` DDL for this type.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type actually exists before dropping.
+
+        """
+        bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
+
+    def _check_for_name_in_memos(self, checkfirst, kw):
+        """Look in the 'ddl runner' for 'memos', then
+        note our name in that collection.
+
+        This to ensure a particular named type is operated
+        upon only once within any kind of create/drop
+        sequence without relying upon "checkfirst".
+
+        """
+        if not self.create_type:
+            return True
+        if "_ddl_runner" in kw:
+            ddl_runner = kw["_ddl_runner"]
+            type_name = f"pg_{self.__visit_name__}"
+            if type_name in ddl_runner.memo:
+                existing = ddl_runner.memo[type_name]
+            else:
+                existing = ddl_runner.memo[type_name] = set()
+            present = (self.schema, self.name) in existing
+            existing.add((self.schema, self.name))
+            return present
+        else:
+            return False
+
+    def _on_table_create(self, target, bind, checkfirst=False, **kw):
+        if (
+            checkfirst
+            or (
+                not self.metadata
+                and not kw.get("_is_metadata_operation", False)
+            )
+        ) and not self._check_for_name_in_memos(checkfirst, kw):
+            self.create(bind=bind, checkfirst=checkfirst)
+
+    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+        if (
+            not self.metadata
+            and not kw.get("_is_metadata_operation", False)
+            and not self._check_for_name_in_memos(checkfirst, kw)
+        ):
+            self.drop(bind=bind, checkfirst=checkfirst)
+
+    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+        if not self._check_for_name_in_memos(checkfirst, kw):
+            self.create(bind=bind, checkfirst=checkfirst)
+
+    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+        if not self._check_for_name_in_memos(checkfirst, kw):
+            self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class NamedTypeGenerator(InvokeDDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+        super().__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+
+    def _can_create_type(self, type_):
+        if not self.checkfirst:
+            return True
+
+        effective_schema = self.connection.schema_for_object(type_)
+        return not self.connection.dialect.has_type(
+            self.connection, type_.name, schema=effective_schema
+        )
+
+
+class NamedTypeDropper(InvokeDDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+        super().__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+
+    def _can_drop_type(self, type_):
+        if not self.checkfirst:
+            return True
+
+        effective_schema = self.connection.schema_for_object(type_)
+        return self.connection.dialect.has_type(
+            self.connection, type_.name, schema=effective_schema
+        )
+
+
+class EnumGenerator(NamedTypeGenerator):
+    def visit_enum(self, enum):
+        if not self._can_create_type(enum):
+            return
+
+        self.connection.execute(CreateEnumType(enum))
+
+
+class EnumDropper(NamedTypeDropper):
+    def visit_enum(self, enum):
+        if not self._can_drop_type(enum):
+            return
+
+        self.connection.execute(DropEnumType(enum))
+
+
+class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
+
+    """PostgreSQL ENUM type.
+
+    This is a subclass of :class:`_types.Enum` which includes
+    support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+    When the builtin type :class:`_types.Enum` is used and the
+    :paramref:`.Enum.native_enum` flag is left at its default of
+    True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+    type as the implementation, so the special create/drop rules
+    will be used.
+
+    The create/drop behavior of ENUM is necessarily intricate, due to the
+    awkward relationship the ENUM type has in relationship to the
+    parent table, in that it may be "owned" by just a single table, or
+    may be shared among many tables.
+
+    When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+    in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+    corresponding to when the :meth:`_schema.Table.create` and
+    :meth:`_schema.Table.drop`
+    methods are called::
+
+        table = Table('sometable', metadata,
+            Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+        )
+
+        table.create(engine)  # will emit CREATE ENUM and CREATE TABLE
+        table.drop(engine)  # will emit DROP TABLE and DROP ENUM
+
+    To use a common enumerated type between multiple tables, the best
+    practice is to declare the :class:`_types.Enum` or
+    :class:`_postgresql.ENUM` independently, and associate it with the
+    :class:`_schema.MetaData` object itself::
+
+        my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+        t1 = Table('sometable_one', metadata,
+            Column('some_enum', myenum)
+        )
+
+        t2 = Table('sometable_two', metadata,
+            Column('some_enum', myenum)
+        )
+
+    When this pattern is used, care must still be taken at the level
+    of individual table creates.  Emitting CREATE TABLE without also
+    specifying ``checkfirst=True`` will still cause issues::
+
+        t1.create(engine) # will fail: no such type 'myenum'
+
+    If we specify ``checkfirst=True``, the individual table-level create
+    operation will check for the ``ENUM`` and create if not exists::
+
+        # will check if enum exists, and emit CREATE TYPE if not
+        t1.create(engine, checkfirst=True)
+
+    When using a metadata-level ENUM type, the type will always be created
+    and dropped if either the metadata-wide create/drop is called::
+
+        metadata.create_all(engine)  # will emit CREATE TYPE
+        metadata.drop_all(engine)  # will emit DROP TYPE
+
+    The type can also be created and dropped directly::
+
+        my_enum.create(engine)
+        my_enum.drop(engine)
+
+    .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+       now behaves more strictly with regards to CREATE/DROP.  A metadata-level
+       ENUM type will only be created and dropped at the metadata level,
+       not the table level, with the exception of
+       ``table.create(checkfirst=True)``.
+       The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+       enumerated type.
+
+    """
+
+    native_enum = True
+    DDLGenerator = EnumGenerator
+    DDLDropper = EnumDropper
+
+    def __init__(self, *enums, name: str, create_type: bool = True, **kw):
+        """Construct an :class:`_postgresql.ENUM`.
+
+        Arguments are the same as that of
+        :class:`_types.Enum`, but also including
+        the following parameters.
+
+        :param create_type: Defaults to True.
+         Indicates that ``CREATE TYPE`` should be
+         emitted, after optionally checking for the
+         presence of the type, when the parent
+         table is being created; and additionally
+         that ``DROP TYPE`` is called when the table
+         is dropped.    When ``False``, no check
+         will be performed and no ``CREATE TYPE``
+         or ``DROP TYPE`` is emitted, unless
+         :meth:`~.postgresql.ENUM.create`
+         or :meth:`~.postgresql.ENUM.drop`
+         are called directly.
+         Setting to ``False`` is helpful
+         when invoking a creation scheme to a SQL file
+         without access to the actual database -
+         the :meth:`~.postgresql.ENUM.create` and
+         :meth:`~.postgresql.ENUM.drop` methods can
+         be used to emit SQL to a target bind.
+
+        """
+        native_enum = kw.pop("native_enum", None)
+        if native_enum is False:
+            util.warn(
+                "the native_enum flag does not apply to the "
+                "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+                "always refers to ENUM.   Use sqlalchemy.types.Enum for "
+                "non-native enum."
+            )
+        self.create_type = create_type
+        super().__init__(*enums, name=name, **kw)
+
+    @classmethod
+    def __test_init__(cls):
+        return cls(name="name")
+
+    @classmethod
+    def adapt_emulated_to_native(cls, impl, **kw):
+        """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+        :class:`.Enum`.
+
+        """
+        kw.setdefault("validate_strings", impl.validate_strings)
+        kw.setdefault("name", impl.name)
+        kw.setdefault("schema", impl.schema)
+        kw.setdefault("inherit_schema", impl.inherit_schema)
+        kw.setdefault("metadata", impl.metadata)
+        kw.setdefault("_create_events", False)
+        kw.setdefault("values_callable", impl.values_callable)
+        kw.setdefault("omit_aliases", impl._omit_aliases)
+        return cls(**kw)
+
+    def create(self, bind=None, checkfirst=True):
+        """Emit ``CREATE TYPE`` for this
+        :class:`_postgresql.ENUM`.
+
+        If the underlying dialect does not support
+        PostgreSQL CREATE TYPE, no action is taken.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type does not exist already before
+         creating.
+
+        """
+        if not bind.dialect.supports_native_enum:
+            return
+
+        super().create(bind, checkfirst=checkfirst)
+
+    def drop(self, bind=None, checkfirst=True):
+        """Emit ``DROP TYPE`` for this
+        :class:`_postgresql.ENUM`.
+
+        If the underlying dialect does not support
+        PostgreSQL DROP TYPE, no action is taken.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type actually exists before dropping.
+
+        """
+        if not bind.dialect.supports_native_enum:
+            return
+
+        super().drop(bind, checkfirst=checkfirst)
+
+    def get_dbapi_type(self, dbapi):
+        """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+        a different type"""
+
+        return None
+
+
+class DomainGenerator(NamedTypeGenerator):
+    def visit_DOMAIN(self, domain):
+        if not self._can_create_type(domain):
+            return
+        self.connection.execute(CreateDomainType(domain))
+
+
+class DomainDropper(NamedTypeDropper):
+    def visit_DOMAIN(self, domain):
+        if not self._can_drop_type(domain):
+            return
+
+        self.connection.execute(DropDomainType(domain))
+
+
+class DOMAIN(NamedType, sqltypes.SchemaType):
+    r"""Represent the DOMAIN PostgreSQL type.
+
+    A domain is essentially a data type with optional constraints
+    that restrict the allowed set of values. E.g.::
+
+        PositiveInt = Domain(
+            "pos_int", Integer, check="VALUE > 0", not_null=True
+        )
+
+        UsPostalCode = Domain(
+            "us_postal_code",
+            Text,
+            check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
+        )
+
+    See the `PostgreSQL documentation`__ for additional details
+
+    __ https://www.postgresql.org/docs/current/sql-createdomain.html
+
+    .. versionadded:: 2.0
+
+    """
+
+    DDLGenerator = DomainGenerator
+    DDLDropper = DomainDropper
+
+    __visit_name__ = "DOMAIN"
+
+    def __init__(
+        self,
+        name: str,
+        data_type: _TypeEngineArgument[Any],
+        *,
+        collation: Optional[str] = None,
+        default: Optional[Union[str, elements.TextClause]] = None,
+        constraint_name: Optional[str] = None,
+        not_null: Optional[bool] = None,
+        check: Optional[str] = None,
+        create_type: bool = True,
+        **kw: Any,
+    ):
+        """
+        Construct a DOMAIN.
+
+        :param name: the name of the domain
+        :param data_type: The underlying data type of the domain.
+          This can include array specifiers.
+        :param collation: An optional collation for the domain.
+          If no collation is specified, the underlying data type's default
+          collation is used. The underlying type must be collatable if
+          ``collation`` is specified.
+        :param default: The DEFAULT clause specifies a default value for
+          columns of the domain data type. The default should be a string
+          or a :func:`_expression.text` value.
+          If no default value is specified, then the default value is
+          the null value.
+        :param constraint_name: An optional name for a constraint.
+          If not specified, the backend generates a name.
+        :param not_null: Values of this domain are prevented from being null.
+          By default domain are allowed to be null. If not specified
+          no nullability clause will be emitted.
+        :param check: CHECK clause specify integrity constraint or test
+          which values of the domain must satisfy. A constraint must be
+          an expression producing a Boolean result that can use the key
+          word VALUE to refer to the value being tested.
+          Differently from PostgreSQL, only a single check clause is
+          currently allowed in SQLAlchemy.
+        :param schema: optional schema name
+        :param metadata: optional :class:`_schema.MetaData` object which
+         this :class:`_postgresql.DOMAIN` will be directly associated
+        :param create_type: Defaults to True.
+         Indicates that ``CREATE TYPE`` should be emitted, after optionally
+         checking for the presence of the type, when the parent table is
+         being created; and additionally that ``DROP TYPE`` is called
+         when the table is dropped.
+
+        """
+        self.data_type = type_api.to_instance(data_type)
+        self.default = default
+        self.collation = collation
+        self.constraint_name = constraint_name
+        self.not_null = not_null
+        if check is not None:
+            check = coercions.expect(roles.DDLExpressionRole, check)
+        self.check = check
+        self.create_type = create_type
+        super().__init__(name=name, **kw)
+
+    @classmethod
+    def __test_init__(cls):
+        return cls("name", sqltypes.Integer)
+
+
+class CreateEnumType(schema._CreateDropBase):
+    __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+    __visit_name__ = "drop_enum_type"
+
+
+class CreateDomainType(schema._CreateDropBase):
+    """Represent a CREATE DOMAIN statement."""
+
+    __visit_name__ = "create_domain_type"
+
+
+class DropDomainType(schema._CreateDropBase):
+    """Represent a DROP DOMAIN statement."""
+
+    __visit_name__ = "drop_domain_type"
index 55735953b508e7592a2d68b5e1ad9c3b7f12115a..374adcac1ff88d6def31b15538e9e815617625e6 100644 (file)
@@ -8,10 +8,7 @@
 import datetime as dt
 from typing import Any
 
-from ... import schema
-from ... import util
 from ...sql import sqltypes
-from ...sql.ddl import InvokeDDLBase
 
 
 _DECIMAL_TYPES = (1231, 1700)
@@ -201,285 +198,3 @@ class TSVECTOR(sqltypes.TypeEngine[Any]):
     """
 
     __visit_name__ = "TSVECTOR"
-
-
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
-
-    """PostgreSQL ENUM type.
-
-    This is a subclass of :class:`_types.Enum` which includes
-    support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
-
-    When the builtin type :class:`_types.Enum` is used and the
-    :paramref:`.Enum.native_enum` flag is left at its default of
-    True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
-    type as the implementation, so the special create/drop rules
-    will be used.
-
-    The create/drop behavior of ENUM is necessarily intricate, due to the
-    awkward relationship the ENUM type has in relationship to the
-    parent table, in that it may be "owned" by just a single table, or
-    may be shared among many tables.
-
-    When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
-    in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
-    corresponding to when the :meth:`_schema.Table.create` and
-    :meth:`_schema.Table.drop`
-    methods are called::
-
-        table = Table('sometable', metadata,
-            Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
-        )
-
-        table.create(engine)  # will emit CREATE ENUM and CREATE TABLE
-        table.drop(engine)  # will emit DROP TABLE and DROP ENUM
-
-    To use a common enumerated type between multiple tables, the best
-    practice is to declare the :class:`_types.Enum` or
-    :class:`_postgresql.ENUM` independently, and associate it with the
-    :class:`_schema.MetaData` object itself::
-
-        my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
-
-        t1 = Table('sometable_one', metadata,
-            Column('some_enum', myenum)
-        )
-
-        t2 = Table('sometable_two', metadata,
-            Column('some_enum', myenum)
-        )
-
-    When this pattern is used, care must still be taken at the level
-    of individual table creates.  Emitting CREATE TABLE without also
-    specifying ``checkfirst=True`` will still cause issues::
-
-        t1.create(engine) # will fail: no such type 'myenum'
-
-    If we specify ``checkfirst=True``, the individual table-level create
-    operation will check for the ``ENUM`` and create if not exists::
-
-        # will check if enum exists, and emit CREATE TYPE if not
-        t1.create(engine, checkfirst=True)
-
-    When using a metadata-level ENUM type, the type will always be created
-    and dropped if either the metadata-wide create/drop is called::
-
-        metadata.create_all(engine)  # will emit CREATE TYPE
-        metadata.drop_all(engine)  # will emit DROP TYPE
-
-    The type can also be created and dropped directly::
-
-        my_enum.create(engine)
-        my_enum.drop(engine)
-
-    .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
-       now behaves more strictly with regards to CREATE/DROP.  A metadata-level
-       ENUM type will only be created and dropped at the metadata level,
-       not the table level, with the exception of
-       ``table.create(checkfirst=True)``.
-       The ``table.drop()`` call will now emit a DROP TYPE for a table-level
-       enumerated type.
-
-    """
-
-    native_enum = True
-
-    def __init__(self, *enums, **kw):
-        """Construct an :class:`_postgresql.ENUM`.
-
-        Arguments are the same as that of
-        :class:`_types.Enum`, but also including
-        the following parameters.
-
-        :param create_type: Defaults to True.
-         Indicates that ``CREATE TYPE`` should be
-         emitted, after optionally checking for the
-         presence of the type, when the parent
-         table is being created; and additionally
-         that ``DROP TYPE`` is called when the table
-         is dropped.    When ``False``, no check
-         will be performed and no ``CREATE TYPE``
-         or ``DROP TYPE`` is emitted, unless
-         :meth:`~.postgresql.ENUM.create`
-         or :meth:`~.postgresql.ENUM.drop`
-         are called directly.
-         Setting to ``False`` is helpful
-         when invoking a creation scheme to a SQL file
-         without access to the actual database -
-         the :meth:`~.postgresql.ENUM.create` and
-         :meth:`~.postgresql.ENUM.drop` methods can
-         be used to emit SQL to a target bind.
-
-        """
-        native_enum = kw.pop("native_enum", None)
-        if native_enum is False:
-            util.warn(
-                "the native_enum flag does not apply to the "
-                "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
-                "always refers to ENUM.   Use sqlalchemy.types.Enum for "
-                "non-native enum."
-            )
-        self.create_type = kw.pop("create_type", True)
-        super(ENUM, self).__init__(*enums, **kw)
-
-    @classmethod
-    def adapt_emulated_to_native(cls, impl, **kw):
-        """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
-        :class:`.Enum`.
-
-        """
-        kw.setdefault("validate_strings", impl.validate_strings)
-        kw.setdefault("name", impl.name)
-        kw.setdefault("schema", impl.schema)
-        kw.setdefault("inherit_schema", impl.inherit_schema)
-        kw.setdefault("metadata", impl.metadata)
-        kw.setdefault("_create_events", False)
-        kw.setdefault("values_callable", impl.values_callable)
-        kw.setdefault("omit_aliases", impl._omit_aliases)
-        return cls(**kw)
-
-    def create(self, bind=None, checkfirst=True):
-        """Emit ``CREATE TYPE`` for this
-        :class:`_postgresql.ENUM`.
-
-        If the underlying dialect does not support
-        PostgreSQL CREATE TYPE, no action is taken.
-
-        :param bind: a connectable :class:`_engine.Engine`,
-         :class:`_engine.Connection`, or similar object to emit
-         SQL.
-        :param checkfirst: if ``True``, a query against
-         the PG catalog will be first performed to see
-         if the type does not exist already before
-         creating.
-
-        """
-        if not bind.dialect.supports_native_enum:
-            return
-
-        bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
-
-    def drop(self, bind=None, checkfirst=True):
-        """Emit ``DROP TYPE`` for this
-        :class:`_postgresql.ENUM`.
-
-        If the underlying dialect does not support
-        PostgreSQL DROP TYPE, no action is taken.
-
-        :param bind: a connectable :class:`_engine.Engine`,
-         :class:`_engine.Connection`, or similar object to emit
-         SQL.
-        :param checkfirst: if ``True``, a query against
-         the PG catalog will be first performed to see
-         if the type actually exists before dropping.
-
-        """
-        if not bind.dialect.supports_native_enum:
-            return
-
-        bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
-    class EnumGenerator(InvokeDDLBase):
-        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
-            super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
-            self.checkfirst = checkfirst
-
-        def _can_create_enum(self, enum):
-            if not self.checkfirst:
-                return True
-
-            effective_schema = self.connection.schema_for_object(enum)
-
-            return not self.connection.dialect.has_type(
-                self.connection, enum.name, schema=effective_schema
-            )
-
-        def visit_enum(self, enum):
-            if not self._can_create_enum(enum):
-                return
-
-            self.connection.execute(CreateEnumType(enum))
-
-    class EnumDropper(InvokeDDLBase):
-        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
-            super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
-            self.checkfirst = checkfirst
-
-        def _can_drop_enum(self, enum):
-            if not self.checkfirst:
-                return True
-
-            effective_schema = self.connection.schema_for_object(enum)
-
-            return self.connection.dialect.has_type(
-                self.connection, enum.name, schema=effective_schema
-            )
-
-        def visit_enum(self, enum):
-            if not self._can_drop_enum(enum):
-                return
-
-            self.connection.execute(DropEnumType(enum))
-
-    def get_dbapi_type(self, dbapi):
-        """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
-        a different type"""
-
-        return None
-
-    def _check_for_name_in_memos(self, checkfirst, kw):
-        """Look in the 'ddl runner' for 'memos', then
-        note our name in that collection.
-
-        This to ensure a particular named enum is operated
-        upon only once within any kind of create/drop
-        sequence without relying upon "checkfirst".
-
-        """
-        if not self.create_type:
-            return True
-        if "_ddl_runner" in kw:
-            ddl_runner = kw["_ddl_runner"]
-            if "_pg_enums" in ddl_runner.memo:
-                pg_enums = ddl_runner.memo["_pg_enums"]
-            else:
-                pg_enums = ddl_runner.memo["_pg_enums"] = set()
-            present = (self.schema, self.name) in pg_enums
-            pg_enums.add((self.schema, self.name))
-            return present
-        else:
-            return False
-
-    def _on_table_create(self, target, bind, checkfirst=False, **kw):
-        if (
-            checkfirst
-            or (
-                not self.metadata
-                and not kw.get("_is_metadata_operation", False)
-            )
-        ) and not self._check_for_name_in_memos(checkfirst, kw):
-            self.create(bind=bind, checkfirst=checkfirst)
-
-    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
-        if (
-            not self.metadata
-            and not kw.get("_is_metadata_operation", False)
-            and not self._check_for_name_in_memos(checkfirst, kw)
-        ):
-            self.drop(bind=bind, checkfirst=checkfirst)
-
-    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
-        if not self._check_for_name_in_memos(checkfirst, kw):
-            self.create(bind=bind, checkfirst=checkfirst)
-
-    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
-        if not self._check_for_name_in_memos(checkfirst, kw):
-            self.drop(bind=bind, checkfirst=checkfirst)
-
-
-class CreateEnumType(schema._CreateDropBase):
-    __visit_name__ = "create_enum_type"
-
-
-class DropEnumType(schema._CreateDropBase):
-    __visit_name__ = "drop_enum_type"
index 8ce0c65e42ce865ad72ad06c9ac1367c2be9ea42..1ad547b79448f1b297e8418dab1f5c17eb741f9c 100644 (file)
@@ -5251,17 +5251,18 @@ class DDLCompiler(Compiled):
 
     def get_column_default_string(self, column):
         if isinstance(column.server_default, schema.DefaultClause):
-            if isinstance(column.server_default.arg, str):
-                return self.sql_compiler.render_literal_value(
-                    column.server_default.arg, sqltypes.STRINGTYPE
-                )
-            else:
-                return self.sql_compiler.process(
-                    column.server_default.arg, literal_binds=True
-                )
+            return self.render_default_string(column.server_default.arg)
         else:
             return None
 
+    def render_default_string(self, default):
+        if isinstance(default, str):
+            return self.sql_compiler.render_literal_value(
+                default, sqltypes.STRINGTYPE
+            )
+        else:
+            return self.sql_compiler.process(default, literal_binds=True)
+
     def visit_table_or_column_check_constraint(self, constraint, **kw):
         if constraint.is_column_level:
             return self.visit_column_check_constraint(constraint)
index e90c428f70584380f5d91e1b35dcb01c5c9d688d..dc6dcf060a61d390edf9cadd07545e191ed381ab 100644 (file)
@@ -312,20 +312,22 @@ class MemUsageTest(EnsureZeroed):
 
         eng = engines.testing_engine()
         for args in (
-            (types.Integer,),
-            (types.String,),
-            (types.PickleType,),
-            (types.Enum, "a", "b", "c"),
-            (sqlite.DATETIME,),
-            (postgresql.ENUM, "a", "b", "c"),
-            (types.Interval,),
-            (postgresql.INTERVAL,),
-            (mysql.VARCHAR,),
+            (types.Integer, {}),
+            (types.String, {}),
+            (types.PickleType, {}),
+            (types.Enum, "a", "b", "c", {}),
+            (sqlite.DATETIME, {}),
+            (postgresql.ENUM, "a", "b", "c", {"name": "pgenum"}),
+            (types.Interval, {}),
+            (postgresql.INTERVAL, {}),
+            (mysql.VARCHAR, {}),
         ):
 
             @profile_memory()
             def go():
-                type_ = args[0](*args[1:])
+                kwargs = args[-1]
+                posargs = args[1:-1]
+                type_ = args[0](*posargs, **kwargs)
                 bp = type_._cached_bind_processor(eng.dialect)
                 rp = type_._cached_result_processor(eng.dialect, 0)
                 bp, rp  # strong reference
index 25550afe1442f80e1c9383e4a3decdcd80c9b5ad..9be76130d5ed7dc6327e2465939aa0c7e999f543 100644 (file)
@@ -38,6 +38,7 @@ from sqlalchemy.dialects.postgresql import aggregate_order_by
 from sqlalchemy.dialects.postgresql import ARRAY as PG_ARRAY
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import array_agg as pg_array_agg
+from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import ExcludeConstraint
 from sqlalchemy.dialects.postgresql import insert
 from sqlalchemy.dialects.postgresql import TSRANGE
@@ -270,7 +271,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             render_schema_translate=True,
         )
 
-    def test_create_type_schema_translate(self):
+    def test_create_enum_schema_translate(self):
         e1 = Enum("x", "y", "z", name="somename")
         e2 = Enum("x", "y", "z", name="somename", schema="someschema")
         schema_translate_map = {None: "foo", "someschema": "bar"}
@@ -289,6 +290,79 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             render_schema_translate=True,
         )
 
+    def test_domain(self):
+        self.assert_compile(
+            postgresql.CreateDomainType(
+                DOMAIN(
+                    "x",
+                    Integer,
+                    default=text("11"),
+                    not_null=True,
+                    check="VALUE < 0",
+                )
+            ),
+            "CREATE DOMAIN x AS INTEGER DEFAULT 11 NOT NULL CHECK (VALUE < 0)",
+        )
+        self.assert_compile(
+            postgresql.CreateDomainType(
+                DOMAIN(
+                    "sOmEnAmE",
+                    Text,
+                    collation="utf8",
+                    constraint_name="a constraint",
+                    not_null=True,
+                )
+            ),
+            'CREATE DOMAIN "sOmEnAmE" AS TEXT COLLATE utf8 CONSTRAINT '
+            '"a constraint" NOT NULL',
+        )
+        self.assert_compile(
+            postgresql.CreateDomainType(
+                DOMAIN(
+                    "foo",
+                    Text,
+                    collation="utf8",
+                    default="foobar",
+                    constraint_name="no_bar",
+                    not_null=True,
+                    check="VALUE != 'bar'",
+                )
+            ),
+            "CREATE DOMAIN foo AS TEXT COLLATE utf8 DEFAULT 'foobar' "
+            "CONSTRAINT no_bar NOT NULL CHECK (VALUE != 'bar')",
+        )
+
+    def test_cast_domain_schema(self):
+        """test #6739"""
+        d1 = DOMAIN("somename", Integer)
+        d2 = DOMAIN("somename", Integer, schema="someschema")
+
+        stmt = select(cast(column("foo"), d1), cast(column("bar"), d2))
+        self.assert_compile(
+            stmt,
+            "SELECT CAST(foo AS somename) AS foo, "
+            "CAST(bar AS someschema.somename) AS bar",
+        )
+
+    def test_create_domain_schema_translate(self):
+        d1 = DOMAIN("somename", Integer)
+        d2 = DOMAIN("somename", Integer, schema="someschema")
+        schema_translate_map = {None: "foo", "someschema": "bar"}
+
+        self.assert_compile(
+            postgresql.CreateDomainType(d1),
+            "CREATE DOMAIN foo.somename AS INTEGER ",
+            schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
+        )
+
+        self.assert_compile(
+            postgresql.CreateDomainType(d2),
+            "CREATE DOMAIN bar.somename AS INTEGER ",
+            schema_translate_map=schema_translate_map,
+            render_schema_translate=True,
+        )
+
     def test_create_table_with_schema_type_schema_translate(self):
         e1 = Enum("x", "y", "z", name="somename")
         e2 = Enum("x", "y", "z", name="somename", schema="someschema")
index 21b4149bc0e48a121850b1d305a3c9fe101410e6..99bc14d7841367cdce98a39447da9b5a36ac3201 100644 (file)
@@ -410,6 +410,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
                 "CREATE DOMAIN nullable_domain AS TEXT CHECK "
                 "(VALUE IN('FOO', 'BAR'))",
                 "CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL",
+                "CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK "
+                "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) "
+                "CHECK(VALUE != 22)",
             ]:
                 try:
                     con.exec_driver_sql(ddl)
@@ -468,6 +471,7 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
             con.exec_driver_sql("DROP TABLE nullable_domain_test")
             con.exec_driver_sql("DROP DOMAIN nullable_domain")
             con.exec_driver_sql("DROP DOMAIN not_nullable_domain")
+            con.exec_driver_sql("DROP DOMAIN my_int")
 
     def test_table_is_reflected(self, connection):
         metadata = MetaData()
@@ -579,6 +583,122 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults):
         finally:
             base.PGDialect.ischema_names = ischema_names
 
+    @property
+    def all_domains(self):
+        return {
+            "public": [
+                {
+                    "visible": True,
+                    "name": "arraydomain",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "integer[]",
+                    "default": None,
+                    "constraints": [],
+                },
+                {
+                    "visible": True,
+                    "name": "enumdomain",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "testtype",
+                    "default": None,
+                    "constraints": [],
+                },
+                {
+                    "visible": True,
+                    "name": "my_int",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "integer",
+                    "default": None,
+                    "constraints": [
+                        {"check": "VALUE < 42", "name": "a_my_int_two"},
+                        {"check": "VALUE > 1", "name": "b_my_int_one"},
+                        # autogenerated name by pg
+                        {"check": "VALUE <> 22", "name": "my_int_check"},
+                    ],
+                },
+                {
+                    "visible": True,
+                    "name": "not_nullable_domain",
+                    "schema": "public",
+                    "nullable": False,
+                    "type": "text",
+                    "default": None,
+                    "constraints": [],
+                },
+                {
+                    "visible": True,
+                    "name": "nullable_domain",
+                    "schema": "public",
+                    "nullable": True,
+                    "type": "text",
+                    "default": None,
+                    "constraints": [
+                        {
+                            "check": "VALUE = ANY (ARRAY['FOO'::text, "
+                            "'BAR'::text])",
+                            # autogenerated name by pg
+                            "name": "nullable_domain_check",
+                        }
+                    ],
+                },
+                {
+                    "visible": True,
+                    "name": "testdomain",
+                    "schema": "public",
+                    "nullable": False,
+                    "type": "integer",
+                    "default": "42",
+                    "constraints": [],
+                },
+            ],
+            "test_schema": [
+                {
+                    "visible": False,
+                    "name": "testdomain",
+                    "schema": "test_schema",
+                    "nullable": True,
+                    "type": "integer",
+                    "default": "0",
+                    "constraints": [],
+                }
+            ],
+            "SomeSchema": [
+                {
+                    "visible": False,
+                    "name": "Quoted.Domain",
+                    "schema": "SomeSchema",
+                    "nullable": True,
+                    "type": "integer",
+                    "default": "0",
+                    "constraints": [],
+                }
+            ],
+        }
+
+    def test_inspect_domains(self, connection):
+        inspector = inspect(connection)
+        eq_(inspector.get_domains(), self.all_domains["public"])
+
+    def test_inspect_domains_schema(self, connection):
+        inspector = inspect(connection)
+        eq_(
+            inspector.get_domains("test_schema"),
+            self.all_domains["test_schema"],
+        )
+        eq_(
+            inspector.get_domains("SomeSchema"), self.all_domains["SomeSchema"]
+        )
+
+    def test_inspect_domains_star(self, connection):
+        inspector = inspect(connection)
+        all_ = [d for dl in self.all_domains.values() for d in dl]
+        all_ += inspector.get_domains("information_schema")
+        exp = sorted(all_, key=lambda d: (d["schema"], d["name"]))
+        eq_(inspector.get_domains("*"), exp)
+
 
 class ReflectionTest(
     ReflectionFixtures, AssertsCompiledSQL, fixtures.TestBase
@@ -1800,10 +1920,10 @@ class ReflectionTest(
         eq_(
             check_constraints,
             {
-                "cc1": "(a > 1) AND (a < 5)",
-                "cc2": "(a = 1) OR ((a > 2) AND (a < 5))",
+                "cc1": "a > 1 AND a < 5",
+                "cc2": "a = 1 OR a > 2 AND a < 5",
                 "cc3": "is_positive(a)",
-                "cc4": "(b)::text <> 'hi\nim a name   \nyup\n'::text",
+                "cc4": "b::text <> 'hi\nim a name   \nyup\n'::text",
             },
         )
 
index 79f029391b2ae12600f3b0d0518244dc0571e98f..5c3935d44614a03960d5277f34af74eeaa42d9c3 100644 (file)
@@ -38,12 +38,15 @@ from sqlalchemy import util
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATERANGE
+from sqlalchemy.dialects.postgresql import DOMAIN
+from sqlalchemy.dialects.postgresql import ENUM
 from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import hstore
 from sqlalchemy.dialects.postgresql import INT4RANGE
 from sqlalchemy.dialects.postgresql import INT8RANGE
 from sqlalchemy.dialects.postgresql import JSON
 from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import NamedType
 from sqlalchemy.dialects.postgresql import NUMRANGE
 from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZRANGE
@@ -161,7 +164,7 @@ class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults):
         eq_(row, ([5], [5], [6], [7], [decimal.Decimal("6.4")]))
 
 
-class EnumTest(fixtures.TestBase, AssertsExecutionResults):
+class NamedTypeTest(fixtures.TestBase, AssertsExecutionResults):
     __backend__ = True
 
     __only_on__ = "postgresql > 8.3"
@@ -173,16 +176,18 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             "the native_enum flag does not apply to the "
             "sqlalchemy.dialects.postgresql.ENUM datatype;"
         ):
-            e1 = postgresql.ENUM("a", "b", "c", native_enum=False)
+            e1 = postgresql.ENUM(
+                "a", "b", "c", name="pgenum", native_enum=False
+            )
 
-        e2 = postgresql.ENUM("a", "b", "c", native_enum=True)
-        e3 = postgresql.ENUM("a", "b", "c")
+        e2 = postgresql.ENUM("a", "b", "c", name="pgenum", native_enum=True)
+        e3 = postgresql.ENUM("a", "b", "c", name="pgenum")
 
         is_(e1.native_enum, True)
         is_(e2.native_enum, True)
         is_(e3.native_enum, True)
 
-    def test_create_table(self, metadata, connection):
+    def test_enum_create_table(self, metadata, connection):
         metadata = self.metadata
         t1 = Table(
             "table",
@@ -202,50 +207,147 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             [(1, "two"), (2, "three"), (3, "three")],
         )
 
+    def test_domain_create_table(self, metadata, connection):
+        metadata = self.metadata
+        Email = DOMAIN(
+            name="email",
+            data_type=Text,
+            check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+        )
+        PosInt = DOMAIN(
+            name="pos_int",
+            data_type=Integer,
+            not_null=True,
+            check=r"VALUE > 0",
+        )
+        t1 = Table(
+            "table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("email", Email),
+            Column("number", PosInt),
+        )
+        t1.create(connection)
+        t1.create(connection, checkfirst=True)  # check the create
+        connection.execute(
+            t1.insert(), {"email": "test@example.com", "number": 42}
+        )
+        connection.execute(t1.insert(), {"email": "a@b.c", "number": 1})
+        connection.execute(
+            t1.insert(), {"email": "example@gmail.co.uk", "number": 99}
+        )
+        eq_(
+            connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
+            [
+                (1, "test@example.com", 42),
+                (2, "a@b.c", 1),
+                (3, "example@gmail.co.uk", 99),
+            ],
+        )
+
+    @testing.combinations(
+        (ENUM("one", "two", "three", name="mytype"), "get_enums"),
+        (
+            DOMAIN(
+                name="mytype",
+                data_type=Text,
+                check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+            ),
+            "get_domains",
+        ),
+        argnames="datatype, method",
+    )
+    def test_drops_on_table(
+        self, connection, metadata, datatype: "NamedType", method
+    ):
+        table = Table("e1", metadata, Column("e1", datatype))
+
+        table.create(connection)
+        table.drop(connection)
+
+        assert "mytype" not in [
+            e["name"] for e in getattr(inspect(connection), method)()
+        ]
+        table.create(connection)
+        assert "mytype" in [
+            e["name"] for e in getattr(inspect(connection), method)()
+        ]
+        table.drop(connection)
+        assert "mytype" not in [
+            e["name"] for e in getattr(inspect(connection), method)()
+        ]
+
+    @testing.combinations(
+        (
+            lambda symbol_name: ENUM(
+                "one", "two", "three", name="schema_mytype", schema=symbol_name
+            ),
+            ["two", "three", "three"],
+            "get_enums",
+        ),
+        (
+            lambda symbol_name: DOMAIN(
+                name="schema_mytype",
+                data_type=Text,
+                check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+                schema=symbol_name,
+            ),
+            ["test@example.com", "a@b.c", "example@gmail.co.uk"],
+            "get_domains",
+        ),
+        argnames="datatype,data,method",
+    )
     @testing.combinations(None, "foo", argnames="symbol_name")
-    def test_create_table_schema_translate_map(self, connection, symbol_name):
+    def test_create_table_schema_translate_map(
+        self, connection, symbol_name, datatype, data, method
+    ):
         # note we can't use the fixture here because it will not drop
         # from the correct schema
         metadata = MetaData()
 
+        dt = datatype(symbol_name)
+
         t1 = Table(
             "table",
             metadata,
             Column("id", Integer, primary_key=True),
-            Column(
-                "value",
-                Enum(
-                    "one",
-                    "two",
-                    "three",
-                    name="schema_enum",
-                    schema=symbol_name,
-                ),
-            ),
+            Column("value", dt),
             schema=symbol_name,
         )
         conn = connection.execution_options(
             schema_translate_map={symbol_name: testing.config.test_schema}
         )
         t1.create(conn)
-        assert "schema_enum" in [
+        assert "schema_mytype" in [
             e["name"]
-            for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+            for e in getattr(inspect(conn), method)(
+                schema=testing.config.test_schema
+            )
         ]
         t1.create(conn, checkfirst=True)
 
-        conn.execute(t1.insert(), dict(value="two"))
-        conn.execute(t1.insert(), dict(value="three"))
-        conn.execute(t1.insert(), dict(value="three"))
+        conn.execute(
+            t1.insert(),
+            dict(value=data[0]),
+        )
+        conn.execute(t1.insert(), dict(value=data[1]))
+        conn.execute(t1.insert(), dict(value=data[2]))
         eq_(
             conn.execute(t1.select().order_by(t1.c.id)).fetchall(),
-            [(1, "two"), (2, "three"), (3, "three")],
+            [
+                (1, data[0]),
+                (2, data[1]),
+                (3, data[2]),
+            ],
         )
 
         t1.drop(conn)
-        assert "schema_enum" not in [
+
+        assert "schema_mytype" not in [
             e["name"]
-            for e in inspect(conn).get_enums(schema=testing.config.test_schema)
+            for e in getattr(inspect(conn), method)(
+                schema=testing.config.test_schema
+            )
         ]
         t1.drop(conn, checkfirst=True)
 
@@ -256,40 +358,48 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         ("override_metadata_schema",),
         argnames="test_case",
     )
+    @testing.combinations("enum", "domain", argnames="datatype")
     @testing.requires.schemas
-    def test_schema_inheritance(self, test_case, metadata, connection):
+    def test_schema_inheritance(
+        self, test_case, metadata, connection, datatype
+    ):
         """test #6373"""
 
         metadata.schema = testing.config.test_schema
 
+        def make_type(**kw):
+            if datatype == "enum":
+                return Enum("four", "five", "six", name="mytype", **kw)
+            elif datatype == "domain":
+                return DOMAIN(
+                    name="mytype",
+                    data_type=Text,
+                    check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+                    **kw,
+                )
+            else:
+                assert False
+
         if test_case == "metadata_schema_only":
-            enum = Enum(
-                "four", "five", "six", metadata=metadata, name="myenum"
-            )
+            enum = make_type(metadata=metadata)
             assert_schema = testing.config.test_schema
         elif test_case == "override_metadata_schema":
-            enum = Enum(
-                "four",
-                "five",
-                "six",
+            enum = make_type(
                 metadata=metadata,
                 schema=testing.config.test_schema_2,
-                name="myenum",
             )
             assert_schema = testing.config.test_schema_2
         elif test_case == "inherit_table_schema":
-            enum = Enum(
-                "four",
-                "five",
-                "six",
+            enum = make_type(
                 metadata=metadata,
                 inherit_schema=True,
-                name="myenum",
             )
             assert_schema = testing.config.test_schema_2
         elif test_case == "local_schema":
-            enum = Enum("four", "five", "six", name="myenum")
+            enum = make_type()
             assert_schema = testing.config.db.dialect.default_schema_name
+        else:
+            assert False
 
         Table(
             "t",
@@ -300,27 +410,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
 
         metadata.create_all(connection)
 
-        eq_(
-            inspect(connection).get_enums(schema=assert_schema),
-            [
-                {
-                    "labels": ["four", "five", "six"],
-                    "name": "myenum",
-                    "schema": assert_schema,
-                    "visible": assert_schema
-                    == testing.config.db.dialect.default_schema_name,
-                }
-            ],
-        )
+        if datatype == "enum":
+            eq_(
+                inspect(connection).get_enums(schema=assert_schema),
+                [
+                    {
+                        "labels": ["four", "five", "six"],
+                        "name": "mytype",
+                        "schema": assert_schema,
+                        "visible": assert_schema
+                        == testing.config.db.dialect.default_schema_name,
+                    }
+                ],
+            )
+        elif datatype == "domain":
+
+            def_schame = testing.config.db.dialect.default_schema_name
+            eq_(
+                inspect(connection).get_domains(schema=assert_schema),
+                [
+                    {
+                        "name": "mytype",
+                        "type": "text",
+                        "nullable": True,
+                        "default": None,
+                        "schema": assert_schema,
+                        "visible": assert_schema == def_schame,
+                        "constraints": [
+                            {
+                                "name": "mytype_check",
+                                "check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text",
+                            }
+                        ],
+                    }
+                ],
+            )
+        else:
+            assert False
 
-    def test_name_required(self, metadata, connection):
-        etype = Enum("four", "five", "six", metadata=metadata)
-        assert_raises(exc.CompileError, etype.create, connection)
+    @testing.combinations(
+        (ENUM("one", "two", "three", name=None)),
+        (
+            DOMAIN(
+                name=None,
+                data_type=Text,
+                check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+            ),
+        ),
+        argnames="datatype",
+    )
+    def test_name_required(self, metadata, connection, datatype):
+        assert_raises(exc.CompileError, datatype.create, connection)
         assert_raises(
-            exc.CompileError, etype.compile, dialect=connection.dialect
+            exc.CompileError, datatype.compile, dialect=connection.dialect
         )
 
-    def test_unicode_labels(self, connection, metadata):
+    def test_enum_unicode_labels(self, connection, metadata):
         t1 = Table(
             "table",
             metadata,
@@ -426,22 +571,30 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         connection.execute(t1.insert(), {"bar": "Ü"})
         eq_(connection.scalar(select(t1.c.bar)), "Ü")
 
-    def test_disable_create(self, metadata, connection):
+    @testing.combinations(
+        (ENUM("one", "two", "three", name="mytype", create_type=False),),
+        (
+            DOMAIN(
+                name="mytype",
+                data_type=Text,
+                check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+                create_type=False,
+            ),
+        ),
+        argnames="datatype",
+    )
+    def test_disable_create(self, metadata, connection, datatype):
         metadata = self.metadata
 
-        e1 = postgresql.ENUM(
-            "one", "two", "three", name="myenum", create_type=False
-        )
-
-        t1 = Table("e1", metadata, Column("c1", e1))
+        t1 = Table("e1", metadata, Column("c1", datatype))
         # table can be created separately
         # without conflict
-        e1.create(bind=connection)
+        datatype.create(bind=connection)
         t1.create(connection)
         t1.drop(connection)
-        e1.drop(bind=connection)
+        datatype.drop(bind=connection)
 
-    def test_dont_keep_checking(self, metadata, connection):
+    def test_enum_dont_keep_checking(self, metadata, connection):
         metadata = self.metadata
 
         e1 = postgresql.ENUM("one", "two", "three", name="myenum")
@@ -486,7 +639,36 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             RegexSQL("DROP TYPE myenum", dialect="postgresql"),
         )
 
-    def test_generate_multiple(self, metadata, connection):
+    @testing.combinations(
+        (
+            Enum(
+                "one",
+                "two",
+                "three",
+                name="mytype",
+            ),
+            "get_enums",
+        ),
+        (
+            ENUM(
+                "one",
+                "two",
+                "three",
+                name="mytype",
+            ),
+            "get_enums",
+        ),
+        (
+            DOMAIN(
+                name="mytype",
+                data_type=Text,
+                check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+            ),
+            "get_domains",
+        ),
+        argnames="datatype, method",
+    )
+    def test_generate_multiple(self, metadata, connection, datatype, method):
         """Test that the same enum twice only generates once
         for the create_all() call, without using checkfirst.
 
@@ -494,15 +676,20 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         now handles this.
 
         """
-        e1 = Enum("one", "two", "three", name="myenum")
-        Table("e1", metadata, Column("c1", e1))
+        Table("e1", metadata, Column("c1", datatype))
 
-        Table("e2", metadata, Column("c1", e1))
+        Table("e2", metadata, Column("c1", datatype))
 
         metadata.create_all(connection, checkfirst=False)
+
+        assert "mytype" in [
+            e["name"] for e in getattr(inspect(connection), method)()
+        ]
+
         metadata.drop_all(connection, checkfirst=False)
-        assert "myenum" not in [
-            e["name"] for e in inspect(connection).get_enums()
+
+        assert "mytype" not in [
+            e["name"] for e in getattr(inspect(connection), method)()
         ]
 
     def test_generate_alone_on_metadata(self, connection, metadata):
@@ -571,23 +758,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
             for e in inspect(connection).get_enums(schema="test_schema")
         ]
 
-    def test_drops_on_table(self, connection, metadata):
-
-        e1 = Enum("one", "two", "three", name="myenum")
-        table = Table("e1", metadata, Column("c1", e1))
-
-        table.create(connection)
-        table.drop(connection)
-        assert "myenum" not in [
-            e["name"] for e in inspect(connection).get_enums()
-        ]
-        table.create(connection)
-        assert "myenum" in [e["name"] for e in inspect(connection).get_enums()]
-        table.drop(connection)
-        assert "myenum" not in [
-            e["name"] for e in inspect(connection).get_enums()
-        ]
-
     def test_create_drop_schema_translate_map(self, connection):
 
         conn = connection.execution_options(
@@ -1445,15 +1615,16 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase):
             array_agg,
         )
 
-        element_type = ENUM if with_enum else Integer
+        element = ENUM(name="pgenum") if with_enum else Integer()
+        element_type = type(element)
         expr = (
             array_agg(
                 aggregate_order_by(
-                    column("q", element_type), column("idx", Integer)
+                    column("q", element), column("idx", Integer)
                 )
             )
             if using_aggregate_order_by
-            else array_agg(column("q", element_type))
+            else array_agg(column("q", element))
         )
         is_(expr.type.__class__, postgresql.ARRAY)
         is_(expr.type.item_type.__class__, element_type)
@@ -2081,10 +2252,13 @@ class ArrayRoundTripTest:
                 ],
                 testing.requires.hstore,
             ),
-            (postgresql.ENUM(AnEnum), enum_values),
+            (postgresql.ENUM(AnEnum, name="pgenum"), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=False), enum_values),
-            (postgresql.ENUM(AnEnum, native_enum=True), enum_values),
+            (
+                postgresql.ENUM(AnEnum, name="pgenum", native_enum=True),
+                enum_values,
+            ),
             (
                 make_difficult_enum(sqltypes.Enum, native=True),
                 difficult_enum_values,
@@ -2102,10 +2276,15 @@ class ArrayRoundTripTest:
         if not exclude_empty_lists:
             elements.extend(
                 [
-                    (postgresql.ENUM(AnEnum), empty_list),
+                    (postgresql.ENUM(AnEnum, name="pgenum"), empty_list),
                     (sqltypes.Enum(AnEnum, native_enum=True), empty_list),
                     (sqltypes.Enum(AnEnum, native_enum=False), empty_list),
-                    (postgresql.ENUM(AnEnum, native_enum=True), empty_list),
+                    (
+                        postgresql.ENUM(
+                            AnEnum, name="pgenum", native_enum=True
+                        ),
+                        empty_list,
+                    ),
                 ]
             )
         if not exclude_json:
@@ -2410,7 +2589,7 @@ class ArrayEnum(fixtures.TestBase):
                 ),
                 Column(
                     "pyenum_col",
-                    array_cls(enum_cls(MyEnum)),
+                    array_cls(enum_cls(MyEnum, name="pgenum")),
                 ),
             )
 
index 04aa4e000e32e87ce3ce9eedaafda1d26903004a..623688b83ea5fe8872e5db595cd294ef6364b7d2 100644 (file)
@@ -111,7 +111,11 @@ def _all_dialects():
 def _types_for_mod(mod):
     for key in dir(mod):
         typ = getattr(mod, key)
-        if not isinstance(typ, type) or not issubclass(typ, types.TypeEngine):
+        if (
+            not isinstance(typ, type)
+            or not issubclass(typ, types.TypeEngine)
+            or typ.__dict__.get("__abstract__")
+        ):
             continue
         yield typ
 
@@ -143,6 +147,17 @@ def _all_types(omit_special_types=False):
             yield typ
 
 
+def _get_instance(type_):
+    if issubclass(type_, ARRAY):
+        return type_(String)
+    elif hasattr(type_, "__test_init__"):
+        t1 = type_.__test_init__()
+        is_(isinstance(t1, type_), True)
+        return t1
+    else:
+        return type_()
+
+
 class AdaptTest(fixtures.TestBase):
     @testing.combinations(((t,) for t in _types_for_mod(types)), id_="n")
     def test_uppercase_importable(self, typ):
@@ -240,11 +255,8 @@ class AdaptTest(fixtures.TestBase):
         adapt() beyond their defaults.
 
         """
+        t1 = _get_instance(typ)
 
-        if issubclass(typ, ARRAY):
-            t1 = typ(String)
-        else:
-            t1 = typ()
         for cls in target_adaptions:
             if (is_down_adaption and issubclass(typ, sqltypes.Emulated)) or (
                 not is_down_adaption and issubclass(cls, sqltypes.Emulated)
@@ -301,19 +313,13 @@ class AdaptTest(fixtures.TestBase):
     @testing.uses_deprecated()
     @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
     def test_repr(self, typ):
-        if issubclass(typ, ARRAY):
-            t1 = typ(String)
-        else:
-            t1 = typ()
+        t1 = _get_instance(typ)
         repr(t1)
 
     @testing.uses_deprecated()
     @testing.combinations(*[(t,) for t in _all_types(omit_special_types=True)])
     def test_str(self, typ):
-        if issubclass(typ, ARRAY):
-            t1 = typ(String)
-        else:
-            t1 = typ()
+        t1 = _get_instance(typ)
         str(t1)
 
     def test_str_third_party(self):
@@ -400,7 +406,7 @@ class AsGenericTest(fixtures.TestBase):
         (pg.JSON(), sa.JSON()),
         (pg.ARRAY(sa.String), sa.ARRAY(sa.String)),
         (Enum("a", "b", "c"), Enum("a", "b", "c")),
-        (pg.ENUM("a", "b", "c"), Enum("a", "b", "c")),
+        (pg.ENUM("a", "b", "c", name="pgenum"), Enum("a", "b", "c")),
         (mysql.ENUM("a", "b", "c"), Enum("a", "b", "c")),
         (pg.INTERVAL(precision=5), Interval(native=True, second_precision=5)),
         (
@@ -419,11 +425,7 @@ class AsGenericTest(fixtures.TestBase):
         ]
     )
     def test_as_generic_all_types_heuristic(self, type_):
-        if issubclass(type_, ARRAY):
-            t1 = type_(String)
-        else:
-            t1 = type_()
-
+        t1 = _get_instance(type_)
         try:
             gentype = t1.as_generic()
         except NotImplementedError:
@@ -445,10 +447,7 @@ class AsGenericTest(fixtures.TestBase):
         ]
     )
     def test_as_generic_all_types_custom(self, type_):
-        if issubclass(type_, ARRAY):
-            t1 = type_(String)
-        else:
-            t1 = type_()
+        t1 = _get_instance(type_)
 
         gentype = t1.as_generic(allow_nulltype=False)
         assert isinstance(gentype, TypeEngine)