]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Query linter option
authorAlessio Bogon <youtux@gmail.com>
Sun, 15 Sep 2019 15:12:24 +0000 (11:12 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 22 Jan 2020 16:31:23 +0000 (11:31 -0500)
Added "from linting" as a built-in feature to the SQL compiler.  This
allows the compiler to maintain graph of all the FROM clauses in a
particular SELECT statement, linked by criteria in either the WHERE
or in JOIN clauses that link these FROM clauses together.  If any two
FROM clauses have no path between them, a warning is emitted that the
query may be producing a cartesian product.   As the Core expression
language as well as the ORM are built on an "implicit FROMs" model where
a particular FROM clause is automatically added if any part of the query
refers to it, it is easy for this to happen inadvertently and it is
hoped that the new feature helps with this issue.

The original recipe is from:
https://github.com/sqlalchemy/sqlalchemy/wiki/FromLinter

The linter is now enabled for all tests in the test suite as well.
This has necessitated that a lot of the queries be adjusted to
not include cartesian products.  Part of the rationale for the
linter to not be enabled for statement compilation only was to reduce
the need for adjustment for the many test case statements throughout
the test suite that are not real-world statements.

This gerrit is adapted from Ib5946e57c9dba6da428c4d1dee6760b3e978dda0.

Fixes: #4737
Change-Id: Ic91fd9774379f895d021c3ad564db6062299211c
Closes: #4830
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/4830
Pull-request-sha: f8a21aa6262d1bcc9ff0d11a2616e41fba97a47a

25 files changed:
doc/build/changelog/migration_14.rst
doc/build/changelog/unreleased_14/4737.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/oracle/base.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/create.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/__init__.py
lib/sqlalchemy/sql/compiler.py
test/base/test_tutorials.py
test/orm/inheritance/test_abc_inheritance.py
test/orm/inheritance/test_assorted_poly.py
test/orm/inheritance/test_poly_persistence.py
test/orm/inheritance/test_polymorphic_rel.py
test/orm/inheritance/test_relationship.py
test/orm/inheritance/test_single.py
test/orm/test_deprecations.py
test/orm/test_eager_relations.py
test/orm/test_froms.py
test/orm/test_pickled.py
test/orm/test_query.py
test/profiles.txt
test/sql/test_defaults.py
test/sql/test_from_linter.py [new file with mode: 0644]
test/sql/test_resultset.py

index 3a57191e356dc89e47f49bfbb42c2672a91ce255..63b59841f14e2175fa48255320858c87ac4d088d 100644 (file)
@@ -756,6 +756,136 @@ the cascade settings for a viewonly relationship.
 :ticket:`4993`
 :ticket:`4994`
 
+New Features - Core
+====================
+
+.. _change_4737:
+
+
+Built-in FROM linting will warn for any potential cartesian products in a SELECT statement
+------------------------------------------------------------------------------------------
+
+As the Core expression language as well as the ORM are built on an "implicit
+FROMs" model where a particular FROM clause is automatically added if any part
+of the query refers to it, a common issue is the case where a SELECT statement,
+either a top level statement or an embedded subquery, contains FROM elements
+that are not joined to the rest of the FROM elements in the query, causing
+what's referred to as a "cartesian product" in the result set, i.e. every
+possible combination of rows from each FROM element not otherwise joined.  In
+relational databases, this is nearly always an undesirable outcome as it
+produces an enormous result set full of duplicated, uncorrelated data.
+
+SQLAlchemy, for all of its great features, is particularly prone to this sort
+of issue happening as a SELECT statement will have elements added to its FROM
+clause automatically from any table seen in the other clauses. A typical
+scenario looks like the following, where two tables are JOINed together,
+however an additional entry in the WHERE clause that perhaps inadvertently does
+not line up with these two tables will create an additional FROM entry::
+
+    address_alias = aliased(Address)
+
+    q = session.query(User).\
+        join(address_alias, User.addresses).\
+        filter(Address.email_address == 'foo')
+
+The above query selects from a JOIN of ``User`` and ``address_alias``, the
+latter of which is an alias of the ``Address`` entity.  However, the
+``Address`` entity is used within the WHERE clause directly, so the above would
+result in the SQL::
+
+    SELECT
+        users.id AS users_id, users.name AS users_name,
+        users.fullname AS users_fullname,
+        users.nickname AS users_nickname
+    FROM addresses, users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id
+    WHERE addresses.email_address = :email_address_1
+
+In the above SQL, we can see what SQLAlchemy developers term "the dreaded
+comma", as we see "FROM addresses, users JOIN addresses" in the FROM clause
+which is the classic sign of a cartesian product; where a query is making use
+of JOIN in order to join FROM clauses together, however because one of them is
+not joined, it uses a comma.      The above query will return a full set of
+rows that join the "user" and "addresses" table together on the "id / user_id"
+column, and will then apply all those rows into a cartesian product against
+every row in the "addresses" table directly.   That is, if there are ten user
+rows and 100 rows in addresses, the above query will return its expected result
+rows, likely to be 100 as all address rows would be selected, multiplied by 100
+again, so that the total result size would be 10000 rows.
+
+The "table1, table2 JOIN table3" pattern is one that also occurs quite
+frequently within the SQLAlchemy ORM due to either subtle mis-application of
+ORM features particularly those related to joined eager loading or joined table
+inheritance, as well as a result of SQLAlchemy ORM bugs within those same
+systems.   Similar issues apply to SELECT statements that use "implicit joins",
+where the JOIN keyword is not used and instead each FROM element is linked with
+another one via the WHERE clause.
+
+For some years there has been a recipe on the Wiki that applies a graph
+algorithm to a :func:`.select` construct at query execution time and inspects
+the structure of the query for these un-linked FROM clauses, parsing through
+the WHERE clause and all JOIN clauses to determine how FROM elements are linked
+together and ensuring that all the FROM elements are connected in a single
+graph. This recipe has now been adapted to be part of the :class:`.SQLCompiler`
+itself where it now optionally emits a warning for a statement if this
+condition is detected.   The warning is enabled using the
+:paramref:`.create_engine.enable_from_linting` flag and is enabled by default.
+The computational overhead of the linter is very low, and additionally it only
+occurs during statement compilation which means for a cached SQL statement it
+only occurs once.
+
+Using this feature, our ORM query above will emit a warning::
+
+    >>> q.all()
+    SAWarning: SELECT statement has a cartesian product between FROM
+    element(s) "addresses_1", "users" and FROM element "addresses".
+    Apply join condition(s) between each element to resolve.
+
+The linter feature accommodates not just for tables linked together through the
+JOIN clauses but also through the WHERE clause  Above, we can add a WHERE
+clause to link the new ``Address`` entity with the previous ``address_alias``
+entity and that will remove the warning::
+
+    q = session.query(User).\
+        join(address_alias, User.addresses).\
+        filter(Address.email_address == 'foo').\
+        filter(Address.id == address_alias.id)  # resolve cartesian products,
+                                                # will no longer warn
+
+The cartesian product warning considers **any** kind of link between two
+FROM clauses to be a resolution, even if the end result set is still
+wasteful, as the linter is intended only to detect the common case of a
+FROM clause that is completely unexpected.  If the FROM clause is referred
+to explicitly elsewhere and linked to the other FROMs, no warning is emitted::
+
+    q = session.query(User).\
+        join(address_alias, User.addresses).\
+        filter(Address.email_address == 'foo').\
+        filter(Address.id > address_alias.id)  # will generate a lot of rows,
+                                               # but no warning
+
+Full cartesian products are also allowed if they are explicitly stated; if we
+wanted for example the cartesian product of ``User`` and ``Address``, we can
+JOIN on :func:`.true` so that every row will match with every other; the
+following query will return all rows and produce no warnings::
+
+    from sqlalchemy import true
+
+    # intentional cartesian product
+    q = session.query(User).join(Address, true())  # intentional cartesian product
+
+The warning is only generated by default when the statement is compiled by the
+:class:`.Connection` for execution; calling the :meth:`.ClauseElement.compile`
+method will not emit a warning unless the linting flag is supplied::
+
+    >>> from sqlalchemy.sql import FROM_LINTING
+    >>> print(q.statement.compile(linting=FROM_LINTING))
+    SAWarning: SELECT statement has a cartesian product between FROM element(s) "addresses" and FROM element "users".  Apply join condition(s) between each element to resolve.
+    SELECT users.id, users.name, users.fullname, users.nickname
+    FROM addresses, users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id
+    WHERE addresses.email_address = :email_address_1
+
+:ticket:`4737`
+
 
 
 Behavior Changes - Core
diff --git a/doc/build/changelog/unreleased_14/4737.rst b/doc/build/changelog/unreleased_14/4737.rst
new file mode 100644 (file)
index 0000000..072788e
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: feature,sql
+    :tickets: 4737
+
+    Added "from linting" as a built-in feature to the SQL compiler.  This
+    allows the compiler to maintain graph of all the FROM clauses in a
+    particular SELECT statement, linked by criteria in either the WHERE
+    or in JOIN clauses that link these FROM clauses together.  If any two
+    FROM clauses have no path between them, a warning is emitted that the
+    query may be producing a cartesian product.   As the Core expression
+    language as well as the ORM are built on an "implicit FROMs" model where
+    a particular FROM clause is automatically added if any part of the query
+    refers to it, it is easy for this to happen inadvertently and it is
+    hoped that the new feature helps with this issue.
+
+    .. seealso::
+
+        :ref:`change_4737`
index 8241d951b35106ebd19a853036d1078abd7b7b6f..6e84c9da1d1bbe31b9cc1ac33a8f7270a8e03253 100644 (file)
@@ -1434,7 +1434,10 @@ class MySQLCompiler(compiler.SQLCompiler):
         else:
             return ""
 
-    def visit_join(self, join, asfrom=False, **kwargs):
+    def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+        if from_linter:
+            from_linter.edges.add((join.left, join.right))
+
         if join.full:
             join_type = " FULL OUTER JOIN "
         elif join.isouter:
@@ -1444,11 +1447,15 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return "".join(
             (
-                self.process(join.left, asfrom=True, **kwargs),
+                self.process(
+                    join.left, asfrom=True, from_linter=from_linter, **kwargs
+                ),
                 join_type,
-                self.process(join.right, asfrom=True, **kwargs),
+                self.process(
+                    join.right, asfrom=True, from_linter=from_linter, **kwargs
+                ),
                 " ON ",
-                self.process(join.onclause, **kwargs),
+                self.process(join.onclause, from_linter=from_linter, **kwargs),
             )
         )
 
index 9cb25b934fd17f2a676494c8700bf9309dc0e83d..87e0baa58a408195699149a52d74a05a38be6e67 100644 (file)
@@ -829,19 +829,24 @@ class OracleCompiler(compiler.SQLCompiler):
 
         return " FROM DUAL"
 
-    def visit_join(self, join, **kwargs):
+    def visit_join(self, join, from_linter=None, **kwargs):
         if self.dialect.use_ansi:
-            return compiler.SQLCompiler.visit_join(self, join, **kwargs)
+            return compiler.SQLCompiler.visit_join(
+                self, join, from_linter=from_linter, **kwargs
+            )
         else:
+            if from_linter:
+                from_linter.edges.add((join.left, join.right))
+
             kwargs["asfrom"] = True
             if isinstance(join.right, expression.FromGrouping):
                 right = join.right.element
             else:
                 right = join.right
             return (
-                self.process(join.left, **kwargs)
+                self.process(join.left, from_linter=from_linter, **kwargs)
                 + ", "
-                + self.process(right, **kwargs)
+                + self.process(right, from_linter=from_linter, **kwargs)
             )
 
     def _get_nonansi_join_whereclause(self, froms):
index 88558df5d717db80b5be55043eac53dab18b4742..462e5f9ec47f2b61c1ce9f1743c666ee286e7f09 100644 (file)
@@ -16,6 +16,7 @@ from .. import exc
 from .. import inspection
 from .. import log
 from .. import util
+from ..sql import compiler
 from ..sql import schema
 from ..sql import util as sql_util
 
@@ -1083,6 +1084,8 @@ class Connection(Connectable):
                     schema_translate_map=self.schema_for_object
                     if not self.schema_for_object.is_default
                     else None,
+                    linting=self.dialect.compiler_linting
+                    | compiler.WARN_LINTING,
                 )
                 self._execution_options["compiled_cache"][key] = compiled_sql
         else:
