]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add result map targeting for custom compiled, text objects
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Oct 2019 21:36:27 +0000 (17:36 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 8 Oct 2019 03:06:06 +0000 (23:06 -0400)
In order for text(), custom compiled objects, etc. to be usable
by Query(), they are all targeted by object key in the result map.
As we no longer want Query to implicitly label these, as well as that
text() has no label feature, support adding entries to the result
map that have no name, key, or type, only the object itself, and
then ensure that the compiler sets up for positional targeting
when this condition is detected.

Allows for more flexible ORM query usage with custom expressions
and text() while having less special logic in query itself.

Fixes: #4887
Change-Id: Ie073da127d292d43cb132a2b31bc90af88bfe2fd

doc/build/changelog/unreleased_14/4887.rst [new file with mode: 0644]
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/compiler.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/selectable.py
test/ext/test_compiler.py
test/orm/test_query.py
test/sql/test_functions.py
test/sql/test_resultset.py
test/sql/test_text.py

diff --git a/doc/build/changelog/unreleased_14/4887.rst b/doc/build/changelog/unreleased_14/4887.rst
new file mode 100644 (file)
index 0000000..ffff57f
--- /dev/null
@@ -0,0 +1,26 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 4887
+
+    Custom functions that are created as subclasses of
+    :class:`.FunctionElement` will now generate an "anonymous label" based on
+    the "name" of the function just like any other :class:`.Function` object,
+    e.g. ``"SELECT myfunc() AS myfunc_1"``. While SELECT statements no longer
+    require labels in order for the result proxy object to function, the ORM
+    still targets columns in rows by using objects as mapping keys, which works
+    more reliably when the column expressions have distinct names.  In any
+    case, the behavior is  now made consistent between functions generated by
+    :attr:`.func` and those generated as custom :class:`.FunctionElement`
+    objects.
+
+
+.. change::
+    :tags: usecase, ext
+    :tickets: 4887
+
+    Custom compiler constructs created using the :mod:`sqlalchemy.ext.compiled`
+    extension will automatically add contextual information to the compiler
+    when a custom construct is interpreted as an element in the columns
+    clause of a SELECT statement, such that the custom element will be
+    targetable as a key in result row mappings, which is the kind of targeting
+    that the ORM uses in order to match column elements into result tuples.
\ No newline at end of file
index af5303658b7b9c67f1db5306ec08b9f0dca582c2..733bd6f6ab701d1aeda17b27346ff5c93f180dea 100644 (file)
@@ -321,40 +321,41 @@ class ResultMetaData(object):
             # dupe records with "None" for index which results in
             # ambiguous column exception when accessed.
             if len(by_key) != num_ctx_cols:
-                seen = set()
+                # new in 1.4: get the complete set of all possible keys,
+                # strings, objects, whatever, that are dupes across two
+                # different records, first.
+                index_by_key = {}
+                dupes = set()
                 for metadata_entry in raw:
-                    key = metadata_entry[MD_RENDERED_NAME]
-                    if key in seen:
-                        # this is an "ambiguous" element, replacing
-                        # the full record in the map
-                        key = key.lower() if not self.case_sensitive else key
-                        by_key[key] = (None, (), key)
-                    seen.add(key)
-
-                # copy secondary elements from compiled columns
-                # into self._keymap, write in the potentially "ambiguous"
-                # element
+                    for key in (metadata_entry[MD_RENDERED_NAME],) + (
+                        metadata_entry[MD_OBJECTS] or ()
+                    ):
+                        if not self.case_sensitive and isinstance(
+                            key, util.string_types
+                        ):
+                            key = key.lower()
+                        idx = metadata_entry[MD_INDEX]
+                        # if this key has been associated with more than one
+                        # positional index, it's a dupe
+                        if index_by_key.setdefault(key, idx) != idx:
+                            dupes.add(key)
+
+                # then put everything we have into the keymap excluding only
+                # those keys that are dupes.
                 self._keymap.update(
                     [
-                        (obj_elem, by_key[metadata_entry[MD_LOOKUP_KEY]])
+                        (obj_elem, metadata_entry)
                         for metadata_entry in raw
                         if metadata_entry[MD_OBJECTS]
                         for obj_elem in metadata_entry[MD_OBJECTS]
+                        if obj_elem not in dupes
                     ]
                 )
 
-                # if we did a pure positional match, then reset the
-                # original "expression element" back to the "unambiguous"
-                # entry.  This is a new behavior in 1.1 which impacts
-                # TextualSelect but also straight compiled SQL constructs.
-                if not self.matched_on_name:
-                    self._keymap.update(
-                        [
-                            (metadata_entry[MD_OBJECTS][0], metadata_entry)
-                            for metadata_entry in raw
-                            if metadata_entry[MD_OBJECTS]
-                        ]
-                    )
+                # then for the dupe keys, put the "ambiguous column"
+                # record into by_key.
+                by_key.update({key: (None, (), key) for key in dupes})
+
             else:
                 # no dupes - copy secondary elements from compiled
                 # columns into self._keymap
@@ -502,16 +503,16 @@ class ResultMetaData(object):
                 (
                     idx,
                     obj,
-                    colname,
-                    colname,
+                    cursor_colname,
+                    cursor_colname,
                     context.get_result_processor(
-                        mapped_type, colname, coltype
+                        mapped_type, cursor_colname, coltype
                     ),
                     untranslated,
                 )
                 for (
                     idx,
-                    colname,
+                    cursor_colname,
                     mapped_type,
                     coltype,
                     obj,
@@ -592,7 +593,6 @@ class ResultMetaData(object):
             else:
                 mapped_type = sqltypes.NULLTYPE
                 obj = None
-
             yield idx, colname, mapped_type, coltype, obj, untranslated
 
     def _merge_cols_by_name(
@@ -758,7 +758,7 @@ class ResultMetaData(object):
         if index is None:
             raise exc.InvalidRequestError(
                 "Ambiguous column name '%s' in "
-                "result set column descriptions" % obj
+                "result set column descriptions" % rec[MD_LOOKUP_KEY]
             )
 
         return operator.methodcaller("_get_by_key_impl", index)
index 572c62b8e6dc5810a2cd3f85a1712b74019f1ca4..4a5a8ba9cdeba9f111cbf9f1cdd51129dd230290 100644 (file)
@@ -398,6 +398,7 @@ Example usage::
 
 """
 from .. import exc
+from ..sql import sqltypes
 from ..sql import visitors
 
 
@@ -475,4 +476,22 @@ class _dispatcher(object):
                     "compilation handler." % type(element)
                 )
 
-        return fn(element, compiler, **kw)
+        # if compilation includes add_to_result_map, collect add_to_result_map
+        # arguments from the user-defined callable, which are probably none
+        # because this is not public API.  if it wasn't called, then call it
+        # ourselves.
+        arm = kw.get("add_to_result_map", None)
+        if arm:
+            arm_collection = []
+            kw["add_to_result_map"] = lambda *args: arm_collection.append(args)
+
+        expr = fn(element, compiler, **kw)
+
+        if arm:
+            if not arm_collection:
+                arm_collection.append(
+                    (None, None, (element,), sqltypes.NULLTYPE)
+                )
+            for tup in arm_collection:
+                arm(*tup)
+        return expr
index 320c7b7821a9d69c80b52051a66be1850591a1f7..453ff56d2eaeed4e9aede886df4e0f3a0765fc5d 100644 (file)
@@ -871,12 +871,11 @@ class SQLCompiler(Compiled):
             name = self._truncated_identifier("colident", name)
 
         if add_to_result_map is not None:
-            add_to_result_map(
-                name,
-                orig_name,
-                (column, name, column.key, column._label) + result_map_targets,
-                column.type,
-            )
+            targets = (column, name, column.key) + result_map_targets
+            if column._label:
+                targets += (column._label,)
+
+            add_to_result_map(name, orig_name, targets, column.type)
 
         if is_literal:
             # note we are not currently accommodating for
@@ -925,7 +924,7 @@ class SQLCompiler(Compiled):
             text = text.replace("%", "%%")
         return text
 
-    def visit_textclause(self, textclause, **kw):
+    def visit_textclause(self, textclause, add_to_result_map=None, **kw):
         def do_bindparam(m):
             name = m.group(1)
             if name in textclause._bindparams:
@@ -936,6 +935,12 @@ class SQLCompiler(Compiled):
         if not self.stack:
             self.isplaintext = True
 
+        if add_to_result_map:
+            # text() object is present in the columns clause of a
+            # select().   Add a no-name entry to the result map so that
+            # row[text()] produces a result
+            add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
+
         # un-escape any \:params
         return BIND_PARAMS_ESC.sub(
             lambda m: m.group(1),
@@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled):
         return " AS " + alias_name_text
 
     def _add_to_result_map(self, keyname, name, objects, type_):
+        if keyname is None:
+            self._ordered_columns = False
+            self._textual_ordered_columns = True
         self._result_columns.append((keyname, name, objects, type_))
 
     def _label_select_column(
@@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled):
         column_clause_args,
         name=None,
         within_columns_clause=True,
+        column_is_repeated=False,
         need_column_expressions=False,
     ):
         """produce labeled columns present in a select()."""
@@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled):
             need_column_expressions or populate_result_map
         ):
             col_expr = impl.column_expression(column)
+        else:
+            col_expr = column
 
-            if populate_result_map:
+        if populate_result_map:
+            # pass an "add_to_result_map" callable into the compilation
+            # of embedded columns.  this collects information about the
+            # column as it will be fetched in the result and is coordinated
+            # with cursor.description when the query is executed.
+            add_to_result_map = self._add_to_result_map
+
+            # if the SELECT statement told us this column is a repeat,
+            # wrap the callable with one that prevents the addition of the
+            # targets
+            if column_is_repeated:
+                _add_to_result_map = add_to_result_map
 
                 def add_to_result_map(keyname, name, objects, type_):
-                    self._add_to_result_map(
+                    _add_to_result_map(keyname, name, (), type_)
+
+            # if we redefined col_expr for type expressions, wrap the
+            # callable with one that adds the original column to the targets
+            elif col_expr is not column:
+                _add_to_result_map = add_to_result_map
+
+                def add_to_result_map(keyname, name, objects, type_):
+                    _add_to_result_map(
                         keyname, name, (column,) + objects, type_
                     )
 
-            else:
-                add_to_result_map = None
         else:
-            col_expr = column
-            if populate_result_map:
-                add_to_result_map = self._add_to_result_map
-            else:
-                add_to_result_map = None
+            add_to_result_map = None
 
         if not within_columns_clause:
             result_expr = col_expr
@@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled):
             )
             and (
                 not hasattr(column, "name")
-                or isinstance(column, functions.Function)
+                or isinstance(column, functions.FunctionElement)
             )
         ):
             result_expr = _CompileLabel(col_expr, column.anon_label)
@@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled):
                     asfrom,
                     column_clause_args,
                     name=name,
+                    column_is_repeated=repeated,
                     need_column_expressions=need_column_expressions,
                 )
-                for name, column in select._columns_plus_names
+                for name, column, repeated in select._columns_plus_names
             ]
             if c is not None
         ]
@@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled):
 
             translate = dict(
                 zip(
-                    [name for (key, name) in select._columns_plus_names],
                     [
                         name
-                        for (key, name) in select_wraps_for._columns_plus_names
+                        for (key, name, repeated) in select._columns_plus_names
+                    ],
+                    [
+                        name
+                        for (
+                            key,
+                            name,
+                            repeated,
+                        ) in select_wraps_for._columns_plus_names
                     ],
                 )
             )
index ddbcdf91d6363f1026ff716b567762c063312b82..6282cf2ee08e2254bb7798514ea37dca746ef8e2 100644 (file)
@@ -4191,8 +4191,9 @@ class Select(
 
             def name_for_col(c):
                 if c._label is None or not c._render_label_in_columns_clause:
-                    return (None, c)
+                    return (None, c, False)
 
+                repeated = False
                 name = c._label
 
                 if name in names:
@@ -4218,19 +4219,22 @@ class Select(
                             # subsequent occurrences of the column so that the
                             # original stays non-ambiguous
                             name = c._dedupe_label_anon_label
+                            repeated = True
                         else:
                             names[name] = c
                     elif anon_for_dupe_key:
                         # same column under the same name. apply the "dedupe"
                         # label so that the original stays non-ambiguous
                         name = c._dedupe_label_anon_label
+                        repeated = True
                 else:
                     names[name] = c
-                return name, c
+                return name, c, repeated
 
             return [name_for_col(c) for c in cols]
         else:
-            return [(None, c) for c in cols]
+            # repeated name logic only for use labels at the moment
+            return [(None, c, False) for c in cols]
 
     @_memoized_property
     def _columns_plus_names(self):
@@ -4245,7 +4249,7 @@ class Select(
         keys_seen = set()
         prox = []
 
-        for name, c in self._generate_columns_plus_names(False):
+        for name, c, repeated in self._generate_columns_plus_names(False):
             if not hasattr(c, "_make_proxy"):
                 continue
             if name is None:
index ccd79f8d11088a14eb3fac9994c8fd48424904d4..f1817b94e7cb8e191ebcb1750a908bcb75426768 100644 (file)
@@ -3,6 +3,7 @@ from sqlalchemy import column
 from sqlalchemy import desc
 from sqlalchemy import exc
 from sqlalchemy import Integer
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import Numeric
 from sqlalchemy import select
@@ -13,11 +14,13 @@ from sqlalchemy.ext.compiler import deregister
 from sqlalchemy.schema import CreateColumn
 from sqlalchemy.schema import CreateTable
 from sqlalchemy.schema import DDLElement
+from sqlalchemy.sql.elements import ColumnElement
 from sqlalchemy.sql.expression import BindParameter
 from sqlalchemy.sql.expression import ClauseElement
 from sqlalchemy.sql.expression import ColumnClause
 from sqlalchemy.sql.expression import FunctionElement
 from sqlalchemy.sql.expression import Select
+from sqlalchemy.sql.sqltypes import NULLTYPE
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -319,7 +322,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
             dialect=mssql.dialect(),
         )
 
-    def test_subclasses_one(self):
+    def test_function_subclasses_one(self):
         class Base(FunctionElement):
             name = "base"
 
@@ -339,11 +342,11 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             select([Sub1(), Sub2()]),
-            "SELECT FOOsub1, sub2",
+            "SELECT FOOsub1 AS sub1_1, sub2 AS sub2_1",
             use_default_dialect=True,
         )
 
-    def test_subclasses_two(self):
+    def test_function_subclasses_two(self):
         class Base(FunctionElement):
             name = "base"
 
@@ -362,7 +365,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             select([Sub1(), Sub2(), SubSub1()]),
-            "SELECT sub1, sub2, subsub1",
+            "SELECT sub1 AS sub1_1, sub2 AS sub2_1, subsub1 AS subsub1_1",
             use_default_dialect=True,
         )
 
@@ -372,10 +375,51 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL):
 
         self.assert_compile(
             select([Sub1(), Sub2(), SubSub1()]),
-            "SELECT FOOsub1, sub2, FOOsubsub1",
+            "SELECT FOOsub1 AS sub1_1, sub2 AS sub2_1, "
+            "FOOsubsub1 AS subsub1_1",
             use_default_dialect=True,
         )
 
+    def _test_result_map_population(self, expression):
+        lc1 = literal_column("1")
+        lc2 = literal_column("2")
+        stmt = select([lc1, expression, lc2])
+
+        compiled = stmt.compile()
+        eq_(
+            compiled._result_columns,
+            [
+                ("1", "1", (lc1, "1", "1"), NULLTYPE),
+                (None, None, (expression,), NULLTYPE),
+                ("2", "2", (lc2, "2", "2"), NULLTYPE),
+            ],
+        )
+
+    def test_result_map_population_explicit(self):
+        class not_named_max(ColumnElement):
+            name = "not_named_max"
+
+        @compiles(not_named_max)
+        def visit_max(element, compiler, **kw):
+            # explicit add
+            kw["add_to_result_map"](None, None, (element,), NULLTYPE)
+            return "max(a)"
+
+        nnm = not_named_max()
+        self._test_result_map_population(nnm)
+
+    def test_result_map_population_implicit(self):
+        class not_named_max(ColumnElement):
+            name = "not_named_max"
+
+        @compiles(not_named_max)
+        def visit_max(element, compiler, **kw):
+            # we don't add to keymap here; compiler should be doing it
+            return "max(a)"
+
+        nnm = not_named_max()
+        self._test_result_map_population(nnm)
+
 
 class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL):
     """Test replacement of default compilation on existing constructs."""
index d5dddb3aa3988b3acfa1e542b6996249524852a2..d2c4f5bc7ba25e08dbcf1fa6328dc1aad0dfe883 100644 (file)
@@ -31,6 +31,7 @@ from sqlalchemy import Unicode
 from sqlalchemy import union
 from sqlalchemy import util
 from sqlalchemy.engine import default
+from sqlalchemy.ext.compiler import compiles
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import attributes
 from sqlalchemy.orm import backref
@@ -1744,6 +1745,38 @@ class OperatorTest(QueryTest, AssertsCompiledSQL):
 class ExpressionTest(QueryTest, AssertsCompiledSQL):
     __dialect__ = "default"
 
+    def test_function_element_column_labels(self):
+        users = self.tables.users
+        sess = Session()
+
+        class max_(expression.FunctionElement):
+            name = "max"
+
+        @compiles(max_)
+        def visit_max(element, compiler, **kw):
+            return "max(%s)" % compiler.process(element.clauses, **kw)
+
+        q = sess.query(max_(users.c.id))
+        eq_(q.all(), [(10,)])
+
+    def test_truly_unlabeled_sql_expressions(self):
+        users = self.tables.users
+        sess = Session()
+
+        class not_named_max(expression.ColumnElement):
+            name = "not_named_max"
+
+        @compiles(not_named_max)
+        def visit_max(element, compiler, **kw):
+            return "max(id)"
+
+        # assert that there is no "AS max_" or any label of any kind.
+        eq_(str(select([not_named_max()])), "SELECT max(id)")
+
+        # ColumnElement still handles it by applying label()
+        q = sess.query(not_named_max()).select_from(users)
+        eq_(q.all(), [(10,)])
+
     def test_deferred_instances(self):
         User, addresses, Address = (
             self.classes.User,
@@ -4342,6 +4375,11 @@ class TextTest(QueryTest, AssertsCompiledSQL):
             "SELECT users.id AS users_id, users.name FROM users",
         )
 
+        eq_(
+            s.query(User.id, text("users.name")).all(),
+            [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")],
+        )
+
         eq_(
             s.query(User.id, literal_column("name")).order_by(User.id).all(),
             [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")],
index d0e68f1e33255b85d34cd6eaa55ec6d293a8e91e..a46d1af548d4bd23b165bb8d339b6994ccb71b02 100644 (file)
@@ -99,6 +99,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             select([func.foo()], use_labels=True), "SELECT foo() AS foo_1"
         )
 
+    def test_use_labels_function_element(self):
+        from sqlalchemy.ext.compiler import compiles
+
+        class max_(FunctionElement):
+            name = "max"
+
+        @compiles(max_)
+        def visit_max(element, compiler, **kw):
+            return "max(%s)" % compiler.process(element.clauses, **kw)
+
+        self.assert_compile(
+            select([max_(5, 6)], use_labels=True),
+            "SELECT max(:max_2, :max_3) AS max_1",
+        )
+
     def test_underscores(self):
         self.assert_compile(func.if_(), "if()")
 
index 35353671c59d7a051cb8d47f0315a0906bbfaffc..2563c7d0c0aee233f6d983940527620d61852019 100644 (file)
@@ -26,7 +26,10 @@ from sqlalchemy import VARCHAR
 from sqlalchemy.engine import default
 from sqlalchemy.engine import result as _result
 from sqlalchemy.engine import Row
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.sql import expression
 from sqlalchemy.sql.selectable import TextualSelect
+from sqlalchemy.sql.sqltypes import NULLTYPE
 from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import assertions
@@ -35,6 +38,7 @@ from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import in_
 from sqlalchemy.testing import is_
+from sqlalchemy.testing import is_false
 from sqlalchemy.testing import is_true
 from sqlalchemy.testing import le_
 from sqlalchemy.testing import ne_
@@ -721,6 +725,13 @@ class ResultProxyTest(fixtures.TablesTest):
             lambda: r["user_id"],
         )
 
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "Ambiguous column name",
+            result._getter,
+            "user_id",
+        )
+
         # pure positional targeting; users.c.user_id
         # and addresses.c.user_id are known!
         # works as of 1.1 issue #3501
@@ -856,7 +867,6 @@ class ResultProxyTest(fixtures.TablesTest):
         addresses = self.tables.addresses
 
         with testing.db.connect() as conn:
-            # MARKMARK
             conn.execute(users.insert(), {"user_id": 1, "user_name": "john"})
             conn.execute(
                 addresses.insert(),
@@ -1344,19 +1354,107 @@ class KeyTargetingTest(fixtures.TablesTest):
         eq_(row.keyed1_a, "a1")
         eq_(row.keyed1_c, "c1")
 
+    def _test_keyed_targeting_no_label_at_all(self, expression):
+        lt = literal_column("2")
+        stmt = select([literal_column("1"), expression, lt]).select_from(
+            self.tables.keyed1
+        )
+        row = testing.db.execute(stmt).first()
+
+        eq_(row[expression], "a1")
+        eq_(row[lt], 2)
+
+        # Postgresql for example has the key as "?column?", which dupes
+        # easily.  we get around that because we know that "2" is unique
+        eq_(row["2"], 2)
+
+    def test_keyed_targeting_no_label_at_all_one(self):
+        class not_named_max(expression.ColumnElement):
+            name = "not_named_max"
+
+        @compiles(not_named_max)
+        def visit_max(element, compiler, **kw):
+            # explicit add
+            kw["add_to_result_map"](None, None, (element,), NULLTYPE)
+            return "max(a)"
+
+        # assert that there is no "AS max_" or any label of any kind.
+        eq_(str(select([not_named_max()])), "SELECT max(a)")
+
+        nnm = not_named_max()
+        self._test_keyed_targeting_no_label_at_all(nnm)
+
+    def test_keyed_targeting_no_label_at_all_two(self):
+        class not_named_max(expression.ColumnElement):
+            name = "not_named_max"
+
+        @compiles(not_named_max)
+        def visit_max(element, compiler, **kw):
+            # we don't add to keymap here; compiler should be doing it
+            return "max(a)"
+
+        # assert that there is no "AS max_" or any label of any kind.
+        eq_(str(select([not_named_max()])), "SELECT max(a)")
+
+        nnm = not_named_max()
+        self._test_keyed_targeting_no_label_at_all(nnm)
+
+    def test_keyed_targeting_no_label_at_all_text(self):
+        t1 = text("max(a)")
+        t2 = text("min(a)")
+
+        stmt = select([t1, t2]).select_from(self.tables.keyed1)
+        row = testing.db.execute(stmt).first()
+
+        eq_(row[t1], "a1")
+        eq_(row[t2], "a1")
+
     @testing.requires.duplicate_names_in_cursor_description
     def test_keyed_accessor_composite_conflict_2(self):
         keyed1 = self.tables.keyed1
         keyed2 = self.tables.keyed2
 
         row = testing.db.execute(select([keyed1, keyed2])).first()
-        # row.b is unambiguous
-        eq_(row.b, "b2")
+
+        # column access is unambiguous
+        eq_(row[self.tables.keyed2.c.b], "b2")
+
         # row.a is ambiguous
         assert_raises_message(
             exc.InvalidRequestError, "Ambig", getattr, row, "a"
         )
 
+        # for "b" we have kind of a choice.  the name "b" is not ambiguous in
+        # cursor.description in this case.  It is however ambiguous as far as
+        # the objects we have queried against, because keyed1.c.a has key="b"
+        # and keyed1.c.b is "b".   historically this was allowed as
+        # non-ambiguous, however the column it targets changes based on
+        # whether or not the dupe is present so it's ambiguous
+        # eq_(row.b, "b2")
+        assert_raises_message(
+            exc.InvalidRequestError, "Ambig", getattr, row, "b"
+        )
+
+        # illustrate why row.b above is ambiguous, and not "b2"; because
+        # if we didn't have keyed2, now it matches row.a.  a new column
+        # shouldn't be able to grab the value from a previous column.
+        row = testing.db.execute(select([keyed1])).first()
+        eq_(row.b, "a1")
+
+    def test_keyed_accessor_composite_conflict_2_fix_w_uselabels(self):
+        keyed1 = self.tables.keyed1
+        keyed2 = self.tables.keyed2
+
+        row = testing.db.execute(
+            select([keyed1, keyed2]).apply_labels()
+        ).first()
+
+        # column access is unambiguous
+        eq_(row[self.tables.keyed2.c.b], "b2")
+
+        eq_(row["keyed2_b"], "b2")
+        eq_(row["keyed1_a"], "a1")
+
     def test_keyed_accessor_composite_names_precedent(self):
         keyed1 = self.tables.keyed1
         keyed4 = self.tables.keyed4
@@ -1374,13 +1472,13 @@ class KeyTargetingTest(fixtures.TablesTest):
 
         row = testing.db.execute(select([keyed1, keyed3])).first()
         eq_(row.q, "c1")
-        assert_raises_message(
-            exc.InvalidRequestError,
-            "Ambiguous column name 'a'",
-            getattr,
-            row,
-            "b",
-        )
+
+        # prior to 1.4 #4887, this raised an "ambiguous column name 'a'""
+        # message, because "b" is linked to "a" which is a dupe.  but we know
+        # where "b" is in the row by position.
+        eq_(row.b, "a1")
+
+        # "a" is of course ambiguous
         assert_raises_message(
             exc.InvalidRequestError,
             "Ambiguous column name 'a'",
@@ -1406,6 +1504,67 @@ class KeyTargetingTest(fixtures.TablesTest):
         assert_raises(KeyError, lambda: row["keyed2_c"])
         assert_raises(KeyError, lambda: row["keyed2_q"])
 
+    def test_keyed_accessor_column_is_repeated_multiple_times(self):
+        # test new logic added as a result of the combination of #4892 and
+        # #4887.   We allow duplicate columns, but we also have special logic
+        # to disambiguate for the same column repeated, and as #4887 adds
+        # stricter ambiguous result column logic, the compiler has to know to
+        # not add these dupe columns to the result map, else they register as
+        # ambiguous.
+
+        keyed2 = self.tables.keyed2
+        keyed3 = self.tables.keyed3
+
+        stmt = select(
+            [
+                keyed2.c.a,
+                keyed3.c.a,
+                keyed2.c.a,
+                keyed2.c.a,
+                keyed3.c.a,
+                keyed3.c.a,
+                keyed3.c.d,
+                keyed3.c.d,
+            ]
+        ).apply_labels()
+
+        result = testing.db.execute(stmt)
+        is_false(result._metadata.matched_on_name)
+
+        # ensure the result map is the same number of cols so we can
+        # use positional targeting
+        eq_(
+            [rec[0] for rec in result.context.compiled._result_columns],
+            [
+                "keyed2_a",
+                "keyed3_a",
+                "keyed2_a__1",
+                "keyed2_a__1",
+                "keyed3_a__1",
+                "keyed3_a__1",
+                "keyed3_d",
+                "keyed3_d__1",
+            ],
+        )
+        row = result.first()
+
+        # keyed access will ignore the dupe cols
+        eq_(row[keyed2.c.a], "a2")
+        eq_(row[keyed3.c.a], "a3")
+        eq_(result._getter(keyed3.c.a)(row), "a3")
+        eq_(row[keyed3.c.d], "d3")
+
+        # however we can get everything positionally
+        eq_(row, ("a2", "a3", "a2", "a2", "a3", "a3", "d3", "d3"))
+        eq_(row[0], "a2")
+        eq_(row[1], "a3")
+        eq_(row[2], "a2")
+        eq_(row[3], "a2")
+        eq_(row[4], "a3")
+        eq_(row[5], "a3")
+        eq_(row[6], "d3")
+        eq_(row[7], "d3")
+
     def test_columnclause_schema_column_one(self):
         # originally addressed by [ticket:2932], however liberalized
         # Column-targeting rules are deprecated
index 6af2cffcfaac8ad81a905cbf72bcbadcb4128992..3f12d06a90f3f13c9efbe910eddf2fd7157c5772 100644 (file)
@@ -20,6 +20,7 @@ from sqlalchemy import union
 from sqlalchemy import util
 from sqlalchemy.sql import column
 from sqlalchemy.sql import quoted_name
+from sqlalchemy.sql import sqltypes
 from sqlalchemy.sql import table
 from sqlalchemy.sql import util as sql_util
 from sqlalchemy.testing import assert_raises_message
@@ -49,6 +50,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
             "select * from foo where lala = bar",
         )
 
+    def test_text_adds_to_result_map(self):
+        t1, t2 = text("t1"), text("t2")
+
+        stmt = select([t1, t2])
+        compiled = stmt.compile()
+        eq_(
+            compiled._result_columns,
+            [
+                (None, None, (t1,), sqltypes.NULLTYPE),
+                (None, None, (t2,), sqltypes.NULLTYPE),
+            ],
+        )
+
 
 class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL):