From adac164d191138265ecd64a28be91254a53a9c25 Mon Sep 17 00:00:00 2001 From: Thomas Stephenson Date: Tue, 28 Nov 2023 18:52:55 +1100 Subject: [PATCH] Fix #10693: postgresql dialect should reflect DOMAIN types Also, refactor postgresql dialect type reflection into new protected dialect instance method `PGDialect._reflect_type` PGDialect._reflect_type(format_type, enums, domains, type_description) test #10693: - psycopg and pg8000 raise different dbapi exceptions when a domain constraint fails. Weaken test assertion to catch both review #10693: - type_description is required parameter to _reflect_type - fix docstring typo - make arrayspec regex a class variable - Use equality checks on attype_args - reflected array dimensionality is always 1 --- lib/sqlalchemy/dialects/postgresql/base.py | 298 ++++++++++-------- .../dialects/postgresql/named_types.py | 20 +- test/dialect/postgresql/test_reflection.py | 110 ++++++- test/dialect/postgresql/test_types.py | 188 ++++++++++- 4 files changed, 465 insertions(+), 151 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 6fe2aebadb..fbbd2d0a30 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -3661,20 +3661,158 @@ class PGDialect(default.DefaultDialect): return columns.items() - def _get_columns_info(self, rows, domains, enums, schema): - array_type_pattern = re.compile(r"\[\]$") - attype_pattern = re.compile(r"\(.*\)") - charlen_pattern = re.compile(r"\(([\d,]+)\)") - args_pattern = re.compile(r"\((.*)\)") - args_split_pattern = re.compile(r"\s*,\s*") - - def _handle_array_type(attype): - return ( - # strip '[]' from integer[], etc. - array_type_pattern.sub("", attype), - attype.endswith("[]"), + _format_type_args_pattern = re.compile(r"\((.*)\)") + _format_type_args_delim = re.compile(r"\s*,\s*") + _format_array_spec_pattern = re.compile(r"((?:\[\])*)$") + + def _reflect_type( + self, + format_type: Optional[str], + domains: dict[str, ReflectedDomain], + enums: dict[str, ReflectedEnum], + type_description: str, + ) -> sqltypes.TypeEngine[Any]: + """ + Attempts to reconstruct a column type defined in ischema_names based + on the information available in the format_type. + + If the `format_type` cannot be associated with a known `ischema_names`, + it is treated as a reference to a known PostgreSQL named `ENUM` or + `DOMAIN` type. + """ + type_description = type_description or "unknown type" + if format_type is None: + util.warn( + "PostgreSQL format_type() returned NULL for %s" + % type_description + ) + return sqltypes.NULLTYPE + + attype_args_match = self._format_type_args_pattern.search(format_type) + if attype_args_match and attype_args_match.group(1): + attype_args = self._format_type_args_delim.split( + attype_args_match.group(1) ) + else: + attype_args = () + + match_array_dim = self._format_array_spec_pattern.search(format_type) + # Each "[]" in array specs corresponds to an array dimension + array_dim = len(match_array_dim.group(1) or "") // 2 + + # Remove all parameters and array specs from format_type to obtain an + # ischema_name candidate + attype = self._format_type_args_pattern.sub("", format_type) + attype = self._format_array_spec_pattern.sub("", attype) + + schema_type = self.ischema_names.get(attype.lower(), None) + args, kwargs = (), {} + + print("attype", attype, "schema_type", schema_type) + + if attype == "numeric": + if len(attype_args) == 2: + precision, scale = map(int, attype_args) + args = (precision, scale) + + elif attype == "double precision": + args = (53,) + + elif attype == "integer": + args = () + + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + elif attype == "bit varying": + kwargs["varying"] = True + if len(attype_args) == 1: + charlen = int(attype_args[0]) + args = (charlen,) + + elif attype.startswith("interval"): + schema_type = INTERVAL + + field_match = re.match(r"interval (.+)", attype) + if field_match: + kwargs["fields"] = field_match.group(1) + + if len(attype_args) == 1: + kwargs["precision"] = int(attype_args[0]) + + else: + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + if enum_or_domain_key in enums: + schema_type = ENUM + enum = enums[enum_or_domain_key] + + args = tuple(enum["labels"]) + kwargs["name"] = enum["name"] + + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + elif enum_or_domain_key in domains: + schema_type = DOMAIN + domain = domains[enum_or_domain_key] + + data_type = self._reflect_type( + domain["type"], + domains, + enums, + type_description="DOMAIN '%s'" % domain["name"], + ) + args = (domain["name"], data_type) + + # TODO: collation not reflected + # kwargs["collation"] = "???" + kwargs["default"] = domain["default"] + kwargs["not_null"] = not domain["nullable"] + + if domain["constraints"]: + # We only support a single constraint + check_constraint = domain["constraints"][0] + + kwargs["constraint_name"] = check_constraint["name"] + kwargs["check"] = check_constraint["check"] + + if not domain["visible"]: + kwargs["schema"] = domain["schema"] + else: + try: + charlen = int(attype_args[0]) + args = (charlen, *attype_args[1:]) + except (ValueError, IndexError): + args = attype_args + + if not schema_type: + util.warn( + "Did not recognize type '%s' of %s" + % (attype, type_description) + ) + return sqltypes.NULLTYPE + + data_type = schema_type(*args, **kwargs) + if array_dim >= 1: + # postgres does not preserve dimensionality or size of array types. + data_type = _array.ARRAY(data_type) + + return data_type + + def _get_columns_info(self, rows, domains, enums, schema): columns = defaultdict(list) for row_dict in rows: # ensure that each table has an entry, even if it has no columns @@ -3685,131 +3823,26 @@ class PGDialect(default.DefaultDialect): continue table_cols = columns[(schema, row_dict["table_name"])] - format_type = row_dict["format_type"] + coltype = self._reflect_type( + row_dict["format_type"], + domains, + enums, + type_description="column '%s'" % row_dict["name"], + ) + default = row_dict["default"] + if isinstance(coltype, DOMAIN) and not default: + # domain can override the default value but cant set it to None + if coltype.default is not None: + default = coltype.default + name = row_dict["name"] generated = row_dict["generated"] - identity = row_dict["identity_options"] - - if format_type is None: - no_format_type = True - attype = format_type = "no format_type()" - is_array = False - else: - no_format_type = False - - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = attype_pattern.sub("", format_type) - - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) - - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - nullable = not row_dict["not_null"] + if isinstance(coltype, DOMAIN): + nullable = nullable and not coltype.not_null - charlen = charlen_pattern.search(format_type) - if charlen: - charlen = charlen.group(1) - args = args_pattern.search(format_type) - if args and args.group(1): - args = tuple(args_split_pattern.split(args.group(1))) - else: - args = () - kwargs = {} - - if attype == "numeric": - if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) - else: - args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: - args = (int(charlen),) - else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["type"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple( - util.quoted_token_parser(attype) - ) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue - else: - coltype = None - break - - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - elif no_format_type: - util.warn( - "PostgreSQL format_type() returned NULL for column '%s'" - % (name,) - ) - coltype = sqltypes.NULLTYPE - else: - util.warn( - "Did not recognize type '%s' of column '%s'" - % (attype, name) - ) - coltype = sqltypes.NULLTYPE + identity = row_dict["identity_options"] # If a zero byte or blank string depending on driver (is also # absent for older PG versions), then not a generated column. @@ -4877,7 +4910,6 @@ class PGDialect(default.DefaultDialect): @reflection.cache def _load_domains(self, connection, schema=None, **kw): - # Load data types for domains: result = connection.execute(self._domain_query(schema)) domains = [] diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index 56bec1dc73..16e5c867ef 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -416,10 +416,10 @@ class DOMAIN(NamedType, sqltypes.SchemaType): data_type: _TypeEngineArgument[Any], *, collation: Optional[str] = None, - default: Optional[Union[str, elements.TextClause]] = None, + default: Union[elements.TextClause, str, None] = None, constraint_name: Optional[str] = None, not_null: Optional[bool] = None, - check: Optional[str] = None, + check: Union[elements.TextClause, str, None] = None, create_type: bool = True, **kw: Any, ): @@ -463,7 +463,7 @@ class DOMAIN(NamedType, sqltypes.SchemaType): self.default = default self.collation = collation self.constraint_name = constraint_name - self.not_null = not_null + self.not_null = bool(not_null) if check is not None: check = coercions.expect(roles.DDLExpressionRole, check) self.check = check @@ -474,6 +474,20 @@ class DOMAIN(NamedType, sqltypes.SchemaType): def __test_init__(cls): return cls("name", sqltypes.Integer) + def adapt(self, impl, **kw): + if self.default: + kw["default"] = self.default + if self.constraint_name is not None: + kw["constraint_name"] = self.constraint_name + if self.not_null: + kw["not_null"] = self.not_null + if self.check is not None: + kw["check"] = str(self.check) + if self.create_type: + kw["create_type"] = self.create_type + + return super().adapt(impl, **kw) + class CreateEnumType(schema._CreateDropBase): __visit_name__ = "create_enum_type" diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index dd6c8aa88e..f5f1c07beb 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -23,6 +23,7 @@ from sqlalchemy import Text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.dialects.postgresql import base as postgresql +from sqlalchemy.dialects.postgresql import DOMAIN from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.dialects.postgresql import INTERVAL @@ -408,12 +409,14 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): def setup_test_class(cls): with testing.db.begin() as con: for ddl in [ - 'CREATE SCHEMA "SomeSchema"', + 'CREATE SCHEMA IF NOT EXISTS "SomeSchema"', "CREATE DOMAIN testdomain INTEGER NOT NULL DEFAULT 42", "CREATE DOMAIN test_schema.testdomain INTEGER DEFAULT 0", "CREATE TYPE testtype AS ENUM ('test')", "CREATE DOMAIN enumdomain AS testtype", "CREATE DOMAIN arraydomain AS INTEGER[]", + "CREATE DOMAIN arraydomain_2d AS INTEGER[][]", + "CREATE DOMAIN arraydomain_3d AS INTEGER[][][]", 'CREATE DOMAIN "SomeSchema"."Quoted.Domain" INTEGER DEFAULT 0', "CREATE DOMAIN nullable_domain AS TEXT CHECK " "(VALUE IN('FOO', 'BAR'))", @@ -422,11 +425,8 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): "(VALUE > 1) CONSTRAINT a_my_int_two CHECK (VALUE < 42) " "CHECK(VALUE != 22)", ]: - try: - con.exec_driver_sql(ddl) - except exc.DBAPIError as e: - if "already exists" not in str(e): - raise e + con.exec_driver_sql(ddl) + con.exec_driver_sql( "CREATE TABLE testtable (question integer, answer " "testdomain)" @@ -446,7 +446,12 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): ) con.exec_driver_sql( - "CREATE TABLE array_test (id integer, data arraydomain)" + "CREATE TABLE array_test (" + "id integer, " + "datas arraydomain, " + "datass arraydomain_2d, " + "datasss arraydomain_3d" + ")" ) con.exec_driver_sql( @@ -473,6 +478,8 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): con.exec_driver_sql("DROP TYPE testtype") con.exec_driver_sql("DROP TABLE array_test") con.exec_driver_sql("DROP DOMAIN arraydomain") + con.exec_driver_sql("DROP DOMAIN arraydomain_2d"), + con.exec_driver_sql("DROP DOMAIN arraydomain_3d"), con.exec_driver_sql('DROP DOMAIN "SomeSchema"."Quoted.Domain"') con.exec_driver_sql('DROP SCHEMA "SomeSchema"') @@ -489,7 +496,9 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): {"question", "answer"}, "Columns of reflected table didn't equal expected columns", ) - assert isinstance(table.c.answer.type, Integer) + assert isinstance(table.c.answer.type, DOMAIN) + assert table.c.answer.type.name, "testdomain" + assert isinstance(table.c.answer.type.data_type, Integer) def test_nullable_from_domain(self, connection): metadata = MetaData() @@ -514,18 +523,36 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): def test_enum_domain_is_reflected(self, connection): metadata = MetaData() table = Table("enum_test", metadata, autoload_with=connection) - eq_(table.c.data.type.enums, ["test"]) + assert isinstance(table.c.data.type, DOMAIN) + eq_(table.c.data.type.data_type.enums, ["test"]) def test_array_domain_is_reflected(self, connection): metadata = MetaData() table = Table("array_test", metadata, autoload_with=connection) - eq_(table.c.data.type.__class__, ARRAY) - eq_(table.c.data.type.item_type.__class__, INTEGER) + + def assert_is_integer_array_domain(domain, name): + # Postgres does not persist the dimensionality of the array. + # It's always treated as integer[] + assert isinstance(domain, DOMAIN) + assert domain.name == name + assert isinstance(domain.data_type, ARRAY) + assert isinstance(domain.data_type.item_type, INTEGER) + + array_domain = table.c.datas.type + assert_is_integer_array_domain(array_domain, "arraydomain") + + array_domain_2d = table.c.datass.type + assert_is_integer_array_domain(array_domain_2d, "arraydomain_2d") + + array_domain_3d = table.c.datasss.type + assert_is_integer_array_domain(array_domain_3d, "arraydomain_3d") def test_quoted_remote_schema_domain_is_reflected(self, connection): metadata = MetaData() table = Table("quote_test", metadata, autoload_with=connection) - eq_(table.c.data.type.__class__, INTEGER) + assert isinstance(table.c.data.type, DOMAIN) + assert table.c.data.type.name, "Quoted.Domain" + assert isinstance(table.c.data.type.data_type, Integer) def test_table_is_reflected_test_schema(self, connection): metadata = MetaData() @@ -604,6 +631,24 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): "default": None, "constraints": [], }, + { + "visible": True, + "name": "arraydomain_2d", + "schema": "public", + "nullable": True, + "type": "integer[]", + "default": None, + "constraints": [], + }, + { + "visible": True, + "name": "arraydomain_3d", + "schema": "public", + "nullable": True, + "type": "integer[]", + "default": None, + "constraints": [], + }, { "visible": True, "name": "enumdomain", @@ -688,7 +733,13 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): def test_inspect_domains(self, connection): inspector = inspect(connection) - eq_(inspector.get_domains(), self.all_domains["public"]) + domains = inspector.get_domains() + + domain_names = {d["name"] for d in domains} + expect_domain_names = {d["name"] for d in self.all_domains["public"]} + eq_(domain_names, expect_domain_names) + + eq_(domains, self.all_domains["public"]) def test_inspect_domains_schema(self, connection): inspector = inspect(connection) @@ -705,7 +756,38 @@ class DomainReflectionTest(fixtures.TestBase, AssertsExecutionResults): all_ = [d for dl in self.all_domains.values() for d in dl] all_ += inspector.get_domains("information_schema") exp = sorted(all_, key=lambda d: (d["schema"], d["name"])) - eq_(inspector.get_domains("*"), exp) + domains = inspector.get_domains("*") + + eq_(domains, exp) + + +class ArrayReflectionTest(fixtures.TablesTest): + __only_on__ = "postgresql >= 10" + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "array_table", + metadata, + Column("id", INTEGER, primary_key=True), + Column("datas", ARRAY(INTEGER)), + Column("datass", ARRAY(INTEGER, dimensions=2)), + Column("datasss", ARRAY(INTEGER, dimensions=3)), + ) + + def test_array_table_is_reflected(self, connection): + metadata = MetaData() + table = Table("array_table", metadata, autoload_with=connection) + + def assert_is_integer_array(data_type): + assert isinstance(data_type, ARRAY) + # posgres treats all arrays as one-dimensional arrays + assert isinstance(data_type.item_type, INTEGER) + + assert_is_integer_array(table.c.datas.type) + assert_is_integer_array(table.c.datass.type) + assert_is_integer_array(table.c.datasss.type) class ReflectionTest( diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index a5093c0bc9..d14639be09 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -73,6 +73,7 @@ from sqlalchemy.dialects.postgresql import TSMULTIRANGE from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql.ranges import MultiRange from sqlalchemy.exc import CompileError from sqlalchemy.exc import DBAPIError @@ -1075,7 +1076,7 @@ class NamedTypeTest( connection, "fourfivesixtype" ) - def test_reflection(self, metadata, connection): + def test_enum_type_reflection(self, metadata, connection): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata ) @@ -1096,6 +1097,19 @@ class NamedTypeTest( eq_(t2.c.value2.type.enums, ["four", "five", "six"]) eq_(t2.c.value2.type.name, "fourfivesixtype") + def test_domain_type_reflection(self, metadata, connection): + positive_int = DOMAIN( + "positive_int", Integer(), check="value > 0", not_null=True + ) + Table("table", metadata, Column("value", positive_int)) + + metadata.create_all(connection) + m2 = MetaData() + t2 = Table("table", m2, autoload_with=connection) + + eq_(t2.c.value.type.name, "positive_int") + eq_(str(t2.c.value.type.check), "VALUE > 0") + def test_schema_reflection(self, metadata, connection): etype = Enum( "four", @@ -1229,6 +1243,174 @@ class NamedTypeTest( ] +class DomainTest( + AssertsCompiledSQL, fixtures.TestBase, AssertsExecutionResults +): + __backend__ = True + __only_on__ = "postgresql > 8.3" + + def test_domain_create_table(self, metadata, connection): + metadata = self.metadata + Email = DOMAIN( + name="email", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + ) + PosInt = DOMAIN( + name="pos_int", + data_type=Integer, + not_null=True, + check=r"VALUE > 0", + ) + t1 = Table( + "table", + metadata, + Column("id", Integer, primary_key=True), + Column("email", Email), + Column("number", PosInt), + ) + t1.create(connection) + t1.create(connection, checkfirst=True) # check the create + connection.execute( + t1.insert(), {"email": "test@example.com", "number": 42} + ) + connection.execute(t1.insert(), {"email": "a@b.c", "number": 1}) + connection.execute( + t1.insert(), {"email": "example@gmail.co.uk", "number": 99} + ) + eq_( + connection.execute(t1.select().order_by(t1.c.id)).fetchall(), + [ + (1, "test@example.com", 42), + (2, "a@b.c", 1), + (3, "example@gmail.co.uk", 99), + ], + ) + + @testing.combinations( + tuple( + [ + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + create_type=True, + ), + ] + ), + tuple( + [ + DOMAIN( + name="mytype", + data_type=Text, + check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'", + create_type=False, + ), + ] + ), + argnames="domain", + ) + def test_create_drop_domain_with_table(self, connection, metadata, domain): + table = Table("e1", metadata, Column("e1", domain)) + + def _domain_names(): + return {d["name"] for d in inspect(connection).get_domains()} + + assert "mytype" not in _domain_names() + + if domain.create_type: + table.create(connection) + assert "mytype" in _domain_names() + else: + with expect_raises(exc.ProgrammingError): + table.create(connection) + connection.rollback() + + domain.create(connection) + assert "mytype" in _domain_names() + table.create(connection) + + table.drop(connection) + if domain.create_type: + assert "mytype" not in _domain_names() + + @testing.combinations( + (Integer, "value > 0", 4), + (String, "value != ''", "hello world"), + ( + UUID, + "value != '{00000000-0000-0000-0000-000000000000}'", + uuid.uuid4(), + ), + ( + DateTime, + "value >= '2020-01-01T00:00:00'", + datetime.datetime.fromisoformat("2021-01-01T00:00:00.000"), + ), + argnames="domain_datatype, domain_check, value", + ) + def test_domain_roundtrip( + self, metadata, connection, domain_datatype, domain_check, value + ): + table = Table( + "domain_roundtrip_test", + metadata, + Column("id", Integer, primary_key=True), + Column( + "value", + DOMAIN("valuedomain", domain_datatype, check=domain_check), + ), + ) + table.create(connection) + + connection.execute(table.insert(), {"value": value}) + + results = connection.execute( + table.select().order_by(table.c.id) + ).fetchall() + eq_(results, [(1, value)]) + + @testing.combinations( + (DOMAIN("pos_int", Integer, check="VALUE > 0", not_null=True), 4, -4), + ( + DOMAIN("email", String, check=r"VALUE ~ '[^@]+@[^@]+\.[^@]+'"), + "e@xample.com", + "fred", + ), + argnames="domain,pass_value,fail_value", + ) + def test_check_constraint( + self, metadata, connection, domain, pass_value, fail_value + ): + table = Table("table", metadata, Column("value", domain)) + table.create(connection) + + connection.execute(table.insert(), {"value": pass_value}) + + # psycopg/psycopg2 raise IntegrityError, while pg8000 raises + # ProgrammingError + with expect_raises(exc.DatabaseError): + connection.execute(table.insert(), {"value": fail_value}) + + @testing.combinations( + (DOMAIN("nullable_domain", Integer, not_null=True), 1), + (DOMAIN("non_nullable_domain", Integer, not_null=False), 1), + argnames="domain,pass_value", + ) + def test_domain_nullable(self, metadata, connection, domain, pass_value): + table = Table("table", metadata, Column("value", domain)) + table.create(connection) + connection.execute(table.insert(), {"value": pass_value}) + + if domain.not_null: + # psycopg/psycopg2 raise IntegrityError, while pg8000 raises + # ProgrammingError + with expect_raises(exc.DatabaseError): + connection.execute(table.insert(), {"value": None}) + else: + connection.execute(table.insert(), {"value": None}) + + class DomainDDLEventTest(DDLEventWCreateHarness, fixtures.TestBase): __backend__ = True @@ -1557,6 +1739,10 @@ class TimePrecisionTest(fixtures.TestBase): t1.create(connection) m2 = MetaData() t2 = Table("t1", m2, autoload_with=connection) + + eq_(t1.c.c1.type.__class__, postgresql.TIME) + eq_(t1.c.c4.type.__class__, postgresql.TIMESTAMP) + eq_(t2.c.c1.type.precision, None) eq_(t2.c.c2.type.precision, 5) eq_(t2.c.c3.type.precision, 5) -- 2.47.2