@@ -1093,6 +1096,7 @@ class Connection(Connectable):
                 schema_translate_map=self.schema_for_object
                 if not self.schema_for_object.is_default
                 else None,
+                linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
             )
 
         ret = self._execute_context(
index 58fe91c7e5e45de72f9ee4bbbdf35bb6d949b96a..5198c8cd600c057b6b4b5f19dca76bb317669183 100644 (file)
@@ -13,6 +13,7 @@ from .. import event
 from .. import exc
 from .. import pool as poollib
 from .. import util
+from ..sql import compiler
 
 
 @util.deprecated_params(
@@ -142,6 +143,16 @@ def create_engine(url, **kwargs):
     :param empty_in_strategy:   No longer used; SQLAlchemy now uses
         "empty set" behavior for IN in all cases.
 
+    :param enable_from_linting: defaults to True.  Will emit a warning
+        if a given SELECT statement is found to have un-linked FROM elements
+        which would cause a cartesian product.
+
+        .. versionadded:: 1.4
+
+        .. seealso::
+
+            :ref:`change_4737`
+
     :param encoding: Defaults to ``utf-8``.  This is the string
         encoding used by SQLAlchemy for string encode/decode
         operations which occur within SQLAlchemy, **outside of
@@ -446,6 +457,11 @@ def create_engine(url, **kwargs):
 
     dialect_args["dbapi"] = dbapi
 
+    dialect_args.setdefault("compiler_linting", compiler.NO_LINTING)
+    enable_from_linting = kwargs.pop("enable_from_linting", True)
+    if enable_from_linting:
+        dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS
+
     for plugin in plugins:
         plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
 
index 1c995f05fec5d434c37b9aaedc2700300494512d..3788904449bc5bc9300f0ff0b4de880f968c2950 100644 (file)
@@ -31,7 +31,6 @@ from ..sql import expression
 from ..sql import schema
 from ..sql.elements import quoted_name
 
-
 AUTOCOMMIT_REGEXP = re.compile(
     r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE
 )
@@ -214,6 +213,9 @@ class DefaultDialect(interfaces.Dialect):
         supports_native_boolean=None,
         max_identifier_length=None,
         label_length=None,
+        # int() is because the @deprecated_params decorator cannot accommodate
+        # the direct reference to the "NO_LINTING" object
+        compiler_linting=int(compiler.NO_LINTING),
         **kwargs
     ):
 
@@ -249,7 +251,7 @@ class DefaultDialect(interfaces.Dialect):
                 self._user_defined_max_identifier_length
             )
         self.label_length = label_length
