text = "CREATE "
if index.unique:
text += "UNIQUE "
+
text += "INDEX "
if self.dialect._supports_create_index_concurrently:
[preparer.quote(c.name) for c in inclusions]
)
+ nulls_not_distinct = index.dialect_options["postgresql"][
+ "nulls_not_distinct"
+ ]
+ if nulls_not_distinct is True:
+ text += " NULLS NOT DISTINCT"
+ elif nulls_not_distinct is False:
+ text += " NULLS DISTINCT"
+
withclause = index.dialect_options["postgresql"]["with"]
if withclause:
text += " WITH (%s)" % (
return text
+ def define_unique_constraint_distinct(self, constraint, **kw):
+ nulls_not_distinct = constraint.dialect_options["postgresql"][
+ "nulls_not_distinct"
+ ]
+ if nulls_not_distinct is True:
+ nulls_not_distinct_param = "NULLS NOT DISTINCT "
+ elif nulls_not_distinct is False:
+ nulls_not_distinct_param = "NULLS DISTINCT "
+ else:
+ nulls_not_distinct_param = ""
+ return nulls_not_distinct_param
+
def visit_drop_index(self, drop, **kw):
index = drop.element
"concurrently": False,
"with": {},
"tablespace": None,
+ "nulls_not_distinct": None,
},
),
(
"not_valid": False,
},
),
+ (
+ schema.UniqueConstraint,
+ {"nulls_not_distinct": None},
+ ),
]
reflection_options = ("postgresql_ignore_search_path",)
result = connection.execute(oid_q, params)
return result.all()
- @util.memoized_property
- def _constraint_query(self):
+ @lru_cache()
+ def _constraint_query(self, is_unique):
con_sq = (
select(
pg_catalog.pg_constraint.c.conrelid,
pg_catalog.pg_constraint.c.conname,
+ pg_catalog.pg_constraint.c.conindid,
sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
"attnum"
),
select(
con_sq.c.conrelid,
con_sq.c.conname,
+ con_sq.c.conindid,
con_sq.c.description,
con_sq.c.ord,
pg_catalog.pg_attribute.c.attname,
pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
),
)
+ .where(
+ # NOTE: restate the condition here, since pg15 otherwise
+ # seems to get confused on pscopg2 sometimes, doing
+ # a sequential scan of pg_attribute.
+ # The condition in the con_sq subquery is not actually needed
+ # in pg15, but it may be needed in older versions. Keeping it
+ # does not seems to have any inpact in any case.
+ con_sq.c.conrelid.in_(bindparam("oids"))
+ )
.subquery("attr")
)
- return (
+ constraint_query = (
select(
attr_sq.c.conrelid,
sql.func.array_agg(
.order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
+ if is_unique:
+ if self.server_version_info >= (15,):
+ constraint_query = constraint_query.join(
+ pg_catalog.pg_index,
+ attr_sq.c.conindid == pg_catalog.pg_index.c.indexrelid,
+ ).add_columns(
+ sql.func.bool_and(
+ pg_catalog.pg_index.c.indnullsnotdistinct
+ ).label("indnullsnotdistinct")
+ )
+ else:
+ constraint_query = constraint_query.add_columns(
+ sql.false().label("indnullsnotdistinct")
+ )
+ else:
+ constraint_query = constraint_query.add_columns(
+ sql.null().label("extra")
+ )
+ return constraint_query
+
def _reflect_constraint(
self, connection, contype, schema, filter_names, scope, kind, **kw
):
+ # used to reflect primary and unique constraint
table_oids = self._get_table_oids(
connection, schema, filter_names, scope, kind, **kw
)
batches = list(table_oids)
+ is_unique = contype == "u"
while batches:
batch = batches[0:3000]
batches[0:3000] = []
result = connection.execute(
- self._constraint_query,
+ self._constraint_query(is_unique),
{"oids": [r[0] for r in batch], "contype": contype},
)
result_by_oid = defaultdict(list)
- for oid, cols, constraint_name, comment in result:
- result_by_oid[oid].append((cols, constraint_name, comment))
+ for oid, cols, constraint_name, comment, extra in result:
+ result_by_oid[oid].append(
+ (cols, constraint_name, comment, extra)
+ )
for oid, tablename in batch:
for_oid = result_by_oid.get(oid, ())
if for_oid:
- for cols, constraint, comment in for_oid:
- yield tablename, cols, constraint, comment
+ for cols, constraint, comment, extra in for_oid:
+ if is_unique:
+ yield tablename, cols, constraint, comment, {
+ "nullsnotdistinct": extra
+ }
+ else:
+ yield tablename, cols, constraint, comment, None
else:
- yield tablename, None, None, None
+ yield tablename, None, None, None, None
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
if pk_name is not None
else default(),
)
- for table_name, cols, pk_name, comment in result
+ for table_name, cols, pk_name, comment, _ in result
)
@reflection.cache
else:
indnkeyatts = sql.null().label("indnkeyatts")
+ if self.server_version_info >= (15,):
+ nulls_not_distinct = pg_catalog.pg_index.c.indnullsnotdistinct
+ else:
+ nulls_not_distinct = sql.false().label("indnullsnotdistinct")
+
return (
select(
pg_catalog.pg_index.c.indrelid,
else_=None,
).label("filter_definition"),
indnkeyatts,
+ nulls_not_distinct,
cols_sq.c.elements,
cols_sq.c.elements_is_expr,
)
dialect_options["postgresql_where"] = row[
"filter_definition"
]
- if self.server_version_info >= (11, 0):
+ if self.server_version_info >= (11,):
# NOTE: this is legacy, this is part of
# dialect_options now as of #7382
index["include_columns"] = inc_cols
dialect_options["postgresql_include"] = inc_cols
+ if row["indnullsnotdistinct"]:
+ # the default is False, so ignore it.
+ dialect_options["postgresql_nulls_not_distinct"] = row[
+ "indnullsnotdistinct"
+ ]
+
if dialect_options:
index["dialect_options"] = dialect_options
# each table can have multiple unique constraints
uniques = defaultdict(list)
default = ReflectionDefaults.unique_constraints
- for table_name, cols, con_name, comment in result:
+ for table_name, cols, con_name, comment, options in result:
# ensure a list is created for each table. leave it empty if
# the table has no unique cosntraint
if con_name is None:
uniques[(schema, table_name)] = default()
continue
- uniques[(schema, table_name)].append(
- {
- "column_names": cols,
- "name": con_name,
- "comment": comment,
- }
- )
+ uc_dict = {
+ "column_names": cols,
+ "name": con_name,
+ "comment": comment,
+ }
+ if options:
+ if options["nullsnotdistinct"]:
+ uc_dict["dialect_options"] = {
+ "postgresql_nulls_not_distinct": options[
+ "nullsnotdistinct"
+ ]
+ }
+
+ uniques[(schema, table_name)].append(uc_dict)
return uniques.items()
@reflection.cache
dialect=postgresql.dialect(),
)
+ @testing.combinations(
+ (
+ lambda tbl: schema.CreateIndex(
+ Index(
+ "test_idx1",
+ tbl.c.data,
+ unique=True,
+ postgresql_nulls_not_distinct=True,
+ )
+ ),
+ "CREATE UNIQUE INDEX test_idx1 ON test_tbl "
+ "(data) NULLS NOT DISTINCT",
+ ),
+ (
+ lambda tbl: schema.CreateIndex(
+ Index(
+ "test_idx2",
+ tbl.c.data2,
+ unique=True,
+ postgresql_nulls_not_distinct=False,
+ )
+ ),
+ "CREATE UNIQUE INDEX test_idx2 ON test_tbl "
+ "(data2) NULLS DISTINCT",
+ ),
+ (
+ lambda tbl: schema.CreateIndex(
+ Index(
+ "test_idx3",
+ tbl.c.data3,
+ unique=True,
+ )
+ ),
+ "CREATE UNIQUE INDEX test_idx3 ON test_tbl " "(data3)",
+ ),
+ (
+ lambda tbl: schema.CreateIndex(
+ Index(
+ "test_idx3_complex",
+ tbl.c.data3,
+ postgresql_nulls_not_distinct=True,
+ postgresql_include=["data2"],
+ postgresql_where=and_(tbl.c.data3 > 5),
+ postgresql_with={"fillfactor": 50},
+ )
+ ),
+ "CREATE INDEX test_idx3_complex ON test_tbl "
+ "(data3) INCLUDE (data2) NULLS NOT DISTINCT WITH "
+ "(fillfactor = 50) WHERE data3 > 5",
+ ),
+ (
+ lambda tbl: schema.AddConstraint(
+ schema.UniqueConstraint(
+ tbl.c.data,
+ name="uq_data1",
+ postgresql_nulls_not_distinct=True,
+ )
+ ),
+ "ALTER TABLE test_tbl ADD CONSTRAINT uq_data1 UNIQUE "
+ "NULLS NOT DISTINCT (data)",
+ ),
+ (
+ lambda tbl: schema.AddConstraint(
+ schema.UniqueConstraint(
+ tbl.c.data2,
+ name="uq_data2",
+ postgresql_nulls_not_distinct=False,
+ )
+ ),
+ "ALTER TABLE test_tbl ADD CONSTRAINT uq_data2 UNIQUE "
+ "NULLS DISTINCT (data2)",
+ ),
+ (
+ lambda tbl: schema.AddConstraint(
+ schema.UniqueConstraint(
+ tbl.c.data3,
+ name="uq_data3",
+ )
+ ),
+ "ALTER TABLE test_tbl ADD CONSTRAINT uq_data3 UNIQUE (data3)",
+ ),
+ )
+ def test_nulls_not_distinct(self, expr_fn, expected):
+ dd = PGDialect()
+ m = MetaData()
+ tbl = Table(
+ "test_tbl",
+ m,
+ Column("data", String),
+ Column("data2", Integer),
+ Column("data3", Integer),
+ )
+
+ expr = testing.resolve_lambda(expr_fn, tbl=tbl)
+ self.assert_compile(expr, expected, dialect=dd)
+
def test_create_index_with_labeled_ops(self):
m = MetaData()
tbl = Table(
where name != 'foo'
"""
)
+ version = connection.dialect.server_version_info
+ if version >= (15,):
+ connection.exec_driver_sql(
+ """
+ create unique index zz_idx5 on party
+ (name desc, upper(other))
+ nulls not distinct
+ """
+ )
expected = [
{
"dialect_options": {"postgresql_include": []},
},
]
- if connection.dialect.server_version_info < (11,):
+ if version > (15,):
+ expected.append(
+ {
+ "name": "zz_idx5",
+ "column_names": ["name", None],
+ "expressions": ["name", "upper(other::text)"],
+ "unique": True,
+ "include_columns": [],
+ "dialect_options": {
+ "postgresql_include": [],
+ "postgresql_nulls_not_distinct": True,
+ },
+ "column_sorting": {"name": ("desc",)},
+ },
+ )
+
+ if version < (11,):
for index in expected:
index.pop("include_columns")
index["dialect_options"].pop("postgresql_include")
"gin",
)
+ @testing.skip_if("postgresql < 15.0", "nullsnotdistinct not supported")
+ def test_nullsnotdistinct(self, metadata, connection):
+ Table(
+ "t",
+ metadata,
+ Column("id", Integer, primary_key=True),
+ Column("x", ARRAY(Integer)),
+ Column("y", ARRAY(Integer)),
+ Index(
+ "idx1", "x", unique=True, postgresql_nulls_not_distinct=True
+ ),
+ UniqueConstraint(
+ "y", name="unq1", postgresql_nulls_not_distinct=True
+ ),
+ )
+ metadata.create_all(connection)
+
+ ind = inspect(connection).get_indexes("t", None)
+ expected_ind = [
+ {
+ "unique": True,
+ "column_names": ["x"],
+ "name": "idx1",
+ "dialect_options": {
+ "postgresql_nulls_not_distinct": True,
+ "postgresql_include": [],
+ },
+ "include_columns": [],
+ },
+ {
+ "unique": True,
+ "column_names": ["y"],
+ "name": "unq1",
+ "dialect_options": {
+ "postgresql_nulls_not_distinct": True,
+ "postgresql_include": [],
+ },
+ "include_columns": [],
+ "duplicates_constraint": "unq1",
+ },
+ ]
+ eq_(ind, expected_ind)
+
+ unq = inspect(connection).get_unique_constraints("t", None)
+ expected_unq = [
+ {
+ "column_names": ["y"],
+ "name": "unq1",
+ "dialect_options": {
+ "postgresql_nulls_not_distinct": True,
+ },
+ "comment": None,
+ }
+ ]
+ eq_(unq, expected_unq)
+
+ m = MetaData()
+ t1 = Table("t", m, autoload_with=connection)
+ eq_(len(t1.indexes), 1)
+ idx_options = list(t1.indexes)[0].dialect_options["postgresql"]
+ eq_(idx_options["nulls_not_distinct"], True)
+
+ cst = {c.name: c for c in t1.constraints}
+ cst_options = cst["unq1"].dialect_options["postgresql"]
+ eq_(cst_options["nulls_not_distinct"], True)
+
@testing.skip_if("postgresql < 11.0", "indnkeyatts not supported")
def test_index_reflection_with_include(self, metadata, connection):
"""reflect indexes with include set"""