# that of the inheriting (unless concrete or explicit)
self.primary_key = self.inherits.primary_key
else:
- # determine primary key from argument or persist_selectable pks -
- # reduce to the minimal set of columns
+ # determine primary key from argument or persist_selectable pks
if self._primary_key_argument:
- primary_key = sql_util.reduce_columns(
- [
- self.persist_selectable.corresponding_column(c)
- for c in self._primary_key_argument
- ],
- ignore_nonexistent_tables=True,
- )
+ primary_key = [
+ self.persist_selectable.corresponding_column(c)
+ for c in self._primary_key_argument
+ ]
else:
+ # if heuristically determined PKs, reduce to the minimal set
+ # of columns by eliminating FK->PK pairs for a multi-table
+ # expression. May over-reduce for some kinds of UNIONs
+ # / CTEs; use explicit PK argument for these special cases
primary_key = sql_util.reduce_columns(
self._pks_by_table[self.persist_selectable],
ignore_nonexistent_tables=True,
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import literal
from sqlalchemy import MetaData
from sqlalchemy import select
from sqlalchemy import String
from sqlalchemy.testing import is_false
from sqlalchemy.testing import is_true
from sqlalchemy.testing import ne_
+from sqlalchemy.testing.fixtures import ComparableEntity
from sqlalchemy.testing.fixtures import ComparableMixin
from sqlalchemy.testing.fixtures import fixture_session
from sqlalchemy.testing.schema import Column
],
)
+ @testing.requires.ctes
+ def test_mapping_to_union_dont_overlimit_pk(self, registry, connection):
+ """test #7842"""
+ Base = registry.generate_base()
+
+ class Node(Base):
+ __tablename__ = "cte_nodes"
+
+ id = Column(Integer, primary_key=True)
+ parent = Column(Integer, ForeignKey("cte_nodes.id"))
+
+ # so we dont have to deal with NULLS FIRST
+ sort_key = Column(Integer)
+
+ class NodeRel(ComparableEntity, Base):
+ table = select(
+ Node.id, Node.parent, Node.sort_key, literal(0).label("depth")
+ ).cte(recursive=True)
+ __table__ = table.union_all(
+ select(
+ Node.id,
+ table.c.parent,
+ table.c.sort_key,
+ table.c.depth + literal(1),
+ )
+ .select_from(Node)
+ .join(table, Node.parent == table.c.id)
+ )
+
+ __mapper_args__ = {
+ "primary_key": (__table__.c.id, __table__.c.parent)
+ }
+
+ nt = NodeRel.__table__
+
+ eq_(NodeRel.__mapper__.primary_key, (nt.c.id, nt.c.parent))
+
+ registry.metadata.create_all(connection)
+ with Session(connection) as session:
+ n1, n2, n3, n4 = (
+ Node(id=1, sort_key=1),
+ Node(id=2, parent=1, sort_key=2),
+ Node(id=3, parent=2, sort_key=3),
+ Node(id=4, parent=3, sort_key=4),
+ )
+ session.add_all([n1, n2, n3, n4])
+ session.commit()
+
+ q_rel = select(NodeRel).filter_by(id=4).order_by(NodeRel.sort_key)
+ eq_(
+ session.scalars(q_rel).all(),
+ [
+ NodeRel(id=4, parent=None),
+ NodeRel(id=4, parent=1),
+ NodeRel(id=4, parent=2),
+ NodeRel(id=4, parent=3),
+ ],
+ )
+
def test_scalar_pk_arg(self):
users, Keyword, items, Item, User, keywords = (
self.tables.users,