-
+        self.compiler_linting = compiler_linting
         if self.description_encoding == "use_encoding":
             self._description_decoder = (
                 processors.to_unicode_processor_factory
index 6554faaa0853085bc9ec59b71a72c60d5f29ec28..488717041dadb38b06317d41d6893b4785deac75 100644 (file)
@@ -5,6 +5,10 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
+from .compiler import COLLECT_CARTESIAN_PRODUCTS  # noqa
+from .compiler import FROM_LINTING  # noqa
+from .compiler import NO_LINTING  # noqa
+from .compiler import WARN_LINTING  # noqa
 from .expression import Alias  # noqa
 from .expression import alias  # noqa
 from .expression import all_  # noqa
index 8499484f3341b8023439e951a8f0fceefdc9992a..ed463ebe37a1fa41936d56513ab346b64b3e4a20 100644 (file)
@@ -41,7 +41,6 @@ from .base import NO_ARG
 from .. import exc
 from .. import util
 
-
 RESERVED_WORDS = set(
     [
         "all",
@@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple(
 )
 
 
+NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
+
+COLLECT_CARTESIAN_PRODUCTS = util.symbol(
+    "COLLECT_CARTESIAN_PRODUCTS",
+    "Collect data on FROMs and cartesian products and gather "
+    "into 'self.from_linter'",
+    canonical=1,
+)
+
+WARN_LINTING = util.symbol(
+    "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
+)
+
+FROM_LINTING = util.symbol(
+    "FROM_LINTING",
+    "Warn for cartesian products; "
+    "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
+    canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
+)
+
+
+class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
+    def lint(self, start=None):
+        froms = self.froms
+        if not froms:
+            return None, None
+
+        edges = set(self.edges)
+        the_rest = set(froms)
+
+        if start is not None:
+            start_with = start
+            the_rest.remove(start_with)
+        else:
+            start_with = the_rest.pop()
+
+        stack = collections.deque([start_with])
+
+        while stack and the_rest:
+            node = stack.popleft()
+            the_rest.discard(node)
+
+            # comparison of nodes in edges here is based on hash equality, as
+            # there are "annotated" elements that match the non-annotated ones.
+            #   to remove the need for in-python hash() calls, use native
+            # containment routines (e.g. "node in edge", "edge.index(node)")
+            to_remove = {edge for edge in edges if node in edge}
+
+            # appendleft the node in each edge that is not
+            # the one that matched.
+            stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
+            edges.difference_update(to_remove)
+
+        # FROMS left over?  boom
+        if the_rest:
+            return the_rest, start_with
+        else:
+            return None, None
+
+    def warn(self):
+        the_rest, start_with = self.lint()
+
+        # FROMS left over?  boom
+        if the_rest:
+
+            froms = the_rest
+            if froms:
+                template = (
+                    "SELECT statement has a cartesian product between "
+                    "FROM element(s) {froms} and "
+                    'FROM element "{start}".  Apply join condition(s) '
+                    "between each element to resolve."
+                )
+                froms_str = ", ".join(
+                    '"{elem}"'.format(elem=self.froms[from_])
+                    for from_ in froms
+                )
+                message = template.format(
+                    froms=froms_str, start=self.froms[start_with]
+                )
+                util.warn(message)
+
+
 class Compiled(object):
 
     """Represent a compiled SQL or DDL expression.
@@ -568,7 +650,13 @@ class SQLCompiler(Compiled):
     insert_prefetch = update_prefetch = ()
 
     def __init__(
-        self, dialect, statement, column_keys=None, inline=False, **kwargs
+        self,
+        dialect,
+        statement,
+        column_keys=None,
+        inline=False,
+        linting=NO_LINTING,
+        **kwargs
     ):
         """Construct a new :class:`.SQLCompiler` object.
 
@@ -592,6 +680,8 @@ class SQLCompiler(Compiled):
         # execute)
         self.inline = inline or getattr(statement, "inline", False)
 
+        self.linting = linting
+
         # a dictionary of bind parameter keys to BindParameter
         # instances.
         self.binds = {}
@@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled):
         return to_update, replacement_expression
 
     def visit_binary(
-        self, binary, override_operator=None, eager_grouping=False, **kw
+        self,
+        binary,
+        override_operator=None,
+        eager_grouping=False,
+        from_linter=None,
+        **kw
     ):
 
+        if from_linter and operators.is_comparison(binary.operator):
+            from_linter.edges.update(
+                itertools.product(
+                    binary.left._from_objects, binary.right._from_objects
+                )
+            )
+
         # don't allow "? = ?" to render
         if (
             self.ansi_bind_rules
@@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled):
             except KeyError:
                 raise exc.UnsupportedCompilationError(self, operator_)
             else:
-                return self._generate_generic_binary(binary, opstring, **kw)
+                return self._generate_generic_binary(
+                    binary, opstring, from_linter=from_linter, **kw
+                )
 
     def visit_function_as_comparison_op_binary(self, element, operator, **kw):
         return self.process(element.sql_function, **kw)
@@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled):
         ashint=False,
         fromhints=None,
         visiting_cte=None,
+        from_linter=None,
         **kwargs
     ):
         self._init_cte_state()
@@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled):
                 self.ctes[cte] = text
 
         if asfrom:
+            if from_linter:
+                from_linter.froms[cte] = cte_name
+
             if not is_new_cte and embedded_in_current_named_cte:
                 return self.preparer.format_alias(cte, cte_name)
 
@@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled):
         subquery=False,
         lateral=False,
         enclosing_alias=None,
+        from_linter=None,
         **kwargs
     ):
         if enclosing_alias is not None and enclosing_alias.element is alias:
@@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled):
         if ashint:
             return self.preparer.format_alias(alias, alias_name)
         elif asfrom:
+            if from_linter:
+                from_linter.froms[alias] = alias_name
+
             inner = alias.element._compiler_dispatch(
                 self, asfrom=True, lateral=lateral, **kwargs
             )
@@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled):
         compound_index=0,
         select_wraps_for=None,
         lateral=False,
+        from_linter=None,
         **kwargs
     ):
 
@@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled):
             ]
 
         text = self._compose_select_body(
-            text, select, inner_columns, froms, byfrom, kwargs
+            text, select, inner_columns, froms, byfrom, toplevel, kwargs
         )
 
         if select._statement_hints:
