From: pavelserchenia Date: Fri, 26 May 2023 11:16:54 +0000 (-0400) Subject: PG nulls not distinct support X-Git-Tag: rel_2_0_16~7^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f6dc8b872d6f8b1eaa114c7128125bc4edf646d3;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git PG nulls not distinct support Added support for PostgreSQL 10 ``NULLS NOT DISTINCT`` feature of unique indexes and unique constraint using the dialect option postgresql_nulls_not_distinct``. Updated the reflection logic to also correctly take this option into account. Fixes: #8240 Closes: #9834 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9834 Pull-request-sha: 825a4ff13a1f428470e184944a167c9d4c57e604 Change-Id: I6585fbb05ad32a131cf9ba25a59f7b229bce5b52 --- diff --git a/doc/build/changelog/unreleased_20/8240.rst b/doc/build/changelog/unreleased_20/8240.rst new file mode 100644 index 0000000000..15e119135e --- /dev/null +++ b/doc/build/changelog/unreleased_20/8240.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 8240 + + Added support for PostgreSQL 10 ``NULLS NOT DISTINCT`` feature of + unique indexes and unique constraint using the dialect option + ``postgresql_nulls_not_distinct``. + Updated the reflection logic to also correctly take this option + into account. + Pull request courtesy of Pavel Siarchenia. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index ebf2ce2d3e..99c3abe7a4 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -2688,8 +2688,9 @@ class MSDDLCompiler(compiler.DDLCompiler): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "UNIQUE " - + text += "UNIQUE %s" % self.define_unique_constraint_distinct( + constraint, **kw + ) clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 8ac04d6548..00443d79ba 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -2226,6 +2226,7 @@ class PGDDLCompiler(compiler.DDLCompiler): text = "CREATE " if index.unique: text += "UNIQUE " + text += "INDEX " if self.dialect._supports_create_index_concurrently: @@ -2279,6 +2280,14 @@ class PGDDLCompiler(compiler.DDLCompiler): [preparer.quote(c.name) for c in inclusions] ) + nulls_not_distinct = index.dialect_options["postgresql"][ + "nulls_not_distinct" + ] + if nulls_not_distinct is True: + text += " NULLS NOT DISTINCT" + elif nulls_not_distinct is False: + text += " NULLS DISTINCT" + withclause = index.dialect_options["postgresql"]["with"] if withclause: text += " WITH (%s)" % ( @@ -2307,6 +2316,18 @@ class PGDDLCompiler(compiler.DDLCompiler): return text + def define_unique_constraint_distinct(self, constraint, **kw): + nulls_not_distinct = constraint.dialect_options["postgresql"][ + "nulls_not_distinct" + ] + if nulls_not_distinct is True: + nulls_not_distinct_param = "NULLS NOT DISTINCT " + elif nulls_not_distinct is False: + nulls_not_distinct_param = "NULLS DISTINCT " + else: + nulls_not_distinct_param = "" + return nulls_not_distinct_param + def visit_drop_index(self, drop, **kw): index = drop.element @@ -2969,6 +2990,7 @@ class PGDialect(default.DefaultDialect): "concurrently": False, "with": {}, "tablespace": None, + "nulls_not_distinct": None, }, ), ( @@ -2994,6 +3016,10 @@ class PGDialect(default.DefaultDialect): "not_valid": False, }, ), + ( + schema.UniqueConstraint, + {"nulls_not_distinct": None}, + ), ] reflection_options = ("postgresql_ignore_search_path",) @@ -3747,12 +3773,13 @@ class PGDialect(default.DefaultDialect): result = connection.execute(oid_q, params) return result.all() - @util.memoized_property - def _constraint_query(self): + @lru_cache() + def _constraint_query(self, is_unique): con_sq = ( select( pg_catalog.pg_constraint.c.conrelid, pg_catalog.pg_constraint.c.conname, + pg_catalog.pg_constraint.c.conindid, sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( "attnum" ), @@ -3777,6 +3804,7 @@ class PGDialect(default.DefaultDialect): select( con_sq.c.conrelid, con_sq.c.conname, + con_sq.c.conindid, con_sq.c.description, con_sq.c.ord, pg_catalog.pg_attribute.c.attname, @@ -3789,10 +3817,19 @@ class PGDialect(default.DefaultDialect): pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, ), ) + .where( + # NOTE: restate the condition here, since pg15 otherwise + # seems to get confused on pscopg2 sometimes, doing + # a sequential scan of pg_attribute. + # The condition in the con_sq subquery is not actually needed + # in pg15, but it may be needed in older versions. Keeping it + # does not seems to have any inpact in any case. + con_sq.c.conrelid.in_(bindparam("oids")) + ) .subquery("attr") ) - return ( + constraint_query = ( select( attr_sq.c.conrelid, sql.func.array_agg( @@ -3809,34 +3846,63 @@ class PGDialect(default.DefaultDialect): .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) + if is_unique: + if self.server_version_info >= (15,): + constraint_query = constraint_query.join( + pg_catalog.pg_index, + attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid, + ).add_columns( + sql.func.bool_and( + pg_catalog.pg_index.c.indnullsnotdistinct + ).label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.false().label("indnullsnotdistinct") + ) + else: + constraint_query = constraint_query.add_columns( + sql.null().label("extra") + ) + return constraint_query + def _reflect_constraint( self, connection, contype, schema, filter_names, scope, kind, **kw ): + # used to reflect primary and unique constraint table_oids = self._get_table_oids( connection, schema, filter_names, scope, kind, **kw ) batches = list(table_oids) + is_unique = contype == "u" while batches: batch = batches[0:3000] batches[0:3000] = [] result = connection.execute( - self._constraint_query, + self._constraint_query(is_unique), {"oids": [r[0] for r in batch], "contype": contype}, ) result_by_oid = defaultdict(list) - for oid, cols, constraint_name, comment in result: - result_by_oid[oid].append((cols, constraint_name, comment)) + for oid, cols, constraint_name, comment, extra in result: + result_by_oid[oid].append( + (cols, constraint_name, comment, extra) + ) for oid, tablename in batch: for_oid = result_by_oid.get(oid, ()) if for_oid: - for cols, constraint, comment in for_oid: - yield tablename, cols, constraint, comment + for cols, constraint, comment, extra in for_oid: + if is_unique: + yield tablename, cols, constraint, comment, { + "nullsnotdistinct": extra + } + else: + yield tablename, cols, constraint, comment, None else: - yield tablename, None, None, None + yield tablename, None, None, None, None @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): @@ -3871,7 +3937,7 @@ class PGDialect(default.DefaultDialect): if pk_name is not None else default(), ) - for table_name, cols, pk_name, comment in result + for table_name, cols, pk_name, comment, _ in result ) @reflection.cache @@ -4151,6 +4217,11 @@ class PGDialect(default.DefaultDialect): else: indnkeyatts = sql.null().label("indnkeyatts") + if self.server_version_info >= (15,): + nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct + else: + nulls_not_distinct = sql.false().label("indnullsnotdistinct") + return ( select( pg_catalog.pg_index.c.indrelid, @@ -4175,6 +4246,7 @@ class PGDialect(default.DefaultDialect): else_=None, ).label("filter_definition"), indnkeyatts, + nulls_not_distinct, cols_sq.c.elements, cols_sq.c.elements_is_expr, ) @@ -4318,11 +4390,17 @@ class PGDialect(default.DefaultDialect): dialect_options["postgresql_where"] = row[ "filter_definition" ] - if self.server_version_info >= (11, 0): + if self.server_version_info >= (11,): # NOTE: this is legacy, this is part of # dialect_options now as of #7382 index["include_columns"] = inc_cols dialect_options["postgresql_include"] = inc_cols + if row["indnullsnotdistinct"]: + # the default is False, so ignore it. + dialect_options["postgresql_nulls_not_distinct"] = row[ + "indnullsnotdistinct" + ] + if dialect_options: index["dialect_options"] = dialect_options @@ -4359,20 +4437,27 @@ class PGDialect(default.DefaultDialect): # each table can have multiple unique constraints uniques = defaultdict(list) default = ReflectionDefaults.unique_constraints - for table_name, cols, con_name, comment in result: + for table_name, cols, con_name, comment, options in result: # ensure a list is created for each table. leave it empty if # the table has no unique cosntraint if con_name is None: uniques[(schema, table_name)] = default() continue - uniques[(schema, table_name)].append( - { - "column_names": cols, - "name": con_name, - "comment": comment, - } - ) + uc_dict = { + "column_names": cols, + "name": con_name, + "comment": comment, + } + if options: + if options["nullsnotdistinct"]: + uc_dict["dialect_options"] = { + "postgresql_nulls_not_distinct": options[ + "nullsnotdistinct" + ] + } + + uniques[(schema, table_name)].append(uc_dict) return uniques.items() @reflection.cache diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py index ed8926a26e..fa4b30f03f 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -166,6 +166,7 @@ pg_index = Table( Column("indnatts", SmallInteger), Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), Column("indisunique", Boolean), + Column("indnullsnotdistinct", Boolean, info={"server_version": (15,)}), Column("indisprimary", Boolean), Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), Column("indimmediate", Boolean), diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 775abc4332..4035901aed 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -1893,6 +1893,7 @@ class Inspector(inspection.Inspectable["Inspector"]): columns = const_d["column_names"] comment = const_d.get("comment") duplicates = const_d.get("duplicates_index") + dialect_options = const_d.get("dialect_options", {}) if include_columns and not set(columns).issubset(include_columns): continue if duplicates: @@ -1916,7 +1917,10 @@ class Inspector(inspection.Inspectable["Inspector"]): constrained_cols.append(constrained_col) table.append_constraint( sa_schema.UniqueConstraint( - *constrained_cols, name=conname, comment=comment + *constrained_cols, + name=conname, + comment=comment, + **dialect_options, ) ) diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index f12de97632..198bba3584 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -6858,12 +6858,16 @@ class DDLCompiler(Compiled): formatted_name = self.preparer.format_constraint(constraint) if formatted_name is not None: text += "CONSTRAINT %s " % formatted_name - text += "UNIQUE (%s)" % ( - ", ".join(self.preparer.quote(c.name) for c in constraint) + text += "UNIQUE %s(%s)" % ( + self.define_unique_constraint_distinct(constraint, **kw), + ", ".join(self.preparer.quote(c.name) for c in constraint), ) text += self.define_constraint_deferrability(constraint) return text + def define_unique_constraint_distinct(self, constraint, **kw): + return "" + def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py index 2f57c69315..bd944849de 100644 --- a/test/dialect/postgresql/test_compiler.py +++ b/test/dialect/postgresql/test_compiler.py @@ -676,6 +676,102 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): dialect=postgresql.dialect(), ) + @testing.combinations( + ( + lambda tbl: schema.CreateIndex( + Index( + "test_idx1", + tbl.c.data, + unique=True, + postgresql_nulls_not_distinct=True, + ) + ), + "CREATE UNIQUE INDEX test_idx1 ON test_tbl " + "(data) NULLS NOT DISTINCT", + ), + ( + lambda tbl: schema.CreateIndex( + Index( + "test_idx2", + tbl.c.data2, + unique=True, + postgresql_nulls_not_distinct=False, + ) + ), + "CREATE UNIQUE INDEX test_idx2 ON test_tbl " + "(data2) NULLS DISTINCT", + ), + ( + lambda tbl: schema.CreateIndex( + Index( + "test_idx3", + tbl.c.data3, + unique=True, + ) + ), + "CREATE UNIQUE INDEX test_idx3 ON test_tbl " "(data3)", + ), + ( + lambda tbl: schema.CreateIndex( + Index( + "test_idx3_complex", + tbl.c.data3, + postgresql_nulls_not_distinct=True, + postgresql_include=["data2"], + postgresql_where=and_(tbl.c.data3 > 5), + postgresql_with={"fillfactor": 50}, + ) + ), + "CREATE INDEX test_idx3_complex ON test_tbl " + "(data3) INCLUDE (data2) NULLS NOT DISTINCT WITH " + "(fillfactor = 50) WHERE data3 > 5", + ), + ( + lambda tbl: schema.AddConstraint( + schema.UniqueConstraint( + tbl.c.data, + name="uq_data1", + postgresql_nulls_not_distinct=True, + ) + ), + "ALTER TABLE test_tbl ADD CONSTRAINT uq_data1 UNIQUE " + "NULLS NOT DISTINCT (data)", + ), + ( + lambda tbl: schema.AddConstraint( + schema.UniqueConstraint( + tbl.c.data2, + name="uq_data2", + postgresql_nulls_not_distinct=False, + ) + ), + "ALTER TABLE test_tbl ADD CONSTRAINT uq_data2 UNIQUE " + "NULLS DISTINCT (data2)", + ), + ( + lambda tbl: schema.AddConstraint( + schema.UniqueConstraint( + tbl.c.data3, + name="uq_data3", + ) + ), + "ALTER TABLE test_tbl ADD CONSTRAINT uq_data3 UNIQUE (data3)", + ), + ) + def test_nulls_not_distinct(self, expr_fn, expected): + dd = PGDialect() + m = MetaData() + tbl = Table( + "test_tbl", + m, + Column("data", String), + Column("data2", Integer), + Column("data3", Integer), + ) + + expr = testing.resolve_lambda(expr_fn, tbl=tbl) + self.assert_compile(expr, expected, dialect=dd) + def test_create_index_with_labeled_ops(self): m = MetaData() tbl = Table( diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index 49838ec6ab..f7f86a79c3 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -1179,6 +1179,15 @@ class ReflectionTest( where name != 'foo' """ ) + version = connection.dialect.server_version_info + if version >= (15,): + connection.exec_driver_sql( + """ + create unique index zz_idx5 on party + (name desc, upper(other)) + nulls not distinct + """ + ) expected = [ { @@ -1238,7 +1247,23 @@ class ReflectionTest( "dialect_options": {"postgresql_include": []}, }, ] - if connection.dialect.server_version_info < (11,): + if version > (15,): + expected.append( + { + "name": "zz_idx5", + "column_names": ["name", None], + "expressions": ["name", "upper(other::text)"], + "unique": True, + "include_columns": [], + "dialect_options": { + "postgresql_include": [], + "postgresql_nulls_not_distinct": True, + }, + "column_sorting": {"name": ("desc",)}, + }, + ) + + if version < (11,): for index in expected: index.pop("include_columns") index["dialect_options"].pop("postgresql_include") @@ -1462,6 +1487,72 @@ class ReflectionTest( "gin", ) + @testing.skip_if("postgresql < 15.0", "nullsnotdistinct not supported") + def test_nullsnotdistinct(self, metadata, connection): + Table( + "t", + metadata, + Column("id", Integer, primary_key=True), + Column("x", ARRAY(Integer)), + Column("y", ARRAY(Integer)), + Index( + "idx1", "x", unique=True, postgresql_nulls_not_distinct=True + ), + UniqueConstraint( + "y", name="unq1", postgresql_nulls_not_distinct=True + ), + ) + metadata.create_all(connection) + + ind = inspect(connection).get_indexes("t", None) + expected_ind = [ + { + "unique": True, + "column_names": ["x"], + "name": "idx1", + "dialect_options": { + "postgresql_nulls_not_distinct": True, + "postgresql_include": [], + }, + "include_columns": [], + }, + { + "unique": True, + "column_names": ["y"], + "name": "unq1", + "dialect_options": { + "postgresql_nulls_not_distinct": True, + "postgresql_include": [], + }, + "include_columns": [], + "duplicates_constraint": "unq1", + }, + ] + eq_(ind, expected_ind) + + unq = inspect(connection).get_unique_constraints("t", None) + expected_unq = [ + { + "column_names": ["y"], + "name": "unq1", + "dialect_options": { + "postgresql_nulls_not_distinct": True, + }, + "comment": None, + } + ] + eq_(unq, expected_unq) + + m = MetaData() + t1 = Table("t", m, autoload_with=connection) + eq_(len(t1.indexes), 1) + idx_options = list(t1.indexes)[0].dialect_options["postgresql"] + eq_(idx_options["nulls_not_distinct"], True) + + cst = {c.name: c for c in t1.constraints} + cst_options = cst["unq1"].dialect_options["postgresql"] + eq_(cst_options["nulls_not_distinct"], True) + @testing.skip_if("postgresql < 11.0", "indnkeyatts not supported") def test_index_reflection_with_include(self, metadata, connection): """reflect indexes with include set"""