--- /dev/null
+.. change::
+ :tags: postgresql, reflection
+ :tickets: 10693
+
+ The PostgreSQL dialect now returns :class:`_postgresql.DOMAIN` instances
+ when reflecting a column that has a domain as type. Previously, the domain
+ data type was returned instead. As part of this change, the domain
+ reflection was improved to also return the collation of the text types.
+ Pull request courtesy of Thomas Stephenson.
"""The constraints defined in the domain, if any.
The constraint are in order of evaluation by postgresql.
"""
+ collation: Optional[str]
+ """The collation for the domain."""
class ReflectedEnum(ReflectedNamedType):
return columns.items()
- def _get_columns_info(self, rows, domains, enums, schema):
- array_type_pattern = re.compile(r"\[\]$")
- attype_pattern = re.compile(r"\(.*\)")
- charlen_pattern = re.compile(r"\(([\d,]+)\)")
- args_pattern = re.compile(r"\((.*)\)")
- args_split_pattern = re.compile(r"\s*,\s*")
-
- def _handle_array_type(attype):
- return (
- # strip '[]' from integer[], etc.
- array_type_pattern.sub("", attype),
- attype.endswith("[]"),
+ _format_type_args_pattern = re.compile(r"\((.*)\)")
+ _format_type_args_delim = re.compile(r"\s*,\s*")
+ _format_array_spec_pattern = re.compile(r"((?:\[\])*)$")
+
+ def _reflect_type(
+ self,
+ format_type: Optional[str],
+ domains: dict[str, ReflectedDomain],
+ enums: dict[str, ReflectedEnum],
+ type_description: str,
+ ) -> sqltypes.TypeEngine[Any]:
+ """
+ Attempts to reconstruct a column type defined in ischema_names based
+ on the information available in the format_type.
+
+ If the `format_type` cannot be associated with a known `ischema_names`,
+ it is treated as a reference to a known PostgreSQL named `ENUM` or
+ `DOMAIN` type.
+ """
+ type_description = type_description or "unknown type"
+ if format_type is None:
+ util.warn(
+ "PostgreSQL format_type() returned NULL for %s"
+ % type_description
+ )
+ return sqltypes.NULLTYPE
+
+ attype_args_match = self._format_type_args_pattern.search(format_type)
+ if attype_args_match and attype_args_match.group(1):
+ attype_args = self._format_type_args_delim.split(
+ attype_args_match.group(1)
)
+ else:
+ attype_args = ()
+
+ match_array_dim = self._format_array_spec_pattern.search(format_type)
+ # Each "[]" in array specs corresponds to an array dimension
+ array_dim = len(match_array_dim.group(1) or "") // 2
+
+ # Remove all parameters and array specs from format_type to obtain an
+ # ischema_name candidate
+ attype = self._format_type_args_pattern.sub("", format_type)
+ attype = self._format_array_spec_pattern.sub("", attype)
+
+ schema_type = self.ischema_names.get(attype.lower(), None)
+ args, kwargs = (), {}
+
+ if attype == "numeric":
+ if len(attype_args) == 2:
+ precision, scale = map(int, attype_args)
+ args = (precision, scale)
+
+ elif attype == "double precision":
+ args = (53,)
+
+ elif attype == "integer":
+ args = ()
+
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if len(attype_args) == 1:
+ kwargs["precision"] = int(attype_args[0])
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if len(attype_args) == 1:
+ kwargs["precision"] = int(attype_args[0])
+
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if len(attype_args) == 1:
+ charlen = int(attype_args[0])
+ args = (charlen,)
+
+ elif attype.startswith("interval"):
+ schema_type = INTERVAL
+
+ field_match = re.match(r"interval (.+)", attype)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+
+ if len(attype_args) == 1:
+ kwargs["precision"] = int(attype_args[0])
+
+ else:
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ if enum_or_domain_key in enums:
+ schema_type = ENUM
+ enum = enums[enum_or_domain_key]
+
+ args = tuple(enum["labels"])
+ kwargs["name"] = enum["name"]
+
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ elif enum_or_domain_key in domains:
+ schema_type = DOMAIN
+ domain = domains[enum_or_domain_key]
+
+ data_type = self._reflect_type(
+ domain["type"],
+ domains,
+ enums,
+ type_description="DOMAIN '%s'" % domain["name"],
+ )
+ args = (domain["name"], data_type)
+
+ kwargs["collation"] = domain["collation"]
+ kwargs["default"] = domain["default"]
+ kwargs["not_null"] = not domain["nullable"]
+ kwargs["create_type"] = False
+
+ if domain["constraints"]:
+ # We only support a single constraint
+ check_constraint = domain["constraints"][0]
+
+ kwargs["constraint_name"] = check_constraint["name"]
+ kwargs["check"] = check_constraint["check"]
+
+ if not domain["visible"]:
+ kwargs["schema"] = domain["schema"]
+
+ else:
+ try:
+ charlen = int(attype_args[0])
+ args = (charlen, *attype_args[1:])
+ except (ValueError, IndexError):
+ args = attype_args
+
+ if not schema_type:
+ util.warn(
+ "Did not recognize type '%s' of %s"
+ % (attype, type_description)
+ )
+ return sqltypes.NULLTYPE
+
+ data_type = schema_type(*args, **kwargs)
+ if array_dim >= 1:
+ # postgres does not preserve dimensionality or size of array types.
+ data_type = _array.ARRAY(data_type)
+
+ return data_type
+
+ def _get_columns_info(self, rows, domains, enums, schema):
columns = defaultdict(list)
for row_dict in rows:
# ensure that each table has an entry, even if it has no columns
continue
table_cols = columns[(schema, row_dict["table_name"])]
- format_type = row_dict["format_type"]
+ coltype = self._reflect_type(
+ row_dict["format_type"],
+ domains,
+ enums,
+ type_description="column '%s'" % row_dict["name"],
+ )
+
default = row_dict["default"]
name = row_dict["name"]
generated = row_dict["generated"]
- identity = row_dict["identity_options"]
-
- if format_type is None:
- no_format_type = True
- attype = format_type = "no format_type()"
- is_array = False
- else:
- no_format_type = False
-
- # strip (*) from character varying(5), timestamp(5)
- # with time zone, geometry(POLYGON), etc.
- attype = attype_pattern.sub("", format_type)
-
- # strip '[]' from integer[], etc. and check if an array
- attype, is_array = _handle_array_type(attype)
-
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
-
nullable = not row_dict["not_null"]
- charlen = charlen_pattern.search(format_type)
- if charlen:
- charlen = charlen.group(1)
- args = args_pattern.search(format_type)
- if args and args.group(1):
- args = tuple(args_split_pattern.split(args.group(1)))
- else:
- args = ()
- kwargs = {}
+ if isinstance(coltype, DOMAIN):
+ if not default:
+ # domain can override the default value but
+ # cant set it to None
+ if coltype.default is not None:
+ default = coltype.default
- if attype == "numeric":
- if charlen:
- prec, scale = charlen.split(",")
- args = (int(prec), int(scale))
- else:
- args = ()
- elif attype == "double precision":
- args = (53,)
- elif attype == "integer":
- args = ()
- elif attype in ("timestamp with time zone", "time with time zone"):
- kwargs["timezone"] = True
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype in (
- "timestamp without time zone",
- "time without time zone",
- "time",
- ):
- kwargs["timezone"] = False
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype == "bit varying":
- kwargs["varying"] = True
- if charlen:
- args = (int(charlen),)
- else:
- args = ()
- elif attype.startswith("interval"):
- field_match = re.match(r"interval (.+)", attype, re.I)
- if charlen:
- kwargs["precision"] = int(charlen)
- if field_match:
- kwargs["fields"] = field_match.group(1)
- attype = "interval"
- args = ()
- elif charlen:
- args = (int(charlen),)
-
- while True:
- # looping here to suit nested domains
- if attype in self.ischema_names:
- coltype = self.ischema_names[attype]
- break
- elif enum_or_domain_key in enums:
- enum = enums[enum_or_domain_key]
- coltype = ENUM
- kwargs["name"] = enum["name"]
- if not enum["visible"]:
- kwargs["schema"] = enum["schema"]
- args = tuple(enum["labels"])
- break
- elif enum_or_domain_key in domains:
- domain = domains[enum_or_domain_key]
- attype = domain["type"]
- attype, is_array = _handle_array_type(attype)
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(
- util.quoted_token_parser(attype)
- )
- # A table can't override a not null on the domain,
- # but can override nullable
- nullable = nullable and domain["nullable"]
- if domain["default"] and not default:
- # It can, however, override the default
- # value, but can't set it to null.
- default = domain["default"]
- continue
- else:
- coltype = None
- break
-
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = self.ischema_names["_array"](coltype)
- elif no_format_type:
- util.warn(
- "PostgreSQL format_type() returned NULL for column '%s'"
- % (name,)
- )
- coltype = sqltypes.NULLTYPE
- else:
- util.warn(
- "Did not recognize type '%s' of column '%s'"
- % (attype, name)
- )
- coltype = sqltypes.NULLTYPE
+ nullable = nullable and not coltype.not_null
+
+ identity = row_dict["identity_options"]
# If a zero byte or blank string depending on driver (is also
# absent for older PG versions), then not a generated column.
pg_catalog.pg_namespace.c.nspname.label("schema"),
con_sq.c.condefs,
con_sq.c.connames,
+ pg_catalog.pg_collation.c.collname,
)
.join(
pg_catalog.pg_namespace,
pg_catalog.pg_namespace.c.oid
== pg_catalog.pg_type.c.typnamespace,
)
+ .outerjoin(
+ pg_catalog.pg_collation,
+ pg_catalog.pg_type.c.typcollation
+ == pg_catalog.pg_collation.c.oid,
+ )
.outerjoin(
con_sq,
pg_catalog.pg_type.c.oid == con_sq.c.contypid,
@reflection.cache
def _load_domains(self, connection, schema=None, **kw):
- # Load data types for domains:
result = connection.execute(self._domain_query(schema))
- domains = []
+ domains: List[ReflectedDomain] = []
for domain in result.mappings():
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
- constraints = []
+ constraints: List[ReflectedDomainConstraint] = []
if domain["connames"]:
# When a domain has multiple CHECK constraints, they will
# be tested in alphabetical order by name.
check = def_[7:-1]
constraints.append({"name": name, "check": check})
- domain_rec = {
+ domain_rec: ReflectedDomain = {
"name": domain["name"],
"schema": domain["schema"],
"visible": domain["visible"],
"nullable": domain["nullable"],
"default": domain["default"],
"constraints": constraints,
+ "collation": domain["collname"],
}
domains.append(domain_rec)
data_type: _TypeEngineArgument[Any],
*,
collation: Optional[str] = None,
- default: Optional[Union[str, elements.TextClause]] = None,
+ default: Union[elements.TextClause, str, None] = None,
constraint_name: Optional[str] = None,
not_null: Optional[bool] = None,
- check: Optional[str] = None,
+ check: Union[elements.TextClause, str, None] = None,
create_type: bool = True,
**kw: Any,
):
self.default = default
self.collation = collation
self.constraint_name = constraint_name
- self.not_null = not_null
+ self.not_null = bool(not_null)
if check is not None:
check = coercions.expect(roles.DDLExpressionRole, check)
self.check = check
def __test_init__(cls):
return cls("name", sqltypes.Integer)
+ def adapt(self, impl, **kw):
+ if self.default:
+ kw["default"] = self.default
+ if self.constraint_name is not None:
+ kw["constraint_name"] = self.constraint_name
+ if self.not_null:
+ kw["not_null"] = self.not_null
+ if self.check is not None:
+ kw["check"] = str(self.check)
+ if self.create_type:
+ kw["create_type"] = self.create_type
+
+ return super().adapt(impl, **kw)
+
class CreateEnumType(schema._CreateDropBase):
__visit_name__ = "create_enum_type"
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
# tables
-pg_catalog_meta = MetaData()
+pg_catalog_meta = MetaData(schema="pg_catalog")
pg_namespace = Table(
"pg_namespace",
Column("oid", OID),
Column("nspname", NAME),
Column("nspowner", OID),
- schema="pg_catalog",
)
pg_class = Table(
Column("relispartition", Boolean, info={"server_version": (10,)}),
Column("relrewrite", OID, info={"server_version": (11,)}),
Column("reloptions", ARRAY(Text)),
- schema="pg_catalog",
)
pg_type = Table(
Column("typndims", Integer),
Column("typcollation", OID, info={"server_version": (9, 1)}),
Column("typdefault", Text),
- schema="pg_catalog",
)
pg_index = Table(
Column("indoption", INT2VECTOR),
Column("indexprs", PG_NODE_TREE),
Column("indpred", PG_NODE_TREE),
- schema="pg_catalog",
)
pg_attribute = Table(
Column("attislocal", Boolean),
Column("attinhcount", Integer),
Column("attcollation", OID, info={"server_version": (9, 1)}),
- schema="pg_catalog",
)
pg_constraint = Table(
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
Column("conkey", ARRAY(SmallInteger)),
Column("confkey", ARRAY(SmallInteger)),
- schema="pg_catalog",
)
pg_sequence = Table(
Column("seqmin", BigInteger),
Column("seqcache", BigInteger),
Column("seqcycle", Boolean),
- schema="pg_catalog",
info={"server_version": (10,)},
)
Column("adrelid", OID),
Column("adnum", SmallInteger),
Column("adbin", PG_NODE_TREE),
- schema="pg_catalog",
)
pg_description = Table(
Column("classoid", OID),
Column("objsubid", Integer),
Column("description", Text(collation="C")),
- schema="pg_catalog",
)
pg_enum = Table(
Column("enumtypid", OID),
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
Column("enumlabel", NAME),
- schema="pg_catalog",
)
pg_am = Table(
Column("amname", NAME),
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
Column("amtype", CHAR, info={"server_version": (9, 6)}),
- schema="pg_catalog",
+)
+
+pg_collation = Table(
+ "pg_collation",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("collname", NAME),
+ Column("collnamespace", OID),
+ Column("collowner", OID),
+ Column("collprovider", CHAR, info={"server_version": (10,)}),
+ Column("collisdeterministic", Boolean, info={"server_version": (12,)}),
+ Column("collencoding", Integer),
+ Column("collcollate", Text),
+ Column("collctype", Text),
+ Column("colliculocale", Text),
+ Column("collicurules", Text, info={"server_version": (16,)}),
+ Column("collversion", Text, info={"server_version": (10,)}),
)
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import DOMAIN
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.dialects.postgresql import INTERVAL
def setup_test_class(cls):
with testing.db.begin() as con:
for ddl in [
- 'CREATE SCHEMA "SomeSchema"',
+ 'CREATE SCHEMA IF NOT EXISTS "SomeSchema"',
"CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42",
"CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0",
"CREATE TYPE testtype AS ENUM ('test')",
"CREATE DOMAIN enumdomain AS testtype",
"CREATE DOMAIN arraydomain AS INTEGER[]",
+ "CREATE DOMAIN arraydomain_2d AS INTEGER[][]",
+ "CREATE DOMAIN arraydomain_3d AS INTEGER[][][]",
'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0',
- "CREATE DOMAIN nullable_domain AS TEXT CHECK "
+ 'CREATE DOMAIN nullable_domain AS TEXT COLLATE "C" CHECK '
"(VALUE IN('FOO', 'BAR'))",
"CREATE DOMAIN not_nullable_domain AS TEXT NOT NULL",
"CREATE DOMAIN my_int AS int CONSTRAINT b_my_int_one CHECK "
"(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) "
"CHECK(VALUE != 22)",
]:
- try:
- con.exec_driver_sql(ddl)
- except exc.DBAPIError as e:
- if "already exists" not in str(e):
- raise e
+ con.exec_driver_sql(ddl)
+
con.exec_driver_sql(
"CREATE TABLE testtable (question integer, answer "
"testdomain)"
)
con.exec_driver_sql(
- "CREATE TABLE array_test (id integer, data arraydomain)"
+ "CREATE TABLE array_test ("
+ "id integer, "
+ "datas arraydomain, "
+ "datass arraydomain_2d, "
+ "datasss arraydomain_3d"
+ ")"
)
con.exec_driver_sql(
con.exec_driver_sql("DROP TYPE testtype")
con.exec_driver_sql("DROP TABLE array_test")
con.exec_driver_sql("DROP DOMAIN arraydomain")
+ con.exec_driver_sql("DROP DOMAIN arraydomain_2d")
+ con.exec_driver_sql("DROP DOMAIN arraydomain_3d")
con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"')
con.exec_driver_sql('DROP SCHEMA "SomeSchema"')
{"question", "answer"},
"Columns of reflected table didn't equal expected columns",
)
- assert isinstance(table.c.answer.type, Integer)
+ assert isinstance(table.c.answer.type, DOMAIN)
+ assert table.c.answer.type.name, "testdomain"
+ assert isinstance(table.c.answer.type.data_type, Integer)
def test_nullable_from_domain(self, connection):
metadata = MetaData()
def test_enum_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table("enum_test", metadata, autoload_with=connection)
- eq_(table.c.data.type.enums, ["test"])
+ assert isinstance(table.c.data.type, DOMAIN)
+ eq_(table.c.data.type.data_type.enums, ["test"])
def test_array_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table("array_test", metadata, autoload_with=connection)
- eq_(table.c.data.type.__class__, ARRAY)
- eq_(table.c.data.type.item_type.__class__, INTEGER)
+
+ def assert_is_integer_array_domain(domain, name):
+ # Postgres does not persist the dimensionality of the array.
+ # It's always treated as integer[]
+ assert isinstance(domain, DOMAIN)
+ assert domain.name == name
+ assert isinstance(domain.data_type, ARRAY)
+ assert isinstance(domain.data_type.item_type, INTEGER)
+
+ array_domain = table.c.datas.type
+ assert_is_integer_array_domain(array_domain, "arraydomain")
+
+ array_domain_2d = table.c.datass.type
+ assert_is_integer_array_domain(array_domain_2d, "arraydomain_2d")
+
+ array_domain_3d = table.c.datasss.type
+ assert_is_integer_array_domain(array_domain_3d, "arraydomain_3d")
def test_quoted_remote_schema_domain_is_reflected(self, connection):
metadata = MetaData()
table = Table("quote_test", metadata, autoload_with=connection)
- eq_(table.c.data.type.__class__, INTEGER)
+ assert isinstance(table.c.data.type, DOMAIN)
+ assert table.c.data.type.name, "Quoted.Domain"
+ assert isinstance(table.c.data.type.data_type, Integer)
def test_table_is_reflected_test_schema(self, connection):
metadata = MetaData()
"type": "integer[]",
"default": None,
"constraints": [],
+ "collation": None,
+ },
+ {
+ "visible": True,
+ "name": "arraydomain_2d",
+ "schema": "public",
+ "nullable": True,
+ "type": "integer[]",
+ "default": None,
+ "constraints": [],
+ "collation": None,
+ },
+ {
+ "visible": True,
+ "name": "arraydomain_3d",
+ "schema": "public",
+ "nullable": True,
+ "type": "integer[]",
+ "default": None,
+ "constraints": [],
+ "collation": None,
},
{
"visible": True,
"type": "testtype",
"default": None,
"constraints": [],
+ "collation": None,
},
{
"visible": True,
# autogenerated name by pg
{"check": "VALUE <> 22", "name": "my_int_check"},
],
+ "collation": None,
},
{
"visible": True,
"type": "text",
"default": None,
"constraints": [],
+ "collation": "default",
},
{
"visible": True,
"name": "nullable_domain_check",
}
],
+ "collation": "C",
},
{
"visible": True,
"type": "integer",
"default": "42",
"constraints": [],
+ "collation": None,
},
],
"test_schema": [
"type": "integer",
"default": "0",
"constraints": [],
+ "collation": None,
}
],
"SomeSchema": [
"type": "integer",
"default": "0",
"constraints": [],
+ "collation": None,
}
],
}
def test_inspect_domains(self, connection):
inspector = inspect(connection)
- eq_(inspector.get_domains(), self.all_domains["public"])
+ domains = inspector.get_domains()
+
+ domain_names = {d["name"] for d in domains}
+ expect_domain_names = {d["name"] for d in self.all_domains["public"]}
+ eq_(domain_names, expect_domain_names)
+
+ eq_(domains, self.all_domains["public"])
def test_inspect_domains_schema(self, connection):
inspector = inspect(connection)
all_ = [d for dl in self.all_domains.values() for d in dl]
all_ += inspector.get_domains("information_schema")
exp = sorted(all_, key=lambda d: (d["schema"], d["name"]))
- eq_(inspector.get_domains("*"), exp)
+ domains = inspector.get_domains("*")
+
+ eq_(domains, exp)
+
+
+class ArrayReflectionTest(fixtures.TablesTest):
+ __only_on__ = "postgresql >= 10"
+ __backend__ = True
+
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ "array_table",
+ metadata,
+ Column("id", INTEGER, primary_key=True),
+ Column("datas", ARRAY(INTEGER)),
+ Column("datass", ARRAY(INTEGER, dimensions=2)),
+ Column("datasss", ARRAY(INTEGER, dimensions=3)),
+ )
+
+ def test_array_table_is_reflected(self, connection):
+ metadata = MetaData()
+ table = Table("array_table", metadata, autoload_with=connection)
+
+ def assert_is_integer_array(data_type):
+ assert isinstance(data_type, ARRAY)
+ # posgres treats all arrays as one-dimensional arrays
+ assert isinstance(data_type.item_type, INTEGER)
+
+ assert_is_integer_array(table.c.datas.type)
+ assert_is_integer_array(table.c.datass.type)
+ assert_is_integer_array(table.c.datasss.type)
class ReflectionTest(
from sqlalchemy.dialects.postgresql import TSRANGE
from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
from sqlalchemy.dialects.postgresql import TSTZRANGE
+from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql.ranges import MultiRange
from sqlalchemy.exc import CompileError
from sqlalchemy.exc import DBAPIError
"check": r"VALUE ~ '[^@]+@[^@]+\.[^@]+'::text",
}
],
+ "collation": "default",
}
],
)
connection, "fourfivesixtype"
)
- def test_reflection(self, metadata, connection):
+ def test_enum_type_reflection(self, metadata, connection):
etype = Enum(
"four", "five", "six", name="fourfivesixtype", metadata=metadata
)
]
+class DomainTest(
+ AssertsCompiledSQL, fixtures.TestBase, AssertsExecutionResults
+):
+ __backend__ = True
+ __only_on__ = "postgresql > 8.3"
+
+ def test_domain_type_reflection(self, metadata, connection):
+ positive_int = DOMAIN(
+ "positive_int", Integer(), check="value > 0", not_null=True
+ )
+ my_str = DOMAIN("my_string", Text(), collation="C", default="~~")
+ Table(
+ "table",
+ metadata,
+ Column("value", positive_int),
+ Column("str", my_str),
+ )
+
+ metadata.create_all(connection)
+ m2 = MetaData()
+ t2 = Table("table", m2, autoload_with=connection)
+
+ vt = t2.c.value.type
+ is_true(isinstance(vt, DOMAIN))
+ is_true(isinstance(vt.data_type, Integer))
+ eq_(vt.name, "positive_int")
+ eq_(str(vt.check), "VALUE > 0")
+ is_(vt.default, None)
+ is_(vt.collation, None)
+ is_true(vt.constraint_name is not None)
+ is_true(vt.not_null)
+ is_false(vt.create_type)
+
+ st = t2.c.str.type
+ is_true(isinstance(st, DOMAIN))
+ is_true(isinstance(st.data_type, Text))
+ eq_(st.name, "my_string")
+ is_(st.check, None)
+ is_true("~~" in st.default)
+ eq_(st.collation, "C")
+ is_(st.constraint_name, None)
+ is_false(st.not_null)
+ is_false(st.create_type)
+
+ def test_domain_create_table(self, metadata, connection):
+ metadata = self.metadata
+ Email = DOMAIN(
+ name="email",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ )
+ PosInt = DOMAIN(
+ name="pos_int",
+ data_type=Integer,
+ not_null=True,
+ check=r"VALUE > 0",
+ )
+ t1 = Table(
+ "table",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("email", Email),
+ Column("number", PosInt),
+ )
+ t1.create(connection)
+ t1.create(connection, checkfirst=True) # check the create
+ connection.execute(
+ t1.insert(), {"email": "test@example.com", "number": 42}
+ )
+ connection.execute(t1.insert(), {"email": "a@b.c", "number": 1})
+ connection.execute(
+ t1.insert(), {"email": "example@gmail.co.uk", "number": 99}
+ )
+ eq_(
+ connection.execute(t1.select().order_by(t1.c.id)).fetchall(),
+ [
+ (1, "test@example.com", 42),
+ (2, "a@b.c", 1),
+ (3, "example@gmail.co.uk", 99),
+ ],
+ )
+
+ @testing.combinations(
+ tuple(
+ [
+ DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ create_type=True,
+ ),
+ ]
+ ),
+ tuple(
+ [
+ DOMAIN(
+ name="mytype",
+ data_type=Text,
+ check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'",
+ create_type=False,
+ ),
+ ]
+ ),
+ argnames="domain",
+ )
+ def test_create_drop_domain_with_table(self, connection, metadata, domain):
+ table = Table("e1", metadata, Column("e1", domain))
+
+ def _domain_names():
+ return {d["name"] for d in inspect(connection).get_domains()}
+
+ assert "mytype" not in _domain_names()
+
+ if domain.create_type:
+ table.create(connection)
+ assert "mytype" in _domain_names()
+ else:
+ with expect_raises(exc.ProgrammingError):
+ table.create(connection)
+ connection.rollback()
+
+ domain.create(connection)
+ assert "mytype" in _domain_names()
+ table.create(connection)
+
+ table.drop(connection)
+ if domain.create_type:
+ assert "mytype" not in _domain_names()
+
+ @testing.combinations(
+ (Integer, "value > 0", 4),
+ (String, "value != ''", "hello world"),
+ (
+ UUID,
+ "value != '{00000000-0000-0000-0000-000000000000}'",
+ uuid.uuid4(),
+ ),
+ (
+ DateTime,
+ "value >= '2020-01-01T00:00:00'",
+ datetime.datetime.fromisoformat("2021-01-01T00:00:00.000"),
+ ),
+ argnames="domain_datatype, domain_check, value",
+ )
+ def test_domain_roundtrip(
+ self, metadata, connection, domain_datatype, domain_check, value
+ ):
+ table = Table(
+ "domain_roundtrip_test",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column(
+ "value",
+ DOMAIN("valuedomain", domain_datatype, check=domain_check),
+ ),
+ )
+ table.create(connection)
+
+ connection.execute(table.insert(), {"value": value})
+
+ results = connection.execute(
+ table.select().order_by(table.c.id)
+ ).fetchall()
+ eq_(results, [(1, value)])
+
+ @testing.combinations(
+ (DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True), 4, -4),
+ (
+ DOMAIN("email", String, check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'"),
+ "e@xample.com",
+ "fred",
+ ),
+ argnames="domain,pass_value,fail_value",
+ )
+ def test_check_constraint(
+ self, metadata, connection, domain, pass_value, fail_value
+ ):
+ table = Table("table", metadata, Column("value", domain))
+ table.create(connection)
+
+ connection.execute(table.insert(), {"value": pass_value})
+
+ # psycopg/psycopg2 raise IntegrityError, while pg8000 raises
+ # ProgrammingError
+ with expect_raises(exc.DatabaseError):
+ connection.execute(table.insert(), {"value": fail_value})
+
+ @testing.combinations(
+ (DOMAIN("nullable_domain", Integer, not_null=True), 1),
+ (DOMAIN("non_nullable_domain", Integer, not_null=False), 1),
+ argnames="domain,pass_value",
+ )
+ def test_domain_nullable(self, metadata, connection, domain, pass_value):
+ table = Table("table", metadata, Column("value", domain))
+ table.create(connection)
+ connection.execute(table.insert(), {"value": pass_value})
+
+ if domain.not_null:
+ # psycopg/psycopg2 raise IntegrityError, while pg8000 raises
+ # ProgrammingError
+ with expect_raises(exc.DatabaseError):
+ connection.execute(table.insert(), {"value": None})
+ else:
+ connection.execute(table.insert(), {"value": None})
+
+
class DomainDDLEventTest(DDLEventWCreateHarness, fixtures.TestBase):
__backend__ = True
t1.create(connection)
m2 = MetaData()
t2 = Table("t1", m2, autoload_with=connection)
+
+ eq_(t1.c.c1.type.__class__, postgresql.TIME)
+ eq_(t1.c.c4.type.__class__, postgresql.TIMESTAMP)
+
eq_(t2.c.c1.type.precision, None)
eq_(t2.c.c2.type.precision, 5)
eq_(t2.c.c3.type.precision, 5)