@@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled):
         return froms
 
     def _compose_select_body(
-        self, text, select, inner_columns, froms, byfrom, kwargs
+        self, text, select, inner_columns, froms, byfrom, toplevel, kwargs
     ):
         text += ", ".join(inner_columns)
 
+        if self.linting & COLLECT_CARTESIAN_PRODUCTS:
+            from_linter = FromLinter({}, set())
+            if toplevel:
+                self.from_linter = from_linter
+        else:
+            from_linter = None
+
         if froms:
             text += " \nFROM "
 
@@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled):
                 text += ", ".join(
                     [
                         f._compiler_dispatch(
-                            self, asfrom=True, fromhints=byfrom, **kwargs
+                            self,
+                            asfrom=True,
+                            fromhints=byfrom,
+                            from_linter=from_linter,
+                            **kwargs
                         )
                         for f in froms
                     ]
@@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled):
             else:
                 text += ", ".join(
                     [
-                        f._compiler_dispatch(self, asfrom=True, **kwargs)
+                        f._compiler_dispatch(
+                            self,
+                            asfrom=True,
+                            from_linter=from_linter,
+                            **kwargs
+                        )
                         for f in froms
                     ]
                 )
@@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled):
             text += self.default_from()
 
         if select._whereclause is not None:
-            t = select._whereclause._compiler_dispatch(self, **kwargs)
+            t = select._whereclause._compiler_dispatch(
+                self, from_linter=from_linter, **kwargs
+            )
             if t:
                 text += " \nWHERE " + t
 
+        if (
+            self.linting & COLLECT_CARTESIAN_PRODUCTS
+            and self.linting & WARN_LINTING
+        ):
+            from_linter.warn()
+
         if select._group_by_clause.clauses:
             text += self.group_by_clause(select, **kwargs)
 
@@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled):
         ashint=False,
         fromhints=None,
         use_schema=True,
+        from_linter=None,
         **kwargs
     ):
+        if from_linter:
+            from_linter.froms[table] = table.fullname
+
         if asfrom or ashint:
             effective_schema = self.preparer.schema_for_object(table)
 
@@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled):
         else:
             return ""
 
-    def visit_join(self, join, asfrom=False, **kwargs):
+    def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
+        if from_linter:
+            from_linter.edges.add((join.left, join.right))
+
         if join.full:
             join_type = " FULL OUTER JOIN "
         elif join.isouter:
@@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled):
         else:
             join_type = " JOIN "
         return (
-            join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+            join.left._compiler_dispatch(
+                self, asfrom=True, from_linter=from_linter, **kwargs
+            )
             + join_type
-            + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+            + join.right._compiler_dispatch(
+                self, asfrom=True, from_linter=from_linter, **kwargs
+            )
             + " ON "
             # TODO: likely need asfrom=True here?
-            + join.onclause._compiler_dispatch(self, **kwargs)
+            + join.onclause._compiler_dispatch(
+                self, from_linter=from_linter, **kwargs
+            )
         )
 
     def _setup_crud_hints(self, stmt, table_text):
index b8322db0a9553b17f6ec3261f4e24cca65a33f57..2c1058b9a37eff6c3a08e74155ab9fee0a6cd80e 100644 (file)
@@ -6,6 +6,7 @@ import os
 import re
 import sys
 
+from sqlalchemy import testing
 from sqlalchemy.testing import config
 from sqlalchemy.testing import fixtures
 
@@ -86,6 +87,7 @@ class DocTest(fixtures.TestBase):
     def test_orm(self):
         self._run_doctest("orm/tutorial.rst")
 
+    @testing.emits_warning("SELECT statement has a cartesian")
     def test_core(self):
         self._run_doctest("core/tutorial.rst")
 
index e2c3a3aa04c14e4645c17a89ecf2a45e33b6a79f..b5bff78d24db83a44cba6f2516b793aaf28e6b40 100644 (file)
@@ -210,6 +210,7 @@ def produce_test(parent, child, direction):
                 C,
                 tc,
                 polymorphic_identity="c",
+                with_polymorphic=("*", tc.join(tb, btoc).join(ta, atob)),
                 inherits=B,
                 inherit_condition=btoc,
             )
index ecab0a497d0fbf14d0274e91cb1327f9643b3590..616ed79a6cd6b55b0739930872d769241f489e24 100644 (file)
@@ -597,12 +597,14 @@ class RelationshipTest4(fixtures.MappedTest):
         mapper(
             Engineer,
             engineers,
+            with_polymorphic=([Engineer], people.join(engineers)),
             inherits=person_mapper,
             polymorphic_identity="engineer",
         )
         mapper(
             Manager,
             managers,
+            with_polymorphic=([Manager], people.join(managers)),
             inherits=person_mapper,
             polymorphic_identity="manager",
         )
@@ -1239,12 +1241,14 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults):
         mapper(
             Engineer,
             engineers,
+            with_polymorphic=([Engineer], people.join(engineers)),
             inherits=person_mapper,
             polymorphic_identity="engineer",
         )
         mapper(
             Manager,
             managers,
+            with_polymorphic=([Manager], people.join(managers)),
             inherits=person_mapper,
             polymorphic_identity="manager",
         )
index 508cb99657b5f43d6e76d8ce8c25fa5a29c109b2..deb33838f23679a7df89868b5fa17c00885fdb87 100644 (file)
@@ -502,12 +502,15 @@ class RoundTripTest(PolymorphTest):
         session = Session()
         dilbert = get_dilbert(session)
 
+        # this unusual test is selecting from the plain people/engineers
+        # table at the same time as the polymorphic entity
         is_(
             dilbert,
             session.query(Person)
             .filter(
                 (Engineer.engineer_name == "engineer1")
                 & (engineers.c.person_id == people.c.person_id)
+                & (people.c.person_id == Person.person_id)
             )
             .first(),
         )
index b188e598d7a485b5e0eab7f81f3848a3883d4792..b376be12af2439a49300ada5f6deb31ae588bb8a 100644 (file)
@@ -3,9 +3,11 @@ from sqlalchemy import exc as sa_exc
 from sqlalchemy import func
 from sqlalchemy import select
 from sqlalchemy import testing
+from sqlalchemy import true
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import create_session
 from sqlalchemy.orm import defaultload
+from sqlalchemy.orm import join
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_polymorphic
@@ -174,6 +176,7 @@ class _PolymorphicTestBase(object):
             sess.query(Company, Person, c, e)
             .join(Person, Company.employees)
             .join(e, c.employees)
+            .filter(Person.person_id != e.person_id)
             .filter(Person.name == "dilbert")
             .filter(e.name == "wally")
         )
@@ -897,15 +900,28 @@ class _PolymorphicTestBase(object):
         ]
 
         def go():
+            wp = with_polymorphic(Person, "*")
             eq_(
-                sess.query(Person)
-                .with_polymorphic("*")
-                .options(subqueryload(Engineer.machines))
-                .filter(Person.name == "dilbert")
+                sess.query(wp)
+                .options(subqueryload(wp.Engineer.machines))
+                .filter(wp.name == "dilbert")
                 .all(),
                 expected,
             )
 
