]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Streamline visitors.iterate
authorMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 May 2020 20:08:33 +0000 (16:08 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 18 May 2020 20:21:54 +0000 (16:21 -0400)
This method might be used more significantly in the
ORM refactor, so further refine it.

* all get_children() methods now work entirely based on iterators.
  Basically only select() was sensitive to this anymore and it now
  chains the iterators together

* remove all kinds of flags like column_collections, schema_visitor
  that apparently aren't used anymore.

* remove the "depthfirst" visitors as these don't seem to be
  used either.

* make sure select() yields its columns first as these will be used
  to determine the current mapper.

Change-Id: I05273a2d5306a57c2d1b0979050748cf3ac964bf

doc/build/changelog/unreleased_14/removed_depthfirst.rst [new file with mode: 0644]
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/schema.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/sql/util.py
lib/sqlalchemy/sql/visitors.py
test/sql/test_compare.py
test/sql/test_utils.py

diff --git a/doc/build/changelog/unreleased_14/removed_depthfirst.rst b/doc/build/changelog/unreleased_14/removed_depthfirst.rst
new file mode 100644 (file)
index 0000000..147b2c9
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: change, sql
+
+    Removed the ``sqlalchemy.sql.visitors.iterate_depthfirst`` and
+    ``sqlalchemy.sql.visitors.traverse_depthfirst`` functions.  These functions
+    were unused by any part of SQLAlchemy.  The
+    :func:`_sa.sql.visitors.iterate` and :func:`_sa.sql.visitors.traverse`
+    functions are commonly used for these functions.  Also removed unused
+    options from the remaining functions including "column_collections",
+    "schema_visitor".
+
index 6e22a6904485a5da5e08476088906510a2a9d014..db1fbea2c148bceca9a2ff45a1f3af1d6fe28efe 100644 (file)
@@ -4660,10 +4660,7 @@ class _ColumnEntity(_QueryEntity):
             return self.mapper
 
         else:
-            for obj in visitors.iterate(
-                self.column,
-                {"column_tables": True, "column_collections": False},
-            ):
+            for obj in visitors.iterate(self.column, {"column_tables": True},):
                 if "parententity" in obj._annotations:
                     return obj._annotations["parententity"]
                 elif "deepentity" in obj._annotations:
index e7c1f3f77ebb5453b6059a16bc421f55fd476ec6..43115f117ef013231b259b89d1f5896738cabd78 100644 (file)
@@ -388,19 +388,18 @@ class ClauseElement(
         clause-level).
 
         """
-        result = []
         try:
             traverse_internals = self._traverse_internals
         except AttributeError:
-            return result
+            return []
 
-        for attrname, obj, meth in _get_children.run_generated_dispatch(
-            self, traverse_internals, "_generated_get_children_traversal"
-        ):
-            if obj is None or attrname in omit_attrs:
-                continue
-            result.extend(meth(obj, **kw))
-        return result
+        return itertools.chain.from_iterable(
+            meth(obj, **kw)
+            for attrname, obj, meth in _get_children.run_generated_dispatch(
+                self, traverse_internals, "_generated_get_children_traversal"
+            )
+            if attrname not in omit_attrs and obj is not None
+        )
 
     def self_group(self, against=None):
         # type: (Optional[Any]) -> ClauseElement
@@ -4302,8 +4301,14 @@ class ColumnClause(
 
     def get_children(self, column_tables=False, **kw):
         if column_tables and self.table is not None:
+            # TODO: this is only used by ORM query deep_entity_zero.
+            # this is being removed in a later release so remove
+            # column_tables also at that time.
             return [self.table]
         else:
+            # override base get_children() to not return the Table
+            # or selectable that is parent to this column.  Traversals
+            # expect the columns of tables and subqueries to be leaf nodes.
             return []
 
     @HasMemoized.memoized_attribute
index 08dc487d498c3ab3206b7da19ec1a39f8e146f3f..8d28d6309b789d9a3c25950c7b97969f40c94ae9 100644 (file)
@@ -117,10 +117,6 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
                 else:
                     spwd(self)
 
-    def get_children(self, **kwargs):
-        """used to allow SchemaVisitor access"""
-        return []
-
     def __repr__(self):
         return util.generic_repr(self, omit_kwarg=["info"])
 
@@ -820,21 +816,6 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
         metadata._add_table(self.name, self.schema, self)
         self.metadata = metadata
 
-    def get_children(
-        self, column_collections=True, schema_visitor=False, **kw
-    ):
-        # TODO: consider that we probably don't need column_collections=True
-        # at all, it does not seem to impact anything
-        if not schema_visitor:
-            return TableClause.get_children(
-                self, column_collections=column_collections, **kw
-            )
-        else:
-            if column_collections:
-                return list(self.columns)
-            else:
-                return []
-
     @util.deprecated(
         "1.4",
         "The :meth:`_schema.Table.exists` method is deprecated and will be "
@@ -1656,16 +1637,6 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
             selectable.foreign_keys.update(fk)
         return c.key, c
 
-    def get_children(self, schema_visitor=False, **kwargs):
-        if schema_visitor:
-            return (
-                [x for x in (self.default, self.onupdate) if x is not None]
-                + list(self.foreign_keys)
-                + list(self.constraints)
-            )
-        else:
-            return ColumnClause.get_children(self, **kwargs)
-
 
 class ForeignKey(DialectKWArgs, SchemaItem):
     """Defines a dependency between two columns.
index 27b9425ec00c23d06582100037ed45b1b819281e..0f3d2419166e3516790ed55e48d75469ca7b3f22 100644 (file)
@@ -3454,8 +3454,8 @@ class Select(
 
     _traverse_internals = (
         [
-            ("_from_obj", InternalTraversal.dp_clauseelement_list),
             ("_raw_columns", InternalTraversal.dp_clauseelement_list),
+            ("_from_obj", InternalTraversal.dp_clauseelement_list),
             ("_where_criteria", InternalTraversal.dp_clauseelement_list),
             ("_having_criteria", InternalTraversal.dp_clauseelement_list),
             ("_order_by_clauses", InternalTraversal.dp_clauseelement_list,),
@@ -3944,10 +3944,11 @@ class Select(
         self._assert_no_memoizations()
 
     def get_children(self, **kwargs):
-        return list(set(self._iterate_from_elements())) + super(
-            Select, self
-        ).get_children(
-            omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+        return itertools.chain(
+            super(Select, self).get_children(
+                omit_attrs=["_from_obj", "_correlate", "_correlate_except"]
+            ),
+            self._iterate_from_elements(),
         )
 
     @_generative
index 4a135538e9298d5420377b5d8a7e424b9f7c5057..8c63fcba148841509b55780cd2b42e978a9984ab 100644 (file)
@@ -1,5 +1,6 @@
 from collections import deque
 from collections import namedtuple
+import itertools
 import operator
 
 from . import operators
@@ -589,28 +590,22 @@ class _GetChildren(InternalTraversal):
         return (element,)
 
     def visit_clauseelement_list(self, element, **kw):
-        return tuple(element)
+        return element
 
     def visit_clauseelement_tuples(self, element, **kw):
-        tup = ()
-        for elem in element:
-            tup += elem
-        return tup
+        return itertools.chain.from_iterable(element)
 
     def visit_fromclause_canonical_column_collection(self, element, **kw):
-        if kw.get("column_collections", False):
-            return tuple(element)
-        else:
-            return ()
+        return ()
 
     def visit_string_clauseelement_dict(self, element, **kw):
-        return tuple(element.values())
+        return element.values()
 
     def visit_fromclause_ordered_set(self, element, **kw):
-        return tuple(element)
+        return element
 
     def visit_clauseelement_unordered_set(self, element, **kw):
-        return tuple(element)
+        return element
 
     def visit_dml_ordered_values(self, element, **kw):
         for k, v in element:
index ca5bde091fc50b1197c31827e87556a9a1af994e..0a67ff9bf26c724d524482b16322fb2b21b4478e 100644 (file)
@@ -267,7 +267,7 @@ def find_tables(
 
     _visitors["table"] = tables.append
 
-    visitors.traverse(clause, {"column_collections": False}, _visitors)
+    visitors.traverse(clause, {}, _visitors)
     return tables
 
 
index 8f6bb2333d487a158eb58197576edf615e3b6927..574896cc7cf0e00ba6a4289166de3fe806e97f16 100644 (file)
@@ -32,10 +32,8 @@ from ..util import symbol
 
 __all__ = [
     "iterate",
-    "iterate_depthfirst",
     "traverse_using",
     "traverse",
-    "traverse_depthfirst",
     "cloned_traverse",
     "replacement_traverse",
     "Traversible",
@@ -568,23 +566,20 @@ CloningVisitor = CloningExternalTraversal
 ReplacingCloningVisitor = ReplacingExternalTraversal
 
 
-def iterate(obj, opts):
+def iterate(obj, opts=util.immutabledict()):
     r"""traverse the given expression structure, returning an iterator.
 
     traversal is configured to be breadth-first.
 
-    The central API feature used by the :func:`.visitors.iterate` and
-    :func:`.visitors.iterate_depthfirst` functions is the
+    The central API feature used by the :func:`.visitors.iterate`
+    function is the
     :meth:`_expression.ClauseElement.get_children` method of
-    :class:`_expression.ClauseElement`
-    objects.  This method should return all the
-    :class:`_expression.ClauseElement` objects
-    which are associated with a particular :class:`_expression.ClauseElement`
-    object.
-    For example, a :class:`.Case` structure will refer to a series of
-    :class:`_expression.ColumnElement`
-    objects within its "whens" and "else\_" member
-    variables.
+    :class:`_expression.ClauseElement` objects.  This method should return all
+    the :class:`_expression.ClauseElement` objects which are associated with a
+    particular :class:`_expression.ClauseElement` object. For example, a
+    :class:`.Case` structure will refer to a series of
+    :class:`_expression.ColumnElement` objects within its "whens" and "else\_"
+    member variables.
 
     :param obj: :class:`_expression.ClauseElement` structure to be traversed
 
@@ -592,49 +587,17 @@ def iterate(obj, opts):
      empty in modern usage.
 
     """
-    # fasttrack for atomic elements like columns
+    yield obj
     children = obj.get_children(**opts)
     if not children:
-        return [obj]
+        return
 
-    traversal = deque()
-    stack = deque([obj])
+    stack = deque([children])
     while stack:
-        t = stack.popleft()
-        traversal.append(t)
-        for c in t.get_children(**opts):
-            stack.append(c)
-    return iter(traversal)
-
-
-def iterate_depthfirst(obj, opts):
-    """traverse the given expression structure, returning an iterator.
-
-    traversal is configured to be depth-first.
-
-    :param obj: :class:`_expression.ClauseElement` structure to be traversed
-
-    :param opts: dictionary of iteration options.   This dictionary is usually
-     empty in modern usage.
-
-    .. seealso::
-
-        :func:`.visitors.iterate` - includes a general overview of iteration.
-
-    """
-    # fasttrack for atomic elements like columns
-    children = obj.get_children(**opts)
-    if not children:
-        return [obj]
-
-    stack = deque([obj])
-    traversal = deque()
-    while stack:
-        t = stack.pop()
-        traversal.appendleft(t)
-        for c in t.get_children(**opts):
-            stack.append(c)
-    return iter(traversal)
+        t_iterator = stack.popleft()
+        for t in t_iterator:
+            yield t
+            stack.append(t.get_children(**opts))
 
 
 def traverse_using(iterator, obj, visitors):
@@ -642,18 +605,16 @@ def traverse_using(iterator, obj, visitors):
     objects.
 
     :func:`.visitors.traverse_using` is usually called internally as the result
-    of the :func:`.visitors.traverse` or :func:`.visitors.traverse_depthfirst`
-    functions.
+    of the :func:`.visitors.traverse` function.
 
     :param iterator: an iterable or sequence which will yield
      :class:`_expression.ClauseElement`
      structures; the iterator is assumed to be the
-     product of the :func:`.visitors.iterate` or
-     :func:`.visitors.iterate_depthfirst` functions.
+     product of the :func:`.visitors.iterate` function.
 
     :param obj: the :class:`_expression.ClauseElement`
      that was used as the target of the
-     :func:`.iterate` or :func:`.iterate_depthfirst` function.
+     :func:`.iterate` function.
 
     :param visitors: dictionary of visit functions.  See :func:`.traverse`
      for details on this dictionary.
@@ -662,7 +623,6 @@ def traverse_using(iterator, obj, visitors):
 
         :func:`.traverse`
 
-        :func:`.traverse_depthfirst`
 
     """
     for target in iterator:
@@ -705,20 +665,6 @@ def traverse(obj, opts, visitors):
     return traverse_using(iterate(obj, opts), obj, visitors)
 
 
-def traverse_depthfirst(obj, opts, visitors):
-    """traverse and visit the given expression structure using the
-    depth-first iterator.
-
-    The iteration of objects uses the :func:`.visitors.iterate_depthfirst`
-    function, which does a depth-first traversal using a stack.
-
-    Usage is the same as that of :func:`.visitors.traverse` function.
-
-
-    """
-    return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
-
-
 def cloned_traverse(obj, opts, visitors):
     """clone the given expression structure, allowing modifications by
     visitors.
index 3a6feac018e48baf5e642c6ae108ddf4da706852..8cc7b7fb67d2e108fc866d4100c11c9a19eeee13 100644 (file)
@@ -674,22 +674,16 @@ class CacheKeyFixture(object):
             ):
                 assert_a_params = []
                 assert_b_params = []
-                visitors.traverse_depthfirst(
+                visitors.traverse(
                     case_a[a], {}, {"bindparam": assert_a_params.append}
                 )
-                visitors.traverse_depthfirst(
+                visitors.traverse(
                     case_b[b], {}, {"bindparam": assert_b_params.append}
                 )
 
                 # note we're asserting the order of the params as well as
                 # if there are dupes or not.  ordering has to be
                 # deterministic and matches what a traversal would provide.
-                # regular traverse_depthfirst does produce dupes in cases
-                # like
-                # select([some_alias]).
-                #    select_from(join(some_alias, other_table))
-                # where a bound parameter is inside of some_alias.  the
-                # cache key case is more minimalistic
                 eq_(
                     sorted(a_key.bindparams, key=lambda b: b.key),
                     sorted(
index a63e55c4e57469f84cbe3dfcdadd6a2f77e3a29a..4e713dd28671ffd8eb0973c17c6d8462af5bffe8 100644 (file)
@@ -29,7 +29,7 @@ class MiscTest(fixtures.TestBase):
 
         subset_select = select([common.c.id, common.c.data]).alias()
 
-        eq_(sql_util.find_tables(subset_select), [common])
+        eq_(set(sql_util.find_tables(subset_select)), {common})
 
     def test_find_tables_aliases(self):
         metadata = MetaData()