--- /dev/null
+.. change::
+ :tags: bug, orm
+ :tickets: 4829
+
+ Added new entity-targeting capabilities to the :class:`.Query` object to
+ help with the case where the :class:`.Session` is using a bind dictionary
+ against mapped classes, rather than a single bind, and the :class:`.Query`
+ is against a Core statement that was ultimately generated from a method
+ such as :meth:`.Query.subquery`; a deep search is performed to locate
+ any ORM entity related to the query in order to locate a mapper if
+ one is not otherwise present.
+
else self._query_entity_zero().entity_zero
)
+ def _deep_entity_zero(self):
+ """Return a 'deep' entity; this is any entity we can find associated
+ with the first entity / column experssion. this is used only for
+ session.get_bind().
+
+ """
+
+ if (
+ self._select_from_entity is not None
+ and not self._select_from_entity.is_clause_element
+ ):
+ return self._select_from_entity.mapper
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero.mapper
+ else:
+ return None
+
@property
def _mapper_entities(self):
for ent in self._entities:
return self._joinpoint.get("_joinpoint_entity", self._entity_zero())
def _bind_mapper(self):
- ezero = self._entity_zero()
- if ezero is not None:
- insp = inspect(ezero)
- if not insp.is_clause_element:
- return insp.mapper
-
- return None
+ return self._deep_entity_zero()
def _only_full_mapper_zero(self, methname):
if self._entities != [self._primary_entity]:
else:
context.statement = self._simple_statement(context)
+ if for_statement:
+ ezero = self._mapper_zero()
+ if ezero is not None:
+ context.statement = context.statement._annotate(
+ {"deepentity": ezero}
+ )
return context
def _compound_eager_statement(self, context):
def entity_zero_or_selectable(self):
return self.entity_zero
+ def _deep_entity_zero(self):
+ return self.entity_zero
+
def corresponds_to(self, entity):
return _entity_corresponds_to(self.entity_zero, entity)
else:
return None
+ def _deep_entity_zero(self):
+ for ent in self._entities:
+ ezero = ent._deep_entity_zero()
+ if ezero is not None:
+ return ezero
+ else:
+ return None
+
def adapt_to_selectable(self, query, sel):
c = _BundleEntity(query, self.bundle, setup_entities=False)
# c._label_name = self._label_name
# of FROMs for the overall expression - this helps
# subqueries which were built from ORM constructs from
# leaking out their entities into the main select construct
- self.actual_froms = actual_froms = set(column._from_objects)
+ self.actual_froms = set(column._from_objects)
if not search_entities:
self.entity_zero = _entity
else:
self.entities = []
self.mapper = None
- self._from_entities = set(self.entities)
else:
all_elements = [
elem
]
self.entities = util.unique_list(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- ]
+ [elem._annotations["parententity"] for elem in all_elements]
)
- self._from_entities = set(
- [
- elem._annotations["parententity"]
- for elem in all_elements
- if "parententity" in elem._annotations
- and actual_froms.intersection(elem._from_objects)
- ]
- )
if self.entities:
self.entity_zero = self.entities[0]
self.mapper = self.entity_zero.mapper
supports_single_entity = False
+ def _deep_entity_zero(self):
+ if self.mapper is not None:
+ return self.mapper
+
+ else:
+ for obj in visitors.iterate(
+ self.column,
+ {"column_tables": True, "column_collections": False},
+ ):
+ if "parententity" in obj._annotations:
+ return obj._annotations["parententity"]
+ elif "deepentity" in obj._annotations:
+ return obj._annotations["deepentity"]
+ else:
+ return None
+
@property
def entity_zero_or_selectable(self):
if self.entity_zero is not None:
from .. import util
+class SupportsCloneAnnotations(object):
+ _annotations = util.immutabledict()
+
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = new._annotations.union(values)
+ return new
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ new = self._clone()
+ new._annotations = util.immutabledict(values)
+ return new
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone or self._annotations:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ new = self._clone()
+ new._annotations = {}
+ return new
+ else:
+ return self
+
+
+class SupportsWrappingAnnotations(object):
+ def _annotate(self, values):
+ """return a copy of this ClauseElement with annotations
+ updated by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _with_annotations(self, values):
+ """return a copy of this ClauseElement with annotations
+ replaced by the given dictionary.
+
+ """
+ return Annotated(self, values)
+
+ def _deannotate(self, values=None, clone=False):
+ """return a copy of this :class:`.ClauseElement` with annotations
+ removed.
+
+ :param values: optional tuple of individual values
+ to remove.
+
+ """
+ if clone:
+ # clone is used when we are also copying
+ # the expression for a deep deannotation
+ return self._clone()
+ else:
+ # if no clone, since we have no annotations we return
+ # self
+ return self
+
+
class Annotated(object):
- """clones a ClauseElement and applies an 'annotations' dictionary.
+ """clones a SupportsAnnotated and applies an 'annotations' dictionary.
Unlike regular clones, this clone also mimics __hash__() and
__cmp__() of the original element so that it takes its place
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsWrappingAnnotations
from .base import _clone
from .base import _generative
from .base import Executable
@inspection._self_inspects
-class ClauseElement(roles.SQLRole, Visitable):
+class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable):
"""Base class for elements of a programmatically constructed SQL
expression.
d.pop("_is_clone_of", None)
return d
- def _annotate(self, values):
- """return a copy of this ClauseElement with annotations
- updated by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _with_annotations(self, values):
- """return a copy of this ClauseElement with annotations
- replaced by the given dictionary.
-
- """
- return Annotated(self, values)
-
- def _deannotate(self, values=None, clone=False):
- """return a copy of this :class:`.ClauseElement` with annotations
- removed.
-
- :param values: optional tuple of individual values
- to remove.
-
- """
- if clone:
- # clone is used when we are also copying
- # the expression for a deep deannotation
- return self._clone()
- else:
- # if no clone, since we have no annotations we return
- # self
- return self
-
def _execute_on_connection(self, connection, multiparams, params):
if self.supports_execution:
return connection._execute_clauseelement(self, multiparams, params)
self._memoized_property.expire_instance(self)
self.__dict__["table"] = table
+ def get_children(self, column_tables=False, **kw):
+ if column_tables and self.table is not None:
+ return [self.table]
+ else:
+ return []
+
table = property(_get_table, _set_table)
def _cache_key(self, **kw):
from . import roles
from . import type_api
from .annotation import Annotated
+from .annotation import SupportsCloneAnnotations
from .base import _clone
from .base import _cloned_difference
from .base import _cloned_intersection
roles.InElementRole,
HasCTE,
Executable,
+ SupportsCloneAnnotations,
Selectable,
):
"""Base class for SELECT statements.
assert row.id == 7
assert row.uname == "jack"
+ def test_deep_entity(self):
+ users, User = (self.tables.users, self.classes.User)
+
+ mapper(User, users)
+
+ sess = create_session()
+ bundle = Bundle("b1", User.id, User.name)
+ subq1 = sess.query(User.id).subquery()
+ subq2 = sess.query(bundle).subquery()
+ cte = sess.query(User.id).cte()
+ ex = sess.query(User).exists()
+
+ is_(sess.query(subq1)._deep_entity_zero(), inspect(User))
+ is_(sess.query(subq2)._deep_entity_zero(), inspect(User))
+ is_(sess.query(cte)._deep_entity_zero(), inspect(User))
+ is_(sess.query(ex)._deep_entity_zero(), inspect(User))
+
def test_column_metadata(self):
users, Address, addresses, User = (
self.tables.users,
fn = func.count(User.id)
name_label = User.name.label("uname")
bundle = Bundle("b1", User.id, User.name)
+ subq1 = sess.query(User.id).subquery()
+ subq2 = sess.query(bundle).subquery()
cte = sess.query(User.id).cte()
for q, asserted in [
(
}
],
),
+ (
+ sess.query(subq1.c.id),
+ [
+ {
+ "aliased": False,
+ "expr": subq1.c.id,
+ "type": subq1.c.id.type,
+ "name": "id",
+ "entity": None,
+ }
+ ],
+ ),
+ (
+ sess.query(subq2.c.id),
+ [
+ {
+ "aliased": False,
+ "expr": subq2.c.id,
+ "type": subq2.c.id.type,
+ "name": "id",
+ "entity": None,
+ }
+ ],
+ ),
(
sess.query(users),
[
class SessionBindTest(QueryTest):
@contextlib.contextmanager
- def _assert_bind_args(self, session):
+ def _assert_bind_args(self, session, expect_mapped_bind=True):
get_bind = mock.Mock(side_effect=session.get_bind)
with mock.patch.object(session, "get_bind", get_bind):
yield
for call_ in get_bind.mock_calls:
- is_(call_[1][0], inspect(self.classes.User))
+ if expect_mapped_bind:
+ is_(call_[1][0], inspect(self.classes.User))
+ else:
+ is_(call_[1][0], None)
is_not_(call_[2]["clause"], None)
def test_single_entity_q(self):
with self._assert_bind_args(session):
session.query(User).all()
+ def test_aliased_entity_q(self):
+ User = self.classes.User
+ u = aliased(User)
+ session = Session()
+ with self._assert_bind_args(session):
+ session.query(u).all()
+
def test_sql_expr_entity_q(self):
User = self.classes.User
session = Session()
with self._assert_bind_args(session):
session.query(User.id).all()
+ def test_sql_expr_subquery_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ subq = session.query(User.id).subquery()
+ session.query(subq).all()
+
+ def test_sql_expr_cte_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ cte = session.query(User.id).cte()
+ subq = session.query(cte).subquery()
+ session.query(subq).all()
+
+ def test_sql_expr_bundle_cte_from_entity(self):
+ User = self.classes.User
+ session = Session()
+ with self._assert_bind_args(session):
+ cte = session.query(User.id, User.name).cte()
+ subq = session.query(cte).subquery()
+ bundle = Bundle(subq.c.id, subq.c.name)
+ session.query(bundle).all()
+
def test_count(self):
User = self.classes.User
session = Session()
with self._assert_bind_args(session):
session.query(func.max(User.score)).scalar()
+ def test_plain_table(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).all()
+
+ def test_plain_table_from_self(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).from_self().all()
+
+ def test_plain_table_count(self):
+ User = self.classes.User
+
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(inspect(User).local_table).count()
+
+ def test_plain_table_select_from(self):
+ User = self.classes.User
+
+ table = inspect(User).local_table
+ session = Session()
+ with self._assert_bind_args(session, expect_mapped_bind=False):
+ session.query(table).select_from(table).all()
+
@testing.requires.nested_aggregates
def test_column_property_select(self):
User = self.classes.User
from sqlalchemy.testing import AssertsExecutionResults
from sqlalchemy.testing import eq_
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import in_
from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_not_
+from sqlalchemy.testing import ne_
metadata = MetaData()
t = table("t", column("x"))
a = t.alias()
+
+ for obj in [t, t.c.x, a, t.c.x > 1, (t.c.x > 1).label(None)]:
+ annot = obj._annotate({})
+ eq_(set([obj]), set([annot]))
+
+ def test_clone_annotations_dont_hash(self):
+ t = table("t", column("x"))
+
s = t.select()
+ a = t.alias()
s2 = a.select()
- for obj in [t, t.c.x, a, s, s2, t.c.x > 1, (t.c.x > 1).label(None)]:
+ for obj in [s, s2]:
annot = obj._annotate({})
- eq_(set([obj]), set([annot]))
+ ne_(set([obj]), set([annot]))
def test_compare(self):
t = table("t", column("x"), column("y"))
expected,
)
- def test_deannotate(self):
+ def test_deannotate_wrapping(self):
table1 = table("table1", column("col1"), column("col2"))
bin_ = table1.c.col1 == bindparam("foo", value=None)
b4 = sql_util._deep_deannotate(bin_)
for elem in (b2._annotations, b2.left._annotations):
- assert "_orm_adapt" in elem
+ in_("_orm_adapt", elem)
for elem in (
b3._annotations,
b4._annotations,
b4.left._annotations,
):
- assert elem == {}
+ eq_(elem, {})
- assert b2.left is not bin_.left
- assert b3.left is not b2.left and b2.left is not bin_.left
- assert b4.left is bin_.left # since column is immutable
+ is_not_(b2.left, bin_.left)
+ is_not_(b3.left, b2.left)
+ is_not_(b2.left, bin_.left)
+ is_(b4.left, bin_.left) # since column is immutable
# deannotate copies the element
- assert (
- bin_.right is not b2.right
- and b2.right is not b3.right
- and b3.right is not b4.right
+ is_not_(bin_.right, b2.right)
+ is_not_(b2.right, b3.right)
+ is_not_(b3.right, b4.right)
+
+ def test_deannotate_clone(self):
+ table1 = table("table1", column("col1"), column("col2"))
+
+ subq = (
+ select([table1])
+ .where(table1.c.col1 == bindparam("foo"))
+ .subquery()
)
+ stmt = select([subq])
+
+ s2 = sql_util._deep_annotate(stmt, {"_orm_adapt": True})
+ s3 = sql_util._deep_deannotate(s2)
+ s4 = sql_util._deep_deannotate(s3)
+
+ eq_(stmt._annotations, {})
+ eq_(subq._annotations, {})
+
+ eq_(s2._annotations, {"_orm_adapt": True})
+ eq_(s3._annotations, {})
+ eq_(s4._annotations, {})
+
+ # select._raw_columns[0] is the subq object
+ eq_(s2._raw_columns[0]._annotations, {"_orm_adapt": True})
+ eq_(s3._raw_columns[0]._annotations, {})
+ eq_(s4._raw_columns[0]._annotations, {})
+
+ is_not_(s3, s2)
+ is_not_(s4, s3) # deep deannotate makes a clone unconditionally
+
+ is_(s3._deannotate(), s3) # regular deannotate returns same object
def test_annotate_unique_traversal(self):
"""test that items are copied only once during