+            # the old version of this test has never worked, apparently,
+            # was always spitting out a cartesian product.  Since we
+            # are getting rid of query.with_polymorphic() is it not
+            # worth fixing.
+            # eq_(
+            #    sess.query(Person)
+            #    .with_polymorphic("*")
+            #    .options(subqueryload(Engineer.machines))
+            #    .filter(Person.name == "dilbert")
+            #    .all(),
+            #    expected,
+            # )
+
         self.assert_sql_count(testing.db, go, 2)
 
     def test_query_subclass_join_to_base_relationship(self):
@@ -1393,6 +1409,7 @@ class _PolymorphicTestBase(object):
             .join(Company.employees)
             .filter(Company.name == "Elbonia, Inc.")
             .filter(palias.name == "dilbert")
+            .filter(palias.person_id != Person.person_id)
             .all(),
             expected,
         )
@@ -1420,8 +1437,10 @@ class _PolymorphicTestBase(object):
                 ),
             )
         ]
+
         eq_(
             sess.query(palias, Company.name, Person)
+            .select_from(join(palias, Company, true()))
             .join(Company.employees)
             .filter(Company.name == "Elbonia, Inc.")
             .filter(palias.name == "dilbert")
@@ -1438,6 +1457,7 @@ class _PolymorphicTestBase(object):
             .join(Company.employees)
             .filter(Company.name == "Elbonia, Inc.")
             .filter(palias.name == "dilbert")
+            .filter(palias.company_id != Person.company_id)
             .all(),
             expected,
         )
index b5a0fabec642f07c8ba785211a6c4e95e3e31743..94ab0a9943a4e15981c6690a62632ba0d9cbbfde 100644 (file)
@@ -11,6 +11,7 @@ from sqlalchemy.orm import create_session
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import subqueryload
 from sqlalchemy.orm import with_polymorphic
@@ -1302,10 +1303,12 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest):
         d = session.query(D).one()
 
         def go():
+            # NOTE: subqueryload is broken for this case, first found
+            # when cartesian product detection was added.
             for a in (
                 session.query(A)
                 .with_polymorphic([B, C])
-                .options(subqueryload(B.related), subqueryload(C.related))
+                .options(selectinload(B.related), selectinload(C.related))
             ):
                 eq_(a.related, [d])
 
index b070e2848a83ae95eb71348fc48d232d0f16ad59..ed9c781f4140bdb0f39ee4db35dd70e6759e94e2 100644 (file)
@@ -8,6 +8,7 @@ from sqlalchemy import null
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import true
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import Bundle
 from sqlalchemy.orm import class_mapper
@@ -143,29 +144,40 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest):
         session, m1, e1, e2 = self._fixture_one()
 
         ealias = aliased(Engineer)
-        eq_(session.query(Manager, ealias).all(), [(m1, e1), (m1, e2)])
+        eq_(
+            session.query(Manager, ealias).join(ealias, true()).all(),
+            [(m1, e1), (m1, e2)],
+        )
 
         eq_(session.query(Manager.name).all(), [("Tom",)])
 
         eq_(
-            session.query(Manager.name, ealias.name).all(),
+            session.query(Manager.name, ealias.name)
+            .join(ealias, true())
+            .all(),
             [("Tom", "Kurt"), ("Tom", "Ed")],
         )
 
         eq_(
-            session.query(
-                func.upper(Manager.name), func.upper(ealias.name)
-            ).all(),
+            session.query(func.upper(Manager.name), func.upper(ealias.name))
+            .join(ealias, true())
+            .all(),
             [("TOM", "KURT"), ("TOM", "ED")],
         )
 
         eq_(
-            session.query(Manager).add_entity(ealias).all(),
+            session.query(Manager)
+            .add_entity(ealias)
+            .join(ealias, true())
+            .all(),
             [(m1, e1), (m1, e2)],
         )
 
         eq_(
-            session.query(Manager.name).add_column(ealias.name).all(),
+            session.query(Manager.name)
+            .add_column(ealias.name)
+            .join(ealias, true())
+            .all(),
             [("Tom", "Kurt"), ("Tom", "Ed")],
         )
 
index 17954f3089c120a709894db41fc710302383fa10..d5a46e9ead6037425b4fdaf5cfe76261fbef6a60 100644 (file)
@@ -423,10 +423,8 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL):
         with self._expect_implicit_subquery():
             eq_(
                 s.query(User)
-                .select_from(
-                    text("select * from users").columns(
-                        id=Integer, name=String
-                    )
+                .select_entity_from(
+                    text("select * from users").columns(User.id, User.name)
                 )
                 .order_by(User.id)
                 .all(),
index 659f6e103e0a4b6747680e247cd80750fa39a65a..bf39b25a6e0effa8436ea9759166e92e17611aac 100644 (file)
@@ -5583,6 +5583,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(aa, A)
             .filter(aa.id == 1)
             .filter(A.id == 2)
+            .filter(aa.id != A.id)
             .options(joinedload("bs").joinedload("cs"))
         )
         self._run_tests(q, 1)
@@ -5595,6 +5596,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(A, aa)
             .filter(aa.id == 2)
             .filter(A.id == 1)
+            .filter(aa.id != A.id)
             .options(joinedload("bs").joinedload("cs"))
         )
         self._run_tests(q, 1)
@@ -5607,6 +5609,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(aa, A)
             .filter(aa.id == 1)
             .filter(A.id == 2)
+            .filter(aa.id != A.id)
             .options(joinedload(A.bs).joinedload(B.cs))
         )
         self._run_tests(q, 3)
@@ -5619,6 +5622,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(aa, A)
             .filter(aa.id == 1)
             .filter(A.id == 2)
+            .filter(aa.id != A.id)
             .options(defaultload(A.bs).joinedload(B.cs))
         )
         self._run_tests(q, 3)
@@ -5629,7 +5633,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
         aa = aliased(A)
         opt = Load(A).joinedload(A.bs).joinedload(B.cs)
 
-        q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt)
+        q = (
+            s.query(aa, A)
+            .filter(aa.id == 1)
+            .filter(A.id == 2)
+            .filter(aa.id != A.id)
+            .options(opt)
+        )
         self._run_tests(q, 3)
 
     def test_pathed_lazyload_plus_joined_aliased_abs_bcs(self):
@@ -5638,7 +5648,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
         aa = aliased(A)
         opt = Load(aa).defaultload(aa.bs).joinedload(B.cs)
 
-        q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt)
+        q = (
+            s.query(aa, A)
+            .filter(aa.id == 1)
+            .filter(A.id == 2)
+            .filter(aa.id != A.id)
+            .options(opt)
+        )
         self._run_tests(q, 2)
 
     def test_pathed_joinedload_aliased_abs_bcs(self):
@@ -5647,7 +5663,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
         aa = aliased(A)
         opt = Load(aa).joinedload(aa.bs).joinedload(B.cs)
 
