From: David Baumgold Date: Thu, 11 Nov 2021 12:29:48 +0000 (+0100) Subject: Add a Domain type to the Postgresql dialect X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc9a82f010e6ca2f70a6e8a7620b748e483c26c3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add a Domain type to the Postgresql dialect --- diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index b1fd2a3421..fdbf530d5e 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -24,6 +24,7 @@ from .base import CHAR from .base import CIDR from .base import CreateEnumType from .base import DATE +from .base import DOMAIN from .base import DOUBLE_PRECISION from .base import DropEnumType from .base import ENUM @@ -96,6 +97,7 @@ __all__ = ( "INTERVAL", "ARRAY", "ENUM", + "DOMAIN", "dialect", "array", "HSTORE", diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 698ea277f5..9c9e0fa9d3 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1840,7 +1840,136 @@ class TSVECTOR(sqltypes.TypeEngine): __visit_name__ = "TSVECTOR" -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): +class _NamedType(sqltypes.TypeEngine, sqltypes.SchemaType): + """Base for named types.""" + + def create(self, bind=None, checkfirst=True, **kw): + """Emit ``CREATE`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True, **kw): + """Emit ``DROP`` DDL for this type. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst) + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named type is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + type_name = "pg_%s" % (self.__visit_name__,) + if type_name in ddl_runner.memo: + existing = ddl_runner.memo[type_name] + else: + existing = ddl_runner.memo[type_name] = set() + present = (self.schema, self.name) in existing + existing.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class NamedTypeGenerator(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(NamedTypeGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_create_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return not self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class NamedTypeDropper(DDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(NamedTypeDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + self.preparer = dialect.identifier_preparer + self.dialect = dialect + + def _can_drop_type(self, type_): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(type_) + return self.connection.dialect.has_type( + self.connection, type_.name, schema=effective_schema + ) + + +class EnumGenerator(NamedTypeGenerator): + def visit_enum(self, enum): + if not self._can_create_type(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + +class EnumDropper(NamedTypeDropper): + def visit_enum(self, enum): + if not self._can_drop_type(enum): + return + + self.connection.execute(DropEnumType(enum)) + + +class ENUM(_NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): """PostgreSQL ENUM type. @@ -1920,6 +2049,8 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): """ native_enum = True + DDLGenerator = EnumGenerator + DDLDropper = EnumDropper def __init__(self, *enums, **kw): """Construct an :class:`_postgresql.ENUM`. @@ -1994,7 +2125,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) + super(ENUM, self).create(bind, checkfirst=checkfirst) def drop(self, bind=None, checkfirst=True): """Emit ``DROP TYPE`` for this @@ -2014,104 +2145,52 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - - class EnumGenerator(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_create_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return not self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_create_enum(enum): - return - - self.connection.execute(CreateEnumType(enum)) - - class EnumDropper(DDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst + super(ENUM, self).drop(bind, checkfirst=checkfirst) - def _can_drop_enum(self, enum): - if not self.checkfirst: - return True - effective_schema = self.connection.schema_for_object(enum) - - return self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) +class DomainGenerator(NamedTypeGenerator): + def visit_domain(self, domain): + if not self._can_create_type(domain): + return + self.connection.execute(CreateDomainType(domain)) - def visit_enum(self, enum): - if not self._can_drop_enum(enum): - return - self.connection.execute(DropEnumType(enum)) +class DomainDropper(NamedTypeDropper): + def visit_domain(self, domain): + if not self._can_drop_type(domain): + return - def get_dbapi_type(self, dbapi): - """dont return dbapi.STRING for ENUM in PostgreSQL, since that's - a different type""" + self.connection.execute(DropDomainType(domain)) - return None - def _check_for_name_in_memos(self, checkfirst, kw): - """Look in the 'ddl runner' for 'memos', then - note our name in that collection. +class DOMAIN(_NamedType): + """Represent the DOMAIN type.""" - This to ensure a particular named enum is operated - upon only once within any kind of create/drop - sequence without relying upon "checkfirst". + __visit_name__ = "domain" + DDLGenerator = DomainGenerator + DDLDropper = DomainDropper + def __init__( + self, + name=None, + data_type=TEXT, + default=None, + constraint=None, + collation=None, + **kw + ): """ - if not self.create_type: - return True - if "_ddl_runner" in kw: - ddl_runner = kw["_ddl_runner"] - if "_pg_enums" in ddl_runner.memo: - pg_enums = ddl_runner.memo["_pg_enums"] - else: - pg_enums = ddl_runner.memo["_pg_enums"] = set() - present = (self.schema, self.name) in pg_enums - pg_enums.add((self.schema, self.name)) - return present + Construct a DOMAIN. + """ + self.name = name + if isinstance(data_type, sqltypes.TypeEngine): + self.data_type = data_type else: - return False - - def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( - checkfirst - or ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - ) - ) and not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - and not self._check_for_name_in_memos(checkfirst, kw) - ): - self.drop(bind=bind, checkfirst=checkfirst) - - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.drop(bind=bind, checkfirst=checkfirst) + self.data_type = data_type() + self.default = default + self.constraint = constraint + self.collation = collation + super(DOMAIN, self).__init__(name=name, **kw) colspecs = { @@ -2677,6 +2756,31 @@ class PGDDLCompiler(compiler.DDLCompiler): return "DROP TYPE %s" % (self.preparer.format_type(type_)) + def visit_create_domain_type(self, create): + type_ = create.element + + options = [] + if type_.collation is not None: + options.append( + "COLLATE %s" % (self.preparer.quote(type_.collation),) + ) + if type_.default is not None: + options.append( + "DEFAULT %s" % (type_.default,) + ) + if type_.constraint is not None: + options.append("CHECK (%s)" % type_.constraint) + + return "CREATE DOMAIN %s AS %s %s" % ( + self.preparer.format_type(type_), + self.type_compiler.process(type_.data_type), + " ".join(options), + ) + + def visit_drop_domain_type(self, domain_type): + domain = domain_type.element + return "DROP DOMAIN %s" % (self.preparer.format_type(domain)) + def visit_create_index(self, create): preparer = self.preparer index = create.element @@ -2949,6 +3053,9 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): identifier_preparer = self.dialect.identifier_preparer return identifier_preparer.format_type(type_) + def visit_domain(self, type_, **kw): + return type_.name + def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( "(%d)" % type_.precision @@ -3021,7 +3128,10 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def format_type(self, type_, use_schema=True): if not type_.name: - raise exc.CompileError("PostgreSQL ENUM type requires a name.") + raise exc.CompileError( + "PostgreSQL %s type requires a name." + % (type_.__class__.__name__,) + ) name = self.quote(type_.name) effective_schema = self.schema_for_object(type_) @@ -3044,6 +3154,17 @@ class PGInspector(reflection.Inspector): conn, table_name, schema, info_cache=self.info_cache ) + def get_domains(self, schema=None): + """Return a list of DOMAIN objects. + + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + indicate load enums for all schemas. + """ + schema = schema or self.default_schema_name + with self._operation_context() as conn: + return self.dialect._load_domains(conn) + def get_enums(self, schema=None): """Return a list of ENUM objects. @@ -3109,6 +3230,18 @@ class DropEnumType(schema._CreateDropBase): __visit_name__ = "drop_enum_type" +class CreateDomainType(schema._CreateDropBase): + """Represent a CREATE DOMAIN statement.""" + + __visit_name__ = "create_domain_type" + + +class DropDomainType(schema._CreateDropBase): + """Represent a DROP DOMAIN statement.""" + + __visit_name__ = "drop_domain_type" + + class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 0c00c76333..50869fbb36 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -36,6 +36,7 @@ from sqlalchemy import util from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import array from sqlalchemy.dialects.postgresql import DATERANGE +from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import HSTORE from sqlalchemy.dialects.postgresql import hstore from sqlalchemy.dialects.postgresql import INT4RANGE @@ -857,6 +858,62 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): ] +class DomainTest(fixtures.TestBase, AssertsExecutionResults): + __backend__ = True + __only_on__ = "postgresql" + + def test_create_table(self, metadata, connection): + metadata = self.metadata + Email = DOMAIN( + name="email", + data_type=Text, + constraint=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ) + t1 = Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("email", Email), + ) + t1.create(connection) + t1.create(connection, checkfirst=True) # check the create + connection.execute(t1.insert(), dict(email="test@example.com")) + connection.execute(t1.insert(), dict(email="a@b.c")) + connection.execute(t1.insert(), dict(email="example@gmail.co.uk")) + eq_( + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), + [ + (1, "test@example.com"), + (2, "a@b.c"), + (3, "example@gmail.co.uk"), + ], + ) + + def test_name_required(self, metadata, connection): + dtype = DOMAIN(metadata=metadata) + assert_raises(exc.CompileError, dtype.create, connection) + # not sure why this doesn't work... + # assert_raises( + # exc.CompileError, dtype.compile, dialect=connection.dialect + # ) + + def test_drops_on_table(self, connection, metadata): + Email = DOMAIN( + name="email", + data_type=Text, + constraint=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ) + table = Table("e1", metadata, Column("e1", Email)) + + table.create(connection) + table.drop(connection) + assert ("email",) not in inspect(connection).get_domains().keys() + table.create(connection) + assert ("email",) in inspect(connection).get_domains().keys() + table.drop(connection) + assert ("email",) not in inspect(connection).get_domains().keys() + + class OIDTest(fixtures.TestBase): __only_on__ = "postgresql" __backend__ = True