From: Mike Bayer Date: Mon, 18 May 2020 20:08:33 +0000 (-0400) Subject: Streamline visitors.iterate X-Git-Tag: rel_1_4_0b1~326 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=53af60b3536221f2503af29c1e90cf9db1295faf;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Streamline visitors.iterate This method might be used more significantly in the ORM refactor, so further refine it. * all get_children() methods now work entirely based on iterators. Basically only select() was sensitive to this anymore and it now chains the iterators together * remove all kinds of flags like column_collections, schema_visitor that apparently aren't used anymore. * remove the "depthfirst" visitors as these don't seem to be used either. * make sure select() yields its columns first as these will be used to determine the current mapper. Change-Id: I05273a2d5306a57c2d1b0979050748cf3ac964bf --- diff --git a/doc/build/changelog/unreleased_14/removed_depthfirst.rst b/doc/build/changelog/unreleased_14/removed_depthfirst.rst new file mode 100644 index 0000000000..147b2c9d20 --- /dev/null +++ b/doc/build/changelog/unreleased_14/removed_depthfirst.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: change, sql + + Removed the ``sqlalchemy.sql.visitors.iterate_depthfirst`` and + ``sqlalchemy.sql.visitors.traverse_depthfirst`` functions. These functions + were unused by any part of SQLAlchemy. The + :func:`_sa.sql.visitors.iterate` and :func:`_sa.sql.visitors.traverse` + functions are commonly used for these functions. Also removed unused + options from the remaining functions including "column_collections", + "schema_visitor". + diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 6e22a69044..db1fbea2c1 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -4660,10 +4660,7 @@ class _ColumnEntity(_QueryEntity): return self.mapper else: - for obj in visitors.iterate( - self.column, - {"column_tables": True, "column_collections": False}, - ): + for obj in visitors.iterate(self.column, {"column_tables": True},): if "parententity" in obj._annotations: return obj._annotations["parententity"] elif "deepentity" in obj._annotations: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e7c1f3f77e..43115f117e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -388,19 +388,18 @@ class ClauseElement( clause-level). """ - result = [] try: traverse_internals = self._traverse_internals except AttributeError: - return result + return [] - for attrname, obj, meth in _get_children.run_generated_dispatch( - self, traverse_internals, "_generated_get_children_traversal" - ): - if obj is None or attrname in omit_attrs: - continue - result.extend(meth(obj, **kw)) - return result + return itertools.chain.from_iterable( + meth(obj, **kw) + for attrname, obj, meth in _get_children.run_generated_dispatch( + self, traverse_internals, "_generated_get_children_traversal" + ) + if attrname not in omit_attrs and obj is not None + ) def self_group(self, against=None): # type: (Optional[Any]) -> ClauseElement @@ -4302,8 +4301,14 @@ class ColumnClause( def get_children(self, column_tables=False, **kw): if column_tables and self.table is not None: + # TODO: this is only used by ORM query deep_entity_zero. + # this is being removed in a later release so remove + # column_tables also at that time. return [self.table] else: + # override base get_children() to not return the Table + # or selectable that is parent to this column. Traversals + # expect the columns of tables and subqueries to be leaf nodes. return [] @HasMemoized.memoized_attribute diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 08dc487d49..8d28d6309b 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -117,10 +117,6 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable): else: spwd(self) - def get_children(self, **kwargs): - """used to allow SchemaVisitor access""" - return [] - def __repr__(self): return util.generic_repr(self, omit_kwarg=["info"]) @@ -820,21 +816,6 @@ class Table(DialectKWArgs, SchemaItem, TableClause): metadata._add_table(self.name, self.schema, self) self.metadata = metadata - def get_children( - self, column_collections=True, schema_visitor=False, **kw - ): - # TODO: consider that we probably don't need column_collections=True - # at all, it does not seem to impact anything - if not schema_visitor: - return TableClause.get_children( - self, column_collections=column_collections, **kw - ) - else: - if column_collections: - return list(self.columns) - else: - return [] - @util.deprecated( "1.4", "The :meth:`_schema.Table.exists` method is deprecated and will be " @@ -1656,16 +1637,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause): selectable.foreign_keys.update(fk) return c.key, c - def get_children(self, schema_visitor=False, **kwargs): - if schema_visitor: - return ( - [x for x in (self.default, self.onupdate) if x is not None] - + list(self.foreign_keys) - + list(self.constraints) - ) - else: - return ColumnClause.get_children(self, **kwargs) - class ForeignKey(DialectKWArgs, SchemaItem): """Defines a dependency between two columns. diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 27b9425ec0..0f3d241916 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -3454,8 +3454,8 @@ class Select( _traverse_internals = ( [ - ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_raw_columns", InternalTraversal.dp_clauseelement_list), + ("_from_obj", InternalTraversal.dp_clauseelement_list), ("_where_criteria", InternalTraversal.dp_clauseelement_list), ("_having_criteria", InternalTraversal.dp_clauseelement_list), ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,), @@ -3944,10 +3944,11 @@ class Select( self._assert_no_memoizations() def get_children(self, **kwargs): - return list(set(self._iterate_from_elements())) + super( - Select, self - ).get_children( - omit_attrs=["_from_obj", "_correlate", "_correlate_except"] + return itertools.chain( + super(Select, self).get_children( + omit_attrs=["_from_obj", "_correlate", "_correlate_except"] + ), + self._iterate_from_elements(), ) @_generative diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 4a135538e9..8c63fcba14 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -1,5 +1,6 @@ from collections import deque from collections import namedtuple +import itertools import operator from . import operators @@ -589,28 +590,22 @@ class _GetChildren(InternalTraversal): return (element,) def visit_clauseelement_list(self, element, **kw): - return tuple(element) + return element def visit_clauseelement_tuples(self, element, **kw): - tup = () - for elem in element: - tup += elem - return tup + return itertools.chain.from_iterable(element) def visit_fromclause_canonical_column_collection(self, element, **kw): - if kw.get("column_collections", False): - return tuple(element) - else: - return () + return () def visit_string_clauseelement_dict(self, element, **kw): - return tuple(element.values()) + return element.values() def visit_fromclause_ordered_set(self, element, **kw): - return tuple(element) + return element def visit_clauseelement_unordered_set(self, element, **kw): - return tuple(element) + return element def visit_dml_ordered_values(self, element, **kw): for k, v in element: diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py index ca5bde091f..0a67ff9bf2 100644 --- a/lib/sqlalchemy/sql/util.py +++ b/lib/sqlalchemy/sql/util.py @@ -267,7 +267,7 @@ def find_tables( _visitors["table"] = tables.append - visitors.traverse(clause, {"column_collections": False}, _visitors) + visitors.traverse(clause, {}, _visitors) return tables diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py index 8f6bb2333d..574896cc7c 100644 --- a/lib/sqlalchemy/sql/visitors.py +++ b/lib/sqlalchemy/sql/visitors.py @@ -32,10 +32,8 @@ from ..util import symbol __all__ = [ "iterate", - "iterate_depthfirst", "traverse_using", "traverse", - "traverse_depthfirst", "cloned_traverse", "replacement_traverse", "Traversible", @@ -568,23 +566,20 @@ CloningVisitor = CloningExternalTraversal ReplacingCloningVisitor = ReplacingExternalTraversal -def iterate(obj, opts): +def iterate(obj, opts=util.immutabledict()): r"""traverse the given expression structure, returning an iterator. traversal is configured to be breadth-first. - The central API feature used by the :func:`.visitors.iterate` and - :func:`.visitors.iterate_depthfirst` functions is the + The central API feature used by the :func:`.visitors.iterate` + function is the :meth:`_expression.ClauseElement.get_children` method of - :class:`_expression.ClauseElement` - objects. This method should return all the - :class:`_expression.ClauseElement` objects - which are associated with a particular :class:`_expression.ClauseElement` - object. - For example, a :class:`.Case` structure will refer to a series of - :class:`_expression.ColumnElement` - objects within its "whens" and "else\_" member - variables. + :class:`_expression.ClauseElement` objects. This method should return all + the :class:`_expression.ClauseElement` objects which are associated with a + particular :class:`_expression.ClauseElement` object. For example, a + :class:`.Case` structure will refer to a series of + :class:`_expression.ColumnElement` objects within its "whens" and "else\_" + member variables. :param obj: :class:`_expression.ClauseElement` structure to be traversed @@ -592,49 +587,17 @@ def iterate(obj, opts): empty in modern usage. """ - # fasttrack for atomic elements like columns + yield obj children = obj.get_children(**opts) if not children: - return [obj] + return - traversal = deque() - stack = deque([obj]) + stack = deque([children]) while stack: - t = stack.popleft() - traversal.append(t) - for c in t.get_children(**opts): - stack.append(c) - return iter(traversal) - - -def iterate_depthfirst(obj, opts): - """traverse the given expression structure, returning an iterator. - - traversal is configured to be depth-first. - - :param obj: :class:`_expression.ClauseElement` structure to be traversed - - :param opts: dictionary of iteration options. This dictionary is usually - empty in modern usage. - - .. seealso:: - - :func:`.visitors.iterate` - includes a general overview of iteration. - - """ - # fasttrack for atomic elements like columns - children = obj.get_children(**opts) - if not children: - return [obj] - - stack = deque([obj]) - traversal = deque() - while stack: - t = stack.pop() - traversal.appendleft(t) - for c in t.get_children(**opts): - stack.append(c) - return iter(traversal) + t_iterator = stack.popleft() + for t in t_iterator: + yield t + stack.append(t.get_children(**opts)) def traverse_using(iterator, obj, visitors): @@ -642,18 +605,16 @@ def traverse_using(iterator, obj, visitors): objects. :func:`.visitors.traverse_using` is usually called internally as the result - of the :func:`.visitors.traverse` or :func:`.visitors.traverse_depthfirst` - functions. + of the :func:`.visitors.traverse` function. :param iterator: an iterable or sequence which will yield :class:`_expression.ClauseElement` structures; the iterator is assumed to be the - product of the :func:`.visitors.iterate` or - :func:`.visitors.iterate_depthfirst` functions. + product of the :func:`.visitors.iterate` function. :param obj: the :class:`_expression.ClauseElement` that was used as the target of the - :func:`.iterate` or :func:`.iterate_depthfirst` function. + :func:`.iterate` function. :param visitors: dictionary of visit functions. See :func:`.traverse` for details on this dictionary. @@ -662,7 +623,6 @@ def traverse_using(iterator, obj, visitors): :func:`.traverse` - :func:`.traverse_depthfirst` """ for target in iterator: @@ -705,20 +665,6 @@ def traverse(obj, opts, visitors): return traverse_using(iterate(obj, opts), obj, visitors) -def traverse_depthfirst(obj, opts, visitors): - """traverse and visit the given expression structure using the - depth-first iterator. - - The iteration of objects uses the :func:`.visitors.iterate_depthfirst` - function, which does a depth-first traversal using a stack. - - Usage is the same as that of :func:`.visitors.traverse` function. - - - """ - return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) - - def cloned_traverse(obj, opts, visitors): """clone the given expression structure, allowing modifications by visitors. diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 3a6feac018..8cc7b7fb67 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -674,22 +674,16 @@ class CacheKeyFixture(object): ): assert_a_params = [] assert_b_params = [] - visitors.traverse_depthfirst( + visitors.traverse( case_a[a], {}, {"bindparam": assert_a_params.append} ) - visitors.traverse_depthfirst( + visitors.traverse( case_b[b], {}, {"bindparam": assert_b_params.append} ) # note we're asserting the order of the params as well as # if there are dupes or not. ordering has to be # deterministic and matches what a traversal would provide. - # regular traverse_depthfirst does produce dupes in cases - # like - # select([some_alias]). - # select_from(join(some_alias, other_table)) - # where a bound parameter is inside of some_alias. the - # cache key case is more minimalistic eq_( sorted(a_key.bindparams, key=lambda b: b.key), sorted( diff --git a/test/sql/test_utils.py b/test/sql/test_utils.py index a63e55c4e5..4e713dd286 100644 --- a/test/sql/test_utils.py +++ b/test/sql/test_utils.py @@ -29,7 +29,7 @@ class MiscTest(fixtures.TestBase): subset_select = select([common.c.id, common.c.data]).alias() - eq_(sql_util.find_tables(subset_select), [common]) + eq_(set(sql_util.find_tables(subset_select)), {common}) def test_find_tables_aliases(self): metadata = MetaData()