-        q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt)
+        q = (
+            s.query(aa, A)
+            .filter(aa.id == 1)
+            .filter(A.id == 2)
+            .filter(aa.id != A.id)
+            .options(opt)
+        )
         self._run_tests(q, 1)
 
     def test_lazyload_plus_joined_aliased_abs_bcs(self):
@@ -5658,6 +5680,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(aa, A)
             .filter(aa.id == 1)
             .filter(A.id == 2)
+            .filter(aa.id != A.id)
             .options(defaultload(aa.bs).joinedload(B.cs))
         )
         self._run_tests(q, 2)
@@ -5670,6 +5693,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(aa, A)
             .filter(aa.id == 1)
             .filter(A.id == 2)
+            .filter(aa.id != A.id)
             .options(joinedload(aa.bs).joinedload(B.cs))
         )
         self._run_tests(q, 1)
@@ -5682,6 +5706,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(A, aa)
             .filter(aa.id == 2)
             .filter(A.id == 1)
+            .filter(aa.id != A.id)
             .options(joinedload(aa.bs).joinedload(B.cs))
         )
         self._run_tests(q, 3)
@@ -5694,6 +5719,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(A, aa)
             .filter(aa.id == 2)
             .filter(A.id == 1)
+            .filter(aa.id != A.id)
             .options(defaultload(aa.bs).joinedload(B.cs))
         )
         self._run_tests(q, 3)
@@ -5706,6 +5732,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(A, aa)
             .filter(aa.id == 2)
             .filter(A.id == 1)
+            .filter(aa.id != A.id)
             .options(defaultload(A.bs).joinedload(B.cs))
         )
         self._run_tests(q, 2)
@@ -5718,6 +5745,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest):
             s.query(A, aa)
             .filter(aa.id == 2)
             .filter(A.id == 1)
+            .filter(aa.id != A.id)
             .options(joinedload(A.bs).joinedload(B.cs))
         )
         self._run_tests(q, 1)
index 54714864dc054d586aa0265f930b754ab2d95de8..7195f53cb052f3ad46a68bfb47e2cbe7f3095270 100644 (file)
@@ -9,11 +9,13 @@ from sqlalchemy import ForeignKey
 from sqlalchemy import func
 from sqlalchemy import Integer
 from sqlalchemy import literal_column
+from sqlalchemy import or_
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import util
 from sqlalchemy.engine import default
 from sqlalchemy.orm import aliased
@@ -1481,6 +1483,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
         q2 = (
             q.select_entity_from(sel)
             .filter(u2.id > 1)
+            .filter(or_(u2.id == User.id, u2.id != User.id))
             .order_by(User.id, sel.c.id, u2.id)
             .values(User.name, sel.c.name, u2.name)
         )
@@ -1853,17 +1856,17 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL):
             .filter(Order.id > oalias.id)
             .order_by(Order.id, oalias.id),
             sess.query(Order, oalias)
+            .filter(Order.id > oalias.id)
             .from_self()
             .filter(Order.user_id == oalias.user_id)
             .filter(Order.user_id == 7)
-            .filter(Order.id > oalias.id)
             .order_by(Order.id, oalias.id),
             # same thing, but reversed.
             sess.query(oalias, Order)
+            .filter(Order.id < oalias.id)
             .from_self()
             .filter(oalias.user_id == Order.user_id)
             .filter(oalias.user_id == 7)
-            .filter(Order.id < oalias.id)
             .order_by(oalias.id, Order.id),
             # here we go....two layers of aliasing
             sess.query(Order, oalias)
@@ -3537,7 +3540,11 @@ class LabelCollideTest(fixtures.MappedTest):
 
     def test_overlap_plain(self):
         s = Session()
-        row = s.query(self.classes.Foo, self.classes.Bar).all()[0]
+        row = (
+            s.query(self.classes.Foo, self.classes.Bar)
+            .join(self.classes.Bar, true())
+            .all()[0]
+        )
 
         def go():
             eq_(row.Foo.id, 1)
@@ -3550,7 +3557,12 @@ class LabelCollideTest(fixtures.MappedTest):
 
     def test_overlap_subquery(self):
         s = Session()
-        row = s.query(self.classes.Foo, self.classes.Bar).from_self().all()[0]
+        row = (
+            s.query(self.classes.Foo, self.classes.Bar)
+            .join(self.classes.Bar, true())
+            .from_self()
+            .all()[0]
+        )
 
         def go():
             eq_(row.Foo.id, 1)
index 72a68c42bec0db967d6c5961d38255946bd576b1..4da583a0c18bc4ac3800db1e1a9fcddbc78a03c0 100644 (file)
@@ -777,7 +777,11 @@ class TupleLabelTest(_fixtures.FixtureTest):
                 eq_(row.foobar, row[1])
 
             oalias = aliased(Order)
-            for row in sess.query(User, oalias).join(User.orders).all():
+            for row in (
+                sess.query(User, oalias)
+                .join(User.orders.of_type(oalias))
+                .all()
+            ):
                 if pickled is not False:
                     row = pickle.loads(pickle.dumps(row, pickled))
                 eq_(list(row.keys()), ["User"])
index 271d85dd6968d54f3bbe445fba38ab4974c3e64e..55809ad385d3742e229c775eb4702dea3d3c098e 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy import String
 from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import Unicode
 from sqlalchemy import union
 from sqlalchemy import util
@@ -3783,7 +3784,7 @@ class CountTest(QueryTest):
         User, Address = self.classes.User, self.classes.Address
 
         s = create_session()
-        q = s.query(User, Address)
+        q = s.query(User, Address).join(Address, true())
         eq_(q.count(), 20)  # cartesian product
 
         q = s.query(User, Address).join(User.addresses)
@@ -3793,10 +3794,10 @@ class CountTest(QueryTest):
         User, Address = self.classes.User, self.classes.Address
 
         s = create_session()
-        q = s.query(User, Address).limit(2)
+        q = s.query(User, Address).join(Address, true()).limit(2)
         eq_(q.count(), 2)
 
-        q = s.query(User, Address).limit(100)
+        q = s.query(User, Address).join(Address, true()).limit(100)
         eq_(q.count(), 20)
 
         q = s.query(User, Address).join(User.addresses).limit(100)
@@ -3818,7 +3819,7 @@ class CountTest(QueryTest):
         q = s.query(User.name)
         eq_(q.count(), 4)
 
-        q = s.query(User.name, Address)
+        q = s.query(User.name, Address).join(Address, true())
         eq_(q.count(), 20)
 
         q = s.query(Address.user_id)
@@ -3888,7 +3889,9 @@ class DistinctTest(QueryTest, AssertsCompiledSQL):
 
         q = (
             sess.query(User.id, User.name.label("foo"), Address.id)
+            .join(Address, true())
             .filter(User.name == "jack")
+            .filter(User.id + Address.user_id > 0)
             .distinct()
             .order_by(User.id, User.name, Address.email_address)
         )
