From: Mike Bayer Date: Thu, 29 Aug 2019 16:09:17 +0000 (-0400) Subject: Annotate session-bind-lookup entity in Query-produced selectables X-Git-Tag: rel_1_4_0b1~738^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f6c9b20a04d183d86078252048563b14e27fb6d2;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Annotate session-bind-lookup entity in Query-produced selectables Added new entity-targeting capabilities to the :class:`.Query` object to help with the case where the :class:`.Session` is using a bind dictionary against mapped classes, rather than a single bind, and the :class:`.Query` is against a Core statement that was ultimately generated from a method such as :meth:`.Query.subquery`; a deep search is performed to locate any ORM entity related to the query in order to locate a mapper if one is not otherwise present. Fixes: #4829 Change-Id: I95cf325a5aba21baec4b313246c6f4d692284820 --- diff --git a/doc/build/changelog/unreleased_14/4829.rst b/doc/build/changelog/unreleased_14/4829.rst new file mode 100644 index 0000000000..93c582fa21 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4829.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, orm + :tickets: 4829 + + Added new entity-targeting capabilities to the :class:`.Query` object to + help with the case where the :class:`.Session` is using a bind dictionary + against mapped classes, rather than a single bind, and the :class:`.Query` + is against a Core statement that was ultimately generated from a method + such as :meth:`.Query.subquery`; a deep search is performed to locate + any ORM entity related to the query in order to locate a mapper if + one is not otherwise present. + diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 9369297033..d4ff35d2e5 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -384,6 +384,25 @@ class Query(object): else self._query_entity_zero().entity_zero ) + def _deep_entity_zero(self): + """Return a 'deep' entity; this is any entity we can find associated + with the first entity / column experssion. this is used only for + session.get_bind(). + + """ + + if ( + self._select_from_entity is not None + and not self._select_from_entity.is_clause_element + ): + return self._select_from_entity.mapper + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero.mapper + else: + return None + @property def _mapper_entities(self): for ent in self._entities: @@ -394,13 +413,7 @@ class Query(object): return self._joinpoint.get("_joinpoint_entity", self._entity_zero()) def _bind_mapper(self): - ezero = self._entity_zero() - if ezero is not None: - insp = inspect(ezero) - if not insp.is_clause_element: - return insp.mapper - - return None + return self._deep_entity_zero() def _only_full_mapper_zero(self, methname): if self._entities != [self._primary_entity]: @@ -3900,6 +3913,12 @@ class Query(object): else: context.statement = self._simple_statement(context) + if for_statement: + ezero = self._mapper_zero() + if ezero is not None: + context.statement = context.statement._annotate( + {"deepentity": ezero} + ) return context def _compound_eager_statement(self, context): @@ -4161,6 +4180,9 @@ class _MapperEntity(_QueryEntity): def entity_zero_or_selectable(self): return self.entity_zero + def _deep_entity_zero(self): + return self.entity_zero + def corresponds_to(self, entity): return _entity_corresponds_to(self.entity_zero, entity) @@ -4430,6 +4452,14 @@ class _BundleEntity(_QueryEntity): else: return None + def _deep_entity_zero(self): + for ent in self._entities: + ezero = ent._deep_entity_zero() + if ezero is not None: + return ezero + else: + return None + def adapt_to_selectable(self, query, sel): c = _BundleEntity(query, self.bundle, setup_entities=False) # c._label_name = self._label_name @@ -4530,7 +4560,7 @@ class _ColumnEntity(_QueryEntity): # of FROMs for the overall expression - this helps # subqueries which were built from ORM constructs from # leaking out their entities into the main select construct - self.actual_froms = actual_froms = set(column._from_objects) + self.actual_froms = set(column._from_objects) if not search_entities: self.entity_zero = _entity @@ -4540,7 +4570,6 @@ class _ColumnEntity(_QueryEntity): else: self.entities = [] self.mapper = None - self._from_entities = set(self.entities) else: all_elements = [ elem @@ -4551,21 +4580,9 @@ class _ColumnEntity(_QueryEntity): ] self.entities = util.unique_list( - [ - elem._annotations["parententity"] - for elem in all_elements - if "parententity" in elem._annotations - ] + [elem._annotations["parententity"] for elem in all_elements] ) - self._from_entities = set( - [ - elem._annotations["parententity"] - for elem in all_elements - if "parententity" in elem._annotations - and actual_froms.intersection(elem._from_objects) - ] - ) if self.entities: self.entity_zero = self.entities[0] self.mapper = self.entity_zero.mapper @@ -4578,6 +4595,22 @@ class _ColumnEntity(_QueryEntity): supports_single_entity = False + def _deep_entity_zero(self): + if self.mapper is not None: + return self.mapper + + else: + for obj in visitors.iterate( + self.column, + {"column_tables": True, "column_collections": False}, + ): + if "parententity" in obj._annotations: + return obj._annotations["parententity"] + elif "deepentity" in obj._annotations: + return obj._annotations["deepentity"] + else: + return None + @property def entity_zero_or_selectable(self): if self.entity_zero is not None: diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py index 7fc9245ab5..a0264845e3 100644 --- a/lib/sqlalchemy/sql/annotation.py +++ b/lib/sqlalchemy/sql/annotation.py @@ -15,8 +15,80 @@ from . import operators from .. import util +class SupportsCloneAnnotations(object): + _annotations = util.immutabledict() + + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + new = self._clone() + new._annotations = new._annotations.union(values) + return new + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + new = self._clone() + new._annotations = util.immutabledict(values) + return new + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone or self._annotations: + # clone is used when we are also copying + # the expression for a deep deannotation + new = self._clone() + new._annotations = {} + return new + else: + return self + + +class SupportsWrappingAnnotations(object): + def _annotate(self, values): + """return a copy of this ClauseElement with annotations + updated by the given dictionary. + + """ + return Annotated(self, values) + + def _with_annotations(self, values): + """return a copy of this ClauseElement with annotations + replaced by the given dictionary. + + """ + return Annotated(self, values) + + def _deannotate(self, values=None, clone=False): + """return a copy of this :class:`.ClauseElement` with annotations + removed. + + :param values: optional tuple of individual values + to remove. + + """ + if clone: + # clone is used when we are also copying + # the expression for a deep deannotation + return self._clone() + else: + # if no clone, since we have no annotations we return + # self + return self + + class Annotated(object): - """clones a ClauseElement and applies an 'annotations' dictionary. + """clones a SupportsAnnotated and applies an 'annotations' dictionary. Unlike regular clones, this clone also mimics __hash__() and __cmp__() of the original element so that it takes its place diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e2df1adc2d..19d26f1382 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -22,6 +22,7 @@ from . import operators from . import roles from . import type_api from .annotation import Annotated +from .annotation import SupportsWrappingAnnotations from .base import _clone from .base import _generative from .base import Executable @@ -161,7 +162,7 @@ def not_(clause): @inspection._self_inspects -class ClauseElement(roles.SQLRole, Visitable): +class ClauseElement(roles.SQLRole, SupportsWrappingAnnotations, Visitable): """Base class for elements of a programmatically constructed SQL expression. @@ -267,37 +268,6 @@ class ClauseElement(roles.SQLRole, Visitable): d.pop("_is_clone_of", None) return d - def _annotate(self, values): - """return a copy of this ClauseElement with annotations - updated by the given dictionary. - - """ - return Annotated(self, values) - - def _with_annotations(self, values): - """return a copy of this ClauseElement with annotations - replaced by the given dictionary. - - """ - return Annotated(self, values) - - def _deannotate(self, values=None, clone=False): - """return a copy of this :class:`.ClauseElement` with annotations - removed. - - :param values: optional tuple of individual values - to remove. - - """ - if clone: - # clone is used when we are also copying - # the expression for a deep deannotation - return self._clone() - else: - # if no clone, since we have no annotations we return - # self - return self - def _execute_on_connection(self, connection, multiparams, params): if self.supports_execution: return connection._execute_clauseelement(self, multiparams, params) @@ -4136,6 +4106,12 @@ class ColumnClause(roles.LabeledColumnExprRole, Immutable, ColumnElement): self._memoized_property.expire_instance(self) self.__dict__["table"] = table + def get_children(self, column_tables=False, **kw): + if column_tables and self.table is not None: + return [self.table] + else: + return [] + table = property(_get_table, _set_table) def _cache_key(self, **kw): diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 03dbcd449a..97c49f8fcc 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -19,6 +19,7 @@ from . import operators from . import roles from . import type_api from .annotation import Annotated +from .annotation import SupportsCloneAnnotations from .base import _clone from .base import _cloned_difference from .base import _cloned_intersection @@ -2068,6 +2069,7 @@ class SelectBase( roles.InElementRole, HasCTE, Executable, + SupportsCloneAnnotations, Selectable, ): """Base class for SELECT statements. diff --git a/test/orm/test_query.py b/test/orm/test_query.py index f5283ff443..4dff6fe56d 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -139,6 +139,23 @@ class RowTupleTest(QueryTest): assert row.id == 7 assert row.uname == "jack" + def test_deep_entity(self): + users, User = (self.tables.users, self.classes.User) + + mapper(User, users) + + sess = create_session() + bundle = Bundle("b1", User.id, User.name) + subq1 = sess.query(User.id).subquery() + subq2 = sess.query(bundle).subquery() + cte = sess.query(User.id).cte() + ex = sess.query(User).exists() + + is_(sess.query(subq1)._deep_entity_zero(), inspect(User)) + is_(sess.query(subq2)._deep_entity_zero(), inspect(User)) + is_(sess.query(cte)._deep_entity_zero(), inspect(User)) + is_(sess.query(ex)._deep_entity_zero(), inspect(User)) + def test_column_metadata(self): users, Address, addresses, User = ( self.tables.users, @@ -156,6 +173,8 @@ class RowTupleTest(QueryTest): fn = func.count(User.id) name_label = User.name.label("uname") bundle = Bundle("b1", User.id, User.name) + subq1 = sess.query(User.id).subquery() + subq2 = sess.query(bundle).subquery() cte = sess.query(User.id).cte() for q, asserted in [ ( @@ -275,6 +294,30 @@ class RowTupleTest(QueryTest): } ], ), + ( + sess.query(subq1.c.id), + [ + { + "aliased": False, + "expr": subq1.c.id, + "type": subq1.c.id.type, + "name": "id", + "entity": None, + } + ], + ), + ( + sess.query(subq2.c.id), + [ + { + "aliased": False, + "expr": subq2.c.id, + "type": subq2.c.id.type, + "name": "id", + "entity": None, + } + ], + ), ( sess.query(users), [ @@ -5518,12 +5561,15 @@ class BooleanEvalTest(fixtures.TestBase, testing.AssertsCompiledSQL): class SessionBindTest(QueryTest): @contextlib.contextmanager - def _assert_bind_args(self, session): + def _assert_bind_args(self, session, expect_mapped_bind=True): get_bind = mock.Mock(side_effect=session.get_bind) with mock.patch.object(session, "get_bind", get_bind): yield for call_ in get_bind.mock_calls: - is_(call_[1][0], inspect(self.classes.User)) + if expect_mapped_bind: + is_(call_[1][0], inspect(self.classes.User)) + else: + is_(call_[1][0], None) is_not_(call_[2]["clause"], None) def test_single_entity_q(self): @@ -5532,12 +5578,43 @@ class SessionBindTest(QueryTest): with self._assert_bind_args(session): session.query(User).all() + def test_aliased_entity_q(self): + User = self.classes.User + u = aliased(User) + session = Session() + with self._assert_bind_args(session): + session.query(u).all() + def test_sql_expr_entity_q(self): User = self.classes.User session = Session() with self._assert_bind_args(session): session.query(User.id).all() + def test_sql_expr_subquery_from_entity(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session): + subq = session.query(User.id).subquery() + session.query(subq).all() + + def test_sql_expr_cte_from_entity(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session): + cte = session.query(User.id).cte() + subq = session.query(cte).subquery() + session.query(subq).all() + + def test_sql_expr_bundle_cte_from_entity(self): + User = self.classes.User + session = Session() + with self._assert_bind_args(session): + cte = session.query(User.id, User.name).cte() + subq = session.query(cte).subquery() + bundle = Bundle(subq.c.id, subq.c.name) + session.query(bundle).all() + def test_count(self): User = self.classes.User session = Session() @@ -5594,6 +5671,35 @@ class SessionBindTest(QueryTest): with self._assert_bind_args(session): session.query(func.max(User.score)).scalar() + def test_plain_table(self): + User = self.classes.User + + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=False): + session.query(inspect(User).local_table).all() + + def test_plain_table_from_self(self): + User = self.classes.User + + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=False): + session.query(inspect(User).local_table).from_self().all() + + def test_plain_table_count(self): + User = self.classes.User + + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=False): + session.query(inspect(User).local_table).count() + + def test_plain_table_select_from(self): + User = self.classes.User + + table = inspect(User).local_table + session = Session() + with self._assert_bind_args(session, expect_mapped_bind=False): + session.query(table).select_from(table).all() + @testing.requires.nested_aggregates def test_column_property_select(self): User = self.classes.User diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index 189436192d..c54f27c23b 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -41,7 +41,10 @@ from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import AssertsExecutionResults from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures +from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_not_ +from sqlalchemy.testing import ne_ metadata = MetaData() @@ -2196,12 +2199,21 @@ class AnnotationsTest(fixtures.TestBase): t = table("t", column("x")) a = t.alias() + + for obj in [t, t.c.x, a, t.c.x > 1, (t.c.x > 1).label(None)]: + annot = obj._annotate({}) + eq_(set([obj]), set([annot])) + + def test_clone_annotations_dont_hash(self): + t = table("t", column("x")) + s = t.select() + a = t.alias() s2 = a.select() - for obj in [t, t.c.x, a, s, s2, t.c.x > 1, (t.c.x > 1).label(None)]: + for obj in [s, s2]: annot = obj._annotate({}) - eq_(set([obj]), set([annot])) + ne_(set([obj]), set([annot])) def test_compare(self): t = table("t", column("x"), column("y")) @@ -2423,7 +2435,7 @@ class AnnotationsTest(fixtures.TestBase): expected, ) - def test_deannotate(self): + def test_deannotate_wrapping(self): table1 = table("table1", column("col1"), column("col2")) bin_ = table1.c.col1 == bindparam("foo", value=None) @@ -2433,7 +2445,7 @@ class AnnotationsTest(fixtures.TestBase): b4 = sql_util._deep_deannotate(bin_) for elem in (b2._annotations, b2.left._annotations): - assert "_orm_adapt" in elem + in_("_orm_adapt", elem) for elem in ( b3._annotations, @@ -2441,17 +2453,47 @@ class AnnotationsTest(fixtures.TestBase): b4._annotations, b4.left._annotations, ): - assert elem == {} + eq_(elem, {}) - assert b2.left is not bin_.left - assert b3.left is not b2.left and b2.left is not bin_.left - assert b4.left is bin_.left # since column is immutable + is_not_(b2.left, bin_.left) + is_not_(b3.left, b2.left) + is_not_(b2.left, bin_.left) + is_(b4.left, bin_.left) # since column is immutable # deannotate copies the element - assert ( - bin_.right is not b2.right - and b2.right is not b3.right - and b3.right is not b4.right + is_not_(bin_.right, b2.right) + is_not_(b2.right, b3.right) + is_not_(b3.right, b4.right) + + def test_deannotate_clone(self): + table1 = table("table1", column("col1"), column("col2")) + + subq = ( + select([table1]) + .where(table1.c.col1 == bindparam("foo")) + .subquery() ) + stmt = select([subq]) + + s2 = sql_util._deep_annotate(stmt, {"_orm_adapt": True}) + s3 = sql_util._deep_deannotate(s2) + s4 = sql_util._deep_deannotate(s3) + + eq_(stmt._annotations, {}) + eq_(subq._annotations, {}) + + eq_(s2._annotations, {"_orm_adapt": True}) + eq_(s3._annotations, {}) + eq_(s4._annotations, {}) + + # select._raw_columns[0] is the subq object + eq_(s2._raw_columns[0]._annotations, {"_orm_adapt": True}) + eq_(s3._raw_columns[0]._annotations, {}) + eq_(s4._raw_columns[0]._annotations, {}) + + is_not_(s3, s2) + is_not_(s4, s3) # deep deannotate makes a clone unconditionally + + is_(s3._deannotate(), s3) # regular deannotate returns same object def test_annotate_unique_traversal(self): """test that items are copied only once during