]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add a Domain type to the Postgresql dialect 7317/head
authorDavid Baumgold <david@davidbaumgold.com>
Thu, 11 Nov 2021 12:29:48 +0000 (13:29 +0100)
committerDavid Baumgold <david@davidbaumgold.com>
Fri, 11 Feb 2022 16:21:20 +0000 (17:21 +0100)
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/base.py
test/dialect/postgresql/test_types.py

index b1fd2a34212939a8ee9d587879f29de64dae71ee..fdbf530d5e305399cb2ce397ad2e8792cc87a250 100644 (file)
@@ -24,6 +24,7 @@ from .base import CHAR
 from .base import CIDR
 from .base import CreateEnumType
 from .base import DATE
+from .base import DOMAIN
 from .base import DOUBLE_PRECISION
 from .base import DropEnumType
 from .base import ENUM
@@ -96,6 +97,7 @@ __all__ = (
     "INTERVAL",
     "ARRAY",
     "ENUM",
+    "DOMAIN",
     "dialect",
     "array",
     "HSTORE",
index 698ea277f5d2c21efc99673759fa1d58f1550c55..9c9e0fa9d35067eceae0d1d165ab8ac62c43766b 100644 (file)
@@ -1840,7 +1840,136 @@ class TSVECTOR(sqltypes.TypeEngine):
     __visit_name__ = "TSVECTOR"
 
 
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+class _NamedType(sqltypes.TypeEngine, sqltypes.SchemaType):
+    """Base for named types."""
+
+    def create(self, bind=None, checkfirst=True, **kw):
+        """Emit ``CREATE`` DDL for this type.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type does not exist already before
+         creating.
+
+        """
+        bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
+
+    def drop(self, bind=None, checkfirst=True, **kw):
+        """Emit ``DROP`` DDL for this type.
+
+        :param bind: a connectable :class:`_engine.Engine`,
+         :class:`_engine.Connection`, or similar object to emit
+         SQL.
+        :param checkfirst: if ``True``, a query against
+         the PG catalog will be first performed to see
+         if the type actually exists before dropping.
+
+        """
+        bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
+
+    def _check_for_name_in_memos(self, checkfirst, kw):
+        """Look in the 'ddl runner' for 'memos', then
+        note our name in that collection.
+
+        This to ensure a particular named type is operated
+        upon only once within any kind of create/drop
+        sequence without relying upon "checkfirst".
+
+        """
+        if "_ddl_runner" in kw:
+            ddl_runner = kw["_ddl_runner"]
+            type_name = "pg_%s" % (self.__visit_name__,)
+            if type_name in ddl_runner.memo:
+                existing = ddl_runner.memo[type_name]
+            else:
+                existing = ddl_runner.memo[type_name] = set()
+            present = (self.schema, self.name) in existing
+            existing.add((self.schema, self.name))
+            return present
+        else:
+            return False
+
+    def _on_table_create(self, target, bind, checkfirst=False, **kw):
+        if (
+            checkfirst
+            or (
+                not self.metadata
+                and not kw.get("_is_metadata_operation", False)
+            )
+        ) and not self._check_for_name_in_memos(checkfirst, kw):
+            self.create(bind=bind, checkfirst=checkfirst)
+
+    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+        if (
+            not self.metadata
+            and not kw.get("_is_metadata_operation", False)
+            and not self._check_for_name_in_memos(checkfirst, kw)
+        ):
+            self.drop(bind=bind, checkfirst=checkfirst)
+
+    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+        if not self._check_for_name_in_memos(checkfirst, kw):
+            self.create(bind=bind, checkfirst=checkfirst)
+
+    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+        if not self._check_for_name_in_memos(checkfirst, kw):
+            self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class NamedTypeGenerator(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+        super(NamedTypeGenerator, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def _can_create_type(self, type_):
+        if not self.checkfirst:
+            return True
+
+        effective_schema = self.connection.schema_for_object(type_)
+        return not self.connection.dialect.has_type(
+            self.connection, type_.name, schema=effective_schema
+        )
+
+
+class NamedTypeDropper(DDLBase):
+    def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+        super(NamedTypeDropper, self).__init__(connection, **kwargs)
+        self.checkfirst = checkfirst
+        self.preparer = dialect.identifier_preparer
+        self.dialect = dialect
+
+    def _can_drop_type(self, type_):
+        if not self.checkfirst:
+            return True
+
+        effective_schema = self.connection.schema_for_object(type_)
+        return self.connection.dialect.has_type(
+            self.connection, type_.name, schema=effective_schema
+        )
+
+
+class EnumGenerator(NamedTypeGenerator):
+    def visit_enum(self, enum):
+        if not self._can_create_type(enum):
+            return
+
+        self.connection.execute(CreateEnumType(enum))
+
+
+class EnumDropper(NamedTypeDropper):
+    def visit_enum(self, enum):
+        if not self._can_drop_type(enum):
+            return
+
+        self.connection.execute(DropEnumType(enum))
+
+
+class ENUM(_NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
 
     """PostgreSQL ENUM type.
 
@@ -1920,6 +2049,8 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
     """
 
     native_enum = True
+    DDLGenerator = EnumGenerator
+    DDLDropper = EnumDropper
 
     def __init__(self, *enums, **kw):
         """Construct an :class:`_postgresql.ENUM`.
@@ -1994,7 +2125,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         if not bind.dialect.supports_native_enum:
             return
 
-        bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+        super(ENUM, self).create(bind, checkfirst=checkfirst)
 
     def drop(self, bind=None, checkfirst=True):
         """Emit ``DROP TYPE`` for this
@@ -2014,104 +2145,52 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
         if not bind.dialect.supports_native_enum:
             return
 
-        bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
-    class EnumGenerator(DDLBase):
-        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
-            super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
-            self.checkfirst = checkfirst
-
-        def _can_create_enum(self, enum):
-            if not self.checkfirst:
-                return True
-
-            effective_schema = self.connection.schema_for_object(enum)
-
-            return not self.connection.dialect.has_type(
-                self.connection, enum.name, schema=effective_schema
-            )
-
-        def visit_enum(self, enum):
-            if not self._can_create_enum(enum):
-                return
-
-            self.connection.execute(CreateEnumType(enum))
-
-    class EnumDropper(DDLBase):
-        def __init__(self, dialect, connection, checkfirst=False, **kwargs):
-            super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
-            self.checkfirst = checkfirst
+        super(ENUM, self).drop(bind, checkfirst=checkfirst)
 
-        def _can_drop_enum(self, enum):
-            if not self.checkfirst:
-                return True
 
-            effective_schema = self.connection.schema_for_object(enum)
-
-            return self.connection.dialect.has_type(
-                self.connection, enum.name, schema=effective_schema
-            )
+class DomainGenerator(NamedTypeGenerator):
+    def visit_domain(self, domain):
+        if not self._can_create_type(domain):
+            return
+        self.connection.execute(CreateDomainType(domain))
 
-        def visit_enum(self, enum):
-            if not self._can_drop_enum(enum):
-                return
 
-            self.connection.execute(DropEnumType(enum))
+class DomainDropper(NamedTypeDropper):
+    def visit_domain(self, domain):
+        if not self._can_drop_type(domain):
+            return
 
-    def get_dbapi_type(self, dbapi):
-        """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
-        a different type"""
+        self.connection.execute(DropDomainType(domain))
 
-        return None
 
-    def _check_for_name_in_memos(self, checkfirst, kw):
-        """Look in the 'ddl runner' for 'memos', then
-        note our name in that collection.
+class DOMAIN(_NamedType):
+    """Represent the DOMAIN type."""
 
-        This to ensure a particular named enum is operated
-        upon only once within any kind of create/drop
-        sequence without relying upon "checkfirst".
+    __visit_name__ = "domain"
+    DDLGenerator = DomainGenerator
+    DDLDropper = DomainDropper
 
+    def __init__(
+        self,
+        name=None,
+        data_type=TEXT,
+        default=None,
+        constraint=None,
+        collation=None,
+        **kw
+    ):
         """
-        if not self.create_type:
-            return True
-        if "_ddl_runner" in kw:
-            ddl_runner = kw["_ddl_runner"]
-            if "_pg_enums" in ddl_runner.memo:
-                pg_enums = ddl_runner.memo["_pg_enums"]
-            else:
-                pg_enums = ddl_runner.memo["_pg_enums"] = set()
-            present = (self.schema, self.name) in pg_enums
-            pg_enums.add((self.schema, self.name))
-            return present
+        Construct a DOMAIN.
+        """
+        self.name = name
+        if isinstance(data_type, sqltypes.TypeEngine):
+            self.data_type = data_type
         else:
-            return False
-
-    def _on_table_create(self, target, bind, checkfirst=False, **kw):
-        if (
-            checkfirst
-            or (
-                not self.metadata
-                and not kw.get("_is_metadata_operation", False)
-            )
-        ) and not self._check_for_name_in_memos(checkfirst, kw):
-            self.create(bind=bind, checkfirst=checkfirst)
-
-    def _on_table_drop(self, target, bind, checkfirst=False, **kw):
-        if (
-            not self.metadata
-            and not kw.get("_is_metadata_operation", False)
-            and not self._check_for_name_in_memos(checkfirst, kw)
-        ):
-            self.drop(bind=bind, checkfirst=checkfirst)
-
-    def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
-        if not self._check_for_name_in_memos(checkfirst, kw):
-            self.create(bind=bind, checkfirst=checkfirst)
-
-    def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
-        if not self._check_for_name_in_memos(checkfirst, kw):
-            self.drop(bind=bind, checkfirst=checkfirst)
+            self.data_type = data_type()
+        self.default = default
+        self.constraint = constraint
+        self.collation = collation
+        super(DOMAIN, self).__init__(name=name, **kw)
 
 
 colspecs = {
@@ -2677,6 +2756,31 @@ class PGDDLCompiler(compiler.DDLCompiler):
 
         return "DROP TYPE %s" % (self.preparer.format_type(type_))
 
+    def visit_create_domain_type(self, create):
+        type_ = create.element
+
+        options = []
+        if type_.collation is not None:
+            options.append(
+                "COLLATE %s" % (self.preparer.quote(type_.collation),)
+            )
+        if type_.default is not None:
+            options.append(
+                "DEFAULT %s" % (type_.default,)
+            )
+        if type_.constraint is not None:
+            options.append("CHECK (%s)" % type_.constraint)
+
+        return "CREATE DOMAIN %s AS %s %s" % (
+            self.preparer.format_type(type_),
+            self.type_compiler.process(type_.data_type),
+            " ".join(options),
+        )
+
+    def visit_drop_domain_type(self, domain_type):
+        domain = domain_type.element
+        return "DROP DOMAIN %s" % (self.preparer.format_type(domain))
+
     def visit_create_index(self, create):
         preparer = self.preparer
         index = create.element
@@ -2949,6 +3053,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
             identifier_preparer = self.dialect.identifier_preparer
         return identifier_preparer.format_type(type_)
 
+    def visit_domain(self, type_, **kw):
+        return type_.name
+
     def visit_TIMESTAMP(self, type_, **kw):
         return "TIMESTAMP%s %s" % (
             "(%d)" % type_.precision
@@ -3021,7 +3128,10 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
 
     def format_type(self, type_, use_schema=True):
         if not type_.name:
-            raise exc.CompileError("PostgreSQL ENUM type requires a name.")
+            raise exc.CompileError(
+                "PostgreSQL %s type requires a name."
+                % (type_.__class__.__name__,)
+            )
 
         name = self.quote(type_.name)
         effective_schema = self.schema_for_object(type_)
@@ -3044,6 +3154,17 @@ class PGInspector(reflection.Inspector):
                 conn, table_name, schema, info_cache=self.info_cache
             )
 
+    def get_domains(self, schema=None):
+        """Return a list of DOMAIN objects.
+
+        :param schema: schema name.  If None, the default schema
+         (typically 'public') is used.  May also be set to '*' to
+         indicate load enums for all schemas.
+        """
+        schema = schema or self.default_schema_name
+        with self._operation_context() as conn:
+            return self.dialect._load_domains(conn)
+
     def get_enums(self, schema=None):
         """Return a list of ENUM objects.
 
@@ -3109,6 +3230,18 @@ class DropEnumType(schema._CreateDropBase):
     __visit_name__ = "drop_enum_type"
 
 
+class CreateDomainType(schema._CreateDropBase):
+    """Represent a CREATE DOMAIN statement."""
+
+    __visit_name__ = "create_domain_type"
+
+
+class DropDomainType(schema._CreateDropBase):
+    """Represent a DROP DOMAIN statement."""
+
+    __visit_name__ = "drop_domain_type"
+
+
 class PGExecutionContext(default.DefaultExecutionContext):
     def fire_sequence(self, seq, type_):
         return self._execute_scalar(
index 0c00c76333c1164c9c3eb55293efb737433b969e..50869fbb36b579b535d8993bd0b41b477ef02b02 100644 (file)
@@ -36,6 +36,7 @@ from sqlalchemy import util
 from sqlalchemy.dialects import postgresql
 from sqlalchemy.dialects.postgresql import array
 from sqlalchemy.dialects.postgresql import DATERANGE
+from sqlalchemy.dialects.postgresql import DOMAIN
 from sqlalchemy.dialects.postgresql import HSTORE
 from sqlalchemy.dialects.postgresql import hstore
 from sqlalchemy.dialects.postgresql import INT4RANGE
@@ -857,6 +858,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults):
         ]
 
 
+class DomainTest(fixtures.TestBase, AssertsExecutionResults):
+    __backend__ = True
+    __only_on__ = "postgresql"
+
+    def test_create_table(self, metadata, connection):
+        metadata = self.metadata
+        Email = DOMAIN(
+            name="email",
+            data_type=Text,
+            constraint=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+        )
+        t1 = Table(
+            "table",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("email", Email),
+        )
+        t1.create(connection)
+        t1.create(connection, checkfirst=True)  # check the create
+        connection.execute(t1.insert(), dict(email="test@example.com"))
+        connection.execute(t1.insert(), dict(email="a@b.c"))
+        connection.execute(t1.insert(), dict(email="example@gmail.co.uk"))
+        eq_(
+            connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
+            [
+                (1, "test@example.com"),
+                (2, "a@b.c"),
+                (3, "example@gmail.co.uk"),
+            ],
+        )
+
+    def test_name_required(self, metadata, connection):
+        dtype = DOMAIN(metadata=metadata)
+        assert_raises(exc.CompileError, dtype.create, connection)
+        # not sure why this doesn't work...
+        # assert_raises(
+        #     exc.CompileError, dtype.compile, dialect=connection.dialect
+        # )
+
+    def test_drops_on_table(self, connection, metadata):
+        Email = DOMAIN(
+            name="email",
+            data_type=Text,
+            constraint=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+        )
+        table = Table("e1", metadata, Column("e1", Email))
+
+        table.create(connection)
+        table.drop(connection)
+        assert ("email",) not in inspect(connection).get_domains().keys()
+        table.create(connection)
+        assert ("email",) in inspect(connection).get_domains().keys()
+        table.drop(connection)
+        assert ("email",) not in inspect(connection).get_domains().keys()
+
+
 class OIDTest(fixtures.TestBase):
     __only_on__ = "postgresql"
     __backend__ = True