@@ -4541,9 +4544,9 @@ class TextTest(QueryTest, AssertsCompiledSQL):
 
         eq_(
             s.query(User)
-            .select_from(
+            .select_entity_from(
                 text("select * from users")
-                .columns(id=Integer, name=String)
+                .columns(User.id, User.name)
                 .subquery()
             )
             .order_by(User.id)
index 5e2ca814ac2b18ac4b529e29afaad80b7432b214..41f1fab7a45c30e3fcb6524005d7afe082f090ac 100644 (file)
@@ -1,15 +1,15 @@
 # /home/classic/dev/sqlalchemy/test/profiles.txt
 # This file is written out on a per-environment basis.
-# For each test in aaa_profiling, the corresponding function and 
+# For each test in aaa_profiling, the corresponding function and
 # environment is located within this file.  If it doesn't exist,
 # the test is skipped.
-# If a callcount does exist, it is compared to what we received. 
+# If a callcount does exist, it is compared to what we received.
 # assertions are raised if the counts do not match.
-# 
-# To add a new callcount test, apply the function_call_count 
-# decorator and re-run the tests using the --write-profiles 
+#
+# To add a new callcount test, apply the function_call_count
+# decorator and re-run the tests using the --write-profiles
 # option - this file will be rewritten including the new count.
-# 
+#
 
 # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert
 
@@ -1027,14 +1027,10 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_unicode 3.7_sqlite_pysqlite
 
 # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation
 
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6412,322,4242,12454,1244,2187,2770
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6429,322,4242,13149,1341,2187,2766
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6177,306,4162,12597,1233,2133,2650
-test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6260,306,4242,13203,1344,2151,2840
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6177,306,4162,12597,1233,2133,2852
+test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6260,306,4242,13203,1344,2151,3046
 
 # TEST: test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation
 
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 7085,420,7205,18682,1257,2846
-test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 7130,423,7229,20001,1346,2864
 test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 7090,411,7281,19190,1247,2897
 test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7186,416,7465,20675,1350,2957
index f033abab23f141febe9345c1c1a1720cce561a93..b31b070d8ea6fdfdc490670c9f0fee719fc9aea1 100644 (file)
@@ -852,7 +852,7 @@ class CTEDefaultTest(fixtures.TablesTest):
 
             if b == "select":
                 conn.execute(p.insert().values(s=1))
-                stmt = select([p.c.s, cte.c.z])
+                stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z)
             elif b == "insert":
                 sel = select([1, cte.c.z])
                 stmt = (
diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py
new file mode 100644 (file)
index 0000000..bf2f06b
--- /dev/null
@@ -0,0 +1,277 @@
+from sqlalchemy import Integer
+from sqlalchemy import select
+from sqlalchemy import sql
+from sqlalchemy import true
+from sqlalchemy.testing import config
+from sqlalchemy.testing import engines
+from sqlalchemy.testing import expect_warnings
+from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_
+from sqlalchemy.testing.schema import Column
+from sqlalchemy.testing.schema import Table
+
+
+def find_unmatching_froms(query, start=None):
+    compiled = query.compile(linting=sql.COLLECT_CARTESIAN_PRODUCTS)
+
+    return compiled.from_linter.lint(start)
+
+
+class TestFindUnmatchingFroms(fixtures.TablesTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+        Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+        Table("table_c", metadata, Column("col_c", Integer, primary_key=True))
+        Table("table_d", metadata, Column("col_d", Integer, primary_key=True))
+
+    def setup(self):
+        self.a = self.tables.table_a
+        self.b = self.tables.table_b
+        self.c = self.tables.table_c
+        self.d = self.tables.table_d
+
+    def test_everything_is_connected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c)
+            .select_from(self.d)
+            .where(self.d.c.col_d == self.b.c.col_b)
+            .where(self.c.c.col_c == self.d.c.col_d)
+            .where(self.c.c.col_c == 5)
+        )
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
+
+    def test_plain_cartesian(self):
+        query = select([self.a]).where(self.b.c.col_b == 5)
+        froms, start = find_unmatching_froms(query, self.a)
+        assert start == self.a
+        assert froms == {self.b}
+
+        froms, start = find_unmatching_froms(query, self.b)
+        assert start == self.b
+        assert froms == {self.a}
+
+    def test_count_non_eq_comparison_operators(self):
+        query = select([self.a]).where(self.a.c.col_a > self.b.c.col_b)
+        froms, start = find_unmatching_froms(query, self.a)
+        is_(start, None)
+        is_(froms, None)
+
+    def test_dont_count_non_comparison_operators(self):
+        query = select([self.a]).where(self.a.c.col_a + self.b.c.col_b == 5)
+        froms, start = find_unmatching_froms(query, self.a)
+        assert start == self.a
+        assert froms == {self.b}
+
+    def test_disconnect_between_ab_cd(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c)
+            .select_from(self.d)
+            .where(self.c.c.col_c == self.d.c.col_d)
+            .where(self.c.c.col_c == 5)
+        )
+        for start in self.a, self.b:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.c, self.d}
+        for start in self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.a, self.b}
+
+    def test_c_and_d_both_disconnected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .where(self.c.c.col_c == 5)
+            .where(self.d.c.col_d == 10)
+        )
+        for start in self.a, self.b:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.c, self.d}
+
+        froms, start = find_unmatching_froms(query, self.c)
+        assert start == self.c
+        assert froms == {self.a, self.b, self.d}
+
+        froms, start = find_unmatching_froms(query, self.d)
+        assert start == self.d
+        assert froms == {self.a, self.b, self.c}
+
+    def test_now_connected(self):
+        query = (
+            select([self.a])
+            .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b))
+            .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d))
+            .where(self.c.c.col_c == self.b.c.col_b)
+            .where(self.c.c.col_c == 5)
+            .where(self.d.c.col_d == 10)
+        )
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c, self.d:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
+
+    def test_disconnected_subquery(self):
+        subq = (
+            select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+        )
+        stmt = select([self.c]).select_from(subq)
+
+        froms, start = find_unmatching_froms(stmt, self.c)
+        assert start == self.c
+        assert froms == {subq}
+
+        froms, start = find_unmatching_froms(stmt, subq)
+        assert start == subq
+        assert froms == {self.c}
+
+    def test_now_connect_it(self):
+        subq = (
+            select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery()
+        )
+        stmt = (
+            select([self.c])
+            .select_from(subq)
+            .where(self.c.c.col_c == subq.c.col_a)
+        )
+
+        froms, start = find_unmatching_froms(stmt)
+        assert not froms
+
+        for start in self.c, subq:
+            froms, start = find_unmatching_froms(stmt, start)
+            assert not froms
+
+    def test_right_nested_join_without_issue(self):
+        query = select([self.a]).select_from(
+            self.a.join(
+                self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+                self.a.c.col_a == self.b.c.col_b,
+            )
+        )
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+        for start in self.a, self.b, self.c:
+            froms, start = find_unmatching_froms(query, start)
+            assert not froms
+
+    def test_join_on_true(self):
+        # test that a join(a, b) counts a->b as an edge even if there isn't
+        # actually a join condition.  this essentially allows a cartesian
+        # product to be added explicitly.
+
+        query = select([self.a]).select_from(self.a.join(self.b, true()))
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+    def test_right_nested_join_with_an_issue(self):
+        query = (
+            select([self.a])
+            .select_from(
+                self.a.join(
+                    self.b.join(self.c, self.b.c.col_b == self.c.c.col_c),
+                    self.a.c.col_a == self.b.c.col_b,
+                )
+            )
+            .where(self.d.c.col_d == 5)
+        )
+
+        for start in self.a, self.b, self.c:
+            froms, start = find_unmatching_froms(query, start)
+            assert start == start
+            assert froms == {self.d}
+
+        froms, start = find_unmatching_froms(query, self.d)
+        assert start == self.d
+        assert froms == {self.a, self.b, self.c}
+
+    def test_no_froms(self):
+        query = select([1])
+
+        froms, start = find_unmatching_froms(query)
+        assert not froms
+
+
+class TestLinter(fixtures.TablesTest):
+    @classmethod
+    def define_tables(cls, metadata):
+        Table("table_a", metadata, Column("col_a", Integer, primary_key=True))
+        Table("table_b", metadata, Column("col_b", Integer, primary_key=True))
+
+    @classmethod
+    def setup_bind(cls):
+        # from linting is enabled by default
+        return config.db
+
+    def test_noop_for_unhandled_objects(self):
+        with self.bind.connect() as conn:
+            conn.execute("SELECT 1;").fetchone()
+
+    def test_does_not_modify_query(self):
+        with self.bind.connect() as conn:
+            [result] = conn.execute(select([1])).fetchone()
+            assert result == 1
+
+    def test_warn_simple(self):
+        a, b = self.tables("table_a", "table_b")
+        query = select([a.c.col_a]).where(b.c.col_b == 5)
+
+        with expect_warnings(
+            r"SELECT statement has a cartesian product between FROM "
+            r'element\(s\) "table_[ab]" '
+            r'and FROM element "table_[ba]"'
+        ):
+            with self.bind.connect() as conn:
+                conn.execute(query)
+
+    def test_warn_anon_alias(self):
+        a, b = self.tables("table_a", "table_b")
+
+        b_alias = b.alias()
+        query = select([a.c.col_a]).where(b_alias.c.col_b == 5)
+
+        with expect_warnings(
+            r"SELECT statement has a cartesian product between FROM "
+            r'element\(s\) "table_(?:a|b_1)" '
+            r'and FROM element "table_(?:a|b_1)"'
+        ):
+            with self.bind.connect() as conn:
+                conn.execute(query)
+
+    def test_warn_anon_cte(self):
+        a, b = self.tables("table_a", "table_b")
+
+        b_cte = select([b]).cte()
+        query = select([a.c.col_a]).where(b_cte.c.col_b == 5)
+
+        with expect_warnings(
+            r"SELECT statement has a cartesian product between "
+            r"FROM element\(s\) "
+            r'"(?:anon_1|table_a)" '
+            r'and FROM element "(?:anon_1|table_a)"'
+        ):
+            with self.bind.connect() as conn:
+                conn.execute(query)
+
+    def test_no_linting(self):
+        eng = engines.testing_engine(options={"enable_from_linting": False})
+        eng.pool = self.bind.pool  # needed for SQLite
+        a, b = self.tables("table_a", "table_b")
+        query = select([a.c.col_a]).where(b.c.col_b == 5)
+
+        with eng.connect() as conn:
+            conn.execute(query)
index 794508a3297f6c4482565faeb2c92772cf2b7637..8aa524d78289cb723fd47795c642a1829c8bd93d 100644 (file)
@@ -19,6 +19,7 @@ from sqlalchemy import String
 from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import true
 from sqlalchemy import type_coerce
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
@@ -771,7 +772,11 @@ class ResultProxyTest(fixtures.TablesTest):
         users.insert().execute(user_id=1, user_name="john")
         ua = users.alias()
         u2 = users.alias()
