From: Mike Bayer Date: Thu, 3 Oct 2019 21:36:27 +0000 (-0400) Subject: Add result map targeting for custom compiled, text objects X-Git-Tag: rel_1_4_0b1~690^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=65aee6cce57fd1cca3a95814feff3ed99a5a51ee;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add result map targeting for custom compiled, text objects In order for text(), custom compiled objects, etc. to be usable by Query(), they are all targeted by object key in the result map. As we no longer want Query to implicitly label these, as well as that text() has no label feature, support adding entries to the result map that have no name, key, or type, only the object itself, and then ensure that the compiler sets up for positional targeting when this condition is detected. Allows for more flexible ORM query usage with custom expressions and text() while having less special logic in query itself. Fixes: #4887 Change-Id: Ie073da127d292d43cb132a2b31bc90af88bfe2fd --- diff --git a/doc/build/changelog/unreleased_14/4887.rst b/doc/build/changelog/unreleased_14/4887.rst new file mode 100644 index 0000000000..ffff57f470 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4887.rst @@ -0,0 +1,26 @@ +.. change:: + :tags: bug, sql + :tickets: 4887 + + Custom functions that are created as subclasses of + :class:`.FunctionElement` will now generate an "anonymous label" based on + the "name" of the function just like any other :class:`.Function` object, + e.g. ``"SELECT myfunc() AS myfunc_1"``. While SELECT statements no longer + require labels in order for the result proxy object to function, the ORM + still targets columns in rows by using objects as mapping keys, which works + more reliably when the column expressions have distinct names. In any + case, the behavior is now made consistent between functions generated by + :attr:`.func` and those generated as custom :class:`.FunctionElement` + objects. + + +.. change:: + :tags: usecase, ext + :tickets: 4887 + + Custom compiler constructs created using the :mod:`sqlalchemy.ext.compiled` + extension will automatically add contextual information to the compiler + when a custom construct is interpreted as an element in the columns + clause of a SELECT statement, such that the custom element will be + targetable as a key in result row mappings, which is the kind of targeting + that the ORM uses in order to match column elements into result tuples. \ No newline at end of file diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index af5303658b..733bd6f6ab 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -321,40 +321,41 @@ class ResultMetaData(object): # dupe records with "None" for index which results in # ambiguous column exception when accessed. if len(by_key) != num_ctx_cols: - seen = set() + # new in 1.4: get the complete set of all possible keys, + # strings, objects, whatever, that are dupes across two + # different records, first. + index_by_key = {} + dupes = set() for metadata_entry in raw: - key = metadata_entry[MD_RENDERED_NAME] - if key in seen: - # this is an "ambiguous" element, replacing - # the full record in the map - key = key.lower() if not self.case_sensitive else key - by_key[key] = (None, (), key) - seen.add(key) - - # copy secondary elements from compiled columns - # into self._keymap, write in the potentially "ambiguous" - # element + for key in (metadata_entry[MD_RENDERED_NAME],) + ( + metadata_entry[MD_OBJECTS] or () + ): + if not self.case_sensitive and isinstance( + key, util.string_types + ): + key = key.lower() + idx = metadata_entry[MD_INDEX] + # if this key has been associated with more than one + # positional index, it's a dupe + if index_by_key.setdefault(key, idx) != idx: + dupes.add(key) + + # then put everything we have into the keymap excluding only + # those keys that are dupes. self._keymap.update( [ - (obj_elem, by_key[metadata_entry[MD_LOOKUP_KEY]]) + (obj_elem, metadata_entry) for metadata_entry in raw if metadata_entry[MD_OBJECTS] for obj_elem in metadata_entry[MD_OBJECTS] + if obj_elem not in dupes ] ) - # if we did a pure positional match, then reset the - # original "expression element" back to the "unambiguous" - # entry. This is a new behavior in 1.1 which impacts - # TextualSelect but also straight compiled SQL constructs. - if not self.matched_on_name: - self._keymap.update( - [ - (metadata_entry[MD_OBJECTS][0], metadata_entry) - for metadata_entry in raw - if metadata_entry[MD_OBJECTS] - ] - ) + # then for the dupe keys, put the "ambiguous column" + # record into by_key. + by_key.update({key: (None, (), key) for key in dupes}) + else: # no dupes - copy secondary elements from compiled # columns into self._keymap @@ -502,16 +503,16 @@ class ResultMetaData(object): ( idx, obj, - colname, - colname, + cursor_colname, + cursor_colname, context.get_result_processor( - mapped_type, colname, coltype + mapped_type, cursor_colname, coltype ), untranslated, ) for ( idx, - colname, + cursor_colname, mapped_type, coltype, obj, @@ -592,7 +593,6 @@ class ResultMetaData(object): else: mapped_type = sqltypes.NULLTYPE obj = None - yield idx, colname, mapped_type, coltype, obj, untranslated def _merge_cols_by_name( @@ -758,7 +758,7 @@ class ResultMetaData(object): if index is None: raise exc.InvalidRequestError( "Ambiguous column name '%s' in " - "result set column descriptions" % obj + "result set column descriptions" % rec[MD_LOOKUP_KEY] ) return operator.methodcaller("_get_by_key_impl", index) diff --git a/lib/sqlalchemy/ext/compiler.py b/lib/sqlalchemy/ext/compiler.py index 572c62b8e6..4a5a8ba9cd 100644 --- a/lib/sqlalchemy/ext/compiler.py +++ b/lib/sqlalchemy/ext/compiler.py @@ -398,6 +398,7 @@ Example usage:: """ from .. import exc +from ..sql import sqltypes from ..sql import visitors @@ -475,4 +476,22 @@ class _dispatcher(object): "compilation handler." % type(element) ) - return fn(element, compiler, **kw) + # if compilation includes add_to_result_map, collect add_to_result_map + # arguments from the user-defined callable, which are probably none + # because this is not public API. if it wasn't called, then call it + # ourselves. + arm = kw.get("add_to_result_map", None) + if arm: + arm_collection = [] + kw["add_to_result_map"] = lambda *args: arm_collection.append(args) + + expr = fn(element, compiler, **kw) + + if arm: + if not arm_collection: + arm_collection.append( + (None, None, (element,), sqltypes.NULLTYPE) + ) + for tup in arm_collection: + arm(*tup) + return expr diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 320c7b7821..453ff56d2e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -871,12 +871,11 @@ class SQLCompiler(Compiled): name = self._truncated_identifier("colident", name) if add_to_result_map is not None: - add_to_result_map( - name, - orig_name, - (column, name, column.key, column._label) + result_map_targets, - column.type, - ) + targets = (column, name, column.key) + result_map_targets + if column._label: + targets += (column._label,) + + add_to_result_map(name, orig_name, targets, column.type) if is_literal: # note we are not currently accommodating for @@ -925,7 +924,7 @@ class SQLCompiler(Compiled): text = text.replace("%", "%%") return text - def visit_textclause(self, textclause, **kw): + def visit_textclause(self, textclause, add_to_result_map=None, **kw): def do_bindparam(m): name = m.group(1) if name in textclause._bindparams: @@ -936,6 +935,12 @@ class SQLCompiler(Compiled): if not self.stack: self.isplaintext = True + if add_to_result_map: + # text() object is present in the columns clause of a + # select(). Add a no-name entry to the result map so that + # row[text()] produces a result + add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE) + # un-escape any \:params return BIND_PARAMS_ESC.sub( lambda m: m.group(1), @@ -1938,6 +1943,9 @@ class SQLCompiler(Compiled): return " AS " + alias_name_text def _add_to_result_map(self, keyname, name, objects, type_): + if keyname is None: + self._ordered_columns = False + self._textual_ordered_columns = True self._result_columns.append((keyname, name, objects, type_)) def _label_select_column( @@ -1949,6 +1957,7 @@ class SQLCompiler(Compiled): column_clause_args, name=None, within_columns_clause=True, + column_is_repeated=False, need_column_expressions=False, ): """produce labeled columns present in a select().""" @@ -1959,22 +1968,37 @@ class SQLCompiler(Compiled): need_column_expressions or populate_result_map ): col_expr = impl.column_expression(column) + else: + col_expr = column - if populate_result_map: + if populate_result_map: + # pass an "add_to_result_map" callable into the compilation + # of embedded columns. this collects information about the + # column as it will be fetched in the result and is coordinated + # with cursor.description when the query is executed. + add_to_result_map = self._add_to_result_map + + # if the SELECT statement told us this column is a repeat, + # wrap the callable with one that prevents the addition of the + # targets + if column_is_repeated: + _add_to_result_map = add_to_result_map def add_to_result_map(keyname, name, objects, type_): - self._add_to_result_map( + _add_to_result_map(keyname, name, (), type_) + + # if we redefined col_expr for type expressions, wrap the + # callable with one that adds the original column to the targets + elif col_expr is not column: + _add_to_result_map = add_to_result_map + + def add_to_result_map(keyname, name, objects, type_): + _add_to_result_map( keyname, name, (column,) + objects, type_ ) - else: - add_to_result_map = None else: - col_expr = column - if populate_result_map: - add_to_result_map = self._add_to_result_map - else: - add_to_result_map = None + add_to_result_map = None if not within_columns_clause: result_expr = col_expr @@ -2010,7 +2034,7 @@ class SQLCompiler(Compiled): ) and ( not hasattr(column, "name") - or isinstance(column, functions.Function) + or isinstance(column, functions.FunctionElement) ) ): result_expr = _CompileLabel(col_expr, column.anon_label) @@ -2138,9 +2162,10 @@ class SQLCompiler(Compiled): asfrom, column_clause_args, name=name, + column_is_repeated=repeated, need_column_expressions=need_column_expressions, ) - for name, column in select._columns_plus_names + for name, column, repeated in select._columns_plus_names ] if c is not None ] @@ -2151,10 +2176,17 @@ class SQLCompiler(Compiled): translate = dict( zip( - [name for (key, name) in select._columns_plus_names], [ name - for (key, name) in select_wraps_for._columns_plus_names + for (key, name, repeated) in select._columns_plus_names + ], + [ + name + for ( + key, + name, + repeated, + ) in select_wraps_for._columns_plus_names ], ) ) diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index ddbcdf91d6..6282cf2ee0 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -4191,8 +4191,9 @@ class Select( def name_for_col(c): if c._label is None or not c._render_label_in_columns_clause: - return (None, c) + return (None, c, False) + repeated = False name = c._label if name in names: @@ -4218,19 +4219,22 @@ class Select( # subsequent occurrences of the column so that the # original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c elif anon_for_dupe_key: # same column under the same name. apply the "dedupe" # label so that the original stays non-ambiguous name = c._dedupe_label_anon_label + repeated = True else: names[name] = c - return name, c + return name, c, repeated return [name_for_col(c) for c in cols] else: - return [(None, c) for c in cols] + # repeated name logic only for use labels at the moment + return [(None, c, False) for c in cols] @_memoized_property def _columns_plus_names(self): @@ -4245,7 +4249,7 @@ class Select( keys_seen = set() prox = [] - for name, c in self._generate_columns_plus_names(False): + for name, c, repeated in self._generate_columns_plus_names(False): if not hasattr(c, "_make_proxy"): continue if name is None: diff --git a/test/ext/test_compiler.py b/test/ext/test_compiler.py index ccd79f8d11..f1817b94e7 100644 --- a/test/ext/test_compiler.py +++ b/test/ext/test_compiler.py @@ -3,6 +3,7 @@ from sqlalchemy import column from sqlalchemy import desc from sqlalchemy import exc from sqlalchemy import Integer +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import Numeric from sqlalchemy import select @@ -13,11 +14,13 @@ from sqlalchemy.ext.compiler import deregister from sqlalchemy.schema import CreateColumn from sqlalchemy.schema import CreateTable from sqlalchemy.schema import DDLElement +from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.expression import BindParameter from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ColumnClause from sqlalchemy.sql.expression import FunctionElement from sqlalchemy.sql.expression import Select +from sqlalchemy.sql.sqltypes import NULLTYPE from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ @@ -319,7 +322,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): dialect=mssql.dialect(), ) - def test_subclasses_one(self): + def test_function_subclasses_one(self): class Base(FunctionElement): name = "base" @@ -339,11 +342,11 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2()]), - "SELECT FOOsub1, sub2", + "SELECT FOOsub1 AS sub1_1, sub2 AS sub2_1", use_default_dialect=True, ) - def test_subclasses_two(self): + def test_function_subclasses_two(self): class Base(FunctionElement): name = "base" @@ -362,7 +365,7 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - "SELECT sub1, sub2, subsub1", + "SELECT sub1 AS sub1_1, sub2 AS sub2_1, subsub1 AS subsub1_1", use_default_dialect=True, ) @@ -372,10 +375,51 @@ class UserDefinedTest(fixtures.TestBase, AssertsCompiledSQL): self.assert_compile( select([Sub1(), Sub2(), SubSub1()]), - "SELECT FOOsub1, sub2, FOOsubsub1", + "SELECT FOOsub1 AS sub1_1, sub2 AS sub2_1, " + "FOOsubsub1 AS subsub1_1", use_default_dialect=True, ) + def _test_result_map_population(self, expression): + lc1 = literal_column("1") + lc2 = literal_column("2") + stmt = select([lc1, expression, lc2]) + + compiled = stmt.compile() + eq_( + compiled._result_columns, + [ + ("1", "1", (lc1, "1", "1"), NULLTYPE), + (None, None, (expression,), NULLTYPE), + ("2", "2", (lc2, "2", "2"), NULLTYPE), + ], + ) + + def test_result_map_population_explicit(self): + class not_named_max(ColumnElement): + name = "not_named_max" + + @compiles(not_named_max) + def visit_max(element, compiler, **kw): + # explicit add + kw["add_to_result_map"](None, None, (element,), NULLTYPE) + return "max(a)" + + nnm = not_named_max() + self._test_result_map_population(nnm) + + def test_result_map_population_implicit(self): + class not_named_max(ColumnElement): + name = "not_named_max" + + @compiles(not_named_max) + def visit_max(element, compiler, **kw): + # we don't add to keymap here; compiler should be doing it + return "max(a)" + + nnm = not_named_max() + self._test_result_map_population(nnm) + class DefaultOnExistingTest(fixtures.TestBase, AssertsCompiledSQL): """Test replacement of default compilation on existing constructs.""" diff --git a/test/orm/test_query.py b/test/orm/test_query.py index d5dddb3aa3..d2c4f5bc7b 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -31,6 +31,7 @@ from sqlalchemy import Unicode from sqlalchemy import union from sqlalchemy import util from sqlalchemy.engine import default +from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import aliased from sqlalchemy.orm import attributes from sqlalchemy.orm import backref @@ -1744,6 +1745,38 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): class ExpressionTest(QueryTest, AssertsCompiledSQL): __dialect__ = "default" + def test_function_element_column_labels(self): + users = self.tables.users + sess = Session() + + class max_(expression.FunctionElement): + name = "max" + + @compiles(max_) + def visit_max(element, compiler, **kw): + return "max(%s)" % compiler.process(element.clauses, **kw) + + q = sess.query(max_(users.c.id)) + eq_(q.all(), [(10,)]) + + def test_truly_unlabeled_sql_expressions(self): + users = self.tables.users + sess = Session() + + class not_named_max(expression.ColumnElement): + name = "not_named_max" + + @compiles(not_named_max) + def visit_max(element, compiler, **kw): + return "max(id)" + + # assert that there is no "AS max_" or any label of any kind. + eq_(str(select([not_named_max()])), "SELECT max(id)") + + # ColumnElement still handles it by applying label() + q = sess.query(not_named_max()).select_from(users) + eq_(q.all(), [(10,)]) + def test_deferred_instances(self): User, addresses, Address = ( self.classes.User, @@ -4342,6 +4375,11 @@ class TextTest(QueryTest, AssertsCompiledSQL): "SELECT users.id AS users_id, users.name FROM users", ) + eq_( + s.query(User.id, text("users.name")).all(), + [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], + ) + eq_( s.query(User.id, literal_column("name")).order_by(User.id).all(), [(7, "jack"), (8, "ed"), (9, "fred"), (10, "chuck")], diff --git a/test/sql/test_functions.py b/test/sql/test_functions.py index d0e68f1e33..a46d1af548 100644 --- a/test/sql/test_functions.py +++ b/test/sql/test_functions.py @@ -99,6 +99,21 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): select([func.foo()], use_labels=True), "SELECT foo() AS foo_1" ) + def test_use_labels_function_element(self): + from sqlalchemy.ext.compiler import compiles + + class max_(FunctionElement): + name = "max" + + @compiles(max_) + def visit_max(element, compiler, **kw): + return "max(%s)" % compiler.process(element.clauses, **kw) + + self.assert_compile( + select([max_(5, 6)], use_labels=True), + "SELECT max(:max_2, :max_3) AS max_1", + ) + def test_underscores(self): self.assert_compile(func.if_(), "if()") diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 35353671c5..2563c7d0c0 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -26,7 +26,10 @@ from sqlalchemy import VARCHAR from sqlalchemy.engine import default from sqlalchemy.engine import result as _result from sqlalchemy.engine import Row +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import expression from sqlalchemy.sql.selectable import TextualSelect +from sqlalchemy.sql.sqltypes import NULLTYPE from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import assertions @@ -35,6 +38,7 @@ from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures from sqlalchemy.testing import in_ from sqlalchemy.testing import is_ +from sqlalchemy.testing import is_false from sqlalchemy.testing import is_true from sqlalchemy.testing import le_ from sqlalchemy.testing import ne_ @@ -721,6 +725,13 @@ class ResultProxyTest(fixtures.TablesTest): lambda: r["user_id"], ) + assert_raises_message( + exc.InvalidRequestError, + "Ambiguous column name", + result._getter, + "user_id", + ) + # pure positional targeting; users.c.user_id # and addresses.c.user_id are known! # works as of 1.1 issue #3501 @@ -856,7 +867,6 @@ class ResultProxyTest(fixtures.TablesTest): addresses = self.tables.addresses with testing.db.connect() as conn: - # MARKMARK conn.execute(users.insert(), {"user_id": 1, "user_name": "john"}) conn.execute( addresses.insert(), @@ -1344,19 +1354,107 @@ class KeyTargetingTest(fixtures.TablesTest): eq_(row.keyed1_a, "a1") eq_(row.keyed1_c, "c1") + def _test_keyed_targeting_no_label_at_all(self, expression): + lt = literal_column("2") + stmt = select([literal_column("1"), expression, lt]).select_from( + self.tables.keyed1 + ) + row = testing.db.execute(stmt).first() + + eq_(row[expression], "a1") + eq_(row[lt], 2) + + # Postgresql for example has the key as "?column?", which dupes + # easily. we get around that because we know that "2" is unique + eq_(row["2"], 2) + + def test_keyed_targeting_no_label_at_all_one(self): + class not_named_max(expression.ColumnElement): + name = "not_named_max" + + @compiles(not_named_max) + def visit_max(element, compiler, **kw): + # explicit add + kw["add_to_result_map"](None, None, (element,), NULLTYPE) + return "max(a)" + + # assert that there is no "AS max_" or any label of any kind. + eq_(str(select([not_named_max()])), "SELECT max(a)") + + nnm = not_named_max() + self._test_keyed_targeting_no_label_at_all(nnm) + + def test_keyed_targeting_no_label_at_all_two(self): + class not_named_max(expression.ColumnElement): + name = "not_named_max" + + @compiles(not_named_max) + def visit_max(element, compiler, **kw): + # we don't add to keymap here; compiler should be doing it + return "max(a)" + + # assert that there is no "AS max_" or any label of any kind. + eq_(str(select([not_named_max()])), "SELECT max(a)") + + nnm = not_named_max() + self._test_keyed_targeting_no_label_at_all(nnm) + + def test_keyed_targeting_no_label_at_all_text(self): + t1 = text("max(a)") + t2 = text("min(a)") + + stmt = select([t1, t2]).select_from(self.tables.keyed1) + row = testing.db.execute(stmt).first() + + eq_(row[t1], "a1") + eq_(row[t2], "a1") + @testing.requires.duplicate_names_in_cursor_description def test_keyed_accessor_composite_conflict_2(self): keyed1 = self.tables.keyed1 keyed2 = self.tables.keyed2 row = testing.db.execute(select([keyed1, keyed2])).first() - # row.b is unambiguous - eq_(row.b, "b2") + + # column access is unambiguous + eq_(row[self.tables.keyed2.c.b], "b2") + # row.a is ambiguous assert_raises_message( exc.InvalidRequestError, "Ambig", getattr, row, "a" ) + # for "b" we have kind of a choice. the name "b" is not ambiguous in + # cursor.description in this case. It is however ambiguous as far as + # the objects we have queried against, because keyed1.c.a has key="b" + # and keyed1.c.b is "b". historically this was allowed as + # non-ambiguous, however the column it targets changes based on + # whether or not the dupe is present so it's ambiguous + # eq_(row.b, "b2") + assert_raises_message( + exc.InvalidRequestError, "Ambig", getattr, row, "b" + ) + + # illustrate why row.b above is ambiguous, and not "b2"; because + # if we didn't have keyed2, now it matches row.a. a new column + # shouldn't be able to grab the value from a previous column. + row = testing.db.execute(select([keyed1])).first() + eq_(row.b, "a1") + + def test_keyed_accessor_composite_conflict_2_fix_w_uselabels(self): + keyed1 = self.tables.keyed1 + keyed2 = self.tables.keyed2 + + row = testing.db.execute( + select([keyed1, keyed2]).apply_labels() + ).first() + + # column access is unambiguous + eq_(row[self.tables.keyed2.c.b], "b2") + + eq_(row["keyed2_b"], "b2") + eq_(row["keyed1_a"], "a1") + def test_keyed_accessor_composite_names_precedent(self): keyed1 = self.tables.keyed1 keyed4 = self.tables.keyed4 @@ -1374,13 +1472,13 @@ class KeyTargetingTest(fixtures.TablesTest): row = testing.db.execute(select([keyed1, keyed3])).first() eq_(row.q, "c1") - assert_raises_message( - exc.InvalidRequestError, - "Ambiguous column name 'a'", - getattr, - row, - "b", - ) + + # prior to 1.4 #4887, this raised an "ambiguous column name 'a'"" + # message, because "b" is linked to "a" which is a dupe. but we know + # where "b" is in the row by position. + eq_(row.b, "a1") + + # "a" is of course ambiguous assert_raises_message( exc.InvalidRequestError, "Ambiguous column name 'a'", @@ -1406,6 +1504,67 @@ class KeyTargetingTest(fixtures.TablesTest): assert_raises(KeyError, lambda: row["keyed2_c"]) assert_raises(KeyError, lambda: row["keyed2_q"]) + def test_keyed_accessor_column_is_repeated_multiple_times(self): + # test new logic added as a result of the combination of #4892 and + # #4887. We allow duplicate columns, but we also have special logic + # to disambiguate for the same column repeated, and as #4887 adds + # stricter ambiguous result column logic, the compiler has to know to + # not add these dupe columns to the result map, else they register as + # ambiguous. + + keyed2 = self.tables.keyed2 + keyed3 = self.tables.keyed3 + + stmt = select( + [ + keyed2.c.a, + keyed3.c.a, + keyed2.c.a, + keyed2.c.a, + keyed3.c.a, + keyed3.c.a, + keyed3.c.d, + keyed3.c.d, + ] + ).apply_labels() + + result = testing.db.execute(stmt) + is_false(result._metadata.matched_on_name) + + # ensure the result map is the same number of cols so we can + # use positional targeting + eq_( + [rec[0] for rec in result.context.compiled._result_columns], + [ + "keyed2_a", + "keyed3_a", + "keyed2_a__1", + "keyed2_a__1", + "keyed3_a__1", + "keyed3_a__1", + "keyed3_d", + "keyed3_d__1", + ], + ) + row = result.first() + + # keyed access will ignore the dupe cols + eq_(row[keyed2.c.a], "a2") + eq_(row[keyed3.c.a], "a3") + eq_(result._getter(keyed3.c.a)(row), "a3") + eq_(row[keyed3.c.d], "d3") + + # however we can get everything positionally + eq_(row, ("a2", "a3", "a2", "a2", "a3", "a3", "d3", "d3")) + eq_(row[0], "a2") + eq_(row[1], "a3") + eq_(row[2], "a2") + eq_(row[3], "a2") + eq_(row[4], "a3") + eq_(row[5], "a3") + eq_(row[6], "d3") + eq_(row[7], "d3") + def test_columnclause_schema_column_one(self): # originally addressed by [ticket:2932], however liberalized # Column-targeting rules are deprecated diff --git a/test/sql/test_text.py b/test/sql/test_text.py index 6af2cffcfa..3f12d06a90 100644 --- a/test/sql/test_text.py +++ b/test/sql/test_text.py @@ -20,6 +20,7 @@ from sqlalchemy import union from sqlalchemy import util from sqlalchemy.sql import column from sqlalchemy.sql import quoted_name +from sqlalchemy.sql import sqltypes from sqlalchemy.sql import table from sqlalchemy.sql import util as sql_util from sqlalchemy.testing import assert_raises_message @@ -49,6 +50,19 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "select * from foo where lala = bar", ) + def test_text_adds_to_result_map(self): + t1, t2 = text("t1"), text("t2") + + stmt = select([t1, t2]) + compiled = stmt.compile() + eq_( + compiled._result_columns, + [ + (None, None, (t1,), sqltypes.NULLTYPE), + (None, None, (t2,), sqltypes.NULLTYPE), + ], + ) + class SelectCompositionTest(fixtures.TestBase, AssertsCompiledSQL):