]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Ensure iterable passed to Select is not a mapped class
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2021 20:38:03 +0000 (16:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2021 20:41:28 +0000 (16:41 -0400)
Fixed regression caused by :ticket:`5395` where tuning back the check for
sequences in :func:`_sql.select` now caused failures when doing 2.0-style
querying with a mapped class that also happens to have an ``__iter__()``
method. Tuned the check some more to accommodate this as well as some other
interesting ``__iter__()`` scenarios.

Fixes: #6300
Change-Id: Idf1983fd764b91a7d5fa8117aee8a3def3cfe5ff

doc/build/changelog/unreleased_14/6300.rst [new file with mode: 0644]
lib/sqlalchemy/sql/selectable.py
test/sql/test_select.py

diff --git a/doc/build/changelog/unreleased_14/6300.rst b/doc/build/changelog/unreleased_14/6300.rst
new file mode 100644 (file)
index 0000000..30711a6
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, regression, sql
+    :tickets: 6300
+
+    Fixed regression caused by :ticket:`5395` where tuning back the check for
+    sequences in :func:`_sql.select` now caused failures when doing 2.0-style
+    querying with a mapped class that also happens to have an ``__iter__()``
+    method. Tuned the check some more to accommodate this as well as some other
+    interesting ``__iter__()`` scenarios.
+
index ff830dbf6c65723153e1a2cb9a3a0b0064fc6b6c..43ba0da4cbd91483dfdf53fb12ed08d9555c2ed6 100644 (file)
@@ -56,6 +56,7 @@ from .elements import UnaryExpression
 from .visitors import InternalTraversal
 from .. import exc
 from .. import util
+from ..inspection import inspect
 
 if util.TYPE_CHECKING:
     from typing import Any
@@ -4959,8 +4960,17 @@ class Select(
         """
         if (
             args
-            and hasattr(args[0], "__iter__")
-            and not isinstance(args[0], util.string_types + (ClauseElement,))
+            and (
+                isinstance(args[0], list)
+                or (
+                    hasattr(args[0], "__iter__")
+                    and not isinstance(
+                        args[0], util.string_types + (ClauseElement,)
+                    )
+                    and inspect(args[0], raiseerr=False) is None
+                    and not hasattr(args[0], "__clause_element__")
+                )
+            )
         ) or kw:
             return cls.create_legacy_select(*args, **kw)
         else:
index 96c6abd0733a20f4de32fc321f57d868aac29254..1dfb4cd19e96147cb561bd934a42163a2bdd8123 100644 (file)
@@ -82,6 +82,54 @@ class FutureSelectTest(fixtures.TestBase, AssertsCompiledSQL):
             "WHERE mytable.myid = myothertable.otherid",
         )
 
+    def test_new_calling_style_clauseelement_thing_that_has_iter(self):
+        class Thing(object):
+            def __clause_element__(self):
+                return table1
+
+            def __iter__(self):
+                return iter(["a", "b", "c"])
+
+        stmt = select(Thing())
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid, mytable.name, "
+            "mytable.description FROM mytable",
+        )
+
+    def test_new_calling_style_inspectable_ce_thing_that_has_iter(self):
+        class Thing(object):
+            def __iter__(self):
+                return iter(["a", "b", "c"])
+
+        class InspectedThing(object):
+            def __clause_element__(self):
+                return table1
+
+        from sqlalchemy.inspection import _inspects
+
+        @_inspects(Thing)
+        def _ce(thing):
+            return InspectedThing()
+
+        stmt = select(Thing())
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.myid, mytable.name, "
+            "mytable.description FROM mytable",
+        )
+
+    def test_new_calling_style_thing_ok_actually_use_iter(self):
+        class Thing(object):
+            def __iter__(self):
+                return iter([table1.c.name, table1.c.description])
+
+        stmt = select(Thing())
+        self.assert_compile(
+            stmt,
+            "SELECT mytable.name, mytable.description FROM mytable",
+        )
+
     def test_kw_triggers_old_style(self):
 
         assert_raises_message(