-        result = select([users.c.user_id, ua.c.user_id]).execute()
+        result = (
+            select([users.c.user_id, ua.c.user_id])
+            .select_from(users.join(ua, true()))
+            .execute()
+        )
         row = result.first()
 
         # as of 1.1 issue #3501, we use pure positional
@@ -1414,7 +1419,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed1 = self.tables.keyed1
         keyed2 = self.tables.keyed2
 
-        row = testing.db.execute(select([keyed1, keyed2])).first()
+        row = testing.db.execute(
+            select([keyed1, keyed2]).select_from(keyed1.join(keyed2, true()))
+        ).first()
 
         # column access is unambiguous
         eq_(row[self.tables.keyed2.c.b], "b2")
@@ -1446,7 +1453,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed2 = self.tables.keyed2
 
         row = testing.db.execute(
-            select([keyed1, keyed2]).apply_labels()
+            select([keyed1, keyed2])
+            .select_from(keyed1.join(keyed2, true()))
+            .apply_labels()
         ).first()
 
         # column access is unambiguous
@@ -1459,7 +1468,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed1 = self.tables.keyed1
         keyed4 = self.tables.keyed4
 
-        row = testing.db.execute(select([keyed1, keyed4])).first()
+        row = testing.db.execute(
+            select([keyed1, keyed4]).select_from(keyed1.join(keyed4, true()))
+        ).first()
         eq_(row.b, "b4")
         eq_(row.q, "q4")
         eq_(row.a, "a1")
@@ -1470,7 +1481,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed1 = self.tables.keyed1
         keyed3 = self.tables.keyed3
 
-        row = testing.db.execute(select([keyed1, keyed3])).first()
+        row = testing.db.execute(
+            select([keyed1, keyed3]).select_from(keyed1.join(keyed3, true()))
+        ).first()
         eq_(row.q, "c1")
 
         # prior to 1.4 #4887, this raised an "ambiguous column name 'a'""
@@ -1493,7 +1506,9 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed2 = self.tables.keyed2
 
         row = testing.db.execute(
-            select([keyed1, keyed2]).apply_labels()
+            select([keyed1, keyed2])
+            .select_from(keyed1.join(keyed2, true()))
+            .apply_labels()
         ).first()
         eq_(row.keyed1_b, "a1")
         eq_(row.keyed1_a, "a1")
@@ -1515,18 +1530,22 @@ class KeyTargetingTest(fixtures.TablesTest):
         keyed2 = self.tables.keyed2
         keyed3 = self.tables.keyed3
 
-        stmt = select(
-            [
-                keyed2.c.a,
-                keyed3.c.a,
-                keyed2.c.a,
-                keyed2.c.a,
-                keyed3.c.a,
-                keyed3.c.a,
-                keyed3.c.d,
-                keyed3.c.d,
-            ]
-        ).apply_labels()
+        stmt = (
+            select(
+                [
+                    keyed2.c.a,
+                    keyed3.c.a,
+                    keyed2.c.a,
+                    keyed2.c.a,
+                    keyed3.c.a,
+                    keyed3.c.a,
+                    keyed3.c.d,
+                    keyed3.c.d,
+                ]
+            )
+            .select_from(keyed2.join(keyed3, true()))
+            .apply_labels()
+        )
 
         result = testing.db.execute(stmt)
         is_false(result._metadata.matched_on_name)